Two fixes: 1. page_moment_indices was referenced before assignment in the page persist loop — moved assignment to top of loop body. This caused "cannot access local variable" errors on every stage 5 run. 2. Stage 5 now catches LLMTruncationError and splits the chunk in half for retry, instead of blindly retrying the same oversized prompt. This handles categories where synthesis output exceeds the model context window. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1734 lines
66 KiB
Python
1734 lines
66 KiB
Python
"""Pipeline stage tasks (stages 2-5) and run_pipeline orchestrator.
|
|
|
|
Each stage reads from PostgreSQL via sync SQLAlchemy, loads its prompt
|
|
template from disk, calls the LLM client, parses the response, writes
|
|
results back, and updates processing_status on SourceVideo.
|
|
|
|
Celery tasks are synchronous — all DB access uses ``sqlalchemy.orm.Session``.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import subprocess
|
|
import time
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
|
|
import yaml
|
|
from celery import chain as celery_chain
|
|
from pydantic import ValidationError
|
|
from sqlalchemy import create_engine, func, select
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
|
|
from config import get_settings
|
|
from models import (
|
|
Creator,
|
|
KeyMoment,
|
|
KeyMomentContentType,
|
|
PipelineEvent,
|
|
ProcessingStatus,
|
|
SourceVideo,
|
|
TechniquePage,
|
|
TechniquePageVersion,
|
|
TranscriptSegment,
|
|
)
|
|
from pipeline.embedding_client import EmbeddingClient
|
|
from pipeline.llm_client import LLMClient, LLMResponse, estimate_max_tokens
|
|
from pipeline.qdrant_client import QdrantManager
|
|
from pipeline.schemas import (
|
|
ClassificationResult,
|
|
ExtractionResult,
|
|
SegmentationResult,
|
|
SynthesisResult,
|
|
)
|
|
from worker import celery_app
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LLMTruncationError(RuntimeError):
|
|
"""Raised when the LLM response was truncated (finish_reason=length)."""
|
|
pass
|
|
|
|
|
|
# ── Error status helper ──────────────────────────────────────────────────────
|
|
|
|
def _set_error_status(video_id: str, stage_name: str, error: Exception) -> None:
|
|
"""Mark a video as errored when a pipeline stage fails permanently."""
|
|
try:
|
|
session = _get_sync_session()
|
|
video = session.execute(
|
|
select(SourceVideo).where(SourceVideo.id == video_id)
|
|
).scalar_one_or_none()
|
|
if video:
|
|
video.processing_status = ProcessingStatus.error
|
|
session.commit()
|
|
session.close()
|
|
except Exception as mark_exc:
|
|
logger.error(
|
|
"Failed to mark video_id=%s as error after %s failure: %s",
|
|
video_id, stage_name, mark_exc,
|
|
)
|
|
|
|
|
|
# ── Pipeline event persistence ───────────────────────────────────────────────
|
|
|
|
def _emit_event(
|
|
video_id: str,
|
|
stage: str,
|
|
event_type: str,
|
|
*,
|
|
run_id: str | None = None,
|
|
prompt_tokens: int | None = None,
|
|
completion_tokens: int | None = None,
|
|
total_tokens: int | None = None,
|
|
model: str | None = None,
|
|
duration_ms: int | None = None,
|
|
payload: dict | None = None,
|
|
system_prompt_text: str | None = None,
|
|
user_prompt_text: str | None = None,
|
|
response_text: str | None = None,
|
|
) -> None:
|
|
"""Persist a pipeline event to the DB. Best-effort -- failures logged, not raised."""
|
|
try:
|
|
session = _get_sync_session()
|
|
try:
|
|
event = PipelineEvent(
|
|
video_id=video_id,
|
|
run_id=run_id,
|
|
stage=stage,
|
|
event_type=event_type,
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=total_tokens,
|
|
model=model,
|
|
duration_ms=duration_ms,
|
|
payload=payload,
|
|
system_prompt_text=system_prompt_text,
|
|
user_prompt_text=user_prompt_text,
|
|
response_text=response_text,
|
|
)
|
|
session.add(event)
|
|
session.commit()
|
|
finally:
|
|
session.close()
|
|
except Exception as exc:
|
|
logger.warning("Failed to emit pipeline event: %s", exc)
|
|
|
|
|
|
def _is_debug_mode() -> bool:
|
|
"""Check if debug mode is enabled via Redis. Falls back to config setting."""
|
|
try:
|
|
import redis
|
|
settings = get_settings()
|
|
r = redis.from_url(settings.redis_url)
|
|
val = r.get("chrysopedia:debug_mode")
|
|
r.close()
|
|
if val is not None:
|
|
return val.decode().lower() == "true"
|
|
except Exception:
|
|
pass
|
|
return getattr(get_settings(), "debug_mode", False)
|
|
|
|
|
|
def _make_llm_callback(
|
|
video_id: str,
|
|
stage: str,
|
|
system_prompt: str | None = None,
|
|
user_prompt: str | None = None,
|
|
run_id: str | None = None,
|
|
context_label: str | None = None,
|
|
):
|
|
"""Create an on_complete callback for LLMClient that emits llm_call events.
|
|
|
|
When debug mode is enabled, captures full system prompt, user prompt,
|
|
and response text on each llm_call event.
|
|
"""
|
|
debug = _is_debug_mode()
|
|
|
|
def callback(*, model=None, prompt_tokens=None, completion_tokens=None,
|
|
total_tokens=None, content=None, finish_reason=None,
|
|
is_fallback=False, **_kwargs):
|
|
# Truncate content for storage — keep first 2000 chars for debugging
|
|
truncated = content[:2000] if content and len(content) > 2000 else content
|
|
_emit_event(
|
|
video_id=video_id,
|
|
stage=stage,
|
|
event_type="llm_call",
|
|
run_id=run_id,
|
|
model=model,
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=total_tokens,
|
|
payload={
|
|
"content_preview": truncated,
|
|
"content_length": len(content) if content else 0,
|
|
"finish_reason": finish_reason,
|
|
"is_fallback": is_fallback,
|
|
**({"context": context_label} if context_label else {}),
|
|
},
|
|
system_prompt_text=system_prompt if debug else None,
|
|
user_prompt_text=user_prompt if debug else None,
|
|
response_text=content if debug else None,
|
|
)
|
|
return callback
|
|
|
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
|
|
|
_engine = None
|
|
_SessionLocal = None
|
|
|
|
|
|
def _get_sync_engine():
|
|
"""Create a sync SQLAlchemy engine, converting the async URL if needed."""
|
|
global _engine
|
|
if _engine is None:
|
|
settings = get_settings()
|
|
url = settings.database_url
|
|
# Convert async driver to sync driver
|
|
url = url.replace("postgresql+asyncpg://", "postgresql+psycopg2://")
|
|
_engine = create_engine(url, pool_pre_ping=True, pool_size=5, max_overflow=10)
|
|
return _engine
|
|
|
|
|
|
def _get_sync_session() -> Session:
|
|
"""Create a sync SQLAlchemy session for Celery tasks."""
|
|
global _SessionLocal
|
|
if _SessionLocal is None:
|
|
_SessionLocal = sessionmaker(bind=_get_sync_engine())
|
|
return _SessionLocal()
|
|
|
|
|
|
def _load_prompt(template_name: str) -> str:
|
|
"""Read a prompt template from the prompts directory.
|
|
|
|
Raises FileNotFoundError if the template does not exist.
|
|
"""
|
|
settings = get_settings()
|
|
path = Path(settings.prompts_path) / template_name
|
|
if not path.exists():
|
|
logger.error("Prompt template not found: %s", path)
|
|
raise FileNotFoundError(f"Prompt template not found: {path}")
|
|
return path.read_text(encoding="utf-8")
|
|
|
|
|
|
def _get_llm_client() -> LLMClient:
|
|
"""Return an LLMClient configured from settings."""
|
|
return LLMClient(get_settings())
|
|
|
|
|
|
def _get_stage_config(stage_num: int) -> tuple[str | None, str]:
|
|
"""Return (model_override, modality) for a pipeline stage.
|
|
|
|
Reads stage-specific config from Settings. If the stage-specific model
|
|
is None/empty, returns None (LLMClient will use its default). If the
|
|
stage-specific modality is unset, defaults to "chat".
|
|
"""
|
|
settings = get_settings()
|
|
model = getattr(settings, f"llm_stage{stage_num}_model", None) or None
|
|
modality = getattr(settings, f"llm_stage{stage_num}_modality", None) or "chat"
|
|
return model, modality
|
|
|
|
|
|
def _load_canonical_tags() -> dict:
|
|
"""Load canonical tag taxonomy from config/canonical_tags.yaml."""
|
|
# Walk up from backend/ to find config/
|
|
candidates = [
|
|
Path("config/canonical_tags.yaml"),
|
|
Path("../config/canonical_tags.yaml"),
|
|
]
|
|
for candidate in candidates:
|
|
if candidate.exists():
|
|
with open(candidate, encoding="utf-8") as f:
|
|
return yaml.safe_load(f)
|
|
raise FileNotFoundError(
|
|
"canonical_tags.yaml not found. Searched: " + ", ".join(str(c) for c in candidates)
|
|
)
|
|
|
|
|
|
def _format_taxonomy_for_prompt(tags_data: dict) -> str:
|
|
"""Format the canonical tags taxonomy as readable text for the LLM prompt."""
|
|
lines = []
|
|
for cat in tags_data.get("categories", []):
|
|
lines.append(f"Category: {cat['name']}")
|
|
lines.append(f" Description: {cat['description']}")
|
|
lines.append(f" Sub-topics: {', '.join(cat.get('sub_topics', []))}")
|
|
lines.append("")
|
|
return "\n".join(lines)
|
|
|
|
|
|
def _safe_parse_llm_response(
|
|
raw,
|
|
model_cls,
|
|
llm: LLMClient,
|
|
system_prompt: str,
|
|
user_prompt: str,
|
|
modality: str = "chat",
|
|
model_override: str | None = None,
|
|
max_tokens: int | None = None,
|
|
):
|
|
"""Parse LLM response with truncation detection and one retry on failure.
|
|
|
|
If the response was truncated (finish_reason=length), raises
|
|
LLMTruncationError immediately — retrying with a JSON nudge would only
|
|
make things worse by adding tokens to an already-too-large prompt.
|
|
|
|
For non-truncation parse failures: retry once with a JSON nudge, then
|
|
raise on second failure.
|
|
"""
|
|
# Check for truncation before attempting parse
|
|
is_truncated = isinstance(raw, LLMResponse) and raw.truncated
|
|
if is_truncated:
|
|
logger.warning(
|
|
"LLM response truncated (finish=length) for %s. "
|
|
"prompt_tokens=%s, completion_tokens=%s. Will not retry with nudge.",
|
|
model_cls.__name__,
|
|
getattr(raw, "prompt_tokens", "?"),
|
|
getattr(raw, "completion_tokens", "?"),
|
|
)
|
|
|
|
try:
|
|
return llm.parse_response(raw, model_cls)
|
|
except (ValidationError, ValueError, json.JSONDecodeError) as exc:
|
|
if is_truncated:
|
|
raise LLMTruncationError(
|
|
f"LLM output truncated for {model_cls.__name__}: "
|
|
f"prompt_tokens={getattr(raw, 'prompt_tokens', '?')}, "
|
|
f"completion_tokens={getattr(raw, 'completion_tokens', '?')}. "
|
|
f"Response too large for model context window."
|
|
) from exc
|
|
|
|
logger.warning(
|
|
"First parse attempt failed for %s (%s). Retrying with JSON nudge. "
|
|
"Raw response (first 500 chars): %.500s",
|
|
model_cls.__name__,
|
|
type(exc).__name__,
|
|
raw,
|
|
)
|
|
# Retry with explicit JSON instruction
|
|
nudge_prompt = user_prompt + "\n\nIMPORTANT: Output ONLY valid JSON. No markdown, no explanation."
|
|
retry_raw = llm.complete(
|
|
system_prompt, nudge_prompt, response_model=model_cls,
|
|
modality=modality, model_override=model_override,
|
|
max_tokens=max_tokens,
|
|
)
|
|
return llm.parse_response(retry_raw, model_cls)
|
|
|
|
|
|
# ── Stage 2: Segmentation ───────────────────────────────────────────────────
|
|
|
|
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
|
|
def stage2_segmentation(self, video_id: str, run_id: str | None = None) -> str:
|
|
"""Analyze transcript segments and identify topic boundaries.
|
|
|
|
Loads all TranscriptSegment rows for the video, sends them to the LLM
|
|
for topic boundary detection, and updates topic_label on each segment.
|
|
|
|
Returns the video_id for chain compatibility.
|
|
"""
|
|
start = time.monotonic()
|
|
logger.info("Stage 2 (segmentation) starting for video_id=%s", video_id)
|
|
_emit_event(video_id, "stage2_segmentation", "start", run_id=run_id)
|
|
|
|
session = _get_sync_session()
|
|
try:
|
|
# Load segments ordered by index
|
|
segments = (
|
|
session.execute(
|
|
select(TranscriptSegment)
|
|
.where(TranscriptSegment.source_video_id == video_id)
|
|
.order_by(TranscriptSegment.segment_index)
|
|
)
|
|
.scalars()
|
|
.all()
|
|
)
|
|
|
|
if not segments:
|
|
logger.info("Stage 2: No segments found for video_id=%s, skipping.", video_id)
|
|
return video_id
|
|
|
|
# Build transcript text with indices for the LLM
|
|
transcript_lines = []
|
|
for seg in segments:
|
|
transcript_lines.append(
|
|
f"[{seg.segment_index}] ({seg.start_time:.1f}s - {seg.end_time:.1f}s) {seg.text}"
|
|
)
|
|
transcript_text = "\n".join(transcript_lines)
|
|
|
|
# Load prompt and call LLM
|
|
system_prompt = _load_prompt("stage2_segmentation.txt")
|
|
user_prompt = f"<transcript>\n{transcript_text}\n</transcript>"
|
|
|
|
llm = _get_llm_client()
|
|
model_override, modality = _get_stage_config(2)
|
|
hard_limit = get_settings().llm_max_tokens_hard_limit
|
|
max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage2_segmentation", hard_limit=hard_limit)
|
|
logger.info("Stage 2 using model=%s, modality=%s, max_tokens=%d", model_override or "default", modality, max_tokens)
|
|
raw = llm.complete(system_prompt, user_prompt, response_model=SegmentationResult, on_complete=_make_llm_callback(video_id, "stage2_segmentation", system_prompt=system_prompt, user_prompt=user_prompt, run_id=run_id),
|
|
modality=modality, model_override=model_override, max_tokens=max_tokens)
|
|
result = _safe_parse_llm_response(raw, SegmentationResult, llm, system_prompt, user_prompt,
|
|
modality=modality, model_override=model_override, max_tokens=max_tokens)
|
|
|
|
# Update topic_label on each segment row
|
|
seg_by_index = {s.segment_index: s for s in segments}
|
|
for topic_seg in result.segments:
|
|
for idx in range(topic_seg.start_index, topic_seg.end_index + 1):
|
|
if idx in seg_by_index:
|
|
seg_by_index[idx].topic_label = topic_seg.topic_label
|
|
|
|
session.commit()
|
|
elapsed = time.monotonic() - start
|
|
_emit_event(video_id, "stage2_segmentation", "complete", run_id=run_id)
|
|
logger.info(
|
|
"Stage 2 (segmentation) completed for video_id=%s in %.1fs — %d topic groups found",
|
|
video_id, elapsed, len(result.segments),
|
|
)
|
|
return video_id
|
|
|
|
except FileNotFoundError:
|
|
raise # Don't retry missing prompt files
|
|
except Exception as exc:
|
|
session.rollback()
|
|
_emit_event(video_id, "stage2_segmentation", "error", run_id=run_id, payload={"error": str(exc)})
|
|
logger.error("Stage 2 failed for video_id=%s: %s", video_id, exc)
|
|
raise self.retry(exc=exc)
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
# ── Stage 3: Extraction ─────────────────────────────────────────────────────
|
|
|
|
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
|
|
def stage3_extraction(self, video_id: str, run_id: str | None = None) -> str:
|
|
"""Extract key moments from each topic segment group.
|
|
|
|
Groups segments by topic_label, calls the LLM for each group to extract
|
|
moments, creates KeyMoment rows, and sets processing_status=extracted.
|
|
|
|
Returns the video_id for chain compatibility.
|
|
"""
|
|
start = time.monotonic()
|
|
logger.info("Stage 3 (extraction) starting for video_id=%s", video_id)
|
|
_emit_event(video_id, "stage3_extraction", "start", run_id=run_id)
|
|
|
|
session = _get_sync_session()
|
|
try:
|
|
# Load segments with topic labels
|
|
segments = (
|
|
session.execute(
|
|
select(TranscriptSegment)
|
|
.where(TranscriptSegment.source_video_id == video_id)
|
|
.order_by(TranscriptSegment.segment_index)
|
|
)
|
|
.scalars()
|
|
.all()
|
|
)
|
|
|
|
if not segments:
|
|
logger.info("Stage 3: No segments found for video_id=%s, skipping.", video_id)
|
|
return video_id
|
|
|
|
# Group segments by topic_label
|
|
groups: dict[str, list[TranscriptSegment]] = defaultdict(list)
|
|
for seg in segments:
|
|
label = seg.topic_label or "unlabeled"
|
|
groups[label].append(seg)
|
|
|
|
system_prompt = _load_prompt("stage3_extraction.txt")
|
|
llm = _get_llm_client()
|
|
model_override, modality = _get_stage_config(3)
|
|
hard_limit = get_settings().llm_max_tokens_hard_limit
|
|
logger.info("Stage 3 using model=%s, modality=%s", model_override or "default", modality)
|
|
total_moments = 0
|
|
|
|
for topic_label, group_segs in groups.items():
|
|
# Build segment text for this group
|
|
seg_lines = []
|
|
for seg in group_segs:
|
|
seg_lines.append(
|
|
f"({seg.start_time:.1f}s - {seg.end_time:.1f}s) {seg.text}"
|
|
)
|
|
segment_text = "\n".join(seg_lines)
|
|
|
|
user_prompt = (
|
|
f"Topic: {topic_label}\n\n"
|
|
f"<segment>\n{segment_text}\n</segment>"
|
|
)
|
|
|
|
max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage3_extraction", hard_limit=hard_limit)
|
|
raw = llm.complete(system_prompt, user_prompt, response_model=ExtractionResult, on_complete=_make_llm_callback(video_id, "stage3_extraction", system_prompt=system_prompt, user_prompt=user_prompt, run_id=run_id, context_label=topic_label),
|
|
modality=modality, model_override=model_override, max_tokens=max_tokens)
|
|
result = _safe_parse_llm_response(raw, ExtractionResult, llm, system_prompt, user_prompt,
|
|
modality=modality, model_override=model_override, max_tokens=max_tokens)
|
|
|
|
# Create KeyMoment rows
|
|
for moment in result.moments:
|
|
# Validate content_type against enum
|
|
try:
|
|
ct = KeyMomentContentType(moment.content_type)
|
|
except ValueError:
|
|
ct = KeyMomentContentType.technique
|
|
|
|
km = KeyMoment(
|
|
source_video_id=video_id,
|
|
title=moment.title,
|
|
summary=moment.summary,
|
|
start_time=moment.start_time,
|
|
end_time=moment.end_time,
|
|
content_type=ct,
|
|
plugins=moment.plugins if moment.plugins else None,
|
|
raw_transcript=moment.raw_transcript or None,
|
|
)
|
|
session.add(km)
|
|
total_moments += 1
|
|
|
|
session.commit()
|
|
elapsed = time.monotonic() - start
|
|
_emit_event(video_id, "stage3_extraction", "complete", run_id=run_id)
|
|
logger.info(
|
|
"Stage 3 (extraction) completed for video_id=%s in %.1fs — %d moments created",
|
|
video_id, elapsed, total_moments,
|
|
)
|
|
return video_id
|
|
|
|
except FileNotFoundError:
|
|
raise
|
|
except Exception as exc:
|
|
session.rollback()
|
|
_emit_event(video_id, "stage3_extraction", "error", run_id=run_id, payload={"error": str(exc)})
|
|
logger.error("Stage 3 failed for video_id=%s: %s", video_id, exc)
|
|
raise self.retry(exc=exc)
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
# ── Stage 4: Classification ─────────────────────────────────────────────────
|
|
|
|
# Maximum moments per classification batch. Keeps each LLM call well within
|
|
# context window limits. Batches are classified independently and merged.
|
|
_STAGE4_BATCH_SIZE = 20
|
|
|
|
|
|
def _classify_moment_batch(
|
|
moments_batch: list,
|
|
batch_offset: int,
|
|
taxonomy_text: str,
|
|
system_prompt: str,
|
|
llm: LLMClient,
|
|
model_override: str | None,
|
|
modality: str,
|
|
hard_limit: int,
|
|
video_id: str,
|
|
run_id: str | None,
|
|
) -> ClassificationResult:
|
|
"""Classify a single batch of moments. Raises on failure."""
|
|
moments_lines = []
|
|
for i, m in enumerate(moments_batch):
|
|
moments_lines.append(
|
|
f"[{i}] Title: {m.title}\n"
|
|
f" Summary: {m.summary}\n"
|
|
f" Content type: {m.content_type.value}\n"
|
|
f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}"
|
|
)
|
|
moments_text = "\n\n".join(moments_lines)
|
|
|
|
user_prompt = (
|
|
f"<taxonomy>\n{taxonomy_text}\n</taxonomy>\n\n"
|
|
f"<moments>\n{moments_text}\n</moments>"
|
|
)
|
|
|
|
max_tokens = estimate_max_tokens(
|
|
system_prompt, user_prompt,
|
|
stage="stage4_classification", hard_limit=hard_limit,
|
|
)
|
|
batch_label = f"batch {batch_offset // _STAGE4_BATCH_SIZE + 1} (moments {batch_offset}-{batch_offset + len(moments_batch) - 1})"
|
|
logger.info(
|
|
"Stage 4 classifying %s, max_tokens=%d",
|
|
batch_label, max_tokens,
|
|
)
|
|
raw = llm.complete(
|
|
system_prompt, user_prompt,
|
|
response_model=ClassificationResult,
|
|
on_complete=_make_llm_callback(
|
|
video_id, "stage4_classification",
|
|
system_prompt=system_prompt, user_prompt=user_prompt,
|
|
run_id=run_id, context_label=batch_label,
|
|
),
|
|
modality=modality, model_override=model_override,
|
|
max_tokens=max_tokens,
|
|
)
|
|
return _safe_parse_llm_response(
|
|
raw, ClassificationResult, llm, system_prompt, user_prompt,
|
|
modality=modality, model_override=model_override,
|
|
max_tokens=max_tokens,
|
|
)
|
|
|
|
|
|
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
|
|
def stage4_classification(self, video_id: str, run_id: str | None = None) -> str:
|
|
"""Classify key moments against the canonical tag taxonomy.
|
|
|
|
Loads all KeyMoment rows for the video, sends them to the LLM with the
|
|
canonical taxonomy, and stores classification results in Redis for
|
|
stage 5 consumption. Updates content_type if the classifier overrides it.
|
|
|
|
For large moment sets, automatically batches into groups of
|
|
_STAGE4_BATCH_SIZE to stay within model context window limits.
|
|
|
|
Stage 4 does NOT change processing_status.
|
|
|
|
Returns the video_id for chain compatibility.
|
|
"""
|
|
start = time.monotonic()
|
|
logger.info("Stage 4 (classification) starting for video_id=%s", video_id)
|
|
_emit_event(video_id, "stage4_classification", "start", run_id=run_id)
|
|
|
|
session = _get_sync_session()
|
|
try:
|
|
# Load key moments
|
|
moments = (
|
|
session.execute(
|
|
select(KeyMoment)
|
|
.where(KeyMoment.source_video_id == video_id)
|
|
.order_by(KeyMoment.start_time)
|
|
)
|
|
.scalars()
|
|
.all()
|
|
)
|
|
|
|
if not moments:
|
|
logger.info("Stage 4: No moments found for video_id=%s, skipping.", video_id)
|
|
_store_classification_data(video_id, [])
|
|
return video_id
|
|
|
|
# Load canonical tags
|
|
tags_data = _load_canonical_tags()
|
|
taxonomy_text = _format_taxonomy_for_prompt(tags_data)
|
|
|
|
system_prompt = _load_prompt("stage4_classification.txt")
|
|
llm = _get_llm_client()
|
|
model_override, modality = _get_stage_config(4)
|
|
hard_limit = get_settings().llm_max_tokens_hard_limit
|
|
|
|
# Batch moments for classification
|
|
all_classifications = []
|
|
for batch_start in range(0, len(moments), _STAGE4_BATCH_SIZE):
|
|
batch = moments[batch_start:batch_start + _STAGE4_BATCH_SIZE]
|
|
result = _classify_moment_batch(
|
|
batch, batch_start, taxonomy_text, system_prompt,
|
|
llm, model_override, modality, hard_limit,
|
|
video_id, run_id,
|
|
)
|
|
# Reindex: batch uses 0-based indices, remap to global indices
|
|
for cls in result.classifications:
|
|
cls.moment_index += batch_start
|
|
all_classifications.extend(result.classifications)
|
|
|
|
# Apply content_type overrides and prepare classification data for stage 5
|
|
classification_data = []
|
|
|
|
for cls in all_classifications:
|
|
if 0 <= cls.moment_index < len(moments):
|
|
moment = moments[cls.moment_index]
|
|
|
|
# Apply content_type override if provided
|
|
if cls.content_type_override:
|
|
try:
|
|
moment.content_type = KeyMomentContentType(cls.content_type_override)
|
|
except ValueError:
|
|
pass
|
|
|
|
classification_data.append({
|
|
"moment_id": str(moment.id),
|
|
"topic_category": cls.topic_category.strip().title(),
|
|
"topic_tags": cls.topic_tags,
|
|
})
|
|
|
|
session.commit()
|
|
|
|
# Store classification data in Redis for stage 5
|
|
_store_classification_data(video_id, classification_data)
|
|
|
|
elapsed = time.monotonic() - start
|
|
num_batches = (len(moments) + _STAGE4_BATCH_SIZE - 1) // _STAGE4_BATCH_SIZE
|
|
_emit_event(video_id, "stage4_classification", "complete", run_id=run_id)
|
|
logger.info(
|
|
"Stage 4 (classification) completed for video_id=%s in %.1fs — "
|
|
"%d moments classified in %d batch(es)",
|
|
video_id, elapsed, len(classification_data), num_batches,
|
|
)
|
|
return video_id
|
|
|
|
except FileNotFoundError:
|
|
raise
|
|
except Exception as exc:
|
|
session.rollback()
|
|
_emit_event(video_id, "stage4_classification", "error", run_id=run_id, payload={"error": str(exc)})
|
|
logger.error("Stage 4 failed for video_id=%s: %s", video_id, exc)
|
|
raise self.retry(exc=exc)
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
def _store_classification_data(video_id: str, data: list[dict]) -> None:
|
|
"""Store classification data in Redis for cross-stage communication."""
|
|
import redis
|
|
|
|
settings = get_settings()
|
|
r = redis.Redis.from_url(settings.redis_url)
|
|
key = f"chrysopedia:classification:{video_id}"
|
|
r.set(key, json.dumps(data), ex=86400) # Expire after 24 hours
|
|
|
|
|
|
def _load_classification_data(video_id: str) -> list[dict]:
|
|
"""Load classification data from Redis."""
|
|
import redis
|
|
|
|
settings = get_settings()
|
|
r = redis.Redis.from_url(settings.redis_url)
|
|
key = f"chrysopedia:classification:{video_id}"
|
|
raw = r.get(key)
|
|
if raw is None:
|
|
return []
|
|
return json.loads(raw)
|
|
|
|
|
|
|
|
def _get_git_commit_sha() -> str:
|
|
"""Resolve the git commit SHA used to build this image.
|
|
|
|
Resolution order:
|
|
1. /app/.git-commit file (written during Docker build)
|
|
2. git rev-parse --short HEAD (local dev)
|
|
3. GIT_COMMIT_SHA env var / config setting
|
|
4. "unknown"
|
|
"""
|
|
# Docker build artifact
|
|
git_commit_file = Path("/app/.git-commit")
|
|
if git_commit_file.exists():
|
|
sha = git_commit_file.read_text(encoding="utf-8").strip()
|
|
if sha and sha != "unknown":
|
|
return sha
|
|
|
|
# Local dev — run git
|
|
try:
|
|
result = subprocess.run(
|
|
["git", "rev-parse", "--short", "HEAD"],
|
|
capture_output=True, text=True, timeout=5,
|
|
)
|
|
if result.returncode == 0 and result.stdout.strip():
|
|
return result.stdout.strip()
|
|
except (FileNotFoundError, subprocess.TimeoutExpired):
|
|
pass
|
|
|
|
# Config / env var fallback
|
|
try:
|
|
sha = get_settings().git_commit_sha
|
|
if sha and sha != "unknown":
|
|
return sha
|
|
except Exception:
|
|
pass
|
|
|
|
return "unknown"
|
|
|
|
def _capture_pipeline_metadata() -> dict:
|
|
"""Capture current pipeline configuration for version metadata.
|
|
|
|
Returns a dict with model names, prompt file SHA-256 hashes, and stage
|
|
modality settings. Handles missing prompt files gracefully.
|
|
"""
|
|
settings = get_settings()
|
|
prompts_path = Path(settings.prompts_path)
|
|
|
|
# Hash each prompt template file
|
|
prompt_hashes: dict[str, str] = {}
|
|
prompt_files = [
|
|
"stage2_segmentation.txt",
|
|
"stage3_extraction.txt",
|
|
"stage4_classification.txt",
|
|
"stage5_synthesis.txt",
|
|
]
|
|
for filename in prompt_files:
|
|
filepath = prompts_path / filename
|
|
try:
|
|
content = filepath.read_bytes()
|
|
prompt_hashes[filename] = hashlib.sha256(content).hexdigest()
|
|
except FileNotFoundError:
|
|
logger.warning("Prompt file not found for metadata capture: %s", filepath)
|
|
prompt_hashes[filename] = ""
|
|
except OSError as exc:
|
|
logger.warning("Could not read prompt file %s: %s", filepath, exc)
|
|
prompt_hashes[filename] = ""
|
|
|
|
return {
|
|
"git_commit_sha": _get_git_commit_sha(),
|
|
"models": {
|
|
"stage2": settings.llm_stage2_model,
|
|
"stage3": settings.llm_stage3_model,
|
|
"stage4": settings.llm_stage4_model,
|
|
"stage5": settings.llm_stage5_model,
|
|
"embedding": settings.embedding_model,
|
|
},
|
|
"modalities": {
|
|
"stage2": settings.llm_stage2_modality,
|
|
"stage3": settings.llm_stage3_modality,
|
|
"stage4": settings.llm_stage4_modality,
|
|
"stage5": settings.llm_stage5_modality,
|
|
},
|
|
"prompt_hashes": prompt_hashes,
|
|
}
|
|
|
|
|
|
# ── Stage 5: Synthesis ───────────────────────────────────────────────────────
|
|
|
|
|
|
def _compute_page_tags(
|
|
moment_indices: list[int],
|
|
moment_group: list[tuple],
|
|
all_tags: set[str],
|
|
) -> list[str] | None:
|
|
"""Compute tags for a specific page from its linked moment indices.
|
|
|
|
If moment_indices are available, collects tags only from those moments.
|
|
Falls back to all_tags for the category group if no indices provided.
|
|
"""
|
|
if not moment_indices:
|
|
return list(all_tags) if all_tags else None
|
|
|
|
page_tags: set[str] = set()
|
|
for idx in moment_indices:
|
|
if 0 <= idx < len(moment_group):
|
|
_, cls_info = moment_group[idx]
|
|
page_tags.update(cls_info.get("topic_tags", []))
|
|
|
|
return list(page_tags) if page_tags else None
|
|
|
|
|
|
def _build_moments_text(
|
|
moment_group: list[tuple[KeyMoment, dict]],
|
|
category: str,
|
|
) -> tuple[str, set[str]]:
|
|
"""Build the moments prompt text and collect all tags for a group of moments.
|
|
|
|
Returns (moments_text, all_tags).
|
|
"""
|
|
moments_lines = []
|
|
all_tags: set[str] = set()
|
|
for i, (m, cls_info) in enumerate(moment_group):
|
|
tags = cls_info.get("topic_tags", [])
|
|
all_tags.update(tags)
|
|
moments_lines.append(
|
|
f"[{i}] Title: {m.title}\n"
|
|
f" Summary: {m.summary}\n"
|
|
f" Content type: {m.content_type.value}\n"
|
|
f" Time: {m.start_time:.1f}s - {m.end_time:.1f}s\n"
|
|
f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}\n"
|
|
f" Category: {category}\n"
|
|
f" Tags: {', '.join(tags) if tags else 'none'}\n"
|
|
f" Transcript excerpt: {(m.raw_transcript or '')[:300]}"
|
|
)
|
|
return "\n\n".join(moments_lines), all_tags
|
|
|
|
|
|
def _synthesize_chunk(
|
|
chunk: list[tuple[KeyMoment, dict]],
|
|
category: str,
|
|
creator_name: str,
|
|
system_prompt: str,
|
|
llm: LLMClient,
|
|
model_override: str | None,
|
|
modality: str,
|
|
hard_limit: int,
|
|
video_id: str,
|
|
run_id: str | None,
|
|
chunk_label: str,
|
|
) -> SynthesisResult:
|
|
"""Run a single synthesis LLM call for a chunk of moments.
|
|
|
|
Returns the parsed SynthesisResult.
|
|
"""
|
|
moments_text, _ = _build_moments_text(chunk, category)
|
|
user_prompt = f"<creator>{creator_name}</creator>\n<moments>\n{moments_text}\n</moments>"
|
|
|
|
estimated_input = estimate_max_tokens(system_prompt, user_prompt, stage="stage5_synthesis", hard_limit=hard_limit)
|
|
logger.info(
|
|
"Stage 5: Synthesizing %s — %d moments, max_tokens=%d",
|
|
chunk_label, len(chunk), estimated_input,
|
|
)
|
|
|
|
raw = llm.complete(
|
|
system_prompt, user_prompt, response_model=SynthesisResult,
|
|
on_complete=_make_llm_callback(
|
|
video_id, "stage5_synthesis",
|
|
system_prompt=system_prompt, user_prompt=user_prompt,
|
|
run_id=run_id, context_label=chunk_label,
|
|
),
|
|
modality=modality, model_override=model_override, max_tokens=estimated_input,
|
|
)
|
|
return _safe_parse_llm_response(
|
|
raw, SynthesisResult, llm, system_prompt, user_prompt,
|
|
modality=modality, model_override=model_override, max_tokens=estimated_input,
|
|
)
|
|
|
|
|
|
def _slug_base(slug: str) -> str:
|
|
"""Extract the slug prefix before the creator name suffix for merge grouping.
|
|
|
|
E.g. 'wavetable-sound-design-copycatt' → 'wavetable-sound-design'
|
|
Also normalizes casing.
|
|
"""
|
|
return slug.lower().strip()
|
|
|
|
|
|
def _merge_pages_by_slug(
|
|
all_pages: list,
|
|
creator_name: str,
|
|
llm: LLMClient,
|
|
model_override: str | None,
|
|
modality: str,
|
|
hard_limit: int,
|
|
video_id: str,
|
|
run_id: str | None,
|
|
) -> list:
|
|
"""Detect pages with the same slug across chunks and merge them via LLM.
|
|
|
|
Pages with unique slugs pass through unchanged. Pages sharing a slug
|
|
get sent to a merge prompt that combines them into one cohesive page.
|
|
|
|
Returns the final list of SynthesizedPage objects.
|
|
"""
|
|
from pipeline.schemas import SynthesizedPage
|
|
|
|
# Group pages by slug
|
|
by_slug: dict[str, list] = defaultdict(list)
|
|
for page in all_pages:
|
|
by_slug[_slug_base(page.slug)].append(page)
|
|
|
|
final_pages = []
|
|
for slug, pages_group in by_slug.items():
|
|
if len(pages_group) == 1:
|
|
# Unique slug — no merge needed
|
|
final_pages.append(pages_group[0])
|
|
continue
|
|
|
|
# Multiple pages share this slug — merge via LLM
|
|
logger.info(
|
|
"Stage 5: Merging %d partial pages with slug '%s' for video_id=%s",
|
|
len(pages_group), slug, video_id,
|
|
)
|
|
|
|
# Serialize partial pages to JSON for the merge prompt
|
|
pages_json = json.dumps(
|
|
[p.model_dump() for p in pages_group],
|
|
indent=2, ensure_ascii=False,
|
|
)
|
|
|
|
merge_system_prompt = _load_prompt("stage5_merge.txt")
|
|
merge_user_prompt = f"<creator>{creator_name}</creator>\n<pages>\n{pages_json}\n</pages>"
|
|
|
|
max_tokens = estimate_max_tokens(
|
|
merge_system_prompt, merge_user_prompt,
|
|
stage="stage5_synthesis", hard_limit=hard_limit,
|
|
)
|
|
logger.info(
|
|
"Stage 5: Merge call for slug '%s' — %d partial pages, max_tokens=%d",
|
|
slug, len(pages_group), max_tokens,
|
|
)
|
|
|
|
raw = llm.complete(
|
|
merge_system_prompt, merge_user_prompt,
|
|
response_model=SynthesisResult,
|
|
on_complete=_make_llm_callback(
|
|
video_id, "stage5_synthesis",
|
|
system_prompt=merge_system_prompt,
|
|
user_prompt=merge_user_prompt,
|
|
run_id=run_id, context_label=f"merge:{slug}",
|
|
),
|
|
modality=modality, model_override=model_override,
|
|
max_tokens=max_tokens,
|
|
)
|
|
merge_result = _safe_parse_llm_response(
|
|
raw, SynthesisResult, llm,
|
|
merge_system_prompt, merge_user_prompt,
|
|
modality=modality, model_override=model_override,
|
|
max_tokens=max_tokens,
|
|
)
|
|
|
|
if merge_result.pages:
|
|
final_pages.extend(merge_result.pages)
|
|
logger.info(
|
|
"Stage 5: Merge produced %d page(s) for slug '%s'",
|
|
len(merge_result.pages), slug,
|
|
)
|
|
else:
|
|
# Merge returned nothing — fall back to keeping the partials
|
|
logger.warning(
|
|
"Stage 5: Merge returned 0 pages for slug '%s', keeping %d partials",
|
|
slug, len(pages_group),
|
|
)
|
|
final_pages.extend(pages_group)
|
|
|
|
return final_pages
|
|
|
|
|
|
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
|
|
def stage5_synthesis(self, video_id: str, run_id: str | None = None) -> str:
|
|
"""Synthesize technique pages from classified key moments.
|
|
|
|
Groups moments by (creator, topic_category), calls the LLM to synthesize
|
|
each group into a TechniquePage, creates/updates page rows, and links
|
|
KeyMoments to their TechniquePage.
|
|
|
|
For large category groups (exceeding synthesis_chunk_size), moments are
|
|
split into chronological chunks, synthesized independently, then pages
|
|
with matching slugs are merged via a dedicated merge LLM call.
|
|
|
|
Sets processing_status to 'complete'.
|
|
|
|
Returns the video_id for chain compatibility.
|
|
"""
|
|
start = time.monotonic()
|
|
logger.info("Stage 5 (synthesis) starting for video_id=%s", video_id)
|
|
_emit_event(video_id, "stage5_synthesis", "start", run_id=run_id)
|
|
|
|
settings = get_settings()
|
|
chunk_size = settings.synthesis_chunk_size
|
|
session = _get_sync_session()
|
|
try:
|
|
# Load video and moments
|
|
video = session.execute(
|
|
select(SourceVideo).where(SourceVideo.id == video_id)
|
|
).scalar_one()
|
|
|
|
moments = (
|
|
session.execute(
|
|
select(KeyMoment)
|
|
.where(KeyMoment.source_video_id == video_id)
|
|
.order_by(KeyMoment.start_time)
|
|
)
|
|
.scalars()
|
|
.all()
|
|
)
|
|
|
|
# Resolve creator name for the LLM prompt
|
|
creator = session.execute(
|
|
select(Creator).where(Creator.id == video.creator_id)
|
|
).scalar_one_or_none()
|
|
creator_name = creator.name if creator else "Unknown"
|
|
|
|
if not moments:
|
|
logger.info("Stage 5: No moments found for video_id=%s, skipping.", video_id)
|
|
return video_id
|
|
|
|
# Load classification data from stage 4
|
|
classification_data = _load_classification_data(video_id)
|
|
cls_by_moment_id = {c["moment_id"]: c for c in classification_data}
|
|
|
|
# Group moments by topic_category (from classification)
|
|
# Normalize category casing to prevent near-duplicate groups
|
|
# (e.g., "Sound design" vs "Sound Design")
|
|
groups: dict[str, list[tuple[KeyMoment, dict]]] = defaultdict(list)
|
|
for moment in moments:
|
|
cls_info = cls_by_moment_id.get(str(moment.id), {})
|
|
category = cls_info.get("topic_category", "Uncategorized").strip().title()
|
|
groups[category].append((moment, cls_info))
|
|
|
|
system_prompt = _load_prompt("stage5_synthesis.txt")
|
|
llm = _get_llm_client()
|
|
model_override, modality = _get_stage_config(5)
|
|
hard_limit = settings.llm_max_tokens_hard_limit
|
|
logger.info("Stage 5 using model=%s, modality=%s", model_override or "default", modality)
|
|
pages_created = 0
|
|
|
|
for category, moment_group in groups.items():
|
|
# Collect all tags across the full group (used for DB writes later)
|
|
all_tags: set[str] = set()
|
|
for _, cls_info in moment_group:
|
|
all_tags.update(cls_info.get("topic_tags", []))
|
|
|
|
# ── Chunked synthesis with truncation recovery ─────────
|
|
if len(moment_group) <= chunk_size:
|
|
# Small group — try single LLM call first
|
|
try:
|
|
result = _synthesize_chunk(
|
|
moment_group, category, creator_name,
|
|
system_prompt, llm, model_override, modality, hard_limit,
|
|
video_id, run_id, f"category:{category}",
|
|
)
|
|
synthesized_pages = list(result.pages)
|
|
logger.info(
|
|
"Stage 5: category '%s' — %d moments, %d page(s) from single call",
|
|
category, len(moment_group), len(synthesized_pages),
|
|
)
|
|
except LLMTruncationError:
|
|
# Output too large for model context — split in half and retry
|
|
logger.warning(
|
|
"Stage 5: category '%s' truncated with %d moments. "
|
|
"Splitting into smaller chunks and retrying.",
|
|
category, len(moment_group),
|
|
)
|
|
half = max(1, len(moment_group) // 2)
|
|
chunk_pages = []
|
|
for sub_start in range(0, len(moment_group), half):
|
|
sub_chunk = moment_group[sub_start:sub_start + half]
|
|
sub_label = f"category:{category} recovery-chunk:{sub_start // half + 1}"
|
|
sub_result = _synthesize_chunk(
|
|
sub_chunk, category, creator_name,
|
|
system_prompt, llm, model_override, modality, hard_limit,
|
|
video_id, run_id, sub_label,
|
|
)
|
|
# Reindex moment_indices to global offsets
|
|
for p in sub_result.pages:
|
|
if p.moment_indices:
|
|
p.moment_indices = [idx + sub_start for idx in p.moment_indices]
|
|
chunk_pages.extend(sub_result.pages)
|
|
synthesized_pages = chunk_pages
|
|
logger.info(
|
|
"Stage 5: category '%s' — %d page(s) from recovery split",
|
|
category, len(synthesized_pages),
|
|
)
|
|
else:
|
|
# Large group — split into chunks, synthesize each, then merge
|
|
num_chunks = (len(moment_group) + chunk_size - 1) // chunk_size
|
|
logger.info(
|
|
"Stage 5: category '%s' has %d moments — splitting into %d chunks of ≤%d",
|
|
category, len(moment_group), num_chunks, chunk_size,
|
|
)
|
|
|
|
chunk_pages = []
|
|
for chunk_idx in range(num_chunks):
|
|
chunk_start = chunk_idx * chunk_size
|
|
chunk_end = min(chunk_start + chunk_size, len(moment_group))
|
|
chunk = moment_group[chunk_start:chunk_end]
|
|
chunk_label = f"category:{category} chunk:{chunk_idx + 1}/{num_chunks}"
|
|
|
|
result = _synthesize_chunk(
|
|
chunk, category, creator_name,
|
|
system_prompt, llm, model_override, modality, hard_limit,
|
|
video_id, run_id, chunk_label,
|
|
)
|
|
chunk_pages.extend(result.pages)
|
|
logger.info(
|
|
"Stage 5: %s produced %d page(s)",
|
|
chunk_label, len(result.pages),
|
|
)
|
|
|
|
# Merge pages with matching slugs across chunks
|
|
logger.info(
|
|
"Stage 5: category '%s' — %d total pages from %d chunks, checking for merges",
|
|
category, len(chunk_pages), num_chunks,
|
|
)
|
|
synthesized_pages = _merge_pages_by_slug(
|
|
chunk_pages, creator_name,
|
|
llm, model_override, modality, hard_limit,
|
|
video_id, run_id,
|
|
)
|
|
logger.info(
|
|
"Stage 5: category '%s' — %d final page(s) after merge",
|
|
category, len(synthesized_pages),
|
|
)
|
|
|
|
# ── Persist pages to DB ──────────────────────────────────
|
|
# Load prior pages from this video (snapshot taken before pipeline reset)
|
|
prior_page_ids = _load_prior_pages(video_id)
|
|
|
|
for page_data in synthesized_pages:
|
|
page_moment_indices = getattr(page_data, "moment_indices", None) or []
|
|
existing = None
|
|
|
|
# First: check by slug (most specific match)
|
|
if existing is None:
|
|
existing = session.execute(
|
|
select(TechniquePage).where(TechniquePage.slug == page_data.slug)
|
|
).scalar_one_or_none()
|
|
|
|
# Fallback: check prior pages from this video by creator + category
|
|
# Use .first() since multiple pages may share a category
|
|
if existing is None and prior_page_ids:
|
|
existing = session.execute(
|
|
select(TechniquePage).where(
|
|
TechniquePage.id.in_(prior_page_ids),
|
|
TechniquePage.creator_id == video.creator_id,
|
|
func.lower(TechniquePage.topic_category) == func.lower(page_data.topic_category or category),
|
|
)
|
|
).scalars().first()
|
|
if existing:
|
|
logger.info(
|
|
"Stage 5: Matched prior page '%s' (id=%s) by creator+category for video_id=%s",
|
|
existing.slug, existing.id, video_id,
|
|
)
|
|
|
|
if existing:
|
|
# Snapshot existing content before overwriting
|
|
try:
|
|
sq = existing.source_quality
|
|
sq_value = sq.value if hasattr(sq, 'value') else sq
|
|
snapshot = {
|
|
"title": existing.title,
|
|
"slug": existing.slug,
|
|
"topic_category": existing.topic_category,
|
|
"topic_tags": existing.topic_tags,
|
|
"summary": existing.summary,
|
|
"body_sections": existing.body_sections,
|
|
"signal_chains": existing.signal_chains,
|
|
"plugins": existing.plugins,
|
|
"source_quality": sq_value,
|
|
}
|
|
version_count = session.execute(
|
|
select(func.count()).where(
|
|
TechniquePageVersion.technique_page_id == existing.id
|
|
)
|
|
).scalar()
|
|
version_number = version_count + 1
|
|
|
|
version = TechniquePageVersion(
|
|
technique_page_id=existing.id,
|
|
version_number=version_number,
|
|
content_snapshot=snapshot,
|
|
pipeline_metadata=_capture_pipeline_metadata(),
|
|
)
|
|
session.add(version)
|
|
logger.info(
|
|
"Version snapshot v%d created for page slug=%s",
|
|
version_number, existing.slug,
|
|
)
|
|
except Exception as snap_exc:
|
|
logger.error(
|
|
"Failed to create version snapshot for page slug=%s: %s",
|
|
existing.slug, snap_exc,
|
|
)
|
|
# Best-effort versioning — continue with page update
|
|
|
|
# Update existing page
|
|
existing.title = page_data.title
|
|
existing.summary = page_data.summary
|
|
existing.body_sections = page_data.body_sections
|
|
existing.signal_chains = page_data.signal_chains
|
|
existing.plugins = page_data.plugins if page_data.plugins else None
|
|
page_tags = _compute_page_tags(page_moment_indices, moment_group, all_tags)
|
|
existing.topic_tags = page_tags
|
|
existing.source_quality = page_data.source_quality
|
|
page = existing
|
|
else:
|
|
page = TechniquePage(
|
|
creator_id=video.creator_id,
|
|
title=page_data.title,
|
|
slug=page_data.slug,
|
|
topic_category=page_data.topic_category or category,
|
|
topic_tags=_compute_page_tags(page_moment_indices, moment_group, all_tags),
|
|
summary=page_data.summary,
|
|
body_sections=page_data.body_sections,
|
|
signal_chains=page_data.signal_chains,
|
|
plugins=page_data.plugins if page_data.plugins else None,
|
|
source_quality=page_data.source_quality,
|
|
)
|
|
session.add(page)
|
|
session.flush() # Get the page.id assigned
|
|
|
|
pages_created += 1
|
|
|
|
# Link moments to the technique page using moment_indices
|
|
|
|
if page_moment_indices:
|
|
# LLM specified which moments belong to this page
|
|
for idx in page_moment_indices:
|
|
if 0 <= idx < len(moment_group):
|
|
moment_group[idx][0].technique_page_id = page.id
|
|
elif len(synthesized_pages) == 1:
|
|
# Single page — link all moments (safe fallback)
|
|
for m, _ in moment_group:
|
|
m.technique_page_id = page.id
|
|
else:
|
|
# Multiple pages but no moment_indices — log warning
|
|
logger.warning(
|
|
"Stage 5: page '%s' has no moment_indices and is one of %d pages "
|
|
"for category '%s'. Moments will not be linked to this page.",
|
|
page_data.slug, len(synthesized_pages), category,
|
|
)
|
|
|
|
# Update processing_status
|
|
video.processing_status = ProcessingStatus.complete
|
|
|
|
session.commit()
|
|
elapsed = time.monotonic() - start
|
|
_emit_event(video_id, "stage5_synthesis", "complete", run_id=run_id)
|
|
logger.info(
|
|
"Stage 5 (synthesis) completed for video_id=%s in %.1fs — %d pages created/updated",
|
|
video_id, elapsed, pages_created,
|
|
)
|
|
return video_id
|
|
|
|
except FileNotFoundError:
|
|
raise
|
|
except Exception as exc:
|
|
session.rollback()
|
|
_emit_event(video_id, "stage5_synthesis", "error", run_id=run_id, payload={"error": str(exc)})
|
|
logger.error("Stage 5 failed for video_id=%s: %s", video_id, exc)
|
|
raise self.retry(exc=exc)
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
# ── Stage 6: Embed & Index ───────────────────────────────────────────────────
|
|
|
|
@celery_app.task(bind=True, max_retries=0)
|
|
def stage6_embed_and_index(self, video_id: str, run_id: str | None = None) -> str:
|
|
"""Generate embeddings for technique pages and key moments, then upsert to Qdrant.
|
|
|
|
This is a non-blocking side-effect stage — failures are logged but do not
|
|
fail the pipeline. Embeddings can be regenerated later. Does NOT update
|
|
processing_status.
|
|
|
|
Returns the video_id for chain compatibility.
|
|
"""
|
|
start = time.monotonic()
|
|
logger.info("Stage 6 (embed & index) starting for video_id=%s", video_id)
|
|
|
|
settings = get_settings()
|
|
session = _get_sync_session()
|
|
try:
|
|
# Load technique pages created for this video's moments
|
|
moments = (
|
|
session.execute(
|
|
select(KeyMoment)
|
|
.where(KeyMoment.source_video_id == video_id)
|
|
.order_by(KeyMoment.start_time)
|
|
)
|
|
.scalars()
|
|
.all()
|
|
)
|
|
|
|
# Get unique technique page IDs from moments
|
|
page_ids = {m.technique_page_id for m in moments if m.technique_page_id is not None}
|
|
pages = []
|
|
if page_ids:
|
|
pages = (
|
|
session.execute(
|
|
select(TechniquePage).where(TechniquePage.id.in_(page_ids))
|
|
)
|
|
.scalars()
|
|
.all()
|
|
)
|
|
|
|
if not moments and not pages:
|
|
logger.info("Stage 6: No moments or pages for video_id=%s, skipping.", video_id)
|
|
if run_id:
|
|
_finish_run(run_id, "complete")
|
|
return video_id
|
|
|
|
# Resolve creator names for enriched embedding text
|
|
creator_ids = {p.creator_id for p in pages}
|
|
creator_map: dict[str, str] = {}
|
|
if creator_ids:
|
|
creators = (
|
|
session.execute(
|
|
select(Creator).where(Creator.id.in_(creator_ids))
|
|
)
|
|
.scalars()
|
|
.all()
|
|
)
|
|
creator_map = {str(c.id): c.name for c in creators}
|
|
|
|
# Resolve creator name for key moments via source_video → creator
|
|
video_ids = {m.source_video_id for m in moments}
|
|
video_creator_map: dict[str, str] = {}
|
|
if video_ids:
|
|
rows = session.execute(
|
|
select(SourceVideo.id, Creator.name)
|
|
.join(Creator, SourceVideo.creator_id == Creator.id)
|
|
.where(SourceVideo.id.in_(video_ids))
|
|
).all()
|
|
video_creator_map = {str(r[0]): r[1] for r in rows}
|
|
|
|
embed_client = EmbeddingClient(settings)
|
|
qdrant = QdrantManager(settings)
|
|
|
|
# Ensure collection exists before upserting
|
|
qdrant.ensure_collection()
|
|
|
|
# ── Embed & upsert technique pages ───────────────────────────────
|
|
if pages:
|
|
page_texts = []
|
|
page_dicts = []
|
|
for p in pages:
|
|
creator_name = creator_map.get(str(p.creator_id), "")
|
|
tags_joined = " ".join(p.topic_tags) if p.topic_tags else ""
|
|
text = f"{creator_name} {p.title} {p.topic_category or ''} {tags_joined} {p.summary or ''}"
|
|
page_texts.append(text.strip())
|
|
page_dicts.append({
|
|
"page_id": str(p.id),
|
|
"creator_id": str(p.creator_id),
|
|
"creator_name": creator_name,
|
|
"title": p.title,
|
|
"slug": p.slug,
|
|
"topic_category": p.topic_category or "",
|
|
"topic_tags": p.topic_tags or [],
|
|
"summary": p.summary or "",
|
|
})
|
|
|
|
page_vectors = embed_client.embed(page_texts)
|
|
if page_vectors:
|
|
qdrant.upsert_technique_pages(page_dicts, page_vectors)
|
|
logger.info(
|
|
"Stage 6: Upserted %d technique page vectors for video_id=%s",
|
|
len(page_vectors), video_id,
|
|
)
|
|
else:
|
|
logger.warning(
|
|
"Stage 6: Embedding returned empty for %d technique pages (video_id=%s). "
|
|
"Skipping page upsert.",
|
|
len(page_texts), video_id,
|
|
)
|
|
|
|
# ── Embed & upsert key moments ───────────────────────────────────
|
|
if moments:
|
|
# Build page_id → slug mapping for linking moments to technique pages
|
|
page_id_to_slug: dict[str, str] = {}
|
|
if pages:
|
|
for p in pages:
|
|
page_id_to_slug[str(p.id)] = p.slug
|
|
|
|
moment_texts = []
|
|
moment_dicts = []
|
|
for m in moments:
|
|
creator_name = video_creator_map.get(str(m.source_video_id), "")
|
|
text = f"{creator_name} {m.title} {m.summary or ''}"
|
|
moment_texts.append(text.strip())
|
|
tp_id = str(m.technique_page_id) if m.technique_page_id else ""
|
|
moment_dicts.append({
|
|
"moment_id": str(m.id),
|
|
"source_video_id": str(m.source_video_id),
|
|
"technique_page_id": tp_id,
|
|
"technique_page_slug": page_id_to_slug.get(tp_id, ""),
|
|
"title": m.title,
|
|
"creator_name": creator_name,
|
|
"start_time": m.start_time,
|
|
"end_time": m.end_time,
|
|
"content_type": m.content_type.value,
|
|
})
|
|
|
|
moment_vectors = embed_client.embed(moment_texts)
|
|
if moment_vectors:
|
|
qdrant.upsert_key_moments(moment_dicts, moment_vectors)
|
|
logger.info(
|
|
"Stage 6: Upserted %d key moment vectors for video_id=%s",
|
|
len(moment_vectors), video_id,
|
|
)
|
|
else:
|
|
logger.warning(
|
|
"Stage 6: Embedding returned empty for %d key moments (video_id=%s). "
|
|
"Skipping moment upsert.",
|
|
len(moment_texts), video_id,
|
|
)
|
|
|
|
elapsed = time.monotonic() - start
|
|
logger.info(
|
|
"Stage 6 (embed & index) completed for video_id=%s in %.1fs — "
|
|
"%d pages, %d moments processed",
|
|
video_id, elapsed, len(pages), len(moments),
|
|
)
|
|
if run_id:
|
|
_finish_run(run_id, "complete")
|
|
return video_id
|
|
|
|
except Exception as exc:
|
|
# Non-blocking: log error but don't fail the pipeline
|
|
logger.error(
|
|
"Stage 6 failed for video_id=%s: %s. "
|
|
"Pipeline continues — embeddings can be regenerated later.",
|
|
video_id, exc,
|
|
)
|
|
if run_id:
|
|
_finish_run(run_id, "complete") # Run is still "complete" — stage6 is best-effort
|
|
return video_id
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
|
|
|
|
def _snapshot_prior_pages(video_id: str) -> None:
|
|
"""Save existing technique_page_ids linked to this video before pipeline resets them.
|
|
|
|
When a video is reprocessed, stage 3 deletes and recreates key_moments,
|
|
breaking the link to technique pages. This snapshots the page IDs to Redis
|
|
so stage 5 can find and update prior pages instead of creating duplicates.
|
|
"""
|
|
import redis
|
|
|
|
session = _get_sync_session()
|
|
try:
|
|
# Find technique pages linked via this video's key moments
|
|
rows = session.execute(
|
|
select(KeyMoment.technique_page_id)
|
|
.where(
|
|
KeyMoment.source_video_id == video_id,
|
|
KeyMoment.technique_page_id.isnot(None),
|
|
)
|
|
.distinct()
|
|
).scalars().all()
|
|
|
|
page_ids = [str(pid) for pid in rows]
|
|
|
|
if page_ids:
|
|
settings = get_settings()
|
|
r = redis.Redis.from_url(settings.redis_url)
|
|
key = f"chrysopedia:prior_pages:{video_id}"
|
|
r.set(key, json.dumps(page_ids), ex=86400)
|
|
logger.info(
|
|
"Snapshot %d prior technique pages for video_id=%s: %s",
|
|
len(page_ids), video_id, page_ids,
|
|
)
|
|
else:
|
|
logger.info("No prior technique pages for video_id=%s", video_id)
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
def _load_prior_pages(video_id: str) -> list[str]:
|
|
"""Load prior technique page IDs from Redis."""
|
|
import redis
|
|
|
|
settings = get_settings()
|
|
r = redis.Redis.from_url(settings.redis_url)
|
|
key = f"chrysopedia:prior_pages:{video_id}"
|
|
raw = r.get(key)
|
|
if raw is None:
|
|
return []
|
|
return json.loads(raw)
|
|
|
|
|
|
# ── Stage completion detection for auto-resume ───────────────────────────────
|
|
|
|
# Ordered list of pipeline stages for resumability logic
|
|
_PIPELINE_STAGES = [
|
|
"stage2_segmentation",
|
|
"stage3_extraction",
|
|
"stage4_classification",
|
|
"stage5_synthesis",
|
|
"stage6_embed_and_index",
|
|
]
|
|
|
|
_STAGE_TASKS = {
|
|
"stage2_segmentation": stage2_segmentation,
|
|
"stage3_extraction": stage3_extraction,
|
|
"stage4_classification": stage4_classification,
|
|
"stage5_synthesis": stage5_synthesis,
|
|
"stage6_embed_and_index": stage6_embed_and_index,
|
|
}
|
|
|
|
|
|
def _get_last_completed_stage(video_id: str) -> str | None:
|
|
"""Find the last stage that completed successfully for this video.
|
|
|
|
Queries pipeline_events for the most recent run, looking for 'complete'
|
|
events. Returns the stage name (e.g. 'stage3_extraction') or None if
|
|
no stages have completed.
|
|
"""
|
|
session = _get_sync_session()
|
|
try:
|
|
# Find the most recent run for this video
|
|
from models import PipelineRun
|
|
latest_run = session.execute(
|
|
select(PipelineRun)
|
|
.where(PipelineRun.video_id == video_id)
|
|
.order_by(PipelineRun.started_at.desc())
|
|
.limit(1)
|
|
).scalar_one_or_none()
|
|
|
|
if latest_run is None:
|
|
return None
|
|
|
|
# Get all 'complete' events from that run
|
|
completed_events = session.execute(
|
|
select(PipelineEvent.stage)
|
|
.where(
|
|
PipelineEvent.run_id == str(latest_run.id),
|
|
PipelineEvent.event_type == "complete",
|
|
)
|
|
).scalars().all()
|
|
|
|
completed_set = set(completed_events)
|
|
|
|
# Walk backwards through the ordered stages to find the last completed one
|
|
last_completed = None
|
|
for stage_name in _PIPELINE_STAGES:
|
|
if stage_name in completed_set:
|
|
last_completed = stage_name
|
|
else:
|
|
break # Stop at first gap — stages must be sequential
|
|
|
|
if last_completed:
|
|
logger.info(
|
|
"Auto-resume: video_id=%s last completed stage=%s (run_id=%s)",
|
|
video_id, last_completed, latest_run.id,
|
|
)
|
|
return last_completed
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
# ── Orchestrator ─────────────────────────────────────────────────────────────
|
|
|
|
@celery_app.task
|
|
def mark_pipeline_error(request, exc, traceback, video_id: str, run_id: str | None = None) -> None:
|
|
"""Error callback — marks video as errored when a pipeline stage fails."""
|
|
logger.error("Pipeline failed for video_id=%s: %s", video_id, exc)
|
|
_set_error_status(video_id, "pipeline", exc)
|
|
if run_id:
|
|
_finish_run(run_id, "error", error_stage="pipeline")
|
|
|
|
|
|
def _create_run(video_id: str, trigger: str) -> str:
|
|
"""Create a PipelineRun and return its id."""
|
|
from models import PipelineRun, PipelineRunTrigger
|
|
|
|
session = _get_sync_session()
|
|
try:
|
|
# Compute run_number: max existing + 1
|
|
from sqlalchemy import func as sa_func
|
|
max_num = session.execute(
|
|
select(sa_func.coalesce(sa_func.max(PipelineRun.run_number), 0))
|
|
.where(PipelineRun.video_id == video_id)
|
|
).scalar() or 0
|
|
run = PipelineRun(
|
|
video_id=video_id,
|
|
run_number=max_num + 1,
|
|
trigger=PipelineRunTrigger(trigger),
|
|
)
|
|
session.add(run)
|
|
session.commit()
|
|
run_id = str(run.id)
|
|
return run_id
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
def _finish_run(run_id: str, status: str, error_stage: str | None = None) -> None:
|
|
"""Update a PipelineRun's status and finished_at."""
|
|
from models import PipelineRun, PipelineRunStatus, _now
|
|
|
|
session = _get_sync_session()
|
|
try:
|
|
run = session.execute(
|
|
select(PipelineRun).where(PipelineRun.id == run_id)
|
|
).scalar_one_or_none()
|
|
if run:
|
|
run.status = PipelineRunStatus(status)
|
|
run.finished_at = _now()
|
|
if error_stage:
|
|
run.error_stage = error_stage
|
|
# Aggregate total tokens from events
|
|
total = session.execute(
|
|
select(func.coalesce(func.sum(PipelineEvent.total_tokens), 0))
|
|
.where(PipelineEvent.run_id == run_id)
|
|
).scalar() or 0
|
|
run.total_tokens = total
|
|
session.commit()
|
|
except Exception as exc:
|
|
logger.warning("Failed to finish run %s: %s", run_id, exc)
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
@celery_app.task
|
|
def run_pipeline(video_id: str, trigger: str = "manual") -> str:
|
|
"""Orchestrate the full pipeline (stages 2-6) with auto-resume.
|
|
|
|
For error/processing status, queries pipeline_events to find the last
|
|
stage that completed successfully and resumes from the next stage.
|
|
This avoids re-running expensive LLM stages that already succeeded.
|
|
|
|
For clean_reprocess trigger, always starts from stage 2.
|
|
|
|
Returns the video_id.
|
|
"""
|
|
logger.info("run_pipeline starting for video_id=%s", video_id)
|
|
|
|
session = _get_sync_session()
|
|
try:
|
|
video = session.execute(
|
|
select(SourceVideo).where(SourceVideo.id == video_id)
|
|
).scalar_one_or_none()
|
|
|
|
if video is None:
|
|
logger.error("run_pipeline: video_id=%s not found", video_id)
|
|
raise ValueError(f"Video not found: {video_id}")
|
|
|
|
status = video.processing_status
|
|
logger.info(
|
|
"run_pipeline: video_id=%s current status=%s", video_id, status.value
|
|
)
|
|
finally:
|
|
session.close()
|
|
|
|
if status == ProcessingStatus.complete:
|
|
logger.info(
|
|
"run_pipeline: video_id=%s already at status=%s, nothing to do.",
|
|
video_id, status.value,
|
|
)
|
|
return video_id
|
|
|
|
# Snapshot prior technique pages before pipeline resets key_moments
|
|
_snapshot_prior_pages(video_id)
|
|
|
|
# Create a pipeline run record
|
|
run_id = _create_run(video_id, trigger)
|
|
logger.info("run_pipeline: created run_id=%s for video_id=%s (trigger=%s)", run_id, video_id, trigger)
|
|
|
|
# Determine which stages to run
|
|
resume_from_idx = 0 # Default: start from stage 2
|
|
|
|
if trigger != "clean_reprocess" and status in (ProcessingStatus.processing, ProcessingStatus.error):
|
|
# Try to resume from where we left off
|
|
last_completed = _get_last_completed_stage(video_id)
|
|
if last_completed and last_completed in _PIPELINE_STAGES:
|
|
completed_idx = _PIPELINE_STAGES.index(last_completed)
|
|
resume_from_idx = completed_idx + 1
|
|
if resume_from_idx >= len(_PIPELINE_STAGES):
|
|
logger.info(
|
|
"run_pipeline: all stages already completed for video_id=%s",
|
|
video_id,
|
|
)
|
|
return video_id
|
|
|
|
stages_to_run = _PIPELINE_STAGES[resume_from_idx:]
|
|
logger.info(
|
|
"run_pipeline: video_id=%s will run stages: %s (resume_from_idx=%d)",
|
|
video_id, stages_to_run, resume_from_idx,
|
|
)
|
|
|
|
# Build the Celery chain — first stage gets video_id as arg,
|
|
# subsequent stages receive it from the previous stage's return value
|
|
celery_sigs = []
|
|
for i, stage_name in enumerate(stages_to_run):
|
|
task_func = _STAGE_TASKS[stage_name]
|
|
if i == 0:
|
|
celery_sigs.append(task_func.s(video_id, run_id=run_id))
|
|
else:
|
|
celery_sigs.append(task_func.s(run_id=run_id))
|
|
|
|
if celery_sigs:
|
|
# Mark as processing before dispatching
|
|
session = _get_sync_session()
|
|
try:
|
|
video = session.execute(
|
|
select(SourceVideo).where(SourceVideo.id == video_id)
|
|
).scalar_one()
|
|
video.processing_status = ProcessingStatus.processing
|
|
session.commit()
|
|
finally:
|
|
session.close()
|
|
|
|
pipeline = celery_chain(*celery_sigs)
|
|
error_cb = mark_pipeline_error.s(video_id, run_id=run_id)
|
|
pipeline.apply_async(link_error=error_cb)
|
|
logger.info(
|
|
"run_pipeline: dispatched %d stages for video_id=%s (run_id=%s, starting at %s)",
|
|
len(celery_sigs), video_id, run_id, stages_to_run[0],
|
|
)
|
|
|
|
return video_id
|