diff --git a/backend/pipeline/llm_client.py b/backend/pipeline/llm_client.py index dc53772..5ec49e1 100644 --- a/backend/pipeline/llm_client.py +++ b/backend/pipeline/llm_client.py @@ -28,6 +28,38 @@ logger = logging.getLogger(__name__) T = TypeVar("T", bound=BaseModel) + +# ── LLM Response wrapper ───────────────────────────────────────────────────── + +class LLMResponse(str): + """String subclass that carries LLM response metadata. + + Backward-compatible with all code that treats the response as a plain + string, but callers that know about it can inspect finish_reason and + the truncated property. + """ + finish_reason: str | None + prompt_tokens: int | None + completion_tokens: int | None + + def __new__( + cls, + text: str, + finish_reason: str | None = None, + prompt_tokens: int | None = None, + completion_tokens: int | None = None, + ): + obj = super().__new__(cls, text) + obj.finish_reason = finish_reason + obj.prompt_tokens = prompt_tokens + obj.completion_tokens = completion_tokens + return obj + + @property + def truncated(self) -> bool: + """True if the model hit its token limit before finishing.""" + return self.finish_reason == "length" + # ── Think-tag stripping ────────────────────────────────────────────────────── _THINK_PATTERN = re.compile(r".*?", re.DOTALL) @@ -149,7 +181,7 @@ class LLMClient: model_override: str | None = None, on_complete: "Callable | None" = None, max_tokens: int | None = None, - ) -> str: + ) -> "LLMResponse": """Send a chat completion request, falling back on connection/timeout errors. Parameters @@ -174,8 +206,8 @@ class LLMClient: Returns ------- - str - Raw completion text from the model (think tags stripped if thinking). + LLMResponse + Raw completion text (str subclass) with finish_reason metadata. """ kwargs: dict = {} effective_system = system_prompt @@ -230,6 +262,7 @@ class LLMClient: ) if modality == "thinking": raw = strip_think_tags(raw) + finish = response.choices[0].finish_reason if response.choices else None if on_complete is not None: try: on_complete( @@ -238,11 +271,16 @@ class LLMClient: completion_tokens=usage.completion_tokens if usage else None, total_tokens=usage.total_tokens if usage else None, content=raw, - finish_reason=response.choices[0].finish_reason if response.choices else None, + finish_reason=finish, ) except Exception as cb_exc: logger.warning("on_complete callback failed: %s", cb_exc) - return raw + return LLMResponse( + raw, + finish_reason=finish, + prompt_tokens=usage.prompt_tokens if usage else None, + completion_tokens=usage.completion_tokens if usage else None, + ) except (openai.APIConnectionError, openai.APITimeoutError) as exc: logger.warning( @@ -270,6 +308,7 @@ class LLMClient: ) if modality == "thinking": raw = strip_think_tags(raw) + finish = response.choices[0].finish_reason if response.choices else None if on_complete is not None: try: on_complete( @@ -278,12 +317,17 @@ class LLMClient: completion_tokens=usage.completion_tokens if usage else None, total_tokens=usage.total_tokens if usage else None, content=raw, - finish_reason=response.choices[0].finish_reason if response.choices else None, + finish_reason=finish, is_fallback=True, ) except Exception as cb_exc: logger.warning("on_complete callback failed: %s", cb_exc) - return raw + return LLMResponse( + raw, + finish_reason=finish, + prompt_tokens=usage.prompt_tokens if usage else None, + completion_tokens=usage.completion_tokens if usage else None, + ) except (openai.APIConnectionError, openai.APITimeoutError, openai.APIError) as exc: logger.error( diff --git a/backend/pipeline/stages.py b/backend/pipeline/stages.py index 598576e..c623767 100644 --- a/backend/pipeline/stages.py +++ b/backend/pipeline/stages.py @@ -36,7 +36,7 @@ from models import ( TranscriptSegment, ) from pipeline.embedding_client import EmbeddingClient -from pipeline.llm_client import LLMClient, estimate_max_tokens +from pipeline.llm_client import LLMClient, LLMResponse, estimate_max_tokens from pipeline.qdrant_client import QdrantManager from pipeline.schemas import ( ClassificationResult, @@ -49,6 +49,11 @@ from worker import celery_app logger = logging.getLogger(__name__) +class LLMTruncationError(RuntimeError): + """Raised when the LLM response was truncated (finish_reason=length).""" + pass + + # ── Error status helper ────────────────────────────────────────────────────── def _set_error_status(video_id: str, stage_name: str, error: Exception) -> None: @@ -256,7 +261,7 @@ def _format_taxonomy_for_prompt(tags_data: dict) -> str: def _safe_parse_llm_response( - raw: str, + raw, model_cls, llm: LLMClient, system_prompt: str, @@ -265,14 +270,37 @@ def _safe_parse_llm_response( model_override: str | None = None, max_tokens: int | None = None, ): - """Parse LLM response with one retry on failure. + """Parse LLM response with truncation detection and one retry on failure. - On malformed response: log the raw text, retry once with a JSON nudge, - then raise on second failure. + If the response was truncated (finish_reason=length), raises + LLMTruncationError immediately — retrying with a JSON nudge would only + make things worse by adding tokens to an already-too-large prompt. + + For non-truncation parse failures: retry once with a JSON nudge, then + raise on second failure. """ + # Check for truncation before attempting parse + is_truncated = isinstance(raw, LLMResponse) and raw.truncated + if is_truncated: + logger.warning( + "LLM response truncated (finish=length) for %s. " + "prompt_tokens=%s, completion_tokens=%s. Will not retry with nudge.", + model_cls.__name__, + getattr(raw, "prompt_tokens", "?"), + getattr(raw, "completion_tokens", "?"), + ) + try: return llm.parse_response(raw, model_cls) except (ValidationError, ValueError, json.JSONDecodeError) as exc: + if is_truncated: + raise LLMTruncationError( + f"LLM output truncated for {model_cls.__name__}: " + f"prompt_tokens={getattr(raw, 'prompt_tokens', '?')}, " + f"completion_tokens={getattr(raw, 'completion_tokens', '?')}. " + f"Response too large for model context window." + ) from exc + logger.warning( "First parse attempt failed for %s (%s). Retrying with JSON nudge. " "Raw response (first 500 chars): %.500s", @@ -479,6 +507,66 @@ def stage3_extraction(self, video_id: str, run_id: str | None = None) -> str: # ── Stage 4: Classification ───────────────────────────────────────────────── +# Maximum moments per classification batch. Keeps each LLM call well within +# context window limits. Batches are classified independently and merged. +_STAGE4_BATCH_SIZE = 20 + + +def _classify_moment_batch( + moments_batch: list, + batch_offset: int, + taxonomy_text: str, + system_prompt: str, + llm: LLMClient, + model_override: str | None, + modality: str, + hard_limit: int, + video_id: str, + run_id: str | None, +) -> ClassificationResult: + """Classify a single batch of moments. Raises on failure.""" + moments_lines = [] + for i, m in enumerate(moments_batch): + moments_lines.append( + f"[{i}] Title: {m.title}\n" + f" Summary: {m.summary}\n" + f" Content type: {m.content_type.value}\n" + f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}" + ) + moments_text = "\n\n".join(moments_lines) + + user_prompt = ( + f"\n{taxonomy_text}\n\n\n" + f"\n{moments_text}\n" + ) + + max_tokens = estimate_max_tokens( + system_prompt, user_prompt, + stage="stage4_classification", hard_limit=hard_limit, + ) + batch_label = f"batch {batch_offset // _STAGE4_BATCH_SIZE + 1} (moments {batch_offset}-{batch_offset + len(moments_batch) - 1})" + logger.info( + "Stage 4 classifying %s, max_tokens=%d", + batch_label, max_tokens, + ) + raw = llm.complete( + system_prompt, user_prompt, + response_model=ClassificationResult, + on_complete=_make_llm_callback( + video_id, "stage4_classification", + system_prompt=system_prompt, user_prompt=user_prompt, + run_id=run_id, context_label=batch_label, + ), + modality=modality, model_override=model_override, + max_tokens=max_tokens, + ) + return _safe_parse_llm_response( + raw, ClassificationResult, llm, system_prompt, user_prompt, + modality=modality, model_override=model_override, + max_tokens=max_tokens, + ) + + @celery_app.task(bind=True, max_retries=3, default_retry_delay=30) def stage4_classification(self, video_id: str, run_id: str | None = None) -> str: """Classify key moments against the canonical tag taxonomy. @@ -487,6 +575,9 @@ def stage4_classification(self, video_id: str, run_id: str | None = None) -> str canonical taxonomy, and stores classification results in Redis for stage 5 consumption. Updates content_type if the classifier overrides it. + For large moment sets, automatically batches into groups of + _STAGE4_BATCH_SIZE to stay within model context window limits. + Stage 4 does NOT change processing_status. Returns the video_id for chain compatibility. @@ -510,7 +601,6 @@ def stage4_classification(self, video_id: str, run_id: str | None = None) -> str if not moments: logger.info("Stage 4: No moments found for video_id=%s, skipping.", video_id) - # Store empty classification data _store_classification_data(video_id, []) return video_id @@ -518,38 +608,29 @@ def stage4_classification(self, video_id: str, run_id: str | None = None) -> str tags_data = _load_canonical_tags() taxonomy_text = _format_taxonomy_for_prompt(tags_data) - # Build moments text for the LLM - moments_lines = [] - for i, m in enumerate(moments): - moments_lines.append( - f"[{i}] Title: {m.title}\n" - f" Summary: {m.summary}\n" - f" Content type: {m.content_type.value}\n" - f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}" - ) - moments_text = "\n\n".join(moments_lines) - system_prompt = _load_prompt("stage4_classification.txt") - user_prompt = ( - f"\n{taxonomy_text}\n\n\n" - f"\n{moments_text}\n" - ) - llm = _get_llm_client() model_override, modality = _get_stage_config(4) hard_limit = get_settings().llm_max_tokens_hard_limit - max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage4_classification", hard_limit=hard_limit) - logger.info("Stage 4 using model=%s, modality=%s, max_tokens=%d", model_override or "default", modality, max_tokens) - raw = llm.complete(system_prompt, user_prompt, response_model=ClassificationResult, on_complete=_make_llm_callback(video_id, "stage4_classification", system_prompt=system_prompt, user_prompt=user_prompt, run_id=run_id), - modality=modality, model_override=model_override, max_tokens=max_tokens) - result = _safe_parse_llm_response(raw, ClassificationResult, llm, system_prompt, user_prompt, - modality=modality, model_override=model_override, max_tokens=max_tokens) + + # Batch moments for classification + all_classifications = [] + for batch_start in range(0, len(moments), _STAGE4_BATCH_SIZE): + batch = moments[batch_start:batch_start + _STAGE4_BATCH_SIZE] + result = _classify_moment_batch( + batch, batch_start, taxonomy_text, system_prompt, + llm, model_override, modality, hard_limit, + video_id, run_id, + ) + # Reindex: batch uses 0-based indices, remap to global indices + for cls in result.classifications: + cls.moment_index += batch_start + all_classifications.extend(result.classifications) # Apply content_type overrides and prepare classification data for stage 5 classification_data = [] - moment_ids = [str(m.id) for m in moments] - for cls in result.classifications: + for cls in all_classifications: if 0 <= cls.moment_index < len(moments): moment = moments[cls.moment_index] @@ -572,10 +653,12 @@ def stage4_classification(self, video_id: str, run_id: str | None = None) -> str _store_classification_data(video_id, classification_data) elapsed = time.monotonic() - start + num_batches = (len(moments) + _STAGE4_BATCH_SIZE - 1) // _STAGE4_BATCH_SIZE _emit_event(video_id, "stage4_classification", "complete", run_id=run_id) logger.info( - "Stage 4 (classification) completed for video_id=%s in %.1fs — %d moments classified", - video_id, elapsed, len(classification_data), + "Stage 4 (classification) completed for video_id=%s in %.1fs — " + "%d moments classified in %d batch(es)", + video_id, elapsed, len(classification_data), num_batches, ) return video_id @@ -1110,6 +1193,77 @@ def _load_prior_pages(video_id: str) -> list[str]: return [] return json.loads(raw) + +# ── Stage completion detection for auto-resume ─────────────────────────────── + +# Ordered list of pipeline stages for resumability logic +_PIPELINE_STAGES = [ + "stage2_segmentation", + "stage3_extraction", + "stage4_classification", + "stage5_synthesis", + "stage6_embed_and_index", +] + +_STAGE_TASKS = { + "stage2_segmentation": stage2_segmentation, + "stage3_extraction": stage3_extraction, + "stage4_classification": stage4_classification, + "stage5_synthesis": stage5_synthesis, + "stage6_embed_and_index": stage6_embed_and_index, +} + + +def _get_last_completed_stage(video_id: str) -> str | None: + """Find the last stage that completed successfully for this video. + + Queries pipeline_events for the most recent run, looking for 'complete' + events. Returns the stage name (e.g. 'stage3_extraction') or None if + no stages have completed. + """ + session = _get_sync_session() + try: + # Find the most recent run for this video + from models import PipelineRun + latest_run = session.execute( + select(PipelineRun) + .where(PipelineRun.video_id == video_id) + .order_by(PipelineRun.started_at.desc()) + .limit(1) + ).scalar_one_or_none() + + if latest_run is None: + return None + + # Get all 'complete' events from that run + completed_events = session.execute( + select(PipelineEvent.stage) + .where( + PipelineEvent.run_id == str(latest_run.id), + PipelineEvent.event_type == "complete", + ) + ).scalars().all() + + completed_set = set(completed_events) + + # Walk backwards through the ordered stages to find the last completed one + last_completed = None + for stage_name in _PIPELINE_STAGES: + if stage_name in completed_set: + last_completed = stage_name + else: + break # Stop at first gap — stages must be sequential + + if last_completed: + logger.info( + "Auto-resume: video_id=%s last completed stage=%s (run_id=%s)", + video_id, last_completed, latest_run.id, + ) + return last_completed + finally: + session.close() + + # ── Orchestrator ───────────────────────────────────────────────────────────── @celery_app.task @@ -1175,13 +1329,13 @@ def _finish_run(run_id: str, status: str, error_stage: str | None = None) -> Non @celery_app.task def run_pipeline(video_id: str, trigger: str = "manual") -> str: - """Orchestrate the full pipeline (stages 2-5) with resumability. + """Orchestrate the full pipeline (stages 2-6) with auto-resume. - Checks the current processing_status of the video and chains only the - stages that still need to run. For example: - - queued → stages 2, 3, 4, 5 - - processing/error → re-run full pipeline - - complete → no-op + For error/processing status, queries pipeline_events to find the last + stage that completed successfully and resumes from the next stage. + This avoids re-running expensive LLM stages that already succeeded. + + For clean_reprocess trigger, always starts from stage 2. Returns the video_id. """ @@ -1204,6 +1358,13 @@ def run_pipeline(video_id: str, trigger: str = "manual") -> str: finally: session.close() + if status == ProcessingStatus.complete: + logger.info( + "run_pipeline: video_id=%s already at status=%s, nothing to do.", + video_id, status.value, + ) + return video_id + # Snapshot prior technique pages before pipeline resets key_moments _snapshot_prior_pages(video_id) @@ -1211,40 +1372,39 @@ def run_pipeline(video_id: str, trigger: str = "manual") -> str: run_id = _create_run(video_id, trigger) logger.info("run_pipeline: created run_id=%s for video_id=%s (trigger=%s)", run_id, video_id, trigger) - # Build the chain based on current status - stages = [] - if status in (ProcessingStatus.not_started, ProcessingStatus.queued): - stages = [ - stage2_segmentation.s(video_id, run_id=run_id), - stage3_extraction.s(run_id=run_id), # receives video_id from previous - stage4_classification.s(run_id=run_id), - stage5_synthesis.s(run_id=run_id), - stage6_embed_and_index.s(run_id=run_id), - ] - elif status == ProcessingStatus.processing: - stages = [ - stage2_segmentation.s(video_id, run_id=run_id), - stage3_extraction.s(run_id=run_id), - stage4_classification.s(run_id=run_id), - stage5_synthesis.s(run_id=run_id), - stage6_embed_and_index.s(run_id=run_id), - ] - elif status == ProcessingStatus.error: - stages = [ - stage2_segmentation.s(video_id, run_id=run_id), - stage3_extraction.s(run_id=run_id), - stage4_classification.s(run_id=run_id), - stage5_synthesis.s(run_id=run_id), - stage6_embed_and_index.s(run_id=run_id), - ] - elif status == ProcessingStatus.complete: - logger.info( - "run_pipeline: video_id=%s already at status=%s, nothing to do.", - video_id, status.value, - ) - return video_id + # Determine which stages to run + resume_from_idx = 0 # Default: start from stage 2 - if stages: + if trigger != "clean_reprocess" and status in (ProcessingStatus.processing, ProcessingStatus.error): + # Try to resume from where we left off + last_completed = _get_last_completed_stage(video_id) + if last_completed and last_completed in _PIPELINE_STAGES: + completed_idx = _PIPELINE_STAGES.index(last_completed) + resume_from_idx = completed_idx + 1 + if resume_from_idx >= len(_PIPELINE_STAGES): + logger.info( + "run_pipeline: all stages already completed for video_id=%s", + video_id, + ) + return video_id + + stages_to_run = _PIPELINE_STAGES[resume_from_idx:] + logger.info( + "run_pipeline: video_id=%s will run stages: %s (resume_from_idx=%d)", + video_id, stages_to_run, resume_from_idx, + ) + + # Build the Celery chain — first stage gets video_id as arg, + # subsequent stages receive it from the previous stage's return value + celery_sigs = [] + for i, stage_name in enumerate(stages_to_run): + task_func = _STAGE_TASKS[stage_name] + if i == 0: + celery_sigs.append(task_func.s(video_id, run_id=run_id)) + else: + celery_sigs.append(task_func.s(run_id=run_id)) + + if celery_sigs: # Mark as processing before dispatching session = _get_sync_session() try: @@ -1256,12 +1416,12 @@ def run_pipeline(video_id: str, trigger: str = "manual") -> str: finally: session.close() - pipeline = celery_chain(*stages) + pipeline = celery_chain(*celery_sigs) error_cb = mark_pipeline_error.s(video_id, run_id=run_id) pipeline.apply_async(link_error=error_cb) logger.info( - "run_pipeline: dispatched %d stages for video_id=%s (run_id=%s)", - len(stages), video_id, run_id, + "run_pipeline: dispatched %d stages for video_id=%s (run_id=%s, starting at %s)", + len(celery_sigs), video_id, run_id, stages_to_run[0], ) return video_id