diff --git a/backend/pipeline/llm_client.py b/backend/pipeline/llm_client.py
index dc53772..5ec49e1 100644
--- a/backend/pipeline/llm_client.py
+++ b/backend/pipeline/llm_client.py
@@ -28,6 +28,38 @@ logger = logging.getLogger(__name__)
T = TypeVar("T", bound=BaseModel)
+
+# ── LLM Response wrapper ─────────────────────────────────────────────────────
+
+class LLMResponse(str):
+ """String subclass that carries LLM response metadata.
+
+ Backward-compatible with all code that treats the response as a plain
+ string, but callers that know about it can inspect finish_reason and
+ the truncated property.
+ """
+ finish_reason: str | None
+ prompt_tokens: int | None
+ completion_tokens: int | None
+
+ def __new__(
+ cls,
+ text: str,
+ finish_reason: str | None = None,
+ prompt_tokens: int | None = None,
+ completion_tokens: int | None = None,
+ ):
+ obj = super().__new__(cls, text)
+ obj.finish_reason = finish_reason
+ obj.prompt_tokens = prompt_tokens
+ obj.completion_tokens = completion_tokens
+ return obj
+
+ @property
+ def truncated(self) -> bool:
+ """True if the model hit its token limit before finishing."""
+ return self.finish_reason == "length"
+
# ── Think-tag stripping ──────────────────────────────────────────────────────
_THINK_PATTERN = re.compile(r".*?", re.DOTALL)
@@ -149,7 +181,7 @@ class LLMClient:
model_override: str | None = None,
on_complete: "Callable | None" = None,
max_tokens: int | None = None,
- ) -> str:
+ ) -> "LLMResponse":
"""Send a chat completion request, falling back on connection/timeout errors.
Parameters
@@ -174,8 +206,8 @@ class LLMClient:
Returns
-------
- str
- Raw completion text from the model (think tags stripped if thinking).
+ LLMResponse
+ Raw completion text (str subclass) with finish_reason metadata.
"""
kwargs: dict = {}
effective_system = system_prompt
@@ -230,6 +262,7 @@ class LLMClient:
)
if modality == "thinking":
raw = strip_think_tags(raw)
+ finish = response.choices[0].finish_reason if response.choices else None
if on_complete is not None:
try:
on_complete(
@@ -238,11 +271,16 @@ class LLMClient:
completion_tokens=usage.completion_tokens if usage else None,
total_tokens=usage.total_tokens if usage else None,
content=raw,
- finish_reason=response.choices[0].finish_reason if response.choices else None,
+ finish_reason=finish,
)
except Exception as cb_exc:
logger.warning("on_complete callback failed: %s", cb_exc)
- return raw
+ return LLMResponse(
+ raw,
+ finish_reason=finish,
+ prompt_tokens=usage.prompt_tokens if usage else None,
+ completion_tokens=usage.completion_tokens if usage else None,
+ )
except (openai.APIConnectionError, openai.APITimeoutError) as exc:
logger.warning(
@@ -270,6 +308,7 @@ class LLMClient:
)
if modality == "thinking":
raw = strip_think_tags(raw)
+ finish = response.choices[0].finish_reason if response.choices else None
if on_complete is not None:
try:
on_complete(
@@ -278,12 +317,17 @@ class LLMClient:
completion_tokens=usage.completion_tokens if usage else None,
total_tokens=usage.total_tokens if usage else None,
content=raw,
- finish_reason=response.choices[0].finish_reason if response.choices else None,
+ finish_reason=finish,
is_fallback=True,
)
except Exception as cb_exc:
logger.warning("on_complete callback failed: %s", cb_exc)
- return raw
+ return LLMResponse(
+ raw,
+ finish_reason=finish,
+ prompt_tokens=usage.prompt_tokens if usage else None,
+ completion_tokens=usage.completion_tokens if usage else None,
+ )
except (openai.APIConnectionError, openai.APITimeoutError, openai.APIError) as exc:
logger.error(
diff --git a/backend/pipeline/stages.py b/backend/pipeline/stages.py
index 598576e..c623767 100644
--- a/backend/pipeline/stages.py
+++ b/backend/pipeline/stages.py
@@ -36,7 +36,7 @@ from models import (
TranscriptSegment,
)
from pipeline.embedding_client import EmbeddingClient
-from pipeline.llm_client import LLMClient, estimate_max_tokens
+from pipeline.llm_client import LLMClient, LLMResponse, estimate_max_tokens
from pipeline.qdrant_client import QdrantManager
from pipeline.schemas import (
ClassificationResult,
@@ -49,6 +49,11 @@ from worker import celery_app
logger = logging.getLogger(__name__)
+class LLMTruncationError(RuntimeError):
+ """Raised when the LLM response was truncated (finish_reason=length)."""
+ pass
+
+
# ── Error status helper ──────────────────────────────────────────────────────
def _set_error_status(video_id: str, stage_name: str, error: Exception) -> None:
@@ -256,7 +261,7 @@ def _format_taxonomy_for_prompt(tags_data: dict) -> str:
def _safe_parse_llm_response(
- raw: str,
+ raw,
model_cls,
llm: LLMClient,
system_prompt: str,
@@ -265,14 +270,37 @@ def _safe_parse_llm_response(
model_override: str | None = None,
max_tokens: int | None = None,
):
- """Parse LLM response with one retry on failure.
+ """Parse LLM response with truncation detection and one retry on failure.
- On malformed response: log the raw text, retry once with a JSON nudge,
- then raise on second failure.
+ If the response was truncated (finish_reason=length), raises
+ LLMTruncationError immediately — retrying with a JSON nudge would only
+ make things worse by adding tokens to an already-too-large prompt.
+
+ For non-truncation parse failures: retry once with a JSON nudge, then
+ raise on second failure.
"""
+ # Check for truncation before attempting parse
+ is_truncated = isinstance(raw, LLMResponse) and raw.truncated
+ if is_truncated:
+ logger.warning(
+ "LLM response truncated (finish=length) for %s. "
+ "prompt_tokens=%s, completion_tokens=%s. Will not retry with nudge.",
+ model_cls.__name__,
+ getattr(raw, "prompt_tokens", "?"),
+ getattr(raw, "completion_tokens", "?"),
+ )
+
try:
return llm.parse_response(raw, model_cls)
except (ValidationError, ValueError, json.JSONDecodeError) as exc:
+ if is_truncated:
+ raise LLMTruncationError(
+ f"LLM output truncated for {model_cls.__name__}: "
+ f"prompt_tokens={getattr(raw, 'prompt_tokens', '?')}, "
+ f"completion_tokens={getattr(raw, 'completion_tokens', '?')}. "
+ f"Response too large for model context window."
+ ) from exc
+
logger.warning(
"First parse attempt failed for %s (%s). Retrying with JSON nudge. "
"Raw response (first 500 chars): %.500s",
@@ -479,6 +507,66 @@ def stage3_extraction(self, video_id: str, run_id: str | None = None) -> str:
# ── Stage 4: Classification ─────────────────────────────────────────────────
+# Maximum moments per classification batch. Keeps each LLM call well within
+# context window limits. Batches are classified independently and merged.
+_STAGE4_BATCH_SIZE = 20
+
+
+def _classify_moment_batch(
+ moments_batch: list,
+ batch_offset: int,
+ taxonomy_text: str,
+ system_prompt: str,
+ llm: LLMClient,
+ model_override: str | None,
+ modality: str,
+ hard_limit: int,
+ video_id: str,
+ run_id: str | None,
+) -> ClassificationResult:
+ """Classify a single batch of moments. Raises on failure."""
+ moments_lines = []
+ for i, m in enumerate(moments_batch):
+ moments_lines.append(
+ f"[{i}] Title: {m.title}\n"
+ f" Summary: {m.summary}\n"
+ f" Content type: {m.content_type.value}\n"
+ f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}"
+ )
+ moments_text = "\n\n".join(moments_lines)
+
+ user_prompt = (
+ f"\n{taxonomy_text}\n\n\n"
+ f"\n{moments_text}\n"
+ )
+
+ max_tokens = estimate_max_tokens(
+ system_prompt, user_prompt,
+ stage="stage4_classification", hard_limit=hard_limit,
+ )
+ batch_label = f"batch {batch_offset // _STAGE4_BATCH_SIZE + 1} (moments {batch_offset}-{batch_offset + len(moments_batch) - 1})"
+ logger.info(
+ "Stage 4 classifying %s, max_tokens=%d",
+ batch_label, max_tokens,
+ )
+ raw = llm.complete(
+ system_prompt, user_prompt,
+ response_model=ClassificationResult,
+ on_complete=_make_llm_callback(
+ video_id, "stage4_classification",
+ system_prompt=system_prompt, user_prompt=user_prompt,
+ run_id=run_id, context_label=batch_label,
+ ),
+ modality=modality, model_override=model_override,
+ max_tokens=max_tokens,
+ )
+ return _safe_parse_llm_response(
+ raw, ClassificationResult, llm, system_prompt, user_prompt,
+ modality=modality, model_override=model_override,
+ max_tokens=max_tokens,
+ )
+
+
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
def stage4_classification(self, video_id: str, run_id: str | None = None) -> str:
"""Classify key moments against the canonical tag taxonomy.
@@ -487,6 +575,9 @@ def stage4_classification(self, video_id: str, run_id: str | None = None) -> str
canonical taxonomy, and stores classification results in Redis for
stage 5 consumption. Updates content_type if the classifier overrides it.
+ For large moment sets, automatically batches into groups of
+ _STAGE4_BATCH_SIZE to stay within model context window limits.
+
Stage 4 does NOT change processing_status.
Returns the video_id for chain compatibility.
@@ -510,7 +601,6 @@ def stage4_classification(self, video_id: str, run_id: str | None = None) -> str
if not moments:
logger.info("Stage 4: No moments found for video_id=%s, skipping.", video_id)
- # Store empty classification data
_store_classification_data(video_id, [])
return video_id
@@ -518,38 +608,29 @@ def stage4_classification(self, video_id: str, run_id: str | None = None) -> str
tags_data = _load_canonical_tags()
taxonomy_text = _format_taxonomy_for_prompt(tags_data)
- # Build moments text for the LLM
- moments_lines = []
- for i, m in enumerate(moments):
- moments_lines.append(
- f"[{i}] Title: {m.title}\n"
- f" Summary: {m.summary}\n"
- f" Content type: {m.content_type.value}\n"
- f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}"
- )
- moments_text = "\n\n".join(moments_lines)
-
system_prompt = _load_prompt("stage4_classification.txt")
- user_prompt = (
- f"\n{taxonomy_text}\n\n\n"
- f"\n{moments_text}\n"
- )
-
llm = _get_llm_client()
model_override, modality = _get_stage_config(4)
hard_limit = get_settings().llm_max_tokens_hard_limit
- max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage4_classification", hard_limit=hard_limit)
- logger.info("Stage 4 using model=%s, modality=%s, max_tokens=%d", model_override or "default", modality, max_tokens)
- raw = llm.complete(system_prompt, user_prompt, response_model=ClassificationResult, on_complete=_make_llm_callback(video_id, "stage4_classification", system_prompt=system_prompt, user_prompt=user_prompt, run_id=run_id),
- modality=modality, model_override=model_override, max_tokens=max_tokens)
- result = _safe_parse_llm_response(raw, ClassificationResult, llm, system_prompt, user_prompt,
- modality=modality, model_override=model_override, max_tokens=max_tokens)
+
+ # Batch moments for classification
+ all_classifications = []
+ for batch_start in range(0, len(moments), _STAGE4_BATCH_SIZE):
+ batch = moments[batch_start:batch_start + _STAGE4_BATCH_SIZE]
+ result = _classify_moment_batch(
+ batch, batch_start, taxonomy_text, system_prompt,
+ llm, model_override, modality, hard_limit,
+ video_id, run_id,
+ )
+ # Reindex: batch uses 0-based indices, remap to global indices
+ for cls in result.classifications:
+ cls.moment_index += batch_start
+ all_classifications.extend(result.classifications)
# Apply content_type overrides and prepare classification data for stage 5
classification_data = []
- moment_ids = [str(m.id) for m in moments]
- for cls in result.classifications:
+ for cls in all_classifications:
if 0 <= cls.moment_index < len(moments):
moment = moments[cls.moment_index]
@@ -572,10 +653,12 @@ def stage4_classification(self, video_id: str, run_id: str | None = None) -> str
_store_classification_data(video_id, classification_data)
elapsed = time.monotonic() - start
+ num_batches = (len(moments) + _STAGE4_BATCH_SIZE - 1) // _STAGE4_BATCH_SIZE
_emit_event(video_id, "stage4_classification", "complete", run_id=run_id)
logger.info(
- "Stage 4 (classification) completed for video_id=%s in %.1fs — %d moments classified",
- video_id, elapsed, len(classification_data),
+ "Stage 4 (classification) completed for video_id=%s in %.1fs — "
+ "%d moments classified in %d batch(es)",
+ video_id, elapsed, len(classification_data), num_batches,
)
return video_id
@@ -1110,6 +1193,77 @@ def _load_prior_pages(video_id: str) -> list[str]:
return []
return json.loads(raw)
+
+# ── Stage completion detection for auto-resume ───────────────────────────────
+
+# Ordered list of pipeline stages for resumability logic
+_PIPELINE_STAGES = [
+ "stage2_segmentation",
+ "stage3_extraction",
+ "stage4_classification",
+ "stage5_synthesis",
+ "stage6_embed_and_index",
+]
+
+_STAGE_TASKS = {
+ "stage2_segmentation": stage2_segmentation,
+ "stage3_extraction": stage3_extraction,
+ "stage4_classification": stage4_classification,
+ "stage5_synthesis": stage5_synthesis,
+ "stage6_embed_and_index": stage6_embed_and_index,
+}
+
+
+def _get_last_completed_stage(video_id: str) -> str | None:
+ """Find the last stage that completed successfully for this video.
+
+ Queries pipeline_events for the most recent run, looking for 'complete'
+ events. Returns the stage name (e.g. 'stage3_extraction') or None if
+ no stages have completed.
+ """
+ session = _get_sync_session()
+ try:
+ # Find the most recent run for this video
+ from models import PipelineRun
+ latest_run = session.execute(
+ select(PipelineRun)
+ .where(PipelineRun.video_id == video_id)
+ .order_by(PipelineRun.started_at.desc())
+ .limit(1)
+ ).scalar_one_or_none()
+
+ if latest_run is None:
+ return None
+
+ # Get all 'complete' events from that run
+ completed_events = session.execute(
+ select(PipelineEvent.stage)
+ .where(
+ PipelineEvent.run_id == str(latest_run.id),
+ PipelineEvent.event_type == "complete",
+ )
+ ).scalars().all()
+
+ completed_set = set(completed_events)
+
+ # Walk backwards through the ordered stages to find the last completed one
+ last_completed = None
+ for stage_name in _PIPELINE_STAGES:
+ if stage_name in completed_set:
+ last_completed = stage_name
+ else:
+ break # Stop at first gap — stages must be sequential
+
+ if last_completed:
+ logger.info(
+ "Auto-resume: video_id=%s last completed stage=%s (run_id=%s)",
+ video_id, last_completed, latest_run.id,
+ )
+ return last_completed
+ finally:
+ session.close()
+
+
# ── Orchestrator ─────────────────────────────────────────────────────────────
@celery_app.task
@@ -1175,13 +1329,13 @@ def _finish_run(run_id: str, status: str, error_stage: str | None = None) -> Non
@celery_app.task
def run_pipeline(video_id: str, trigger: str = "manual") -> str:
- """Orchestrate the full pipeline (stages 2-5) with resumability.
+ """Orchestrate the full pipeline (stages 2-6) with auto-resume.
- Checks the current processing_status of the video and chains only the
- stages that still need to run. For example:
- - queued → stages 2, 3, 4, 5
- - processing/error → re-run full pipeline
- - complete → no-op
+ For error/processing status, queries pipeline_events to find the last
+ stage that completed successfully and resumes from the next stage.
+ This avoids re-running expensive LLM stages that already succeeded.
+
+ For clean_reprocess trigger, always starts from stage 2.
Returns the video_id.
"""
@@ -1204,6 +1358,13 @@ def run_pipeline(video_id: str, trigger: str = "manual") -> str:
finally:
session.close()
+ if status == ProcessingStatus.complete:
+ logger.info(
+ "run_pipeline: video_id=%s already at status=%s, nothing to do.",
+ video_id, status.value,
+ )
+ return video_id
+
# Snapshot prior technique pages before pipeline resets key_moments
_snapshot_prior_pages(video_id)
@@ -1211,40 +1372,39 @@ def run_pipeline(video_id: str, trigger: str = "manual") -> str:
run_id = _create_run(video_id, trigger)
logger.info("run_pipeline: created run_id=%s for video_id=%s (trigger=%s)", run_id, video_id, trigger)
- # Build the chain based on current status
- stages = []
- if status in (ProcessingStatus.not_started, ProcessingStatus.queued):
- stages = [
- stage2_segmentation.s(video_id, run_id=run_id),
- stage3_extraction.s(run_id=run_id), # receives video_id from previous
- stage4_classification.s(run_id=run_id),
- stage5_synthesis.s(run_id=run_id),
- stage6_embed_and_index.s(run_id=run_id),
- ]
- elif status == ProcessingStatus.processing:
- stages = [
- stage2_segmentation.s(video_id, run_id=run_id),
- stage3_extraction.s(run_id=run_id),
- stage4_classification.s(run_id=run_id),
- stage5_synthesis.s(run_id=run_id),
- stage6_embed_and_index.s(run_id=run_id),
- ]
- elif status == ProcessingStatus.error:
- stages = [
- stage2_segmentation.s(video_id, run_id=run_id),
- stage3_extraction.s(run_id=run_id),
- stage4_classification.s(run_id=run_id),
- stage5_synthesis.s(run_id=run_id),
- stage6_embed_and_index.s(run_id=run_id),
- ]
- elif status == ProcessingStatus.complete:
- logger.info(
- "run_pipeline: video_id=%s already at status=%s, nothing to do.",
- video_id, status.value,
- )
- return video_id
+ # Determine which stages to run
+ resume_from_idx = 0 # Default: start from stage 2
- if stages:
+ if trigger != "clean_reprocess" and status in (ProcessingStatus.processing, ProcessingStatus.error):
+ # Try to resume from where we left off
+ last_completed = _get_last_completed_stage(video_id)
+ if last_completed and last_completed in _PIPELINE_STAGES:
+ completed_idx = _PIPELINE_STAGES.index(last_completed)
+ resume_from_idx = completed_idx + 1
+ if resume_from_idx >= len(_PIPELINE_STAGES):
+ logger.info(
+ "run_pipeline: all stages already completed for video_id=%s",
+ video_id,
+ )
+ return video_id
+
+ stages_to_run = _PIPELINE_STAGES[resume_from_idx:]
+ logger.info(
+ "run_pipeline: video_id=%s will run stages: %s (resume_from_idx=%d)",
+ video_id, stages_to_run, resume_from_idx,
+ )
+
+ # Build the Celery chain — first stage gets video_id as arg,
+ # subsequent stages receive it from the previous stage's return value
+ celery_sigs = []
+ for i, stage_name in enumerate(stages_to_run):
+ task_func = _STAGE_TASKS[stage_name]
+ if i == 0:
+ celery_sigs.append(task_func.s(video_id, run_id=run_id))
+ else:
+ celery_sigs.append(task_func.s(run_id=run_id))
+
+ if celery_sigs:
# Mark as processing before dispatching
session = _get_sync_session()
try:
@@ -1256,12 +1416,12 @@ def run_pipeline(video_id: str, trigger: str = "manual") -> str:
finally:
session.close()
- pipeline = celery_chain(*stages)
+ pipeline = celery_chain(*celery_sigs)
error_cb = mark_pipeline_error.s(video_id, run_id=run_id)
pipeline.apply_async(link_error=error_cb)
logger.info(
- "run_pipeline: dispatched %d stages for video_id=%s (run_id=%s)",
- len(stages), video_id, run_id,
+ "run_pipeline: dispatched %d stages for video_id=%s (run_id=%s, starting at %s)",
+ len(celery_sigs), video_id, run_id, stages_to_run[0],
)
return video_id