When the LLM splits a category group into multiple technique pages, moments were blanket-linked to the last page in the loop, leaving all other pages as orphans with 0 key moments (48 out of 204 pages affected). Added moment_indices field to SynthesizedPage schema and synthesis prompt so the LLM explicitly declares which input moments each page covers. Stage 5 now uses these indices for targeted linking instead of the broken blanket approach. Tags are also computed per-page from linked moments only, fixing cross-contamination (e.g. "stereo imaging" tag appearing on gain staging pages). Deleted 48 orphan technique pages from the database. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1677 lines
64 KiB
Python
1677 lines
64 KiB
Python
"""Pipeline stage tasks (stages 2-5) and run_pipeline orchestrator.
|
|
|
|
Each stage reads from PostgreSQL via sync SQLAlchemy, loads its prompt
|
|
template from disk, calls the LLM client, parses the response, writes
|
|
results back, and updates processing_status on SourceVideo.
|
|
|
|
Celery tasks are synchronous — all DB access uses ``sqlalchemy.orm.Session``.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import subprocess
|
|
import time
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
|
|
import yaml
|
|
from celery import chain as celery_chain
|
|
from pydantic import ValidationError
|
|
from sqlalchemy import create_engine, func, select
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
|
|
from config import get_settings
|
|
from models import (
|
|
Creator,
|
|
KeyMoment,
|
|
KeyMomentContentType,
|
|
PipelineEvent,
|
|
ProcessingStatus,
|
|
SourceVideo,
|
|
TechniquePage,
|
|
TechniquePageVersion,
|
|
TranscriptSegment,
|
|
)
|
|
from pipeline.embedding_client import EmbeddingClient
|
|
from pipeline.llm_client import LLMClient, LLMResponse, estimate_max_tokens
|
|
from pipeline.qdrant_client import QdrantManager
|
|
from pipeline.schemas import (
|
|
ClassificationResult,
|
|
ExtractionResult,
|
|
SegmentationResult,
|
|
SynthesisResult,
|
|
)
|
|
from worker import celery_app
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LLMTruncationError(RuntimeError):
|
|
"""Raised when the LLM response was truncated (finish_reason=length)."""
|
|
pass
|
|
|
|
|
|
# ── Error status helper ──────────────────────────────────────────────────────
|
|
|
|
def _set_error_status(video_id: str, stage_name: str, error: Exception) -> None:
|
|
"""Mark a video as errored when a pipeline stage fails permanently."""
|
|
try:
|
|
session = _get_sync_session()
|
|
video = session.execute(
|
|
select(SourceVideo).where(SourceVideo.id == video_id)
|
|
).scalar_one_or_none()
|
|
if video:
|
|
video.processing_status = ProcessingStatus.error
|
|
session.commit()
|
|
session.close()
|
|
except Exception as mark_exc:
|
|
logger.error(
|
|
"Failed to mark video_id=%s as error after %s failure: %s",
|
|
video_id, stage_name, mark_exc,
|
|
)
|
|
|
|
|
|
# ── Pipeline event persistence ───────────────────────────────────────────────
|
|
|
|
def _emit_event(
|
|
video_id: str,
|
|
stage: str,
|
|
event_type: str,
|
|
*,
|
|
run_id: str | None = None,
|
|
prompt_tokens: int | None = None,
|
|
completion_tokens: int | None = None,
|
|
total_tokens: int | None = None,
|
|
model: str | None = None,
|
|
duration_ms: int | None = None,
|
|
payload: dict | None = None,
|
|
system_prompt_text: str | None = None,
|
|
user_prompt_text: str | None = None,
|
|
response_text: str | None = None,
|
|
) -> None:
|
|
"""Persist a pipeline event to the DB. Best-effort -- failures logged, not raised."""
|
|
try:
|
|
session = _get_sync_session()
|
|
try:
|
|
event = PipelineEvent(
|
|
video_id=video_id,
|
|
run_id=run_id,
|
|
stage=stage,
|
|
event_type=event_type,
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=total_tokens,
|
|
model=model,
|
|
duration_ms=duration_ms,
|
|
payload=payload,
|
|
system_prompt_text=system_prompt_text,
|
|
user_prompt_text=user_prompt_text,
|
|
response_text=response_text,
|
|
)
|
|
session.add(event)
|
|
session.commit()
|
|
finally:
|
|
session.close()
|
|
except Exception as exc:
|
|
logger.warning("Failed to emit pipeline event: %s", exc)
|
|
|
|
|
|
def _is_debug_mode() -> bool:
|
|
"""Check if debug mode is enabled via Redis. Falls back to config setting."""
|
|
try:
|
|
import redis
|
|
settings = get_settings()
|
|
r = redis.from_url(settings.redis_url)
|
|
val = r.get("chrysopedia:debug_mode")
|
|
r.close()
|
|
if val is not None:
|
|
return val.decode().lower() == "true"
|
|
except Exception:
|
|
pass
|
|
return getattr(get_settings(), "debug_mode", False)
|
|
|
|
|
|
def _make_llm_callback(
|
|
video_id: str,
|
|
stage: str,
|
|
system_prompt: str | None = None,
|
|
user_prompt: str | None = None,
|
|
run_id: str | None = None,
|
|
context_label: str | None = None,
|
|
):
|
|
"""Create an on_complete callback for LLMClient that emits llm_call events.
|
|
|
|
When debug mode is enabled, captures full system prompt, user prompt,
|
|
and response text on each llm_call event.
|
|
"""
|
|
debug = _is_debug_mode()
|
|
|
|
def callback(*, model=None, prompt_tokens=None, completion_tokens=None,
|
|
total_tokens=None, content=None, finish_reason=None,
|
|
is_fallback=False, **_kwargs):
|
|
# Truncate content for storage — keep first 2000 chars for debugging
|
|
truncated = content[:2000] if content and len(content) > 2000 else content
|
|
_emit_event(
|
|
video_id=video_id,
|
|
stage=stage,
|
|
event_type="llm_call",
|
|
run_id=run_id,
|
|
model=model,
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=total_tokens,
|
|
payload={
|
|
"content_preview": truncated,
|
|
"content_length": len(content) if content else 0,
|
|
"finish_reason": finish_reason,
|
|
"is_fallback": is_fallback,
|
|
**({"context": context_label} if context_label else {}),
|
|
},
|
|
system_prompt_text=system_prompt if debug else None,
|
|
user_prompt_text=user_prompt if debug else None,
|
|
response_text=content if debug else None,
|
|
)
|
|
return callback
|
|
|
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
|
|
|
_engine = None
|
|
_SessionLocal = None
|
|
|
|
|
|
def _get_sync_engine():
|
|
"""Create a sync SQLAlchemy engine, converting the async URL if needed."""
|
|
global _engine
|
|
if _engine is None:
|
|
settings = get_settings()
|
|
url = settings.database_url
|
|
# Convert async driver to sync driver
|
|
url = url.replace("postgresql+asyncpg://", "postgresql+psycopg2://")
|
|
_engine = create_engine(url, pool_pre_ping=True, pool_size=5, max_overflow=10)
|
|
return _engine
|
|
|
|
|
|
def _get_sync_session() -> Session:
|
|
"""Create a sync SQLAlchemy session for Celery tasks."""
|
|
global _SessionLocal
|
|
if _SessionLocal is None:
|
|
_SessionLocal = sessionmaker(bind=_get_sync_engine())
|
|
return _SessionLocal()
|
|
|
|
|
|
def _load_prompt(template_name: str) -> str:
|
|
"""Read a prompt template from the prompts directory.
|
|
|
|
Raises FileNotFoundError if the template does not exist.
|
|
"""
|
|
settings = get_settings()
|
|
path = Path(settings.prompts_path) / template_name
|
|
if not path.exists():
|
|
logger.error("Prompt template not found: %s", path)
|
|
raise FileNotFoundError(f"Prompt template not found: {path}")
|
|
return path.read_text(encoding="utf-8")
|
|
|
|
|
|
def _get_llm_client() -> LLMClient:
|
|
"""Return an LLMClient configured from settings."""
|
|
return LLMClient(get_settings())
|
|
|
|
|
|
def _get_stage_config(stage_num: int) -> tuple[str | None, str]:
|
|
"""Return (model_override, modality) for a pipeline stage.
|
|
|
|
Reads stage-specific config from Settings. If the stage-specific model
|
|
is None/empty, returns None (LLMClient will use its default). If the
|
|
stage-specific modality is unset, defaults to "chat".
|
|
"""
|
|
settings = get_settings()
|
|
model = getattr(settings, f"llm_stage{stage_num}_model", None) or None
|
|
modality = getattr(settings, f"llm_stage{stage_num}_modality", None) or "chat"
|
|
return model, modality
|
|
|
|
|
|
def _load_canonical_tags() -> dict:
|
|
"""Load canonical tag taxonomy from config/canonical_tags.yaml."""
|
|
# Walk up from backend/ to find config/
|
|
candidates = [
|
|
Path("config/canonical_tags.yaml"),
|
|
Path("../config/canonical_tags.yaml"),
|
|
]
|
|
for candidate in candidates:
|
|
if candidate.exists():
|
|
with open(candidate, encoding="utf-8") as f:
|
|
return yaml.safe_load(f)
|
|
raise FileNotFoundError(
|
|
"canonical_tags.yaml not found. Searched: " + ", ".join(str(c) for c in candidates)
|
|
)
|
|
|
|
|
|
def _format_taxonomy_for_prompt(tags_data: dict) -> str:
|
|
"""Format the canonical tags taxonomy as readable text for the LLM prompt."""
|
|
lines = []
|
|
for cat in tags_data.get("categories", []):
|
|
lines.append(f"Category: {cat['name']}")
|
|
lines.append(f" Description: {cat['description']}")
|
|
lines.append(f" Sub-topics: {', '.join(cat.get('sub_topics', []))}")
|
|
lines.append("")
|
|
return "\n".join(lines)
|
|
|
|
|
|
def _safe_parse_llm_response(
|
|
raw,
|
|
model_cls,
|
|
llm: LLMClient,
|
|
system_prompt: str,
|
|
user_prompt: str,
|
|
modality: str = "chat",
|
|
model_override: str | None = None,
|
|
max_tokens: int | None = None,
|
|
):
|
|
"""Parse LLM response with truncation detection and one retry on failure.
|
|
|
|
If the response was truncated (finish_reason=length), raises
|
|
LLMTruncationError immediately — retrying with a JSON nudge would only
|
|
make things worse by adding tokens to an already-too-large prompt.
|
|
|
|
For non-truncation parse failures: retry once with a JSON nudge, then
|
|
raise on second failure.
|
|
"""
|
|
# Check for truncation before attempting parse
|
|
is_truncated = isinstance(raw, LLMResponse) and raw.truncated
|
|
if is_truncated:
|
|
logger.warning(
|
|
"LLM response truncated (finish=length) for %s. "
|
|
"prompt_tokens=%s, completion_tokens=%s. Will not retry with nudge.",
|
|
model_cls.__name__,
|
|
getattr(raw, "prompt_tokens", "?"),
|
|
getattr(raw, "completion_tokens", "?"),
|
|
)
|
|
|
|
try:
|
|
return llm.parse_response(raw, model_cls)
|
|
except (ValidationError, ValueError, json.JSONDecodeError) as exc:
|
|
if is_truncated:
|
|
raise LLMTruncationError(
|
|
f"LLM output truncated for {model_cls.__name__}: "
|
|
f"prompt_tokens={getattr(raw, 'prompt_tokens', '?')}, "
|
|
f"completion_tokens={getattr(raw, 'completion_tokens', '?')}. "
|
|
f"Response too large for model context window."
|
|
) from exc
|
|
|
|
logger.warning(
|
|
"First parse attempt failed for %s (%s). Retrying with JSON nudge. "
|
|
"Raw response (first 500 chars): %.500s",
|
|
model_cls.__name__,
|
|
type(exc).__name__,
|
|
raw,
|
|
)
|
|
# Retry with explicit JSON instruction
|
|
nudge_prompt = user_prompt + "\n\nIMPORTANT: Output ONLY valid JSON. No markdown, no explanation."
|
|
retry_raw = llm.complete(
|
|
system_prompt, nudge_prompt, response_model=model_cls,
|
|
modality=modality, model_override=model_override,
|
|
max_tokens=max_tokens,
|
|
)
|
|
return llm.parse_response(retry_raw, model_cls)
|
|
|
|
|
|
# ── Stage 2: Segmentation ───────────────────────────────────────────────────
|
|
|
|
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
|
|
def stage2_segmentation(self, video_id: str, run_id: str | None = None) -> str:
|
|
"""Analyze transcript segments and identify topic boundaries.
|
|
|
|
Loads all TranscriptSegment rows for the video, sends them to the LLM
|
|
for topic boundary detection, and updates topic_label on each segment.
|
|
|
|
Returns the video_id for chain compatibility.
|
|
"""
|
|
start = time.monotonic()
|
|
logger.info("Stage 2 (segmentation) starting for video_id=%s", video_id)
|
|
_emit_event(video_id, "stage2_segmentation", "start", run_id=run_id)
|
|
|
|
session = _get_sync_session()
|
|
try:
|
|
# Load segments ordered by index
|
|
segments = (
|
|
session.execute(
|
|
select(TranscriptSegment)
|
|
.where(TranscriptSegment.source_video_id == video_id)
|
|
.order_by(TranscriptSegment.segment_index)
|
|
)
|
|
.scalars()
|
|
.all()
|
|
)
|
|
|
|
if not segments:
|
|
logger.info("Stage 2: No segments found for video_id=%s, skipping.", video_id)
|
|
return video_id
|
|
|
|
# Build transcript text with indices for the LLM
|
|
transcript_lines = []
|
|
for seg in segments:
|
|
transcript_lines.append(
|
|
f"[{seg.segment_index}] ({seg.start_time:.1f}s - {seg.end_time:.1f}s) {seg.text}"
|
|
)
|
|
transcript_text = "\n".join(transcript_lines)
|
|
|
|
# Load prompt and call LLM
|
|
system_prompt = _load_prompt("stage2_segmentation.txt")
|
|
user_prompt = f"<transcript>\n{transcript_text}\n</transcript>"
|
|
|
|
llm = _get_llm_client()
|
|
model_override, modality = _get_stage_config(2)
|
|
hard_limit = get_settings().llm_max_tokens_hard_limit
|
|
max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage2_segmentation", hard_limit=hard_limit)
|
|
logger.info("Stage 2 using model=%s, modality=%s, max_tokens=%d", model_override or "default", modality, max_tokens)
|
|
raw = llm.complete(system_prompt, user_prompt, response_model=SegmentationResult, on_complete=_make_llm_callback(video_id, "stage2_segmentation", system_prompt=system_prompt, user_prompt=user_prompt, run_id=run_id),
|
|
modality=modality, model_override=model_override, max_tokens=max_tokens)
|
|
result = _safe_parse_llm_response(raw, SegmentationResult, llm, system_prompt, user_prompt,
|
|
modality=modality, model_override=model_override, max_tokens=max_tokens)
|
|
|
|
# Update topic_label on each segment row
|
|
seg_by_index = {s.segment_index: s for s in segments}
|
|
for topic_seg in result.segments:
|
|
for idx in range(topic_seg.start_index, topic_seg.end_index + 1):
|
|
if idx in seg_by_index:
|
|
seg_by_index[idx].topic_label = topic_seg.topic_label
|
|
|
|
session.commit()
|
|
elapsed = time.monotonic() - start
|
|
_emit_event(video_id, "stage2_segmentation", "complete", run_id=run_id)
|
|
logger.info(
|
|
"Stage 2 (segmentation) completed for video_id=%s in %.1fs — %d topic groups found",
|
|
video_id, elapsed, len(result.segments),
|
|
)
|
|
return video_id
|
|
|
|
except FileNotFoundError:
|
|
raise # Don't retry missing prompt files
|
|
except Exception as exc:
|
|
session.rollback()
|
|
_emit_event(video_id, "stage2_segmentation", "error", run_id=run_id, payload={"error": str(exc)})
|
|
logger.error("Stage 2 failed for video_id=%s: %s", video_id, exc)
|
|
raise self.retry(exc=exc)
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
# ── Stage 3: Extraction ─────────────────────────────────────────────────────
|
|
|
|
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
|
|
def stage3_extraction(self, video_id: str, run_id: str | None = None) -> str:
|
|
"""Extract key moments from each topic segment group.
|
|
|
|
Groups segments by topic_label, calls the LLM for each group to extract
|
|
moments, creates KeyMoment rows, and sets processing_status=extracted.
|
|
|
|
Returns the video_id for chain compatibility.
|
|
"""
|
|
start = time.monotonic()
|
|
logger.info("Stage 3 (extraction) starting for video_id=%s", video_id)
|
|
_emit_event(video_id, "stage3_extraction", "start", run_id=run_id)
|
|
|
|
session = _get_sync_session()
|
|
try:
|
|
# Load segments with topic labels
|
|
segments = (
|
|
session.execute(
|
|
select(TranscriptSegment)
|
|
.where(TranscriptSegment.source_video_id == video_id)
|
|
.order_by(TranscriptSegment.segment_index)
|
|
)
|
|
.scalars()
|
|
.all()
|
|
)
|
|
|
|
if not segments:
|
|
logger.info("Stage 3: No segments found for video_id=%s, skipping.", video_id)
|
|
return video_id
|
|
|
|
# Group segments by topic_label
|
|
groups: dict[str, list[TranscriptSegment]] = defaultdict(list)
|
|
for seg in segments:
|
|
label = seg.topic_label or "unlabeled"
|
|
groups[label].append(seg)
|
|
|
|
system_prompt = _load_prompt("stage3_extraction.txt")
|
|
llm = _get_llm_client()
|
|
model_override, modality = _get_stage_config(3)
|
|
hard_limit = get_settings().llm_max_tokens_hard_limit
|
|
logger.info("Stage 3 using model=%s, modality=%s", model_override or "default", modality)
|
|
total_moments = 0
|
|
|
|
for topic_label, group_segs in groups.items():
|
|
# Build segment text for this group
|
|
seg_lines = []
|
|
for seg in group_segs:
|
|
seg_lines.append(
|
|
f"({seg.start_time:.1f}s - {seg.end_time:.1f}s) {seg.text}"
|
|
)
|
|
segment_text = "\n".join(seg_lines)
|
|
|
|
user_prompt = (
|
|
f"Topic: {topic_label}\n\n"
|
|
f"<segment>\n{segment_text}\n</segment>"
|
|
)
|
|
|
|
max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage3_extraction", hard_limit=hard_limit)
|
|
raw = llm.complete(system_prompt, user_prompt, response_model=ExtractionResult, on_complete=_make_llm_callback(video_id, "stage3_extraction", system_prompt=system_prompt, user_prompt=user_prompt, run_id=run_id, context_label=topic_label),
|
|
modality=modality, model_override=model_override, max_tokens=max_tokens)
|
|
result = _safe_parse_llm_response(raw, ExtractionResult, llm, system_prompt, user_prompt,
|
|
modality=modality, model_override=model_override, max_tokens=max_tokens)
|
|
|
|
# Create KeyMoment rows
|
|
for moment in result.moments:
|
|
# Validate content_type against enum
|
|
try:
|
|
ct = KeyMomentContentType(moment.content_type)
|
|
except ValueError:
|
|
ct = KeyMomentContentType.technique
|
|
|
|
km = KeyMoment(
|
|
source_video_id=video_id,
|
|
title=moment.title,
|
|
summary=moment.summary,
|
|
start_time=moment.start_time,
|
|
end_time=moment.end_time,
|
|
content_type=ct,
|
|
plugins=moment.plugins if moment.plugins else None,
|
|
raw_transcript=moment.raw_transcript or None,
|
|
)
|
|
session.add(km)
|
|
total_moments += 1
|
|
|
|
session.commit()
|
|
elapsed = time.monotonic() - start
|
|
_emit_event(video_id, "stage3_extraction", "complete", run_id=run_id)
|
|
logger.info(
|
|
"Stage 3 (extraction) completed for video_id=%s in %.1fs — %d moments created",
|
|
video_id, elapsed, total_moments,
|
|
)
|
|
return video_id
|
|
|
|
except FileNotFoundError:
|
|
raise
|
|
except Exception as exc:
|
|
session.rollback()
|
|
_emit_event(video_id, "stage3_extraction", "error", run_id=run_id, payload={"error": str(exc)})
|
|
logger.error("Stage 3 failed for video_id=%s: %s", video_id, exc)
|
|
raise self.retry(exc=exc)
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
# ── Stage 4: Classification ─────────────────────────────────────────────────
|
|
|
|
# Maximum moments per classification batch. Keeps each LLM call well within
|
|
# context window limits. Batches are classified independently and merged.
|
|
_STAGE4_BATCH_SIZE = 20
|
|
|
|
|
|
def _classify_moment_batch(
|
|
moments_batch: list,
|
|
batch_offset: int,
|
|
taxonomy_text: str,
|
|
system_prompt: str,
|
|
llm: LLMClient,
|
|
model_override: str | None,
|
|
modality: str,
|
|
hard_limit: int,
|
|
video_id: str,
|
|
run_id: str | None,
|
|
) -> ClassificationResult:
|
|
"""Classify a single batch of moments. Raises on failure."""
|
|
moments_lines = []
|
|
for i, m in enumerate(moments_batch):
|
|
moments_lines.append(
|
|
f"[{i}] Title: {m.title}\n"
|
|
f" Summary: {m.summary}\n"
|
|
f" Content type: {m.content_type.value}\n"
|
|
f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}"
|
|
)
|
|
moments_text = "\n\n".join(moments_lines)
|
|
|
|
user_prompt = (
|
|
f"<taxonomy>\n{taxonomy_text}\n</taxonomy>\n\n"
|
|
f"<moments>\n{moments_text}\n</moments>"
|
|
)
|
|
|
|
max_tokens = estimate_max_tokens(
|
|
system_prompt, user_prompt,
|
|
stage="stage4_classification", hard_limit=hard_limit,
|
|
)
|
|
batch_label = f"batch {batch_offset // _STAGE4_BATCH_SIZE + 1} (moments {batch_offset}-{batch_offset + len(moments_batch) - 1})"
|
|
logger.info(
|
|
"Stage 4 classifying %s, max_tokens=%d",
|
|
batch_label, max_tokens,
|
|
)
|
|
raw = llm.complete(
|
|
system_prompt, user_prompt,
|
|
response_model=ClassificationResult,
|
|
on_complete=_make_llm_callback(
|
|
video_id, "stage4_classification",
|
|
system_prompt=system_prompt, user_prompt=user_prompt,
|
|
run_id=run_id, context_label=batch_label,
|
|
),
|
|
modality=modality, model_override=model_override,
|
|
max_tokens=max_tokens,
|
|
)
|
|
return _safe_parse_llm_response(
|
|
raw, ClassificationResult, llm, system_prompt, user_prompt,
|
|
modality=modality, model_override=model_override,
|
|
max_tokens=max_tokens,
|
|
)
|
|
|
|
|
|
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
|
|
def stage4_classification(self, video_id: str, run_id: str | None = None) -> str:
|
|
"""Classify key moments against the canonical tag taxonomy.
|
|
|
|
Loads all KeyMoment rows for the video, sends them to the LLM with the
|
|
canonical taxonomy, and stores classification results in Redis for
|
|
stage 5 consumption. Updates content_type if the classifier overrides it.
|
|
|
|
For large moment sets, automatically batches into groups of
|
|
_STAGE4_BATCH_SIZE to stay within model context window limits.
|
|
|
|
Stage 4 does NOT change processing_status.
|
|
|
|
Returns the video_id for chain compatibility.
|
|
"""
|
|
start = time.monotonic()
|
|
logger.info("Stage 4 (classification) starting for video_id=%s", video_id)
|
|
_emit_event(video_id, "stage4_classification", "start", run_id=run_id)
|
|
|
|
session = _get_sync_session()
|
|
try:
|
|
# Load key moments
|
|
moments = (
|
|
session.execute(
|
|
select(KeyMoment)
|
|
.where(KeyMoment.source_video_id == video_id)
|
|
.order_by(KeyMoment.start_time)
|
|
)
|
|
.scalars()
|
|
.all()
|
|
)
|
|
|
|
if not moments:
|
|
logger.info("Stage 4: No moments found for video_id=%s, skipping.", video_id)
|
|
_store_classification_data(video_id, [])
|
|
return video_id
|
|
|
|
# Load canonical tags
|
|
tags_data = _load_canonical_tags()
|
|
taxonomy_text = _format_taxonomy_for_prompt(tags_data)
|
|
|
|
system_prompt = _load_prompt("stage4_classification.txt")
|
|
llm = _get_llm_client()
|
|
model_override, modality = _get_stage_config(4)
|
|
hard_limit = get_settings().llm_max_tokens_hard_limit
|
|
|
|
# Batch moments for classification
|
|
all_classifications = []
|
|
for batch_start in range(0, len(moments), _STAGE4_BATCH_SIZE):
|
|
batch = moments[batch_start:batch_start + _STAGE4_BATCH_SIZE]
|
|
result = _classify_moment_batch(
|
|
batch, batch_start, taxonomy_text, system_prompt,
|
|
llm, model_override, modality, hard_limit,
|
|
video_id, run_id,
|
|
)
|
|
# Reindex: batch uses 0-based indices, remap to global indices
|
|
for cls in result.classifications:
|
|
cls.moment_index += batch_start
|
|
all_classifications.extend(result.classifications)
|
|
|
|
# Apply content_type overrides and prepare classification data for stage 5
|
|
classification_data = []
|
|
|
|
for cls in all_classifications:
|
|
if 0 <= cls.moment_index < len(moments):
|
|
moment = moments[cls.moment_index]
|
|
|
|
# Apply content_type override if provided
|
|
if cls.content_type_override:
|
|
try:
|
|
moment.content_type = KeyMomentContentType(cls.content_type_override)
|
|
except ValueError:
|
|
pass
|
|
|
|
classification_data.append({
|
|
"moment_id": str(moment.id),
|
|
"topic_category": cls.topic_category.strip().title(),
|
|
"topic_tags": cls.topic_tags,
|
|
})
|
|
|
|
session.commit()
|
|
|
|
# Store classification data in Redis for stage 5
|
|
_store_classification_data(video_id, classification_data)
|
|
|
|
elapsed = time.monotonic() - start
|
|
num_batches = (len(moments) + _STAGE4_BATCH_SIZE - 1) // _STAGE4_BATCH_SIZE
|
|
_emit_event(video_id, "stage4_classification", "complete", run_id=run_id)
|
|
logger.info(
|
|
"Stage 4 (classification) completed for video_id=%s in %.1fs — "
|
|
"%d moments classified in %d batch(es)",
|
|
video_id, elapsed, len(classification_data), num_batches,
|
|
)
|
|
return video_id
|
|
|
|
except FileNotFoundError:
|
|
raise
|
|
except Exception as exc:
|
|
session.rollback()
|
|
_emit_event(video_id, "stage4_classification", "error", run_id=run_id, payload={"error": str(exc)})
|
|
logger.error("Stage 4 failed for video_id=%s: %s", video_id, exc)
|
|
raise self.retry(exc=exc)
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
def _store_classification_data(video_id: str, data: list[dict]) -> None:
|
|
"""Store classification data in Redis for cross-stage communication."""
|
|
import redis
|
|
|
|
settings = get_settings()
|
|
r = redis.Redis.from_url(settings.redis_url)
|
|
key = f"chrysopedia:classification:{video_id}"
|
|
r.set(key, json.dumps(data), ex=86400) # Expire after 24 hours
|
|
|
|
|
|
def _load_classification_data(video_id: str) -> list[dict]:
|
|
"""Load classification data from Redis."""
|
|
import redis
|
|
|
|
settings = get_settings()
|
|
r = redis.Redis.from_url(settings.redis_url)
|
|
key = f"chrysopedia:classification:{video_id}"
|
|
raw = r.get(key)
|
|
if raw is None:
|
|
return []
|
|
return json.loads(raw)
|
|
|
|
|
|
|
|
def _get_git_commit_sha() -> str:
|
|
"""Resolve the git commit SHA used to build this image.
|
|
|
|
Resolution order:
|
|
1. /app/.git-commit file (written during Docker build)
|
|
2. git rev-parse --short HEAD (local dev)
|
|
3. GIT_COMMIT_SHA env var / config setting
|
|
4. "unknown"
|
|
"""
|
|
# Docker build artifact
|
|
git_commit_file = Path("/app/.git-commit")
|
|
if git_commit_file.exists():
|
|
sha = git_commit_file.read_text(encoding="utf-8").strip()
|
|
if sha and sha != "unknown":
|
|
return sha
|
|
|
|
# Local dev — run git
|
|
try:
|
|
result = subprocess.run(
|
|
["git", "rev-parse", "--short", "HEAD"],
|
|
capture_output=True, text=True, timeout=5,
|
|
)
|
|
if result.returncode == 0 and result.stdout.strip():
|
|
return result.stdout.strip()
|
|
except (FileNotFoundError, subprocess.TimeoutExpired):
|
|
pass
|
|
|
|
# Config / env var fallback
|
|
try:
|
|
sha = get_settings().git_commit_sha
|
|
if sha and sha != "unknown":
|
|
return sha
|
|
except Exception:
|
|
pass
|
|
|
|
return "unknown"
|
|
|
|
def _capture_pipeline_metadata() -> dict:
|
|
"""Capture current pipeline configuration for version metadata.
|
|
|
|
Returns a dict with model names, prompt file SHA-256 hashes, and stage
|
|
modality settings. Handles missing prompt files gracefully.
|
|
"""
|
|
settings = get_settings()
|
|
prompts_path = Path(settings.prompts_path)
|
|
|
|
# Hash each prompt template file
|
|
prompt_hashes: dict[str, str] = {}
|
|
prompt_files = [
|
|
"stage2_segmentation.txt",
|
|
"stage3_extraction.txt",
|
|
"stage4_classification.txt",
|
|
"stage5_synthesis.txt",
|
|
]
|
|
for filename in prompt_files:
|
|
filepath = prompts_path / filename
|
|
try:
|
|
content = filepath.read_bytes()
|
|
prompt_hashes[filename] = hashlib.sha256(content).hexdigest()
|
|
except FileNotFoundError:
|
|
logger.warning("Prompt file not found for metadata capture: %s", filepath)
|
|
prompt_hashes[filename] = ""
|
|
except OSError as exc:
|
|
logger.warning("Could not read prompt file %s: %s", filepath, exc)
|
|
prompt_hashes[filename] = ""
|
|
|
|
return {
|
|
"git_commit_sha": _get_git_commit_sha(),
|
|
"models": {
|
|
"stage2": settings.llm_stage2_model,
|
|
"stage3": settings.llm_stage3_model,
|
|
"stage4": settings.llm_stage4_model,
|
|
"stage5": settings.llm_stage5_model,
|
|
"embedding": settings.embedding_model,
|
|
},
|
|
"modalities": {
|
|
"stage2": settings.llm_stage2_modality,
|
|
"stage3": settings.llm_stage3_modality,
|
|
"stage4": settings.llm_stage4_modality,
|
|
"stage5": settings.llm_stage5_modality,
|
|
},
|
|
"prompt_hashes": prompt_hashes,
|
|
}
|
|
|
|
|
|
# ── Stage 5: Synthesis ───────────────────────────────────────────────────────
|
|
|
|
|
|
def _compute_page_tags(
|
|
moment_indices: list[int],
|
|
moment_group: list[tuple],
|
|
all_tags: set[str],
|
|
) -> list[str] | None:
|
|
"""Compute tags for a specific page from its linked moment indices.
|
|
|
|
If moment_indices are available, collects tags only from those moments.
|
|
Falls back to all_tags for the category group if no indices provided.
|
|
"""
|
|
if not moment_indices:
|
|
return list(all_tags) if all_tags else None
|
|
|
|
page_tags: set[str] = set()
|
|
for idx in moment_indices:
|
|
if 0 <= idx < len(moment_group):
|
|
_, cls_info = moment_group[idx]
|
|
page_tags.update(cls_info.get("topic_tags", []))
|
|
|
|
return list(page_tags) if page_tags else None
|
|
|
|
|
|
def _build_moments_text(
|
|
moment_group: list[tuple[KeyMoment, dict]],
|
|
category: str,
|
|
) -> tuple[str, set[str]]:
|
|
"""Build the moments prompt text and collect all tags for a group of moments.
|
|
|
|
Returns (moments_text, all_tags).
|
|
"""
|
|
moments_lines = []
|
|
all_tags: set[str] = set()
|
|
for i, (m, cls_info) in enumerate(moment_group):
|
|
tags = cls_info.get("topic_tags", [])
|
|
all_tags.update(tags)
|
|
moments_lines.append(
|
|
f"[{i}] Title: {m.title}\n"
|
|
f" Summary: {m.summary}\n"
|
|
f" Content type: {m.content_type.value}\n"
|
|
f" Time: {m.start_time:.1f}s - {m.end_time:.1f}s\n"
|
|
f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}\n"
|
|
f" Category: {category}\n"
|
|
f" Tags: {', '.join(tags) if tags else 'none'}\n"
|
|
f" Transcript excerpt: {(m.raw_transcript or '')[:300]}"
|
|
)
|
|
return "\n\n".join(moments_lines), all_tags
|
|
|
|
|
|
def _synthesize_chunk(
|
|
chunk: list[tuple[KeyMoment, dict]],
|
|
category: str,
|
|
creator_name: str,
|
|
system_prompt: str,
|
|
llm: LLMClient,
|
|
model_override: str | None,
|
|
modality: str,
|
|
hard_limit: int,
|
|
video_id: str,
|
|
run_id: str | None,
|
|
chunk_label: str,
|
|
) -> SynthesisResult:
|
|
"""Run a single synthesis LLM call for a chunk of moments.
|
|
|
|
Returns the parsed SynthesisResult.
|
|
"""
|
|
moments_text, _ = _build_moments_text(chunk, category)
|
|
user_prompt = f"<creator>{creator_name}</creator>\n<moments>\n{moments_text}\n</moments>"
|
|
|
|
estimated_input = estimate_max_tokens(system_prompt, user_prompt, stage="stage5_synthesis", hard_limit=hard_limit)
|
|
logger.info(
|
|
"Stage 5: Synthesizing %s — %d moments, max_tokens=%d",
|
|
chunk_label, len(chunk), estimated_input,
|
|
)
|
|
|
|
raw = llm.complete(
|
|
system_prompt, user_prompt, response_model=SynthesisResult,
|
|
on_complete=_make_llm_callback(
|
|
video_id, "stage5_synthesis",
|
|
system_prompt=system_prompt, user_prompt=user_prompt,
|
|
run_id=run_id, context_label=chunk_label,
|
|
),
|
|
modality=modality, model_override=model_override, max_tokens=estimated_input,
|
|
)
|
|
return _safe_parse_llm_response(
|
|
raw, SynthesisResult, llm, system_prompt, user_prompt,
|
|
modality=modality, model_override=model_override, max_tokens=estimated_input,
|
|
)
|
|
|
|
|
|
def _slug_base(slug: str) -> str:
|
|
"""Extract the slug prefix before the creator name suffix for merge grouping.
|
|
|
|
E.g. 'wavetable-sound-design-copycatt' → 'wavetable-sound-design'
|
|
Also normalizes casing.
|
|
"""
|
|
return slug.lower().strip()
|
|
|
|
|
|
def _merge_pages_by_slug(
|
|
all_pages: list,
|
|
creator_name: str,
|
|
llm: LLMClient,
|
|
model_override: str | None,
|
|
modality: str,
|
|
hard_limit: int,
|
|
video_id: str,
|
|
run_id: str | None,
|
|
) -> list:
|
|
"""Detect pages with the same slug across chunks and merge them via LLM.
|
|
|
|
Pages with unique slugs pass through unchanged. Pages sharing a slug
|
|
get sent to a merge prompt that combines them into one cohesive page.
|
|
|
|
Returns the final list of SynthesizedPage objects.
|
|
"""
|
|
from pipeline.schemas import SynthesizedPage
|
|
|
|
# Group pages by slug
|
|
by_slug: dict[str, list] = defaultdict(list)
|
|
for page in all_pages:
|
|
by_slug[_slug_base(page.slug)].append(page)
|
|
|
|
final_pages = []
|
|
for slug, pages_group in by_slug.items():
|
|
if len(pages_group) == 1:
|
|
# Unique slug — no merge needed
|
|
final_pages.append(pages_group[0])
|
|
continue
|
|
|
|
# Multiple pages share this slug — merge via LLM
|
|
logger.info(
|
|
"Stage 5: Merging %d partial pages with slug '%s' for video_id=%s",
|
|
len(pages_group), slug, video_id,
|
|
)
|
|
|
|
# Serialize partial pages to JSON for the merge prompt
|
|
pages_json = json.dumps(
|
|
[p.model_dump() for p in pages_group],
|
|
indent=2, ensure_ascii=False,
|
|
)
|
|
|
|
merge_system_prompt = _load_prompt("stage5_merge.txt")
|
|
merge_user_prompt = f"<creator>{creator_name}</creator>\n<pages>\n{pages_json}\n</pages>"
|
|
|
|
max_tokens = estimate_max_tokens(
|
|
merge_system_prompt, merge_user_prompt,
|
|
stage="stage5_synthesis", hard_limit=hard_limit,
|
|
)
|
|
logger.info(
|
|
"Stage 5: Merge call for slug '%s' — %d partial pages, max_tokens=%d",
|
|
slug, len(pages_group), max_tokens,
|
|
)
|
|
|
|
raw = llm.complete(
|
|
merge_system_prompt, merge_user_prompt,
|
|
response_model=SynthesisResult,
|
|
on_complete=_make_llm_callback(
|
|
video_id, "stage5_synthesis",
|
|
system_prompt=merge_system_prompt,
|
|
user_prompt=merge_user_prompt,
|
|
run_id=run_id, context_label=f"merge:{slug}",
|
|
),
|
|
modality=modality, model_override=model_override,
|
|
max_tokens=max_tokens,
|
|
)
|
|
merge_result = _safe_parse_llm_response(
|
|
raw, SynthesisResult, llm,
|
|
merge_system_prompt, merge_user_prompt,
|
|
modality=modality, model_override=model_override,
|
|
max_tokens=max_tokens,
|
|
)
|
|
|
|
if merge_result.pages:
|
|
final_pages.extend(merge_result.pages)
|
|
logger.info(
|
|
"Stage 5: Merge produced %d page(s) for slug '%s'",
|
|
len(merge_result.pages), slug,
|
|
)
|
|
else:
|
|
# Merge returned nothing — fall back to keeping the partials
|
|
logger.warning(
|
|
"Stage 5: Merge returned 0 pages for slug '%s', keeping %d partials",
|
|
slug, len(pages_group),
|
|
)
|
|
final_pages.extend(pages_group)
|
|
|
|
return final_pages
|
|
|
|
|
|
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
|
|
def stage5_synthesis(self, video_id: str, run_id: str | None = None) -> str:
|
|
"""Synthesize technique pages from classified key moments.
|
|
|
|
Groups moments by (creator, topic_category), calls the LLM to synthesize
|
|
each group into a TechniquePage, creates/updates page rows, and links
|
|
KeyMoments to their TechniquePage.
|
|
|
|
For large category groups (exceeding synthesis_chunk_size), moments are
|
|
split into chronological chunks, synthesized independently, then pages
|
|
with matching slugs are merged via a dedicated merge LLM call.
|
|
|
|
Sets processing_status to 'complete'.
|
|
|
|
Returns the video_id for chain compatibility.
|
|
"""
|
|
start = time.monotonic()
|
|
logger.info("Stage 5 (synthesis) starting for video_id=%s", video_id)
|
|
_emit_event(video_id, "stage5_synthesis", "start", run_id=run_id)
|
|
|
|
settings = get_settings()
|
|
chunk_size = settings.synthesis_chunk_size
|
|
session = _get_sync_session()
|
|
try:
|
|
# Load video and moments
|
|
video = session.execute(
|
|
select(SourceVideo).where(SourceVideo.id == video_id)
|
|
).scalar_one()
|
|
|
|
moments = (
|
|
session.execute(
|
|
select(KeyMoment)
|
|
.where(KeyMoment.source_video_id == video_id)
|
|
.order_by(KeyMoment.start_time)
|
|
)
|
|
.scalars()
|
|
.all()
|
|
)
|
|
|
|
# Resolve creator name for the LLM prompt
|
|
creator = session.execute(
|
|
select(Creator).where(Creator.id == video.creator_id)
|
|
).scalar_one_or_none()
|
|
creator_name = creator.name if creator else "Unknown"
|
|
|
|
if not moments:
|
|
logger.info("Stage 5: No moments found for video_id=%s, skipping.", video_id)
|
|
return video_id
|
|
|
|
# Load classification data from stage 4
|
|
classification_data = _load_classification_data(video_id)
|
|
cls_by_moment_id = {c["moment_id"]: c for c in classification_data}
|
|
|
|
# Group moments by topic_category (from classification)
|
|
# Normalize category casing to prevent near-duplicate groups
|
|
# (e.g., "Sound design" vs "Sound Design")
|
|
groups: dict[str, list[tuple[KeyMoment, dict]]] = defaultdict(list)
|
|
for moment in moments:
|
|
cls_info = cls_by_moment_id.get(str(moment.id), {})
|
|
category = cls_info.get("topic_category", "Uncategorized").strip().title()
|
|
groups[category].append((moment, cls_info))
|
|
|
|
system_prompt = _load_prompt("stage5_synthesis.txt")
|
|
llm = _get_llm_client()
|
|
model_override, modality = _get_stage_config(5)
|
|
hard_limit = settings.llm_max_tokens_hard_limit
|
|
logger.info("Stage 5 using model=%s, modality=%s", model_override or "default", modality)
|
|
pages_created = 0
|
|
|
|
for category, moment_group in groups.items():
|
|
# Collect all tags across the full group (used for DB writes later)
|
|
all_tags: set[str] = set()
|
|
for _, cls_info in moment_group:
|
|
all_tags.update(cls_info.get("topic_tags", []))
|
|
|
|
# ── Chunked synthesis ────────────────────────────────────
|
|
if len(moment_group) <= chunk_size:
|
|
# Small group — single LLM call (original behavior)
|
|
result = _synthesize_chunk(
|
|
moment_group, category, creator_name,
|
|
system_prompt, llm, model_override, modality, hard_limit,
|
|
video_id, run_id, f"category:{category}",
|
|
)
|
|
synthesized_pages = list(result.pages)
|
|
logger.info(
|
|
"Stage 5: category '%s' — %d moments, %d page(s) from single call",
|
|
category, len(moment_group), len(synthesized_pages),
|
|
)
|
|
else:
|
|
# Large group — split into chunks, synthesize each, then merge
|
|
num_chunks = (len(moment_group) + chunk_size - 1) // chunk_size
|
|
logger.info(
|
|
"Stage 5: category '%s' has %d moments — splitting into %d chunks of ≤%d",
|
|
category, len(moment_group), num_chunks, chunk_size,
|
|
)
|
|
|
|
chunk_pages = []
|
|
for chunk_idx in range(num_chunks):
|
|
chunk_start = chunk_idx * chunk_size
|
|
chunk_end = min(chunk_start + chunk_size, len(moment_group))
|
|
chunk = moment_group[chunk_start:chunk_end]
|
|
chunk_label = f"category:{category} chunk:{chunk_idx + 1}/{num_chunks}"
|
|
|
|
result = _synthesize_chunk(
|
|
chunk, category, creator_name,
|
|
system_prompt, llm, model_override, modality, hard_limit,
|
|
video_id, run_id, chunk_label,
|
|
)
|
|
chunk_pages.extend(result.pages)
|
|
logger.info(
|
|
"Stage 5: %s produced %d page(s)",
|
|
chunk_label, len(result.pages),
|
|
)
|
|
|
|
# Merge pages with matching slugs across chunks
|
|
logger.info(
|
|
"Stage 5: category '%s' — %d total pages from %d chunks, checking for merges",
|
|
category, len(chunk_pages), num_chunks,
|
|
)
|
|
synthesized_pages = _merge_pages_by_slug(
|
|
chunk_pages, creator_name,
|
|
llm, model_override, modality, hard_limit,
|
|
video_id, run_id,
|
|
)
|
|
logger.info(
|
|
"Stage 5: category '%s' — %d final page(s) after merge",
|
|
category, len(synthesized_pages),
|
|
)
|
|
|
|
# ── Persist pages to DB ──────────────────────────────────
|
|
# Load prior pages from this video (snapshot taken before pipeline reset)
|
|
prior_page_ids = _load_prior_pages(video_id)
|
|
|
|
for page_data in synthesized_pages:
|
|
existing = None
|
|
|
|
# First: check by slug (most specific match)
|
|
if existing is None:
|
|
existing = session.execute(
|
|
select(TechniquePage).where(TechniquePage.slug == page_data.slug)
|
|
).scalar_one_or_none()
|
|
|
|
# Fallback: check prior pages from this video by creator + category
|
|
# Use .first() since multiple pages may share a category
|
|
if existing is None and prior_page_ids:
|
|
existing = session.execute(
|
|
select(TechniquePage).where(
|
|
TechniquePage.id.in_(prior_page_ids),
|
|
TechniquePage.creator_id == video.creator_id,
|
|
func.lower(TechniquePage.topic_category) == func.lower(page_data.topic_category or category),
|
|
)
|
|
).scalars().first()
|
|
if existing:
|
|
logger.info(
|
|
"Stage 5: Matched prior page '%s' (id=%s) by creator+category for video_id=%s",
|
|
existing.slug, existing.id, video_id,
|
|
)
|
|
|
|
if existing:
|
|
# Snapshot existing content before overwriting
|
|
try:
|
|
sq = existing.source_quality
|
|
sq_value = sq.value if hasattr(sq, 'value') else sq
|
|
snapshot = {
|
|
"title": existing.title,
|
|
"slug": existing.slug,
|
|
"topic_category": existing.topic_category,
|
|
"topic_tags": existing.topic_tags,
|
|
"summary": existing.summary,
|
|
"body_sections": existing.body_sections,
|
|
"signal_chains": existing.signal_chains,
|
|
"plugins": existing.plugins,
|
|
"source_quality": sq_value,
|
|
}
|
|
version_count = session.execute(
|
|
select(func.count()).where(
|
|
TechniquePageVersion.technique_page_id == existing.id
|
|
)
|
|
).scalar()
|
|
version_number = version_count + 1
|
|
|
|
version = TechniquePageVersion(
|
|
technique_page_id=existing.id,
|
|
version_number=version_number,
|
|
content_snapshot=snapshot,
|
|
pipeline_metadata=_capture_pipeline_metadata(),
|
|
)
|
|
session.add(version)
|
|
logger.info(
|
|
"Version snapshot v%d created for page slug=%s",
|
|
version_number, existing.slug,
|
|
)
|
|
except Exception as snap_exc:
|
|
logger.error(
|
|
"Failed to create version snapshot for page slug=%s: %s",
|
|
existing.slug, snap_exc,
|
|
)
|
|
# Best-effort versioning — continue with page update
|
|
|
|
# Update existing page
|
|
existing.title = page_data.title
|
|
existing.summary = page_data.summary
|
|
existing.body_sections = page_data.body_sections
|
|
existing.signal_chains = page_data.signal_chains
|
|
existing.plugins = page_data.plugins if page_data.plugins else None
|
|
page_tags = _compute_page_tags(page_moment_indices, moment_group, all_tags)
|
|
existing.topic_tags = page_tags
|
|
existing.source_quality = page_data.source_quality
|
|
page = existing
|
|
else:
|
|
page = TechniquePage(
|
|
creator_id=video.creator_id,
|
|
title=page_data.title,
|
|
slug=page_data.slug,
|
|
topic_category=page_data.topic_category or category,
|
|
topic_tags=_compute_page_tags(page_moment_indices, moment_group, all_tags),
|
|
summary=page_data.summary,
|
|
body_sections=page_data.body_sections,
|
|
signal_chains=page_data.signal_chains,
|
|
plugins=page_data.plugins if page_data.plugins else None,
|
|
source_quality=page_data.source_quality,
|
|
)
|
|
session.add(page)
|
|
session.flush() # Get the page.id assigned
|
|
|
|
pages_created += 1
|
|
|
|
# Link moments to the technique page using moment_indices
|
|
page_moment_indices = getattr(page_data, "moment_indices", None) or []
|
|
|
|
if page_moment_indices:
|
|
# LLM specified which moments belong to this page
|
|
for idx in page_moment_indices:
|
|
if 0 <= idx < len(moment_group):
|
|
moment_group[idx][0].technique_page_id = page.id
|
|
elif len(synthesized_pages) == 1:
|
|
# Single page — link all moments (safe fallback)
|
|
for m, _ in moment_group:
|
|
m.technique_page_id = page.id
|
|
else:
|
|
# Multiple pages but no moment_indices — log warning
|
|
logger.warning(
|
|
"Stage 5: page '%s' has no moment_indices and is one of %d pages "
|
|
"for category '%s'. Moments will not be linked to this page.",
|
|
page_data.slug, len(synthesized_pages), category,
|
|
)
|
|
|
|
# Update processing_status
|
|
video.processing_status = ProcessingStatus.complete
|
|
|
|
session.commit()
|
|
elapsed = time.monotonic() - start
|
|
_emit_event(video_id, "stage5_synthesis", "complete", run_id=run_id)
|
|
logger.info(
|
|
"Stage 5 (synthesis) completed for video_id=%s in %.1fs — %d pages created/updated",
|
|
video_id, elapsed, pages_created,
|
|
)
|
|
return video_id
|
|
|
|
except FileNotFoundError:
|
|
raise
|
|
except Exception as exc:
|
|
session.rollback()
|
|
_emit_event(video_id, "stage5_synthesis", "error", run_id=run_id, payload={"error": str(exc)})
|
|
logger.error("Stage 5 failed for video_id=%s: %s", video_id, exc)
|
|
raise self.retry(exc=exc)
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
# ── Stage 6: Embed & Index ───────────────────────────────────────────────────
|
|
|
|
@celery_app.task(bind=True, max_retries=0)
|
|
def stage6_embed_and_index(self, video_id: str, run_id: str | None = None) -> str:
|
|
"""Generate embeddings for technique pages and key moments, then upsert to Qdrant.
|
|
|
|
This is a non-blocking side-effect stage — failures are logged but do not
|
|
fail the pipeline. Embeddings can be regenerated later. Does NOT update
|
|
processing_status.
|
|
|
|
Returns the video_id for chain compatibility.
|
|
"""
|
|
start = time.monotonic()
|
|
logger.info("Stage 6 (embed & index) starting for video_id=%s", video_id)
|
|
|
|
settings = get_settings()
|
|
session = _get_sync_session()
|
|
try:
|
|
# Load technique pages created for this video's moments
|
|
moments = (
|
|
session.execute(
|
|
select(KeyMoment)
|
|
.where(KeyMoment.source_video_id == video_id)
|
|
.order_by(KeyMoment.start_time)
|
|
)
|
|
.scalars()
|
|
.all()
|
|
)
|
|
|
|
# Get unique technique page IDs from moments
|
|
page_ids = {m.technique_page_id for m in moments if m.technique_page_id is not None}
|
|
pages = []
|
|
if page_ids:
|
|
pages = (
|
|
session.execute(
|
|
select(TechniquePage).where(TechniquePage.id.in_(page_ids))
|
|
)
|
|
.scalars()
|
|
.all()
|
|
)
|
|
|
|
if not moments and not pages:
|
|
logger.info("Stage 6: No moments or pages for video_id=%s, skipping.", video_id)
|
|
if run_id:
|
|
_finish_run(run_id, "complete")
|
|
return video_id
|
|
|
|
embed_client = EmbeddingClient(settings)
|
|
qdrant = QdrantManager(settings)
|
|
|
|
# Ensure collection exists before upserting
|
|
qdrant.ensure_collection()
|
|
|
|
# ── Embed & upsert technique pages ───────────────────────────────
|
|
if pages:
|
|
page_texts = []
|
|
page_dicts = []
|
|
for p in pages:
|
|
text = f"{p.title} {p.summary or ''} {p.topic_category or ''}"
|
|
page_texts.append(text.strip())
|
|
page_dicts.append({
|
|
"page_id": str(p.id),
|
|
"creator_id": str(p.creator_id),
|
|
"title": p.title,
|
|
"slug": p.slug,
|
|
"topic_category": p.topic_category or "",
|
|
"topic_tags": p.topic_tags or [],
|
|
"summary": p.summary or "",
|
|
})
|
|
|
|
page_vectors = embed_client.embed(page_texts)
|
|
if page_vectors:
|
|
qdrant.upsert_technique_pages(page_dicts, page_vectors)
|
|
logger.info(
|
|
"Stage 6: Upserted %d technique page vectors for video_id=%s",
|
|
len(page_vectors), video_id,
|
|
)
|
|
else:
|
|
logger.warning(
|
|
"Stage 6: Embedding returned empty for %d technique pages (video_id=%s). "
|
|
"Skipping page upsert.",
|
|
len(page_texts), video_id,
|
|
)
|
|
|
|
# ── Embed & upsert key moments ───────────────────────────────────
|
|
if moments:
|
|
# Build page_id → slug mapping for linking moments to technique pages
|
|
page_id_to_slug: dict[str, str] = {}
|
|
if pages:
|
|
for p in pages:
|
|
page_id_to_slug[str(p.id)] = p.slug
|
|
|
|
moment_texts = []
|
|
moment_dicts = []
|
|
for m in moments:
|
|
text = f"{m.title} {m.summary or ''}"
|
|
moment_texts.append(text.strip())
|
|
tp_id = str(m.technique_page_id) if m.technique_page_id else ""
|
|
moment_dicts.append({
|
|
"moment_id": str(m.id),
|
|
"source_video_id": str(m.source_video_id),
|
|
"technique_page_id": tp_id,
|
|
"technique_page_slug": page_id_to_slug.get(tp_id, ""),
|
|
"title": m.title,
|
|
"start_time": m.start_time,
|
|
"end_time": m.end_time,
|
|
"content_type": m.content_type.value,
|
|
})
|
|
|
|
moment_vectors = embed_client.embed(moment_texts)
|
|
if moment_vectors:
|
|
qdrant.upsert_key_moments(moment_dicts, moment_vectors)
|
|
logger.info(
|
|
"Stage 6: Upserted %d key moment vectors for video_id=%s",
|
|
len(moment_vectors), video_id,
|
|
)
|
|
else:
|
|
logger.warning(
|
|
"Stage 6: Embedding returned empty for %d key moments (video_id=%s). "
|
|
"Skipping moment upsert.",
|
|
len(moment_texts), video_id,
|
|
)
|
|
|
|
elapsed = time.monotonic() - start
|
|
logger.info(
|
|
"Stage 6 (embed & index) completed for video_id=%s in %.1fs — "
|
|
"%d pages, %d moments processed",
|
|
video_id, elapsed, len(pages), len(moments),
|
|
)
|
|
if run_id:
|
|
_finish_run(run_id, "complete")
|
|
return video_id
|
|
|
|
except Exception as exc:
|
|
# Non-blocking: log error but don't fail the pipeline
|
|
logger.error(
|
|
"Stage 6 failed for video_id=%s: %s. "
|
|
"Pipeline continues — embeddings can be regenerated later.",
|
|
video_id, exc,
|
|
)
|
|
if run_id:
|
|
_finish_run(run_id, "complete") # Run is still "complete" — stage6 is best-effort
|
|
return video_id
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
|
|
|
|
def _snapshot_prior_pages(video_id: str) -> None:
|
|
"""Save existing technique_page_ids linked to this video before pipeline resets them.
|
|
|
|
When a video is reprocessed, stage 3 deletes and recreates key_moments,
|
|
breaking the link to technique pages. This snapshots the page IDs to Redis
|
|
so stage 5 can find and update prior pages instead of creating duplicates.
|
|
"""
|
|
import redis
|
|
|
|
session = _get_sync_session()
|
|
try:
|
|
# Find technique pages linked via this video's key moments
|
|
rows = session.execute(
|
|
select(KeyMoment.technique_page_id)
|
|
.where(
|
|
KeyMoment.source_video_id == video_id,
|
|
KeyMoment.technique_page_id.isnot(None),
|
|
)
|
|
.distinct()
|
|
).scalars().all()
|
|
|
|
page_ids = [str(pid) for pid in rows]
|
|
|
|
if page_ids:
|
|
settings = get_settings()
|
|
r = redis.Redis.from_url(settings.redis_url)
|
|
key = f"chrysopedia:prior_pages:{video_id}"
|
|
r.set(key, json.dumps(page_ids), ex=86400)
|
|
logger.info(
|
|
"Snapshot %d prior technique pages for video_id=%s: %s",
|
|
len(page_ids), video_id, page_ids,
|
|
)
|
|
else:
|
|
logger.info("No prior technique pages for video_id=%s", video_id)
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
def _load_prior_pages(video_id: str) -> list[str]:
|
|
"""Load prior technique page IDs from Redis."""
|
|
import redis
|
|
|
|
settings = get_settings()
|
|
r = redis.Redis.from_url(settings.redis_url)
|
|
key = f"chrysopedia:prior_pages:{video_id}"
|
|
raw = r.get(key)
|
|
if raw is None:
|
|
return []
|
|
return json.loads(raw)
|
|
|
|
|
|
# ── Stage completion detection for auto-resume ───────────────────────────────
|
|
|
|
# Ordered list of pipeline stages for resumability logic
|
|
_PIPELINE_STAGES = [
|
|
"stage2_segmentation",
|
|
"stage3_extraction",
|
|
"stage4_classification",
|
|
"stage5_synthesis",
|
|
"stage6_embed_and_index",
|
|
]
|
|
|
|
_STAGE_TASKS = {
|
|
"stage2_segmentation": stage2_segmentation,
|
|
"stage3_extraction": stage3_extraction,
|
|
"stage4_classification": stage4_classification,
|
|
"stage5_synthesis": stage5_synthesis,
|
|
"stage6_embed_and_index": stage6_embed_and_index,
|
|
}
|
|
|
|
|
|
def _get_last_completed_stage(video_id: str) -> str | None:
|
|
"""Find the last stage that completed successfully for this video.
|
|
|
|
Queries pipeline_events for the most recent run, looking for 'complete'
|
|
events. Returns the stage name (e.g. 'stage3_extraction') or None if
|
|
no stages have completed.
|
|
"""
|
|
session = _get_sync_session()
|
|
try:
|
|
# Find the most recent run for this video
|
|
from models import PipelineRun
|
|
latest_run = session.execute(
|
|
select(PipelineRun)
|
|
.where(PipelineRun.video_id == video_id)
|
|
.order_by(PipelineRun.started_at.desc())
|
|
.limit(1)
|
|
).scalar_one_or_none()
|
|
|
|
if latest_run is None:
|
|
return None
|
|
|
|
# Get all 'complete' events from that run
|
|
completed_events = session.execute(
|
|
select(PipelineEvent.stage)
|
|
.where(
|
|
PipelineEvent.run_id == str(latest_run.id),
|
|
PipelineEvent.event_type == "complete",
|
|
)
|
|
).scalars().all()
|
|
|
|
completed_set = set(completed_events)
|
|
|
|
# Walk backwards through the ordered stages to find the last completed one
|
|
last_completed = None
|
|
for stage_name in _PIPELINE_STAGES:
|
|
if stage_name in completed_set:
|
|
last_completed = stage_name
|
|
else:
|
|
break # Stop at first gap — stages must be sequential
|
|
|
|
if last_completed:
|
|
logger.info(
|
|
"Auto-resume: video_id=%s last completed stage=%s (run_id=%s)",
|
|
video_id, last_completed, latest_run.id,
|
|
)
|
|
return last_completed
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
# ── Orchestrator ─────────────────────────────────────────────────────────────
|
|
|
|
@celery_app.task
|
|
def mark_pipeline_error(request, exc, traceback, video_id: str, run_id: str | None = None) -> None:
|
|
"""Error callback — marks video as errored when a pipeline stage fails."""
|
|
logger.error("Pipeline failed for video_id=%s: %s", video_id, exc)
|
|
_set_error_status(video_id, "pipeline", exc)
|
|
if run_id:
|
|
_finish_run(run_id, "error", error_stage="pipeline")
|
|
|
|
|
|
def _create_run(video_id: str, trigger: str) -> str:
|
|
"""Create a PipelineRun and return its id."""
|
|
from models import PipelineRun, PipelineRunTrigger
|
|
|
|
session = _get_sync_session()
|
|
try:
|
|
# Compute run_number: max existing + 1
|
|
from sqlalchemy import func as sa_func
|
|
max_num = session.execute(
|
|
select(sa_func.coalesce(sa_func.max(PipelineRun.run_number), 0))
|
|
.where(PipelineRun.video_id == video_id)
|
|
).scalar() or 0
|
|
run = PipelineRun(
|
|
video_id=video_id,
|
|
run_number=max_num + 1,
|
|
trigger=PipelineRunTrigger(trigger),
|
|
)
|
|
session.add(run)
|
|
session.commit()
|
|
run_id = str(run.id)
|
|
return run_id
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
def _finish_run(run_id: str, status: str, error_stage: str | None = None) -> None:
|
|
"""Update a PipelineRun's status and finished_at."""
|
|
from models import PipelineRun, PipelineRunStatus, _now
|
|
|
|
session = _get_sync_session()
|
|
try:
|
|
run = session.execute(
|
|
select(PipelineRun).where(PipelineRun.id == run_id)
|
|
).scalar_one_or_none()
|
|
if run:
|
|
run.status = PipelineRunStatus(status)
|
|
run.finished_at = _now()
|
|
if error_stage:
|
|
run.error_stage = error_stage
|
|
# Aggregate total tokens from events
|
|
total = session.execute(
|
|
select(func.coalesce(func.sum(PipelineEvent.total_tokens), 0))
|
|
.where(PipelineEvent.run_id == run_id)
|
|
).scalar() or 0
|
|
run.total_tokens = total
|
|
session.commit()
|
|
except Exception as exc:
|
|
logger.warning("Failed to finish run %s: %s", run_id, exc)
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
@celery_app.task
|
|
def run_pipeline(video_id: str, trigger: str = "manual") -> str:
|
|
"""Orchestrate the full pipeline (stages 2-6) with auto-resume.
|
|
|
|
For error/processing status, queries pipeline_events to find the last
|
|
stage that completed successfully and resumes from the next stage.
|
|
This avoids re-running expensive LLM stages that already succeeded.
|
|
|
|
For clean_reprocess trigger, always starts from stage 2.
|
|
|
|
Returns the video_id.
|
|
"""
|
|
logger.info("run_pipeline starting for video_id=%s", video_id)
|
|
|
|
session = _get_sync_session()
|
|
try:
|
|
video = session.execute(
|
|
select(SourceVideo).where(SourceVideo.id == video_id)
|
|
).scalar_one_or_none()
|
|
|
|
if video is None:
|
|
logger.error("run_pipeline: video_id=%s not found", video_id)
|
|
raise ValueError(f"Video not found: {video_id}")
|
|
|
|
status = video.processing_status
|
|
logger.info(
|
|
"run_pipeline: video_id=%s current status=%s", video_id, status.value
|
|
)
|
|
finally:
|
|
session.close()
|
|
|
|
if status == ProcessingStatus.complete:
|
|
logger.info(
|
|
"run_pipeline: video_id=%s already at status=%s, nothing to do.",
|
|
video_id, status.value,
|
|
)
|
|
return video_id
|
|
|
|
# Snapshot prior technique pages before pipeline resets key_moments
|
|
_snapshot_prior_pages(video_id)
|
|
|
|
# Create a pipeline run record
|
|
run_id = _create_run(video_id, trigger)
|
|
logger.info("run_pipeline: created run_id=%s for video_id=%s (trigger=%s)", run_id, video_id, trigger)
|
|
|
|
# Determine which stages to run
|
|
resume_from_idx = 0 # Default: start from stage 2
|
|
|
|
if trigger != "clean_reprocess" and status in (ProcessingStatus.processing, ProcessingStatus.error):
|
|
# Try to resume from where we left off
|
|
last_completed = _get_last_completed_stage(video_id)
|
|
if last_completed and last_completed in _PIPELINE_STAGES:
|
|
completed_idx = _PIPELINE_STAGES.index(last_completed)
|
|
resume_from_idx = completed_idx + 1
|
|
if resume_from_idx >= len(_PIPELINE_STAGES):
|
|
logger.info(
|
|
"run_pipeline: all stages already completed for video_id=%s",
|
|
video_id,
|
|
)
|
|
return video_id
|
|
|
|
stages_to_run = _PIPELINE_STAGES[resume_from_idx:]
|
|
logger.info(
|
|
"run_pipeline: video_id=%s will run stages: %s (resume_from_idx=%d)",
|
|
video_id, stages_to_run, resume_from_idx,
|
|
)
|
|
|
|
# Build the Celery chain — first stage gets video_id as arg,
|
|
# subsequent stages receive it from the previous stage's return value
|
|
celery_sigs = []
|
|
for i, stage_name in enumerate(stages_to_run):
|
|
task_func = _STAGE_TASKS[stage_name]
|
|
if i == 0:
|
|
celery_sigs.append(task_func.s(video_id, run_id=run_id))
|
|
else:
|
|
celery_sigs.append(task_func.s(run_id=run_id))
|
|
|
|
if celery_sigs:
|
|
# Mark as processing before dispatching
|
|
session = _get_sync_session()
|
|
try:
|
|
video = session.execute(
|
|
select(SourceVideo).where(SourceVideo.id == video_id)
|
|
).scalar_one()
|
|
video.processing_status = ProcessingStatus.processing
|
|
session.commit()
|
|
finally:
|
|
session.close()
|
|
|
|
pipeline = celery_chain(*celery_sigs)
|
|
error_cb = mark_pipeline_error.s(video_id, run_id=run_id)
|
|
pipeline.apply_async(link_error=error_cb)
|
|
logger.info(
|
|
"run_pipeline: dispatched %d stages for video_id=%s (run_id=%s, starting at %s)",
|
|
len(celery_sigs), video_id, run_id, stages_to_run[0],
|
|
)
|
|
|
|
return video_id
|