chrysopedia/backend/pipeline/stages.py
jlightner 4b0914b12b fix: restore complete project tree from ub01 canonical state
Auto-mode commit 7aa33cd accidentally deleted 78 files (14,814 lines) during M005
execution. Subsequent commits rebuilt some frontend files but backend/, alembic/,
tests/, whisper/, docker configs, and prompts were never restored in this repo.

This commit restores the full project tree by syncing from ub01's working directory,
which has all M001-M007 features running in production containers.

Restored: backend/ (config, models, routers, database, redis, search_service, worker),
alembic/ (6 migrations), docker/ (Dockerfiles, nginx, compose), prompts/ (4 stages),
tests/, whisper/, README.md, .env.example, chrysopedia-spec.md
2026-03-31 02:10:41 +00:00

1145 lines
44 KiB
Python

"""Pipeline stage tasks (stages 2-5) and run_pipeline orchestrator.
Each stage reads from PostgreSQL via sync SQLAlchemy, loads its prompt
template from disk, calls the LLM client, parses the response, writes
results back, and updates processing_status on SourceVideo.
Celery tasks are synchronous — all DB access uses ``sqlalchemy.orm.Session``.
"""
from __future__ import annotations
import hashlib
import json
import logging
import subprocess
import time
from collections import defaultdict
from pathlib import Path
import yaml
from celery import chain as celery_chain
from pydantic import ValidationError
from sqlalchemy import create_engine, func, select
from sqlalchemy.orm import Session, sessionmaker
from config import get_settings
from models import (
Creator,
KeyMoment,
KeyMomentContentType,
PipelineEvent,
ProcessingStatus,
SourceVideo,
TechniquePage,
TechniquePageVersion,
TranscriptSegment,
)
from pipeline.embedding_client import EmbeddingClient
from pipeline.llm_client import LLMClient, estimate_max_tokens
from pipeline.qdrant_client import QdrantManager
from pipeline.schemas import (
ClassificationResult,
ExtractionResult,
SegmentationResult,
SynthesisResult,
)
from worker import celery_app
logger = logging.getLogger(__name__)
# ── Pipeline event persistence ───────────────────────────────────────────────
def _emit_event(
video_id: str,
stage: str,
event_type: str,
*,
prompt_tokens: int | None = None,
completion_tokens: int | None = None,
total_tokens: int | None = None,
model: str | None = None,
duration_ms: int | None = None,
payload: dict | None = None,
system_prompt_text: str | None = None,
user_prompt_text: str | None = None,
response_text: str | None = None,
) -> None:
"""Persist a pipeline event to the DB. Best-effort -- failures logged, not raised."""
try:
session = _get_sync_session()
try:
event = PipelineEvent(
video_id=video_id,
stage=stage,
event_type=event_type,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
model=model,
duration_ms=duration_ms,
payload=payload,
system_prompt_text=system_prompt_text,
user_prompt_text=user_prompt_text,
response_text=response_text,
)
session.add(event)
session.commit()
finally:
session.close()
except Exception as exc:
logger.warning("Failed to emit pipeline event: %s", exc)
def _is_debug_mode() -> bool:
"""Check if debug mode is enabled via Redis. Falls back to config setting."""
try:
import redis
settings = get_settings()
r = redis.from_url(settings.redis_url)
val = r.get("chrysopedia:debug_mode")
r.close()
if val is not None:
return val.decode().lower() == "true"
except Exception:
pass
return getattr(get_settings(), "debug_mode", False)
def _make_llm_callback(
video_id: str,
stage: str,
system_prompt: str | None = None,
user_prompt: str | None = None,
):
"""Create an on_complete callback for LLMClient that emits llm_call events.
When debug mode is enabled, captures full system prompt, user prompt,
and response text on each llm_call event.
"""
debug = _is_debug_mode()
def callback(*, model=None, prompt_tokens=None, completion_tokens=None,
total_tokens=None, content=None, finish_reason=None,
is_fallback=False, **_kwargs):
# Truncate content for storage — keep first 2000 chars for debugging
truncated = content[:2000] if content and len(content) > 2000 else content
_emit_event(
video_id=video_id,
stage=stage,
event_type="llm_call",
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
payload={
"content_preview": truncated,
"content_length": len(content) if content else 0,
"finish_reason": finish_reason,
"is_fallback": is_fallback,
},
system_prompt_text=system_prompt if debug else None,
user_prompt_text=user_prompt if debug else None,
response_text=content if debug else None,
)
return callback
# ── Helpers ──────────────────────────────────────────────────────────────────
_engine = None
_SessionLocal = None
def _get_sync_engine():
"""Create a sync SQLAlchemy engine, converting the async URL if needed."""
global _engine
if _engine is None:
settings = get_settings()
url = settings.database_url
# Convert async driver to sync driver
url = url.replace("postgresql+asyncpg://", "postgresql+psycopg2://")
_engine = create_engine(url, pool_pre_ping=True, pool_size=5, max_overflow=10)
return _engine
def _get_sync_session() -> Session:
"""Create a sync SQLAlchemy session for Celery tasks."""
global _SessionLocal
if _SessionLocal is None:
_SessionLocal = sessionmaker(bind=_get_sync_engine())
return _SessionLocal()
def _load_prompt(template_name: str) -> str:
"""Read a prompt template from the prompts directory.
Raises FileNotFoundError if the template does not exist.
"""
settings = get_settings()
path = Path(settings.prompts_path) / template_name
if not path.exists():
logger.error("Prompt template not found: %s", path)
raise FileNotFoundError(f"Prompt template not found: {path}")
return path.read_text(encoding="utf-8")
def _get_llm_client() -> LLMClient:
"""Return an LLMClient configured from settings."""
return LLMClient(get_settings())
def _get_stage_config(stage_num: int) -> tuple[str | None, str]:
"""Return (model_override, modality) for a pipeline stage.
Reads stage-specific config from Settings. If the stage-specific model
is None/empty, returns None (LLMClient will use its default). If the
stage-specific modality is unset, defaults to "chat".
"""
settings = get_settings()
model = getattr(settings, f"llm_stage{stage_num}_model", None) or None
modality = getattr(settings, f"llm_stage{stage_num}_modality", None) or "chat"
return model, modality
def _load_canonical_tags() -> dict:
"""Load canonical tag taxonomy from config/canonical_tags.yaml."""
# Walk up from backend/ to find config/
candidates = [
Path("config/canonical_tags.yaml"),
Path("../config/canonical_tags.yaml"),
]
for candidate in candidates:
if candidate.exists():
with open(candidate, encoding="utf-8") as f:
return yaml.safe_load(f)
raise FileNotFoundError(
"canonical_tags.yaml not found. Searched: " + ", ".join(str(c) for c in candidates)
)
def _format_taxonomy_for_prompt(tags_data: dict) -> str:
"""Format the canonical tags taxonomy as readable text for the LLM prompt."""
lines = []
for cat in tags_data.get("categories", []):
lines.append(f"Category: {cat['name']}")
lines.append(f" Description: {cat['description']}")
lines.append(f" Sub-topics: {', '.join(cat.get('sub_topics', []))}")
lines.append("")
return "\n".join(lines)
def _safe_parse_llm_response(
raw: str,
model_cls,
llm: LLMClient,
system_prompt: str,
user_prompt: str,
modality: str = "chat",
model_override: str | None = None,
):
"""Parse LLM response with one retry on failure.
On malformed response: log the raw text, retry once with a JSON nudge,
then raise on second failure.
"""
try:
return llm.parse_response(raw, model_cls)
except (ValidationError, ValueError, json.JSONDecodeError) as exc:
logger.warning(
"First parse attempt failed for %s (%s). Retrying with JSON nudge. "
"Raw response (first 500 chars): %.500s",
model_cls.__name__,
type(exc).__name__,
raw,
)
# Retry with explicit JSON instruction
nudge_prompt = user_prompt + "\n\nIMPORTANT: Output ONLY valid JSON. No markdown, no explanation."
retry_raw = llm.complete(
system_prompt, nudge_prompt, response_model=model_cls,
modality=modality, model_override=model_override,
)
return llm.parse_response(retry_raw, model_cls)
# ── Stage 2: Segmentation ───────────────────────────────────────────────────
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
def stage2_segmentation(self, video_id: str) -> str:
"""Analyze transcript segments and identify topic boundaries.
Loads all TranscriptSegment rows for the video, sends them to the LLM
for topic boundary detection, and updates topic_label on each segment.
Returns the video_id for chain compatibility.
"""
start = time.monotonic()
logger.info("Stage 2 (segmentation) starting for video_id=%s", video_id)
_emit_event(video_id, "stage2_segmentation", "start")
session = _get_sync_session()
try:
# Load segments ordered by index
segments = (
session.execute(
select(TranscriptSegment)
.where(TranscriptSegment.source_video_id == video_id)
.order_by(TranscriptSegment.segment_index)
)
.scalars()
.all()
)
if not segments:
logger.info("Stage 2: No segments found for video_id=%s, skipping.", video_id)
return video_id
# Build transcript text with indices for the LLM
transcript_lines = []
for seg in segments:
transcript_lines.append(
f"[{seg.segment_index}] ({seg.start_time:.1f}s - {seg.end_time:.1f}s) {seg.text}"
)
transcript_text = "\n".join(transcript_lines)
# Load prompt and call LLM
system_prompt = _load_prompt("stage2_segmentation.txt")
user_prompt = f"<transcript>\n{transcript_text}\n</transcript>"
llm = _get_llm_client()
model_override, modality = _get_stage_config(2)
hard_limit = get_settings().llm_max_tokens_hard_limit
max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage2_segmentation", hard_limit=hard_limit)
logger.info("Stage 2 using model=%s, modality=%s, max_tokens=%d", model_override or "default", modality, max_tokens)
raw = llm.complete(system_prompt, user_prompt, response_model=SegmentationResult, on_complete=_make_llm_callback(video_id, "stage2_segmentation", system_prompt=system_prompt, user_prompt=user_prompt),
modality=modality, model_override=model_override, max_tokens=max_tokens)
result = _safe_parse_llm_response(raw, SegmentationResult, llm, system_prompt, user_prompt,
modality=modality, model_override=model_override)
# Update topic_label on each segment row
seg_by_index = {s.segment_index: s for s in segments}
for topic_seg in result.segments:
for idx in range(topic_seg.start_index, topic_seg.end_index + 1):
if idx in seg_by_index:
seg_by_index[idx].topic_label = topic_seg.topic_label
session.commit()
elapsed = time.monotonic() - start
_emit_event(video_id, "stage2_segmentation", "complete")
logger.info(
"Stage 2 (segmentation) completed for video_id=%s in %.1fs — %d topic groups found",
video_id, elapsed, len(result.segments),
)
return video_id
except FileNotFoundError:
raise # Don't retry missing prompt files
except Exception as exc:
session.rollback()
_emit_event(video_id, "stage2_segmentation", "error", payload={"error": str(exc)})
logger.error("Stage 2 failed for video_id=%s: %s", video_id, exc)
raise self.retry(exc=exc)
finally:
session.close()
# ── Stage 3: Extraction ─────────────────────────────────────────────────────
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
def stage3_extraction(self, video_id: str) -> str:
"""Extract key moments from each topic segment group.
Groups segments by topic_label, calls the LLM for each group to extract
moments, creates KeyMoment rows, and sets processing_status=extracted.
Returns the video_id for chain compatibility.
"""
start = time.monotonic()
logger.info("Stage 3 (extraction) starting for video_id=%s", video_id)
_emit_event(video_id, "stage3_extraction", "start")
session = _get_sync_session()
try:
# Load segments with topic labels
segments = (
session.execute(
select(TranscriptSegment)
.where(TranscriptSegment.source_video_id == video_id)
.order_by(TranscriptSegment.segment_index)
)
.scalars()
.all()
)
if not segments:
logger.info("Stage 3: No segments found for video_id=%s, skipping.", video_id)
return video_id
# Group segments by topic_label
groups: dict[str, list[TranscriptSegment]] = defaultdict(list)
for seg in segments:
label = seg.topic_label or "unlabeled"
groups[label].append(seg)
system_prompt = _load_prompt("stage3_extraction.txt")
llm = _get_llm_client()
model_override, modality = _get_stage_config(3)
hard_limit = get_settings().llm_max_tokens_hard_limit
logger.info("Stage 3 using model=%s, modality=%s", model_override or "default", modality)
total_moments = 0
for topic_label, group_segs in groups.items():
# Build segment text for this group
seg_lines = []
for seg in group_segs:
seg_lines.append(
f"({seg.start_time:.1f}s - {seg.end_time:.1f}s) {seg.text}"
)
segment_text = "\n".join(seg_lines)
user_prompt = (
f"Topic: {topic_label}\n\n"
f"<segment>\n{segment_text}\n</segment>"
)
max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage3_extraction", hard_limit=hard_limit)
raw = llm.complete(system_prompt, user_prompt, response_model=ExtractionResult, on_complete=_make_llm_callback(video_id, "stage3_extraction", system_prompt=system_prompt, user_prompt=user_prompt),
modality=modality, model_override=model_override, max_tokens=max_tokens)
result = _safe_parse_llm_response(raw, ExtractionResult, llm, system_prompt, user_prompt,
modality=modality, model_override=model_override)
# Create KeyMoment rows
for moment in result.moments:
# Validate content_type against enum
try:
ct = KeyMomentContentType(moment.content_type)
except ValueError:
ct = KeyMomentContentType.technique
km = KeyMoment(
source_video_id=video_id,
title=moment.title,
summary=moment.summary,
start_time=moment.start_time,
end_time=moment.end_time,
content_type=ct,
plugins=moment.plugins if moment.plugins else None,
raw_transcript=moment.raw_transcript or None,
)
session.add(km)
total_moments += 1
# Update processing_status to extracted
video = session.execute(
select(SourceVideo).where(SourceVideo.id == video_id)
).scalar_one()
video.processing_status = ProcessingStatus.extracted
session.commit()
elapsed = time.monotonic() - start
_emit_event(video_id, "stage3_extraction", "complete")
logger.info(
"Stage 3 (extraction) completed for video_id=%s in %.1fs — %d moments created",
video_id, elapsed, total_moments,
)
return video_id
except FileNotFoundError:
raise
except Exception as exc:
session.rollback()
_emit_event(video_id, "stage3_extraction", "error", payload={"error": str(exc)})
logger.error("Stage 3 failed for video_id=%s: %s", video_id, exc)
raise self.retry(exc=exc)
finally:
session.close()
# ── Stage 4: Classification ─────────────────────────────────────────────────
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
def stage4_classification(self, video_id: str) -> str:
"""Classify key moments against the canonical tag taxonomy.
Loads all KeyMoment rows for the video, sends them to the LLM with the
canonical taxonomy, and stores classification results in Redis for
stage 5 consumption. Updates content_type if the classifier overrides it.
Stage 4 does NOT change processing_status.
Returns the video_id for chain compatibility.
"""
start = time.monotonic()
logger.info("Stage 4 (classification) starting for video_id=%s", video_id)
_emit_event(video_id, "stage4_classification", "start")
session = _get_sync_session()
try:
# Load key moments
moments = (
session.execute(
select(KeyMoment)
.where(KeyMoment.source_video_id == video_id)
.order_by(KeyMoment.start_time)
)
.scalars()
.all()
)
if not moments:
logger.info("Stage 4: No moments found for video_id=%s, skipping.", video_id)
# Store empty classification data
_store_classification_data(video_id, [])
return video_id
# Load canonical tags
tags_data = _load_canonical_tags()
taxonomy_text = _format_taxonomy_for_prompt(tags_data)
# Build moments text for the LLM
moments_lines = []
for i, m in enumerate(moments):
moments_lines.append(
f"[{i}] Title: {m.title}\n"
f" Summary: {m.summary}\n"
f" Content type: {m.content_type.value}\n"
f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}"
)
moments_text = "\n\n".join(moments_lines)
system_prompt = _load_prompt("stage4_classification.txt")
user_prompt = (
f"<taxonomy>\n{taxonomy_text}\n</taxonomy>\n\n"
f"<moments>\n{moments_text}\n</moments>"
)
llm = _get_llm_client()
model_override, modality = _get_stage_config(4)
hard_limit = get_settings().llm_max_tokens_hard_limit
max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage4_classification", hard_limit=hard_limit)
logger.info("Stage 4 using model=%s, modality=%s, max_tokens=%d", model_override or "default", modality, max_tokens)
raw = llm.complete(system_prompt, user_prompt, response_model=ClassificationResult, on_complete=_make_llm_callback(video_id, "stage4_classification", system_prompt=system_prompt, user_prompt=user_prompt),
modality=modality, model_override=model_override, max_tokens=max_tokens)
result = _safe_parse_llm_response(raw, ClassificationResult, llm, system_prompt, user_prompt,
modality=modality, model_override=model_override)
# Apply content_type overrides and prepare classification data for stage 5
classification_data = []
moment_ids = [str(m.id) for m in moments]
for cls in result.classifications:
if 0 <= cls.moment_index < len(moments):
moment = moments[cls.moment_index]
# Apply content_type override if provided
if cls.content_type_override:
try:
moment.content_type = KeyMomentContentType(cls.content_type_override)
except ValueError:
pass
classification_data.append({
"moment_id": str(moment.id),
"topic_category": cls.topic_category,
"topic_tags": cls.topic_tags,
})
session.commit()
# Store classification data in Redis for stage 5
_store_classification_data(video_id, classification_data)
elapsed = time.monotonic() - start
_emit_event(video_id, "stage4_classification", "complete")
logger.info(
"Stage 4 (classification) completed for video_id=%s in %.1fs — %d moments classified",
video_id, elapsed, len(classification_data),
)
return video_id
except FileNotFoundError:
raise
except Exception as exc:
session.rollback()
_emit_event(video_id, "stage4_classification", "error", payload={"error": str(exc)})
logger.error("Stage 4 failed for video_id=%s: %s", video_id, exc)
raise self.retry(exc=exc)
finally:
session.close()
def _store_classification_data(video_id: str, data: list[dict]) -> None:
"""Store classification data in Redis for cross-stage communication."""
import redis
settings = get_settings()
r = redis.Redis.from_url(settings.redis_url)
key = f"chrysopedia:classification:{video_id}"
r.set(key, json.dumps(data), ex=86400) # Expire after 24 hours
def _load_classification_data(video_id: str) -> list[dict]:
"""Load classification data from Redis."""
import redis
settings = get_settings()
r = redis.Redis.from_url(settings.redis_url)
key = f"chrysopedia:classification:{video_id}"
raw = r.get(key)
if raw is None:
return []
return json.loads(raw)
def _get_git_commit_sha() -> str:
"""Resolve the git commit SHA used to build this image.
Resolution order:
1. /app/.git-commit file (written during Docker build)
2. git rev-parse --short HEAD (local dev)
3. GIT_COMMIT_SHA env var / config setting
4. "unknown"
"""
# Docker build artifact
git_commit_file = Path("/app/.git-commit")
if git_commit_file.exists():
sha = git_commit_file.read_text(encoding="utf-8").strip()
if sha and sha != "unknown":
return sha
# Local dev — run git
try:
result = subprocess.run(
["git", "rev-parse", "--short", "HEAD"],
capture_output=True, text=True, timeout=5,
)
if result.returncode == 0 and result.stdout.strip():
return result.stdout.strip()
except (FileNotFoundError, subprocess.TimeoutExpired):
pass
# Config / env var fallback
try:
sha = get_settings().git_commit_sha
if sha and sha != "unknown":
return sha
except Exception:
pass
return "unknown"
def _capture_pipeline_metadata() -> dict:
"""Capture current pipeline configuration for version metadata.
Returns a dict with model names, prompt file SHA-256 hashes, and stage
modality settings. Handles missing prompt files gracefully.
"""
settings = get_settings()
prompts_path = Path(settings.prompts_path)
# Hash each prompt template file
prompt_hashes: dict[str, str] = {}
prompt_files = [
"stage2_segmentation.txt",
"stage3_extraction.txt",
"stage4_classification.txt",
"stage5_synthesis.txt",
]
for filename in prompt_files:
filepath = prompts_path / filename
try:
content = filepath.read_bytes()
prompt_hashes[filename] = hashlib.sha256(content).hexdigest()
except FileNotFoundError:
logger.warning("Prompt file not found for metadata capture: %s", filepath)
prompt_hashes[filename] = ""
except OSError as exc:
logger.warning("Could not read prompt file %s: %s", filepath, exc)
prompt_hashes[filename] = ""
return {
"git_commit_sha": _get_git_commit_sha(),
"models": {
"stage2": settings.llm_stage2_model,
"stage3": settings.llm_stage3_model,
"stage4": settings.llm_stage4_model,
"stage5": settings.llm_stage5_model,
"embedding": settings.embedding_model,
},
"modalities": {
"stage2": settings.llm_stage2_modality,
"stage3": settings.llm_stage3_modality,
"stage4": settings.llm_stage4_modality,
"stage5": settings.llm_stage5_modality,
},
"prompt_hashes": prompt_hashes,
}
# ── Stage 5: Synthesis ───────────────────────────────────────────────────────
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
def stage5_synthesis(self, video_id: str) -> str:
"""Synthesize technique pages from classified key moments.
Groups moments by (creator, topic_category), calls the LLM to synthesize
each group into a TechniquePage, creates/updates page rows, and links
KeyMoments to their TechniquePage.
Sets processing_status to 'reviewed' (or 'published' if review_mode is False).
Returns the video_id for chain compatibility.
"""
start = time.monotonic()
logger.info("Stage 5 (synthesis) starting for video_id=%s", video_id)
_emit_event(video_id, "stage5_synthesis", "start")
settings = get_settings()
session = _get_sync_session()
try:
# Load video and moments
video = session.execute(
select(SourceVideo).where(SourceVideo.id == video_id)
).scalar_one()
moments = (
session.execute(
select(KeyMoment)
.where(KeyMoment.source_video_id == video_id)
.order_by(KeyMoment.start_time)
)
.scalars()
.all()
)
# Resolve creator name for the LLM prompt
creator = session.execute(
select(Creator).where(Creator.id == video.creator_id)
).scalar_one_or_none()
creator_name = creator.name if creator else "Unknown"
if not moments:
logger.info("Stage 5: No moments found for video_id=%s, skipping.", video_id)
return video_id
# Load classification data from stage 4
classification_data = _load_classification_data(video_id)
cls_by_moment_id = {c["moment_id"]: c for c in classification_data}
# Group moments by topic_category (from classification)
groups: dict[str, list[tuple[KeyMoment, dict]]] = defaultdict(list)
for moment in moments:
cls_info = cls_by_moment_id.get(str(moment.id), {})
category = cls_info.get("topic_category", "Uncategorized")
groups[category].append((moment, cls_info))
system_prompt = _load_prompt("stage5_synthesis.txt")
llm = _get_llm_client()
model_override, modality = _get_stage_config(5)
hard_limit = get_settings().llm_max_tokens_hard_limit
logger.info("Stage 5 using model=%s, modality=%s", model_override or "default", modality)
pages_created = 0
for category, moment_group in groups.items():
# Build moments text for the LLM
moments_lines = []
all_tags: set[str] = set()
for i, (m, cls_info) in enumerate(moment_group):
tags = cls_info.get("topic_tags", [])
all_tags.update(tags)
moments_lines.append(
f"[{i}] Title: {m.title}\n"
f" Summary: {m.summary}\n"
f" Content type: {m.content_type.value}\n"
f" Time: {m.start_time:.1f}s - {m.end_time:.1f}s\n"
f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}\n"
f" Category: {category}\n"
f" Tags: {', '.join(tags) if tags else 'none'}\n"
f" Transcript excerpt: {(m.raw_transcript or '')[:300]}"
)
moments_text = "\n\n".join(moments_lines)
user_prompt = f"<creator>{creator_name}</creator>\n<moments>\n{moments_text}\n</moments>"
max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage5_synthesis", hard_limit=hard_limit)
raw = llm.complete(system_prompt, user_prompt, response_model=SynthesisResult, on_complete=_make_llm_callback(video_id, "stage5_synthesis", system_prompt=system_prompt, user_prompt=user_prompt),
modality=modality, model_override=model_override, max_tokens=max_tokens)
result = _safe_parse_llm_response(raw, SynthesisResult, llm, system_prompt, user_prompt,
modality=modality, model_override=model_override)
# Load prior pages from this video (snapshot taken before pipeline reset)
prior_page_ids = _load_prior_pages(video_id)
# Create/update TechniquePage rows
for page_data in result.pages:
existing = None
# First: check prior pages from this video by creator + category
if prior_page_ids:
existing = session.execute(
select(TechniquePage).where(
TechniquePage.id.in_(prior_page_ids),
TechniquePage.creator_id == video.creator_id,
TechniquePage.topic_category == (page_data.topic_category or category),
)
).scalar_one_or_none()
if existing:
logger.info(
"Stage 5: Matched prior page '%s' (id=%s) by creator+category for video_id=%s",
existing.slug, existing.id, video_id,
)
# Fallback: check by slug (handles cross-video dedup)
if existing is None:
existing = session.execute(
select(TechniquePage).where(TechniquePage.slug == page_data.slug)
).scalar_one_or_none()
if existing:
# Snapshot existing content before overwriting
try:
snapshot = {
"title": existing.title,
"slug": existing.slug,
"topic_category": existing.topic_category,
"topic_tags": existing.topic_tags,
"summary": existing.summary,
"body_sections": existing.body_sections,
"signal_chains": existing.signal_chains,
"plugins": existing.plugins,
"source_quality": existing.source_quality.value if existing.source_quality else None,
}
version_count = session.execute(
select(func.count()).where(
TechniquePageVersion.technique_page_id == existing.id
)
).scalar()
version_number = version_count + 1
version = TechniquePageVersion(
technique_page_id=existing.id,
version_number=version_number,
content_snapshot=snapshot,
pipeline_metadata=_capture_pipeline_metadata(),
)
session.add(version)
logger.info(
"Version snapshot v%d created for page slug=%s",
version_number, existing.slug,
)
except Exception as snap_exc:
logger.error(
"Failed to create version snapshot for page slug=%s: %s",
existing.slug, snap_exc,
)
# Best-effort versioning — continue with page update
# Update existing page
existing.title = page_data.title
existing.summary = page_data.summary
existing.body_sections = page_data.body_sections
existing.signal_chains = page_data.signal_chains
existing.plugins = page_data.plugins if page_data.plugins else None
existing.topic_tags = list(all_tags) if all_tags else None
existing.source_quality = page_data.source_quality
page = existing
else:
page = TechniquePage(
creator_id=video.creator_id,
title=page_data.title,
slug=page_data.slug,
topic_category=page_data.topic_category or category,
topic_tags=list(all_tags) if all_tags else None,
summary=page_data.summary,
body_sections=page_data.body_sections,
signal_chains=page_data.signal_chains,
plugins=page_data.plugins if page_data.plugins else None,
source_quality=page_data.source_quality,
)
session.add(page)
session.flush() # Get the page.id assigned
pages_created += 1
# Link moments to the technique page
for m, _ in moment_group:
m.technique_page_id = page.id
# Update processing_status
if settings.review_mode:
video.processing_status = ProcessingStatus.reviewed
else:
video.processing_status = ProcessingStatus.published
session.commit()
elapsed = time.monotonic() - start
_emit_event(video_id, "stage5_synthesis", "complete")
logger.info(
"Stage 5 (synthesis) completed for video_id=%s in %.1fs — %d pages created/updated",
video_id, elapsed, pages_created,
)
return video_id
except FileNotFoundError:
raise
except Exception as exc:
session.rollback()
_emit_event(video_id, "stage5_synthesis", "error", payload={"error": str(exc)})
logger.error("Stage 5 failed for video_id=%s: %s", video_id, exc)
raise self.retry(exc=exc)
finally:
session.close()
# ── Stage 6: Embed & Index ───────────────────────────────────────────────────
@celery_app.task(bind=True, max_retries=0)
def stage6_embed_and_index(self, video_id: str) -> str:
"""Generate embeddings for technique pages and key moments, then upsert to Qdrant.
This is a non-blocking side-effect stage — failures are logged but do not
fail the pipeline. Embeddings can be regenerated later. Does NOT update
processing_status.
Returns the video_id for chain compatibility.
"""
start = time.monotonic()
logger.info("Stage 6 (embed & index) starting for video_id=%s", video_id)
settings = get_settings()
session = _get_sync_session()
try:
# Load technique pages created for this video's moments
moments = (
session.execute(
select(KeyMoment)
.where(KeyMoment.source_video_id == video_id)
.order_by(KeyMoment.start_time)
)
.scalars()
.all()
)
# Get unique technique page IDs from moments
page_ids = {m.technique_page_id for m in moments if m.technique_page_id is not None}
pages = []
if page_ids:
pages = (
session.execute(
select(TechniquePage).where(TechniquePage.id.in_(page_ids))
)
.scalars()
.all()
)
if not moments and not pages:
logger.info("Stage 6: No moments or pages for video_id=%s, skipping.", video_id)
return video_id
embed_client = EmbeddingClient(settings)
qdrant = QdrantManager(settings)
# Ensure collection exists before upserting
qdrant.ensure_collection()
# ── Embed & upsert technique pages ───────────────────────────────
if pages:
page_texts = []
page_dicts = []
for p in pages:
text = f"{p.title} {p.summary or ''} {p.topic_category or ''}"
page_texts.append(text.strip())
page_dicts.append({
"page_id": str(p.id),
"creator_id": str(p.creator_id),
"title": p.title,
"topic_category": p.topic_category or "",
"topic_tags": p.topic_tags or [],
"summary": p.summary or "",
})
page_vectors = embed_client.embed(page_texts)
if page_vectors:
qdrant.upsert_technique_pages(page_dicts, page_vectors)
logger.info(
"Stage 6: Upserted %d technique page vectors for video_id=%s",
len(page_vectors), video_id,
)
else:
logger.warning(
"Stage 6: Embedding returned empty for %d technique pages (video_id=%s). "
"Skipping page upsert.",
len(page_texts), video_id,
)
# ── Embed & upsert key moments ───────────────────────────────────
if moments:
moment_texts = []
moment_dicts = []
for m in moments:
text = f"{m.title} {m.summary or ''}"
moment_texts.append(text.strip())
moment_dicts.append({
"moment_id": str(m.id),
"source_video_id": str(m.source_video_id),
"title": m.title,
"start_time": m.start_time,
"end_time": m.end_time,
"content_type": m.content_type.value,
})
moment_vectors = embed_client.embed(moment_texts)
if moment_vectors:
qdrant.upsert_key_moments(moment_dicts, moment_vectors)
logger.info(
"Stage 6: Upserted %d key moment vectors for video_id=%s",
len(moment_vectors), video_id,
)
else:
logger.warning(
"Stage 6: Embedding returned empty for %d key moments (video_id=%s). "
"Skipping moment upsert.",
len(moment_texts), video_id,
)
elapsed = time.monotonic() - start
logger.info(
"Stage 6 (embed & index) completed for video_id=%s in %.1fs — "
"%d pages, %d moments processed",
video_id, elapsed, len(pages), len(moments),
)
return video_id
except Exception as exc:
# Non-blocking: log error but don't fail the pipeline
logger.error(
"Stage 6 failed for video_id=%s: %s. "
"Pipeline continues — embeddings can be regenerated later.",
video_id, exc,
)
return video_id
finally:
session.close()
def _snapshot_prior_pages(video_id: str) -> None:
"""Save existing technique_page_ids linked to this video before pipeline resets them.
When a video is reprocessed, stage 3 deletes and recreates key_moments,
breaking the link to technique pages. This snapshots the page IDs to Redis
so stage 5 can find and update prior pages instead of creating duplicates.
"""
import redis
session = _get_sync_session()
try:
# Find technique pages linked via this video's key moments
rows = session.execute(
select(KeyMoment.technique_page_id)
.where(
KeyMoment.source_video_id == video_id,
KeyMoment.technique_page_id.isnot(None),
)
.distinct()
).scalars().all()
page_ids = [str(pid) for pid in rows]
if page_ids:
settings = get_settings()
r = redis.Redis.from_url(settings.redis_url)
key = f"chrysopedia:prior_pages:{video_id}"
r.set(key, json.dumps(page_ids), ex=86400)
logger.info(
"Snapshot %d prior technique pages for video_id=%s: %s",
len(page_ids), video_id, page_ids,
)
else:
logger.info("No prior technique pages for video_id=%s", video_id)
finally:
session.close()
def _load_prior_pages(video_id: str) -> list[str]:
"""Load prior technique page IDs from Redis."""
import redis
settings = get_settings()
r = redis.Redis.from_url(settings.redis_url)
key = f"chrysopedia:prior_pages:{video_id}"
raw = r.get(key)
if raw is None:
return []
return json.loads(raw)
# ── Orchestrator ─────────────────────────────────────────────────────────────
@celery_app.task
def run_pipeline(video_id: str) -> str:
"""Orchestrate the full pipeline (stages 2-5) with resumability.
Checks the current processing_status of the video and chains only the
stages that still need to run. For example:
- pending/transcribed → stages 2, 3, 4, 5
- extracted → stages 4, 5
- reviewed/published → no-op
Returns the video_id.
"""
logger.info("run_pipeline starting for video_id=%s", video_id)
session = _get_sync_session()
try:
video = session.execute(
select(SourceVideo).where(SourceVideo.id == video_id)
).scalar_one_or_none()
if video is None:
logger.error("run_pipeline: video_id=%s not found", video_id)
raise ValueError(f"Video not found: {video_id}")
status = video.processing_status
logger.info(
"run_pipeline: video_id=%s current status=%s", video_id, status.value
)
finally:
session.close()
# Snapshot prior technique pages before pipeline resets key_moments
_snapshot_prior_pages(video_id)
# Build the chain based on current status
stages = []
if status in (ProcessingStatus.pending, ProcessingStatus.transcribed):
stages = [
stage2_segmentation.s(video_id),
stage3_extraction.s(), # receives video_id from previous
stage4_classification.s(),
stage5_synthesis.s(),
stage6_embed_and_index.s(),
]
elif status == ProcessingStatus.extracted:
stages = [
stage4_classification.s(video_id),
stage5_synthesis.s(),
stage6_embed_and_index.s(),
]
elif status in (ProcessingStatus.reviewed, ProcessingStatus.published):
logger.info(
"run_pipeline: video_id=%s already at status=%s, nothing to do.",
video_id, status.value,
)
return video_id
if stages:
pipeline = celery_chain(*stages)
pipeline.apply_async()
logger.info(
"run_pipeline: dispatched %d stages for video_id=%s",
len(stages), video_id,
)
return video_id