diff --git a/backend/pipeline/llm_client.py b/backend/pipeline/llm_client.py index 1c26f95..dc53772 100644 --- a/backend/pipeline/llm_client.py +++ b/backend/pipeline/llm_client.py @@ -59,14 +59,14 @@ def strip_think_tags(text: str) -> str: # Stage-specific output multipliers: estimated output tokens as a ratio of input tokens. # These are empirically tuned based on observed pipeline behavior. _STAGE_OUTPUT_RATIOS: dict[str, float] = { - "stage2_segmentation": 0.3, # Compact topic groups — much smaller than input - "stage3_extraction": 1.2, # Detailed moments with summaries — can exceed input - "stage4_classification": 0.15, # Index + category + tags per moment — very compact - "stage5_synthesis": 1.5, # Full prose technique pages — heaviest output + "stage2_segmentation": 0.6, # Compact topic groups — smaller than input + "stage3_extraction": 2.0, # Detailed moments with summaries — can well exceed input + "stage4_classification": 0.5, # Index + category + tags per moment — small but varies + "stage5_synthesis": 2.5, # Full prose technique pages — heaviest output } # Minimum floor so we never send a trivially small max_tokens -_MIN_MAX_TOKENS = 2048 +_MIN_MAX_TOKENS = 4096 def estimate_tokens(text: str) -> int: @@ -111,8 +111,8 @@ def estimate_max_tokens( ratio = _STAGE_OUTPUT_RATIOS.get(stage or "", 1.0) estimated_output = int(input_tokens * ratio) - # Add a 20% buffer for JSON overhead and variability - estimated_output = int(estimated_output * 1.2) + # Add a 50% buffer for JSON overhead and variability + estimated_output = int(estimated_output * 1.5) # Clamp to [_MIN_MAX_TOKENS, hard_limit] result = max(_MIN_MAX_TOKENS, min(estimated_output, hard_limit)) diff --git a/backend/pipeline/stages.py b/backend/pipeline/stages.py index c8ff245..598576e 100644 --- a/backend/pipeline/stages.py +++ b/backend/pipeline/stages.py @@ -263,6 +263,7 @@ def _safe_parse_llm_response( user_prompt: str, modality: str = "chat", model_override: str | None = None, + max_tokens: int | None = None, ): """Parse LLM response with one retry on failure. @@ -284,6 +285,7 @@ def _safe_parse_llm_response( retry_raw = llm.complete( system_prompt, nudge_prompt, response_model=model_cls, modality=modality, model_override=model_override, + max_tokens=max_tokens, ) return llm.parse_response(retry_raw, model_cls) @@ -340,7 +342,7 @@ def stage2_segmentation(self, video_id: str, run_id: str | None = None) -> str: raw = llm.complete(system_prompt, user_prompt, response_model=SegmentationResult, on_complete=_make_llm_callback(video_id, "stage2_segmentation", system_prompt=system_prompt, user_prompt=user_prompt, run_id=run_id), modality=modality, model_override=model_override, max_tokens=max_tokens) result = _safe_parse_llm_response(raw, SegmentationResult, llm, system_prompt, user_prompt, - modality=modality, model_override=model_override) + modality=modality, model_override=model_override, max_tokens=max_tokens) # Update topic_label on each segment row seg_by_index = {s.segment_index: s for s in segments} @@ -432,7 +434,7 @@ def stage3_extraction(self, video_id: str, run_id: str | None = None) -> str: raw = llm.complete(system_prompt, user_prompt, response_model=ExtractionResult, on_complete=_make_llm_callback(video_id, "stage3_extraction", system_prompt=system_prompt, user_prompt=user_prompt, run_id=run_id, context_label=topic_label), modality=modality, model_override=model_override, max_tokens=max_tokens) result = _safe_parse_llm_response(raw, ExtractionResult, llm, system_prompt, user_prompt, - modality=modality, model_override=model_override) + modality=modality, model_override=model_override, max_tokens=max_tokens) # Create KeyMoment rows for moment in result.moments: @@ -541,7 +543,7 @@ def stage4_classification(self, video_id: str, run_id: str | None = None) -> str raw = llm.complete(system_prompt, user_prompt, response_model=ClassificationResult, on_complete=_make_llm_callback(video_id, "stage4_classification", system_prompt=system_prompt, user_prompt=user_prompt, run_id=run_id), modality=modality, model_override=model_override, max_tokens=max_tokens) result = _safe_parse_llm_response(raw, ClassificationResult, llm, system_prompt, user_prompt, - modality=modality, model_override=model_override) + modality=modality, model_override=model_override, max_tokens=max_tokens) # Apply content_type overrides and prepare classification data for stage 5 classification_data = [] @@ -786,7 +788,7 @@ def stage5_synthesis(self, video_id: str, run_id: str | None = None) -> str: raw = llm.complete(system_prompt, user_prompt, response_model=SynthesisResult, on_complete=_make_llm_callback(video_id, "stage5_synthesis", system_prompt=system_prompt, user_prompt=user_prompt, run_id=run_id, context_label=category), modality=modality, model_override=model_override, max_tokens=max_tokens) result = _safe_parse_llm_response(raw, SynthesisResult, llm, system_prompt, user_prompt, - modality=modality, model_override=model_override) + modality=modality, model_override=model_override, max_tokens=max_tokens) # Load prior pages from this video (snapshot taken before pipeline reset) prior_page_ids = _load_prior_pages(video_id)