"""LangGraph workflow for Opus Orchestrator. Implements proper state machine with: - Structured state across all stages - Checkpointing for resumability - Branching for iteration loops - Cross-validation between stages """ from typing import Any, Optional from pydantic import BaseModel, Field from enum import Enum class Stage(str, Enum): """Snowflake method stages.""" SEED = "seed" ONE_SENTENCE = "one_sentence" ONE_PARAGRAPH = "one_paragraph" CHARACTER_SHEETS = "character_sheets" FOUR_PAGE_OUTLINE = "four_page_outline" CHARACTER_CHARTS = "character_charts" SCENE_LIST = "scene_list" SCENE_DESCRIPTIONS = "scene_descriptions" STYLE_GUIDE = "style_guide" BLUEPRINT = "blueprint" WRITING = "writing" COMPLETE = "complete" # Structured output schemas for each stage class Character(BaseModel): """A character in the story.""" name: str role: str # protagonist, antagonist, mentor, ally, etc. age: Optional[int] = None description: str want: str # external goal need: str # internal growth fear: str secret: Optional[str] = None arc: str # how they change class PlotBeat(BaseModel): """A beat in the story structure.""" name: str description: str chapter: Optional[int] = None characters_involved: list[str] = Field(default_factory=list) location: Optional[str] = None class ChapterPlan(BaseModel): """Plan for a single chapter.""" chapter_number: int title: str summary: str word_count_target: int beats: list[str] = Field(default_factory=list) pov_character: Optional[str] = None key_events: list[str] = Field(default_factory=list) class PreWriting(BaseModel): """Complete pre-writing output.""" # Stage 1 one_sentence: str = "" # Stage 2 one_paragraph: str = "" act_1_setup: str = "" act_2_confrontation: str = "" act_3_resolution: str = "" # Stage 3 characters: list[Character] = Field(default_factory=list) # Stage 4 outline_sections: list[str] = Field(default_factory=list) # Stage 5 character_details: dict[str, str] = Field(default_factory=dict) # Stage 6 scene_list: list[PlotBeat] = Field(default_factory=list) chapter_plans: list[ChapterPlan] = Field(default_factory=list) # Stage 7 scene_descriptions: dict[str, str] = Field(default_factory=dict) # Framework info framework_used: str = "snowflake" class WritingState(BaseModel): """State during writing phase.""" current_chapter: int = 0 chapters_written: dict[int, str] = Field(default_factory=dict) chapters_critiqued: dict[int, float] = Field(default_factory=dict) iteration_counts: dict[int, int] = Field(default_factory=dict) approved_chapters: list[int] = Field(default_factory=list) class OpusState(BaseModel): """Complete state for LangGraph workflow.""" # Metadata stage: Stage = Stage.SEED framework: str = "snowflake" genre: str = "general" target_word_count: int = 80000 # Input seed_concept: str = "" # Pre-writing outputs (structured) prewriting: PreWriting = Field(default_factory=PreWriting) # Style style_guide: str = "" # Writing writing: WritingState = Field(default_factory=WritingState) # Final output manuscript: str = "" total_word_count: int = 0 # Validation validation_errors: list[str] = Field(default_factory=list) warnings: list[str] = Field(default_factory=list) # Progress progress: float = 0.0 # Checkpoint last_updated: str = "" # Validation functions def validate_one_sentence(state: OpusState) -> list[str]: """Validate one sentence output.""" errors = [] text = state.prewriting.one_sentence if len(text) < 20: errors.append("One sentence too short") if len(text) > 200: errors.append("One sentence too long (should be ~1 sentence)") if not any(c in text for c in [',', '.', '!', '?']): errors.append("One sentence needs proper punctuation") return errors def validate_one_paragraph(state: OpusState) -> list[str]: """Validate one paragraph output.""" errors = [] text = state.prewriting.one_paragraph if len(text) < 100: errors.append("One paragraph too short") if len(text) > 2000: errors.append("One paragraph too long") # Check it mentions the main character from Stage 1 if state.prewriting.one_sentence: # Extract name from sentence first_word = state.prewriting.one_sentence.split()[0:3] if first_word: # Just warn if not found pass return errors def validate_character_sheets(state: OpusState) -> list[str]: """Validate character sheets.""" errors = [] characters = state.prewriting.characters if len(characters) < 1: errors.append("No characters defined") return errors # Check protagonist exists has_protagonist = any(c.role.lower() == "protagonist" for c in characters) if not has_protagonist: errors.append("No protagonist defined") # Check each character has required fields for char in characters: if not char.name: errors.append(f"Character missing name: {char}") if not char.want: errors.append(f"Character {char.name} missing want") if not char.need: errors.append(f"Character {char.name} missing need") return errors def validate_scene_list(state: OpusState) -> list[str]: """Validate scene list.""" errors = [] scenes = state.prewriting.scene_list if len(scenes) < 5: errors.append(f"Too few scenes: {len(scenes)} (should be 10+)") # Check chapter plans align with scene count if state.prewriting.chapter_plans: total_scenes = sum(len(cp.beats) for cp in state.prewriting.chapter_plans) if total_scenes < len(scenes) * 0.5: errors.append("Chapter plans don't cover enough scenes") return errors def validate_prewriting_complete(state: OpusState) -> list[str]: """Validate all pre-writing is complete.""" errors = [] if not state.prewriting.one_sentence: errors.append("Stage 1 incomplete: one sentence missing") if not state.prewriting.one_paragraph: errors.append("Stage 2 incomplete: one paragraph missing") if not state.prewriting.characters: errors.append("Stage 3 incomplete: no characters") if not state.prewriting.outline_sections: errors.append("Stage 4 incomplete: no outline") if not state.prewriting.scene_list: errors.append("Stage 6 incomplete: no scene list") return errors # State management def create_initial_state( seed_concept: str, framework: str = "snowflake", genre: str = "general", target_word_count: int = 80000, ) -> OpusState: """Create initial state.""" return OpusState( seed_concept=seed_concept, framework=framework, genre=genre, target_word_count=target_word_count, stage=Stage.SEED, ) def get_progress(stage: Stage) -> float: """Get progress percentage for stage.""" progress_map = { Stage.SEED: 0.0, Stage.ONE_SENTENCE: 0.05, Stage.ONE_PARAGRAPH: 0.10, Stage.CHARACTER_SHEETS: 0.20, Stage.FOUR_PAGE_OUTLINE: 0.30, Stage.CHARACTER_CHARTS: 0.35, Stage.SCENE_LIST: 0.40, Stage.SCENE_DESCRIPTIONS: 0.45, Stage.STYLE_GUIDE: 0.50, Stage.BLUEPRINT: 0.55, Stage.WRITING: 0.80, Stage.COMPLETE: 1.0, } return progress_map.get(stage, 0.0)