feat: Truncation detection, batched classification, and pipeline auto-resume

Three resilience improvements to the pipeline:

1. LLMResponse(str) subclass carries finish_reason metadata from the LLM.
   _safe_parse_llm_response detects truncation (finish=length) and raises
   LLMTruncationError instead of wastefully retrying with a JSON nudge
   that makes the prompt even longer.

2. Stage 4 classification now batches moments (20 per call) instead of
   sending all moments in a single LLM call. Prevents context window
   overflow for videos with many moments. Batch results are merged with
   reindexed moment_index values.

3. run_pipeline auto-resumes from the last completed stage on error/retry
   instead of always restarting from stage 2. Queries pipeline_events for
   the most recent run to find completed stages. clean_reprocess trigger
   still forces a full restart.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
jlightner 2026-03-31 17:48:19 -05:00
parent 5984129e25
commit e80094dc05
2 changed files with 285 additions and 81 deletions

View file

@ -28,6 +28,38 @@ logger = logging.getLogger(__name__)
T = TypeVar("T", bound=BaseModel)
# ── LLM Response wrapper ─────────────────────────────────────────────────────
class LLMResponse(str):
"""String subclass that carries LLM response metadata.
Backward-compatible with all code that treats the response as a plain
string, but callers that know about it can inspect finish_reason and
the truncated property.
"""
finish_reason: str | None
prompt_tokens: int | None
completion_tokens: int | None
def __new__(
cls,
text: str,
finish_reason: str | None = None,
prompt_tokens: int | None = None,
completion_tokens: int | None = None,
):
obj = super().__new__(cls, text)
obj.finish_reason = finish_reason
obj.prompt_tokens = prompt_tokens
obj.completion_tokens = completion_tokens
return obj
@property
def truncated(self) -> bool:
"""True if the model hit its token limit before finishing."""
return self.finish_reason == "length"
# ── Think-tag stripping ──────────────────────────────────────────────────────
_THINK_PATTERN = re.compile(r"<think>.*?</think>", re.DOTALL)
@ -149,7 +181,7 @@ class LLMClient:
model_override: str | None = None,
on_complete: "Callable | None" = None,
max_tokens: int | None = None,
) -> str:
) -> "LLMResponse":
"""Send a chat completion request, falling back on connection/timeout errors.
Parameters
@ -174,8 +206,8 @@ class LLMClient:
Returns
-------
str
Raw completion text from the model (think tags stripped if thinking).
LLMResponse
Raw completion text (str subclass) with finish_reason metadata.
"""
kwargs: dict = {}
effective_system = system_prompt
@ -230,6 +262,7 @@ class LLMClient:
)
if modality == "thinking":
raw = strip_think_tags(raw)
finish = response.choices[0].finish_reason if response.choices else None
if on_complete is not None:
try:
on_complete(
@ -238,11 +271,16 @@ class LLMClient:
completion_tokens=usage.completion_tokens if usage else None,
total_tokens=usage.total_tokens if usage else None,
content=raw,
finish_reason=response.choices[0].finish_reason if response.choices else None,
finish_reason=finish,
)
except Exception as cb_exc:
logger.warning("on_complete callback failed: %s", cb_exc)
return raw
return LLMResponse(
raw,
finish_reason=finish,
prompt_tokens=usage.prompt_tokens if usage else None,
completion_tokens=usage.completion_tokens if usage else None,
)
except (openai.APIConnectionError, openai.APITimeoutError) as exc:
logger.warning(
@ -270,6 +308,7 @@ class LLMClient:
)
if modality == "thinking":
raw = strip_think_tags(raw)
finish = response.choices[0].finish_reason if response.choices else None
if on_complete is not None:
try:
on_complete(
@ -278,12 +317,17 @@ class LLMClient:
completion_tokens=usage.completion_tokens if usage else None,
total_tokens=usage.total_tokens if usage else None,
content=raw,
finish_reason=response.choices[0].finish_reason if response.choices else None,
finish_reason=finish,
is_fallback=True,
)
except Exception as cb_exc:
logger.warning("on_complete callback failed: %s", cb_exc)
return raw
return LLMResponse(
raw,
finish_reason=finish,
prompt_tokens=usage.prompt_tokens if usage else None,
completion_tokens=usage.completion_tokens if usage else None,
)
except (openai.APIConnectionError, openai.APITimeoutError, openai.APIError) as exc:
logger.error(

View file

@ -36,7 +36,7 @@ from models import (
TranscriptSegment,
)
from pipeline.embedding_client import EmbeddingClient
from pipeline.llm_client import LLMClient, estimate_max_tokens
from pipeline.llm_client import LLMClient, LLMResponse, estimate_max_tokens
from pipeline.qdrant_client import QdrantManager
from pipeline.schemas import (
ClassificationResult,
@ -49,6 +49,11 @@ from worker import celery_app
logger = logging.getLogger(__name__)
class LLMTruncationError(RuntimeError):
"""Raised when the LLM response was truncated (finish_reason=length)."""
pass
# ── Error status helper ──────────────────────────────────────────────────────
def _set_error_status(video_id: str, stage_name: str, error: Exception) -> None:
@ -256,7 +261,7 @@ def _format_taxonomy_for_prompt(tags_data: dict) -> str:
def _safe_parse_llm_response(
raw: str,
raw,
model_cls,
llm: LLMClient,
system_prompt: str,
@ -265,14 +270,37 @@ def _safe_parse_llm_response(
model_override: str | None = None,
max_tokens: int | None = None,
):
"""Parse LLM response with one retry on failure.
"""Parse LLM response with truncation detection and one retry on failure.
On malformed response: log the raw text, retry once with a JSON nudge,
then raise on second failure.
If the response was truncated (finish_reason=length), raises
LLMTruncationError immediately retrying with a JSON nudge would only
make things worse by adding tokens to an already-too-large prompt.
For non-truncation parse failures: retry once with a JSON nudge, then
raise on second failure.
"""
# Check for truncation before attempting parse
is_truncated = isinstance(raw, LLMResponse) and raw.truncated
if is_truncated:
logger.warning(
"LLM response truncated (finish=length) for %s. "
"prompt_tokens=%s, completion_tokens=%s. Will not retry with nudge.",
model_cls.__name__,
getattr(raw, "prompt_tokens", "?"),
getattr(raw, "completion_tokens", "?"),
)
try:
return llm.parse_response(raw, model_cls)
except (ValidationError, ValueError, json.JSONDecodeError) as exc:
if is_truncated:
raise LLMTruncationError(
f"LLM output truncated for {model_cls.__name__}: "
f"prompt_tokens={getattr(raw, 'prompt_tokens', '?')}, "
f"completion_tokens={getattr(raw, 'completion_tokens', '?')}. "
f"Response too large for model context window."
) from exc
logger.warning(
"First parse attempt failed for %s (%s). Retrying with JSON nudge. "
"Raw response (first 500 chars): %.500s",
@ -479,6 +507,66 @@ def stage3_extraction(self, video_id: str, run_id: str | None = None) -> str:
# ── Stage 4: Classification ─────────────────────────────────────────────────
# Maximum moments per classification batch. Keeps each LLM call well within
# context window limits. Batches are classified independently and merged.
_STAGE4_BATCH_SIZE = 20
def _classify_moment_batch(
moments_batch: list,
batch_offset: int,
taxonomy_text: str,
system_prompt: str,
llm: LLMClient,
model_override: str | None,
modality: str,
hard_limit: int,
video_id: str,
run_id: str | None,
) -> ClassificationResult:
"""Classify a single batch of moments. Raises on failure."""
moments_lines = []
for i, m in enumerate(moments_batch):
moments_lines.append(
f"[{i}] Title: {m.title}\n"
f" Summary: {m.summary}\n"
f" Content type: {m.content_type.value}\n"
f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}"
)
moments_text = "\n\n".join(moments_lines)
user_prompt = (
f"<taxonomy>\n{taxonomy_text}\n</taxonomy>\n\n"
f"<moments>\n{moments_text}\n</moments>"
)
max_tokens = estimate_max_tokens(
system_prompt, user_prompt,
stage="stage4_classification", hard_limit=hard_limit,
)
batch_label = f"batch {batch_offset // _STAGE4_BATCH_SIZE + 1} (moments {batch_offset}-{batch_offset + len(moments_batch) - 1})"
logger.info(
"Stage 4 classifying %s, max_tokens=%d",
batch_label, max_tokens,
)
raw = llm.complete(
system_prompt, user_prompt,
response_model=ClassificationResult,
on_complete=_make_llm_callback(
video_id, "stage4_classification",
system_prompt=system_prompt, user_prompt=user_prompt,
run_id=run_id, context_label=batch_label,
),
modality=modality, model_override=model_override,
max_tokens=max_tokens,
)
return _safe_parse_llm_response(
raw, ClassificationResult, llm, system_prompt, user_prompt,
modality=modality, model_override=model_override,
max_tokens=max_tokens,
)
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
def stage4_classification(self, video_id: str, run_id: str | None = None) -> str:
"""Classify key moments against the canonical tag taxonomy.
@ -487,6 +575,9 @@ def stage4_classification(self, video_id: str, run_id: str | None = None) -> str
canonical taxonomy, and stores classification results in Redis for
stage 5 consumption. Updates content_type if the classifier overrides it.
For large moment sets, automatically batches into groups of
_STAGE4_BATCH_SIZE to stay within model context window limits.
Stage 4 does NOT change processing_status.
Returns the video_id for chain compatibility.
@ -510,7 +601,6 @@ def stage4_classification(self, video_id: str, run_id: str | None = None) -> str
if not moments:
logger.info("Stage 4: No moments found for video_id=%s, skipping.", video_id)
# Store empty classification data
_store_classification_data(video_id, [])
return video_id
@ -518,38 +608,29 @@ def stage4_classification(self, video_id: str, run_id: str | None = None) -> str
tags_data = _load_canonical_tags()
taxonomy_text = _format_taxonomy_for_prompt(tags_data)
# Build moments text for the LLM
moments_lines = []
for i, m in enumerate(moments):
moments_lines.append(
f"[{i}] Title: {m.title}\n"
f" Summary: {m.summary}\n"
f" Content type: {m.content_type.value}\n"
f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}"
)
moments_text = "\n\n".join(moments_lines)
system_prompt = _load_prompt("stage4_classification.txt")
user_prompt = (
f"<taxonomy>\n{taxonomy_text}\n</taxonomy>\n\n"
f"<moments>\n{moments_text}\n</moments>"
)
llm = _get_llm_client()
model_override, modality = _get_stage_config(4)
hard_limit = get_settings().llm_max_tokens_hard_limit
max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage4_classification", hard_limit=hard_limit)
logger.info("Stage 4 using model=%s, modality=%s, max_tokens=%d", model_override or "default", modality, max_tokens)
raw = llm.complete(system_prompt, user_prompt, response_model=ClassificationResult, on_complete=_make_llm_callback(video_id, "stage4_classification", system_prompt=system_prompt, user_prompt=user_prompt, run_id=run_id),
modality=modality, model_override=model_override, max_tokens=max_tokens)
result = _safe_parse_llm_response(raw, ClassificationResult, llm, system_prompt, user_prompt,
modality=modality, model_override=model_override, max_tokens=max_tokens)
# Batch moments for classification
all_classifications = []
for batch_start in range(0, len(moments), _STAGE4_BATCH_SIZE):
batch = moments[batch_start:batch_start + _STAGE4_BATCH_SIZE]
result = _classify_moment_batch(
batch, batch_start, taxonomy_text, system_prompt,
llm, model_override, modality, hard_limit,
video_id, run_id,
)
# Reindex: batch uses 0-based indices, remap to global indices
for cls in result.classifications:
cls.moment_index += batch_start
all_classifications.extend(result.classifications)
# Apply content_type overrides and prepare classification data for stage 5
classification_data = []
moment_ids = [str(m.id) for m in moments]
for cls in result.classifications:
for cls in all_classifications:
if 0 <= cls.moment_index < len(moments):
moment = moments[cls.moment_index]
@ -572,10 +653,12 @@ def stage4_classification(self, video_id: str, run_id: str | None = None) -> str
_store_classification_data(video_id, classification_data)
elapsed = time.monotonic() - start
num_batches = (len(moments) + _STAGE4_BATCH_SIZE - 1) // _STAGE4_BATCH_SIZE
_emit_event(video_id, "stage4_classification", "complete", run_id=run_id)
logger.info(
"Stage 4 (classification) completed for video_id=%s in %.1fs — %d moments classified",
video_id, elapsed, len(classification_data),
"Stage 4 (classification) completed for video_id=%s in %.1fs — "
"%d moments classified in %d batch(es)",
video_id, elapsed, len(classification_data), num_batches,
)
return video_id
@ -1110,6 +1193,77 @@ def _load_prior_pages(video_id: str) -> list[str]:
return []
return json.loads(raw)
# ── Stage completion detection for auto-resume ───────────────────────────────
# Ordered list of pipeline stages for resumability logic
_PIPELINE_STAGES = [
"stage2_segmentation",
"stage3_extraction",
"stage4_classification",
"stage5_synthesis",
"stage6_embed_and_index",
]
_STAGE_TASKS = {
"stage2_segmentation": stage2_segmentation,
"stage3_extraction": stage3_extraction,
"stage4_classification": stage4_classification,
"stage5_synthesis": stage5_synthesis,
"stage6_embed_and_index": stage6_embed_and_index,
}
def _get_last_completed_stage(video_id: str) -> str | None:
"""Find the last stage that completed successfully for this video.
Queries pipeline_events for the most recent run, looking for 'complete'
events. Returns the stage name (e.g. 'stage3_extraction') or None if
no stages have completed.
"""
session = _get_sync_session()
try:
# Find the most recent run for this video
from models import PipelineRun
latest_run = session.execute(
select(PipelineRun)
.where(PipelineRun.video_id == video_id)
.order_by(PipelineRun.started_at.desc())
.limit(1)
).scalar_one_or_none()
if latest_run is None:
return None
# Get all 'complete' events from that run
completed_events = session.execute(
select(PipelineEvent.stage)
.where(
PipelineEvent.run_id == str(latest_run.id),
PipelineEvent.event_type == "complete",
)
).scalars().all()
completed_set = set(completed_events)
# Walk backwards through the ordered stages to find the last completed one
last_completed = None
for stage_name in _PIPELINE_STAGES:
if stage_name in completed_set:
last_completed = stage_name
else:
break # Stop at first gap — stages must be sequential
if last_completed:
logger.info(
"Auto-resume: video_id=%s last completed stage=%s (run_id=%s)",
video_id, last_completed, latest_run.id,
)
return last_completed
finally:
session.close()
# ── Orchestrator ─────────────────────────────────────────────────────────────
@celery_app.task
@ -1175,13 +1329,13 @@ def _finish_run(run_id: str, status: str, error_stage: str | None = None) -> Non
@celery_app.task
def run_pipeline(video_id: str, trigger: str = "manual") -> str:
"""Orchestrate the full pipeline (stages 2-5) with resumability.
"""Orchestrate the full pipeline (stages 2-6) with auto-resume.
Checks the current processing_status of the video and chains only the
stages that still need to run. For example:
- queued stages 2, 3, 4, 5
- processing/error re-run full pipeline
- complete no-op
For error/processing status, queries pipeline_events to find the last
stage that completed successfully and resumes from the next stage.
This avoids re-running expensive LLM stages that already succeeded.
For clean_reprocess trigger, always starts from stage 2.
Returns the video_id.
"""
@ -1204,6 +1358,13 @@ def run_pipeline(video_id: str, trigger: str = "manual") -> str:
finally:
session.close()
if status == ProcessingStatus.complete:
logger.info(
"run_pipeline: video_id=%s already at status=%s, nothing to do.",
video_id, status.value,
)
return video_id
# Snapshot prior technique pages before pipeline resets key_moments
_snapshot_prior_pages(video_id)
@ -1211,40 +1372,39 @@ def run_pipeline(video_id: str, trigger: str = "manual") -> str:
run_id = _create_run(video_id, trigger)
logger.info("run_pipeline: created run_id=%s for video_id=%s (trigger=%s)", run_id, video_id, trigger)
# Build the chain based on current status
stages = []
if status in (ProcessingStatus.not_started, ProcessingStatus.queued):
stages = [
stage2_segmentation.s(video_id, run_id=run_id),
stage3_extraction.s(run_id=run_id), # receives video_id from previous
stage4_classification.s(run_id=run_id),
stage5_synthesis.s(run_id=run_id),
stage6_embed_and_index.s(run_id=run_id),
]
elif status == ProcessingStatus.processing:
stages = [
stage2_segmentation.s(video_id, run_id=run_id),
stage3_extraction.s(run_id=run_id),
stage4_classification.s(run_id=run_id),
stage5_synthesis.s(run_id=run_id),
stage6_embed_and_index.s(run_id=run_id),
]
elif status == ProcessingStatus.error:
stages = [
stage2_segmentation.s(video_id, run_id=run_id),
stage3_extraction.s(run_id=run_id),
stage4_classification.s(run_id=run_id),
stage5_synthesis.s(run_id=run_id),
stage6_embed_and_index.s(run_id=run_id),
]
elif status == ProcessingStatus.complete:
logger.info(
"run_pipeline: video_id=%s already at status=%s, nothing to do.",
video_id, status.value,
)
return video_id
# Determine which stages to run
resume_from_idx = 0 # Default: start from stage 2
if stages:
if trigger != "clean_reprocess" and status in (ProcessingStatus.processing, ProcessingStatus.error):
# Try to resume from where we left off
last_completed = _get_last_completed_stage(video_id)
if last_completed and last_completed in _PIPELINE_STAGES:
completed_idx = _PIPELINE_STAGES.index(last_completed)
resume_from_idx = completed_idx + 1
if resume_from_idx >= len(_PIPELINE_STAGES):
logger.info(
"run_pipeline: all stages already completed for video_id=%s",
video_id,
)
return video_id
stages_to_run = _PIPELINE_STAGES[resume_from_idx:]
logger.info(
"run_pipeline: video_id=%s will run stages: %s (resume_from_idx=%d)",
video_id, stages_to_run, resume_from_idx,
)
# Build the Celery chain — first stage gets video_id as arg,
# subsequent stages receive it from the previous stage's return value
celery_sigs = []
for i, stage_name in enumerate(stages_to_run):
task_func = _STAGE_TASKS[stage_name]
if i == 0:
celery_sigs.append(task_func.s(video_id, run_id=run_id))
else:
celery_sigs.append(task_func.s(run_id=run_id))
if celery_sigs:
# Mark as processing before dispatching
session = _get_sync_session()
try:
@ -1256,12 +1416,12 @@ def run_pipeline(video_id: str, trigger: str = "manual") -> str:
finally:
session.close()
pipeline = celery_chain(*stages)
pipeline = celery_chain(*celery_sigs)
error_cb = mark_pipeline_error.s(video_id, run_id=run_id)
pipeline.apply_async(link_error=error_cb)
logger.info(
"run_pipeline: dispatched %d stages for video_id=%s (run_id=%s)",
len(stages), video_id, run_id,
"run_pipeline: dispatched %d stages for video_id=%s (run_id=%s, starting at %s)",
len(celery_sigs), video_id, run_id, stages_to_run[0],
)
return video_id