chrysopedia/backend/pipeline/stages.py
jlightner a60f4074dc chore: Add GET/PUT shorts-template admin endpoints, collapsible templat…
- "backend/routers/creators.py"
- "backend/schemas.py"
- "frontend/src/api/templates.ts"
- "frontend/src/pages/HighlightQueue.tsx"
- "frontend/src/pages/HighlightQueue.module.css"
- "backend/routers/shorts.py"
- "backend/pipeline/stages.py"
- "frontend/src/api/shorts.ts"

GSD-Task: S04/T03
2026-04-04 11:25:29 +00:00

3188 lines
123 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Pipeline stage tasks (stages 2-5) and run_pipeline orchestrator.
Each stage reads from PostgreSQL via sync SQLAlchemy, loads its prompt
template from disk, calls the LLM client, parses the response, writes
results back, and updates processing_status on SourceVideo.
Celery tasks are synchronous — all DB access uses ``sqlalchemy.orm.Session``.
"""
from __future__ import annotations
import hashlib
import json
import logging
import os
import re
import secrets
import subprocess
import time
from collections import defaultdict
from pathlib import Path
import yaml
from pydantic import ValidationError
from sqlalchemy import create_engine, func, select
from sqlalchemy.orm import Session, sessionmaker
from config import get_settings
from sqlalchemy.dialects.postgresql import insert as pg_insert
from models import (
Creator,
HighlightCandidate,
KeyMoment,
KeyMomentContentType,
PipelineEvent,
ProcessingStatus,
SourceVideo,
TechniquePage,
TechniquePageVersion,
TechniquePageVideo,
TranscriptSegment,
)
from pipeline.embedding_client import EmbeddingClient
from pipeline.llm_client import LLMClient, LLMResponse, estimate_max_tokens
from pipeline.qdrant_client import QdrantManager
from pipeline.schemas import (
ClassificationResult,
ExtractionResult,
SegmentationResult,
SynthesisResult,
)
from worker import celery_app
logger = logging.getLogger(__name__)
class LLMTruncationError(RuntimeError):
"""Raised when the LLM response was truncated (finish_reason=length)."""
pass
# ── Error status helper ──────────────────────────────────────────────────────
def _set_error_status(video_id: str, stage_name: str, error: Exception) -> None:
"""Mark a video as errored when a pipeline stage fails permanently."""
try:
session = _get_sync_session()
video = session.execute(
select(SourceVideo).where(SourceVideo.id == video_id)
).scalar_one_or_none()
if video:
video.processing_status = ProcessingStatus.error
session.commit()
session.close()
except Exception as mark_exc:
logger.error(
"Failed to mark video_id=%s as error after %s failure: %s",
video_id, stage_name, mark_exc,
)
# ── Pipeline event persistence ───────────────────────────────────────────────
def _emit_event(
video_id: str,
stage: str,
event_type: str,
*,
run_id: str | None = None,
prompt_tokens: int | None = None,
completion_tokens: int | None = None,
total_tokens: int | None = None,
model: str | None = None,
duration_ms: int | None = None,
payload: dict | None = None,
system_prompt_text: str | None = None,
user_prompt_text: str | None = None,
response_text: str | None = None,
) -> None:
"""Persist a pipeline event to the DB. Best-effort -- failures logged, not raised."""
try:
session = _get_sync_session()
try:
event = PipelineEvent(
video_id=video_id,
run_id=run_id,
stage=stage,
event_type=event_type,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
model=model,
duration_ms=duration_ms,
payload=payload,
system_prompt_text=system_prompt_text,
user_prompt_text=user_prompt_text,
response_text=response_text,
)
session.add(event)
session.commit()
finally:
session.close()
except Exception as exc:
logger.warning("Failed to emit pipeline event: %s", exc)
def _is_debug_mode() -> bool:
"""Check if debug mode is enabled via Redis. Falls back to config setting."""
try:
import redis
settings = get_settings()
r = redis.from_url(settings.redis_url)
val = r.get("chrysopedia:debug_mode")
r.close()
if val is not None:
return val.decode().lower() == "true"
except Exception:
pass
return getattr(get_settings(), "debug_mode", False)
def _make_llm_callback(
video_id: str,
stage: str,
system_prompt: str | None = None,
user_prompt: str | None = None,
run_id: str | None = None,
context_label: str | None = None,
request_params: dict | None = None,
):
"""Create an on_complete callback for LLMClient that emits llm_call events.
When debug mode is enabled, captures full system prompt, user prompt,
and response text on each llm_call event.
Parameters
----------
request_params:
Dict of LLM request parameters (max_tokens, model_override, modality,
response_model, temperature, etc.) to store in the event payload for
debugging which parameters were actually sent to the API.
"""
debug = _is_debug_mode()
def callback(*, model=None, prompt_tokens=None, completion_tokens=None,
total_tokens=None, content=None, finish_reason=None,
is_fallback=False, **_kwargs):
# Truncate content for storage — keep first 2000 chars for debugging
truncated = content[:2000] if content and len(content) > 2000 else content
_emit_event(
video_id=video_id,
stage=stage,
event_type="llm_call",
run_id=run_id,
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
payload={
"content_preview": truncated,
"content_length": len(content) if content else 0,
"finish_reason": finish_reason,
"is_fallback": is_fallback,
**({"context": context_label} if context_label else {}),
**({"request_params": request_params} if request_params else {}),
},
system_prompt_text=system_prompt if debug else None,
user_prompt_text=user_prompt if debug else None,
response_text=content if debug else None,
)
return callback
def _build_request_params(
max_tokens: int,
model_override: str | None,
modality: str,
response_model: str,
hard_limit: int,
) -> dict:
"""Build the request_params dict for pipeline event logging.
Separates actual API params (sent to the LLM) from internal config
(used by our estimator only) so the debug JSON is unambiguous.
"""
settings = get_settings()
return {
"api_params": {
"max_tokens": max_tokens,
"model": model_override or settings.llm_model,
"temperature": settings.llm_temperature,
"response_format": "json_object" if modality == "chat" else "none (thinking mode)",
},
"pipeline_config": {
"modality": modality,
"response_model": response_model,
"estimator_hard_limit": hard_limit,
"fallback_max_tokens": settings.llm_max_tokens,
},
}
# ── Helpers ──────────────────────────────────────────────────────────────────
_engine = None
_SessionLocal = None
def _get_sync_engine():
"""Create a sync SQLAlchemy engine, converting the async URL if needed."""
global _engine
if _engine is None:
settings = get_settings()
url = settings.database_url
# Convert async driver to sync driver
url = url.replace("postgresql+asyncpg://", "postgresql+psycopg2://")
_engine = create_engine(url, pool_pre_ping=True, pool_size=5, max_overflow=10)
return _engine
def _get_sync_session() -> Session:
"""Create a sync SQLAlchemy session for Celery tasks."""
global _SessionLocal
if _SessionLocal is None:
_SessionLocal = sessionmaker(bind=_get_sync_engine())
return _SessionLocal()
def _load_prompt(template_name: str, video_id: str | None = None) -> str:
"""Read a prompt template from the prompts directory.
If ``video_id`` is provided, checks Redis for a per-video prompt override
(key: ``chrysopedia:prompt_override:{video_id}:{template_name}``) before
falling back to the on-disk template. Overrides are set by
``run_single_stage`` for single-stage re-runs with custom prompts.
Raises FileNotFoundError if no override exists and the template is missing.
"""
# Check for per-video prompt override in Redis
if video_id:
try:
import redis
settings = get_settings()
r = redis.Redis.from_url(settings.redis_url)
override_key = f"chrysopedia:prompt_override:{video_id}:{template_name}"
override = r.get(override_key)
if override:
prompt_text = override.decode("utf-8")
logger.info(
"[PROMPT] Using override from Redis: video_id=%s, template=%s (%d chars)",
video_id, template_name, len(prompt_text),
)
return prompt_text
except Exception as exc:
logger.warning("[PROMPT] Redis override check failed: %s", exc)
settings = get_settings()
path = Path(settings.prompts_path) / template_name
if not path.exists():
logger.error("Prompt template not found: %s", path)
raise FileNotFoundError(f"Prompt template not found: {path}")
return path.read_text(encoding="utf-8")
def _get_llm_client() -> LLMClient:
"""Return an LLMClient configured from settings."""
return LLMClient(get_settings())
def _get_stage_config(stage_num: int) -> tuple[str | None, str]:
"""Return (model_override, modality) for a pipeline stage.
Reads stage-specific config from Settings. If the stage-specific model
is None/empty, returns None (LLMClient will use its default). If the
stage-specific modality is unset, defaults to "chat".
"""
settings = get_settings()
model = getattr(settings, f"llm_stage{stage_num}_model", None) or None
modality = getattr(settings, f"llm_stage{stage_num}_modality", None) or "chat"
return model, modality
def _load_canonical_tags() -> dict:
"""Load canonical tag taxonomy from config/canonical_tags.yaml."""
# Walk up from backend/ to find config/
candidates = [
Path("config/canonical_tags.yaml"),
Path("../config/canonical_tags.yaml"),
]
for candidate in candidates:
if candidate.exists():
with open(candidate, encoding="utf-8") as f:
return yaml.safe_load(f)
raise FileNotFoundError(
"canonical_tags.yaml not found. Searched: " + ", ".join(str(c) for c in candidates)
)
def _format_taxonomy_for_prompt(tags_data: dict) -> str:
"""Format the canonical tags taxonomy as readable text for the LLM prompt."""
lines = []
for cat in tags_data.get("categories", []):
lines.append(f"Category: {cat['name']}")
lines.append(f" Description: {cat['description']}")
lines.append(f" Sub-topics: {', '.join(cat.get('sub_topics', []))}")
lines.append("")
return "\n".join(lines)
def _safe_parse_llm_response(
raw,
model_cls,
llm: LLMClient,
system_prompt: str,
user_prompt: str,
modality: str = "chat",
model_override: str | None = None,
max_tokens: int | None = None,
):
"""Parse LLM response with truncation detection and one retry on failure.
If the response was truncated (finish_reason=length), raises
LLMTruncationError immediately — retrying with a JSON nudge would only
make things worse by adding tokens to an already-too-large prompt.
For non-truncation parse failures: retry once with a JSON nudge, then
raise on second failure.
"""
# Check for truncation before attempting parse
is_truncated = isinstance(raw, LLMResponse) and raw.truncated
if is_truncated:
logger.warning(
"LLM response truncated (finish=length) for %s. "
"prompt_tokens=%s, completion_tokens=%s. Will not retry with nudge.",
model_cls.__name__,
getattr(raw, "prompt_tokens", "?"),
getattr(raw, "completion_tokens", "?"),
)
try:
return llm.parse_response(raw, model_cls)
except (ValidationError, ValueError, json.JSONDecodeError) as exc:
if is_truncated:
raise LLMTruncationError(
f"LLM output truncated for {model_cls.__name__}: "
f"prompt_tokens={getattr(raw, 'prompt_tokens', '?')}, "
f"completion_tokens={getattr(raw, 'completion_tokens', '?')}. "
f"Response too large for model context window."
) from exc
logger.warning(
"First parse attempt failed for %s (%s). Retrying with JSON nudge. "
"Raw response (first 500 chars): %.500s",
model_cls.__name__,
type(exc).__name__,
raw,
)
# Retry with explicit JSON instruction
nudge_prompt = user_prompt + "\n\nIMPORTANT: Output ONLY valid JSON. No markdown, no explanation."
retry_raw = llm.complete(
system_prompt, nudge_prompt, response_model=model_cls,
modality=modality, model_override=model_override,
max_tokens=max_tokens,
)
return llm.parse_response(retry_raw, model_cls)
# ── Stage 2: Segmentation ───────────────────────────────────────────────────
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
def stage2_segmentation(self, video_id: str, run_id: str | None = None) -> str:
"""Analyze transcript segments and identify topic boundaries.
Loads all TranscriptSegment rows for the video, sends them to the LLM
for topic boundary detection, and updates topic_label on each segment.
Returns the video_id for chain compatibility.
"""
start = time.monotonic()
logger.info("Stage 2 (segmentation) starting for video_id=%s", video_id)
_emit_event(video_id, "stage2_segmentation", "start", run_id=run_id)
session = _get_sync_session()
try:
# Load segments ordered by index
segments = (
session.execute(
select(TranscriptSegment)
.where(TranscriptSegment.source_video_id == video_id)
.order_by(TranscriptSegment.segment_index)
)
.scalars()
.all()
)
if not segments:
logger.info("Stage 2: No segments found for video_id=%s, skipping.", video_id)
return video_id
# Build transcript text with indices for the LLM
transcript_lines = []
for seg in segments:
transcript_lines.append(
f"[{seg.segment_index}] ({seg.start_time:.1f}s - {seg.end_time:.1f}s) {seg.text}"
)
transcript_text = "\n".join(transcript_lines)
# Load prompt and call LLM
system_prompt = _load_prompt("stage2_segmentation.txt", video_id=video_id)
user_prompt = f"<transcript>\n{transcript_text}\n</transcript>"
llm = _get_llm_client()
model_override, modality = _get_stage_config(2)
hard_limit = get_settings().llm_max_tokens_hard_limit
max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage2_segmentation", hard_limit=hard_limit)
logger.info("Stage 2 using model=%s, modality=%s, max_tokens=%d", model_override or "default", modality, max_tokens)
_s2_request_params = _build_request_params(max_tokens, model_override, modality, "SegmentationResult", hard_limit)
raw = llm.complete(system_prompt, user_prompt, response_model=SegmentationResult, on_complete=_make_llm_callback(video_id, "stage2_segmentation", system_prompt=system_prompt, user_prompt=user_prompt, run_id=run_id, request_params=_s2_request_params),
modality=modality, model_override=model_override, max_tokens=max_tokens)
result = _safe_parse_llm_response(raw, SegmentationResult, llm, system_prompt, user_prompt,
modality=modality, model_override=model_override, max_tokens=max_tokens)
# Update topic_label on each segment row
seg_by_index = {s.segment_index: s for s in segments}
for topic_seg in result.segments:
for idx in range(topic_seg.start_index, topic_seg.end_index + 1):
if idx in seg_by_index:
seg_by_index[idx].topic_label = topic_seg.topic_label
session.commit()
elapsed = time.monotonic() - start
_emit_event(video_id, "stage2_segmentation", "complete", run_id=run_id)
logger.info(
"Stage 2 (segmentation) completed for video_id=%s in %.1fs — %d topic groups found",
video_id, elapsed, len(result.segments),
)
return video_id
except FileNotFoundError:
raise # Don't retry missing prompt files
except Exception as exc:
session.rollback()
_emit_event(video_id, "stage2_segmentation", "error", run_id=run_id, payload={"error": str(exc)})
logger.error("Stage 2 failed for video_id=%s: %s", video_id, exc)
raise self.retry(exc=exc)
finally:
session.close()
# ── Stage 3: Extraction ─────────────────────────────────────────────────────
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
def stage3_extraction(self, video_id: str, run_id: str | None = None) -> str:
"""Extract key moments from each topic segment group.
Groups segments by topic_label, calls the LLM for each group to extract
moments, creates KeyMoment rows, and sets processing_status=extracted.
Returns the video_id for chain compatibility.
"""
start = time.monotonic()
logger.info("Stage 3 (extraction) starting for video_id=%s", video_id)
_emit_event(video_id, "stage3_extraction", "start", run_id=run_id)
session = _get_sync_session()
try:
# Load segments with topic labels
segments = (
session.execute(
select(TranscriptSegment)
.where(TranscriptSegment.source_video_id == video_id)
.order_by(TranscriptSegment.segment_index)
)
.scalars()
.all()
)
if not segments:
logger.info("Stage 3: No segments found for video_id=%s, skipping.", video_id)
return video_id
# Group segments by topic_label
groups: dict[str, list[TranscriptSegment]] = defaultdict(list)
for seg in segments:
label = seg.topic_label or "unlabeled"
groups[label].append(seg)
system_prompt = _load_prompt("stage3_extraction.txt", video_id=video_id)
llm = _get_llm_client()
model_override, modality = _get_stage_config(3)
hard_limit = get_settings().llm_max_tokens_hard_limit
logger.info("Stage 3 using model=%s, modality=%s", model_override or "default", modality)
total_moments = 0
for topic_label, group_segs in groups.items():
# Build segment text for this group
seg_lines = []
for seg in group_segs:
seg_lines.append(
f"({seg.start_time:.1f}s - {seg.end_time:.1f}s) {seg.text}"
)
segment_text = "\n".join(seg_lines)
user_prompt = (
f"Topic: {topic_label}\n\n"
f"<segment>\n{segment_text}\n</segment>"
)
max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage3_extraction", hard_limit=hard_limit)
_s3_request_params = _build_request_params(max_tokens, model_override, modality, "ExtractionResult", hard_limit)
raw = llm.complete(system_prompt, user_prompt, response_model=ExtractionResult, on_complete=_make_llm_callback(video_id, "stage3_extraction", system_prompt=system_prompt, user_prompt=user_prompt, run_id=run_id, context_label=topic_label, request_params=_s3_request_params),
modality=modality, model_override=model_override, max_tokens=max_tokens)
result = _safe_parse_llm_response(raw, ExtractionResult, llm, system_prompt, user_prompt,
modality=modality, model_override=model_override, max_tokens=max_tokens)
# Create KeyMoment rows
for moment in result.moments:
# Validate content_type against enum
try:
ct = KeyMomentContentType(moment.content_type)
except ValueError:
ct = KeyMomentContentType.technique
km = KeyMoment(
source_video_id=video_id,
title=moment.title,
summary=moment.summary,
start_time=moment.start_time,
end_time=moment.end_time,
content_type=ct,
plugins=moment.plugins if moment.plugins else None,
raw_transcript=moment.raw_transcript or None,
)
session.add(km)
total_moments += 1
session.commit()
elapsed = time.monotonic() - start
_emit_event(video_id, "stage3_extraction", "complete", run_id=run_id)
logger.info(
"Stage 3 (extraction) completed for video_id=%s in %.1fs — %d moments created",
video_id, elapsed, total_moments,
)
return video_id
except FileNotFoundError:
raise
except Exception as exc:
session.rollback()
_emit_event(video_id, "stage3_extraction", "error", run_id=run_id, payload={"error": str(exc)})
logger.error("Stage 3 failed for video_id=%s: %s", video_id, exc)
raise self.retry(exc=exc)
finally:
session.close()
# ── Stage 4: Classification ─────────────────────────────────────────────────
# Maximum moments per classification batch. Keeps each LLM call well within
# context window limits. Batches are classified independently and merged.
_STAGE4_BATCH_SIZE = 20
def _classify_moment_batch(
moments_batch: list,
batch_offset: int,
taxonomy_text: str,
system_prompt: str,
llm: LLMClient,
model_override: str | None,
modality: str,
hard_limit: int,
video_id: str,
run_id: str | None,
) -> ClassificationResult:
"""Classify a single batch of moments. Raises on failure."""
moments_lines = []
for i, m in enumerate(moments_batch):
moments_lines.append(
f"[{i}] Title: {m.title}\n"
f" Summary: {m.summary}\n"
f" Content type: {m.content_type.value}\n"
f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}"
)
moments_text = "\n\n".join(moments_lines)
user_prompt = (
f"<taxonomy>\n{taxonomy_text}\n</taxonomy>\n\n"
f"<moments>\n{moments_text}\n</moments>"
)
max_tokens = estimate_max_tokens(
system_prompt, user_prompt,
stage="stage4_classification", hard_limit=hard_limit,
)
batch_label = f"batch {batch_offset // _STAGE4_BATCH_SIZE + 1} (moments {batch_offset}-{batch_offset + len(moments_batch) - 1})"
logger.info(
"Stage 4 classifying %s, max_tokens=%d",
batch_label, max_tokens,
)
raw = llm.complete(
system_prompt, user_prompt,
response_model=ClassificationResult,
on_complete=_make_llm_callback(
video_id, "stage4_classification",
system_prompt=system_prompt, user_prompt=user_prompt,
run_id=run_id, context_label=batch_label,
request_params=_build_request_params(max_tokens, model_override, modality, "ClassificationResult", hard_limit),
),
modality=modality, model_override=model_override,
max_tokens=max_tokens,
)
return _safe_parse_llm_response(
raw, ClassificationResult, llm, system_prompt, user_prompt,
modality=modality, model_override=model_override,
max_tokens=max_tokens,
)
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
def stage4_classification(self, video_id: str, run_id: str | None = None) -> str:
"""Classify key moments against the canonical tag taxonomy.
Loads all KeyMoment rows for the video, sends them to the LLM with the
canonical taxonomy, and stores classification results in Redis for
stage 5 consumption. Updates content_type if the classifier overrides it.
For large moment sets, automatically batches into groups of
_STAGE4_BATCH_SIZE to stay within model context window limits.
Stage 4 does NOT change processing_status.
Returns the video_id for chain compatibility.
"""
start = time.monotonic()
logger.info("Stage 4 (classification) starting for video_id=%s", video_id)
_emit_event(video_id, "stage4_classification", "start", run_id=run_id)
session = _get_sync_session()
try:
# Load key moments
moments = (
session.execute(
select(KeyMoment)
.where(KeyMoment.source_video_id == video_id)
.order_by(KeyMoment.start_time)
)
.scalars()
.all()
)
if not moments:
logger.info("Stage 4: No moments found for video_id=%s, skipping.", video_id)
_store_classification_data(video_id, [])
return video_id
# Load canonical tags
tags_data = _load_canonical_tags()
taxonomy_text = _format_taxonomy_for_prompt(tags_data)
system_prompt = _load_prompt("stage4_classification.txt", video_id=video_id)
llm = _get_llm_client()
model_override, modality = _get_stage_config(4)
hard_limit = get_settings().llm_max_tokens_hard_limit
# Batch moments for classification
all_classifications = []
for batch_start in range(0, len(moments), _STAGE4_BATCH_SIZE):
batch = moments[batch_start:batch_start + _STAGE4_BATCH_SIZE]
result = _classify_moment_batch(
batch, batch_start, taxonomy_text, system_prompt,
llm, model_override, modality, hard_limit,
video_id, run_id,
)
# Reindex: batch uses 0-based indices, remap to global indices
for cls in result.classifications:
cls.moment_index += batch_start
all_classifications.extend(result.classifications)
# Apply content_type overrides and prepare classification data for stage 5
classification_data = []
for cls in all_classifications:
if 0 <= cls.moment_index < len(moments):
moment = moments[cls.moment_index]
# Apply content_type override if provided
if cls.content_type_override:
try:
moment.content_type = KeyMomentContentType(cls.content_type_override)
except ValueError:
pass
classification_data.append({
"moment_id": str(moment.id),
"topic_category": cls.topic_category.strip().title(),
"topic_tags": cls.topic_tags,
})
session.commit()
# Store classification data in Redis for stage 5
_store_classification_data(video_id, classification_data)
elapsed = time.monotonic() - start
num_batches = (len(moments) + _STAGE4_BATCH_SIZE - 1) // _STAGE4_BATCH_SIZE
_emit_event(video_id, "stage4_classification", "complete", run_id=run_id)
logger.info(
"Stage 4 (classification) completed for video_id=%s in %.1fs — "
"%d moments classified in %d batch(es)",
video_id, elapsed, len(classification_data), num_batches,
)
return video_id
except FileNotFoundError:
raise
except Exception as exc:
session.rollback()
_emit_event(video_id, "stage4_classification", "error", run_id=run_id, payload={"error": str(exc)})
logger.error("Stage 4 failed for video_id=%s: %s", video_id, exc)
raise self.retry(exc=exc)
finally:
session.close()
def _store_classification_data(video_id: str, data: list[dict]) -> None:
"""Store classification data in Redis (cache) and PostgreSQL (durable).
Dual-write ensures classification data survives Redis TTL expiry or flush.
Redis serves as the fast-path cache; PostgreSQL is the durable fallback.
"""
import redis
settings = get_settings()
# Redis: fast cache with 7-day TTL
try:
r = redis.Redis.from_url(settings.redis_url)
key = f"chrysopedia:classification:{video_id}"
r.set(key, json.dumps(data), ex=604800) # 7 days
logger.info(
"[CLASSIFY-STORE] Redis write: video_id=%s, %d entries, ttl=7d",
video_id, len(data),
)
except Exception as exc:
logger.warning(
"[CLASSIFY-STORE] Redis write failed for video_id=%s: %s", video_id, exc,
)
# PostgreSQL: durable storage on SourceVideo.classification_data
session = _get_sync_session()
try:
video = session.execute(
select(SourceVideo).where(SourceVideo.id == video_id)
).scalar_one_or_none()
if video:
video.classification_data = data
session.commit()
logger.info(
"[CLASSIFY-STORE] PostgreSQL write: video_id=%s, %d entries",
video_id, len(data),
)
else:
logger.warning(
"[CLASSIFY-STORE] Video not found for PostgreSQL write: %s", video_id,
)
except Exception as exc:
session.rollback()
logger.warning(
"[CLASSIFY-STORE] PostgreSQL write failed for video_id=%s: %s", video_id, exc,
)
finally:
session.close()
def _load_classification_data(video_id: str) -> list[dict]:
"""Load classification data from Redis (fast path) or PostgreSQL (fallback).
Tries Redis first. If the key has expired or Redis is unavailable, falls
back to the durable SourceVideo.classification_data column.
"""
import redis
settings = get_settings()
# Try Redis first (fast path)
try:
r = redis.Redis.from_url(settings.redis_url)
key = f"chrysopedia:classification:{video_id}"
raw = r.get(key)
if raw is not None:
data = json.loads(raw)
logger.info(
"[CLASSIFY-LOAD] Source: redis, video_id=%s, %d entries",
video_id, len(data),
)
return data
except Exception as exc:
logger.warning(
"[CLASSIFY-LOAD] Redis unavailable for video_id=%s: %s", video_id, exc,
)
# Fallback to PostgreSQL
logger.info("[CLASSIFY-LOAD] Redis miss, falling back to PostgreSQL for video_id=%s", video_id)
session = _get_sync_session()
try:
video = session.execute(
select(SourceVideo).where(SourceVideo.id == video_id)
).scalar_one_or_none()
if video and video.classification_data:
data = video.classification_data
logger.info(
"[CLASSIFY-LOAD] Source: postgresql, video_id=%s, %d entries",
video_id, len(data),
)
return data
except Exception as exc:
logger.warning(
"[CLASSIFY-LOAD] PostgreSQL fallback failed for video_id=%s: %s", video_id, exc,
)
finally:
session.close()
logger.warning(
"[CLASSIFY-LOAD] No classification data found in Redis or PostgreSQL for video_id=%s",
video_id,
)
return []
def _get_git_commit_sha() -> str:
"""Resolve the git commit SHA used to build this image.
Resolution order:
1. /app/.git-commit file (written during Docker build)
2. git rev-parse --short HEAD (local dev)
3. GIT_COMMIT_SHA env var / config setting
4. "unknown"
"""
# Docker build artifact
git_commit_file = Path("/app/.git-commit")
if git_commit_file.exists():
sha = git_commit_file.read_text(encoding="utf-8").strip()
if sha and sha != "unknown":
return sha
# Local dev — run git
try:
result = subprocess.run(
["git", "rev-parse", "--short", "HEAD"],
capture_output=True, text=True, timeout=5,
)
if result.returncode == 0 and result.stdout.strip():
return result.stdout.strip()
except (FileNotFoundError, subprocess.TimeoutExpired):
pass
# Config / env var fallback
try:
sha = get_settings().git_commit_sha
if sha and sha != "unknown":
return sha
except Exception:
pass
return "unknown"
def _capture_pipeline_metadata() -> dict:
"""Capture current pipeline configuration for version metadata.
Returns a dict with model names, prompt file SHA-256 hashes, and stage
modality settings. Handles missing prompt files gracefully.
"""
settings = get_settings()
prompts_path = Path(settings.prompts_path)
# Hash each prompt template file
prompt_hashes: dict[str, str] = {}
prompt_files = [
"stage2_segmentation.txt",
"stage3_extraction.txt",
"stage4_classification.txt",
"stage5_synthesis.txt",
]
for filename in prompt_files:
filepath = prompts_path / filename
try:
content = filepath.read_bytes()
prompt_hashes[filename] = hashlib.sha256(content).hexdigest()
except FileNotFoundError:
logger.warning("Prompt file not found for metadata capture: %s", filepath)
prompt_hashes[filename] = ""
except OSError as exc:
logger.warning("Could not read prompt file %s: %s", filepath, exc)
prompt_hashes[filename] = ""
return {
"git_commit_sha": _get_git_commit_sha(),
"models": {
"stage2": settings.llm_stage2_model,
"stage3": settings.llm_stage3_model,
"stage4": settings.llm_stage4_model,
"stage5": settings.llm_stage5_model,
"embedding": settings.embedding_model,
},
"modalities": {
"stage2": settings.llm_stage2_modality,
"stage3": settings.llm_stage3_modality,
"stage4": settings.llm_stage4_modality,
"stage5": settings.llm_stage5_modality,
},
"prompt_hashes": prompt_hashes,
}
# ── Stage 5: Synthesis ───────────────────────────────────────────────────────
def _serialize_body_sections(sections) -> list | dict | None:
"""Convert body_sections to JSON-serializable form for DB storage."""
if isinstance(sections, list):
return [s.model_dump() if hasattr(s, 'model_dump') else s for s in sections]
return sections
def _compute_page_tags(
moment_indices: list[int],
moment_group: list[tuple],
all_tags: set[str],
) -> list[str] | None:
"""Compute tags for a specific page from its linked moment indices.
If moment_indices are available, collects tags only from those moments.
Falls back to all_tags for the category group if no indices provided.
"""
if not moment_indices:
return list(all_tags) if all_tags else None
page_tags: set[str] = set()
for idx in moment_indices:
if 0 <= idx < len(moment_group):
_, cls_info = moment_group[idx]
page_tags.update(cls_info.get("topic_tags", []))
return list(page_tags) if page_tags else None
def _build_moments_text(
moment_group: list[tuple[KeyMoment, dict]],
category: str,
) -> tuple[str, set[str]]:
"""Build the moments prompt text and collect all tags for a group of moments.
Returns (moments_text, all_tags).
"""
moments_lines = []
all_tags: set[str] = set()
for i, (m, cls_info) in enumerate(moment_group):
tags = cls_info.get("topic_tags", [])
all_tags.update(tags)
moments_lines.append(
f"[{i}] Title: {m.title}\n"
f" Summary: {m.summary}\n"
f" Content type: {m.content_type.value}\n"
f" Time: {m.start_time:.1f}s - {m.end_time:.1f}s\n"
f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}\n"
f" Category: {category}\n"
f" Tags: {', '.join(tags) if tags else 'none'}\n"
f" Transcript excerpt: {(m.raw_transcript or '')[:300]}"
)
return "\n\n".join(moments_lines), all_tags
def _build_compose_user_prompt(
existing_page: TechniquePage,
existing_moments: list[KeyMoment],
new_moments: list[tuple[KeyMoment, dict]],
creator_name: str,
) -> str:
"""Build the user prompt for composing new moments into an existing page.
Existing moments keep indices [0]-[N-1].
New moments get indices [N]-[N+M-1].
XML-tagged prompt structure matches test_harness.py build_compose_prompt().
"""
category = existing_page.topic_category or "Uncategorized"
# Serialize existing page to dict matching SynthesizedPage shape
sq = existing_page.source_quality
sq_value = sq.value if hasattr(sq, "value") else sq
page_dict = {
"title": existing_page.title,
"slug": existing_page.slug,
"topic_category": existing_page.topic_category,
"summary": existing_page.summary,
"body_sections": existing_page.body_sections,
"signal_chains": existing_page.signal_chains,
"plugins": existing_page.plugins,
"source_quality": sq_value,
}
# Format existing moments [0]-[N-1] using _build_moments_text pattern
# Existing moments don't have classification data — use empty dict
existing_as_tuples = [(m, {}) for m in existing_moments]
existing_text, _ = _build_moments_text(existing_as_tuples, category)
# Format new moments [N]-[N+M-1] with offset indices
n = len(existing_moments)
new_lines = []
for i, (m, cls_info) in enumerate(new_moments):
tags = cls_info.get("topic_tags", [])
new_lines.append(
f"[{n + i}] Title: {m.title}\n"
f" Summary: {m.summary}\n"
f" Content type: {m.content_type.value}\n"
f" Time: {m.start_time:.1f}s - {m.end_time:.1f}s\n"
f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}\n"
f" Category: {category}\n"
f" Tags: {', '.join(tags) if tags else 'none'}\n"
f" Transcript excerpt: {(m.raw_transcript or '')[:300]}"
)
new_text = "\n\n".join(new_lines)
page_json = json.dumps(page_dict, indent=2, ensure_ascii=False, default=str)
return (
f"<existing_page>\n{page_json}\n</existing_page>\n"
f"<existing_moments>\n{existing_text}\n</existing_moments>\n"
f"<new_moments>\n{new_text}\n</new_moments>\n"
f"<creator>{creator_name}</creator>"
)
def _compose_into_existing(
existing_page: TechniquePage,
existing_moments: list[KeyMoment],
new_moment_group: list[tuple[KeyMoment, dict]],
category: str,
creator_name: str,
system_prompt: str,
llm: LLMClient,
model_override: str | None,
modality: str,
hard_limit: int,
video_id: str,
run_id: str | None,
) -> SynthesisResult:
"""Compose new moments into an existing technique page via LLM.
Loads the compose system prompt, builds the compose user prompt, and
calls the LLM with the same retry/parse pattern as _synthesize_chunk().
"""
compose_prompt = _load_prompt("stage5_compose.txt", video_id=video_id)
user_prompt = _build_compose_user_prompt(
existing_page, existing_moments, new_moment_group, creator_name,
)
estimated_input = estimate_max_tokens(
compose_prompt, user_prompt,
stage="stage5_synthesis", hard_limit=hard_limit,
)
logger.info(
"Stage 5: Composing into '%s'%d existing + %d new moments, max_tokens=%d",
existing_page.slug, len(existing_moments), len(new_moment_group), estimated_input,
)
raw = llm.complete(
compose_prompt, user_prompt, response_model=SynthesisResult,
on_complete=_make_llm_callback(
video_id, "stage5_synthesis",
system_prompt=compose_prompt, user_prompt=user_prompt,
run_id=run_id, context_label=f"compose:{category}",
request_params=_build_request_params(
estimated_input, model_override, modality, "SynthesisResult", hard_limit,
),
),
modality=modality, model_override=model_override, max_tokens=estimated_input,
)
return _safe_parse_llm_response(
raw, SynthesisResult, llm, compose_prompt, user_prompt,
modality=modality, model_override=model_override, max_tokens=estimated_input,
)
def _synthesize_chunk(
chunk: list[tuple[KeyMoment, dict]],
category: str,
creator_name: str,
system_prompt: str,
llm: LLMClient,
model_override: str | None,
modality: str,
hard_limit: int,
video_id: str,
run_id: str | None,
chunk_label: str,
) -> SynthesisResult:
"""Run a single synthesis LLM call for a chunk of moments.
Returns the parsed SynthesisResult.
"""
moments_text, _ = _build_moments_text(chunk, category)
user_prompt = f"<creator>{creator_name}</creator>\n<moments>\n{moments_text}\n</moments>"
estimated_input = estimate_max_tokens(system_prompt, user_prompt, stage="stage5_synthesis", hard_limit=hard_limit)
logger.info(
"Stage 5: Synthesizing %s%d moments, max_tokens=%d",
chunk_label, len(chunk), estimated_input,
)
raw = llm.complete(
system_prompt, user_prompt, response_model=SynthesisResult,
on_complete=_make_llm_callback(
video_id, "stage5_synthesis",
system_prompt=system_prompt, user_prompt=user_prompt,
run_id=run_id, context_label=chunk_label,
request_params=_build_request_params(estimated_input, model_override, modality, "SynthesisResult", hard_limit),
),
modality=modality, model_override=model_override, max_tokens=estimated_input,
)
return _safe_parse_llm_response(
raw, SynthesisResult, llm, system_prompt, user_prompt,
modality=modality, model_override=model_override, max_tokens=estimated_input,
)
def _slug_base(slug: str) -> str:
"""Extract the slug prefix before the creator name suffix for merge grouping.
E.g. 'wavetable-sound-design-copycatt''wavetable-sound-design'
Also normalizes casing.
"""
return slug.lower().strip()
def _merge_pages_by_slug(
all_pages: list,
creator_name: str,
llm: LLMClient,
model_override: str | None,
modality: str,
hard_limit: int,
video_id: str,
run_id: str | None,
) -> list:
"""Detect pages with the same slug across chunks and merge them via LLM.
Pages with unique slugs pass through unchanged. Pages sharing a slug
get sent to a merge prompt that combines them into one cohesive page.
Returns the final list of SynthesizedPage objects.
"""
from pipeline.schemas import SynthesizedPage
# Group pages by slug
by_slug: dict[str, list] = defaultdict(list)
for page in all_pages:
by_slug[_slug_base(page.slug)].append(page)
final_pages = []
for slug, pages_group in by_slug.items():
if len(pages_group) == 1:
# Unique slug — no merge needed
final_pages.append(pages_group[0])
continue
# Multiple pages share this slug — merge via LLM
logger.info(
"Stage 5: Merging %d partial pages with slug '%s' for video_id=%s",
len(pages_group), slug, video_id,
)
# Serialize partial pages to JSON for the merge prompt
pages_json = json.dumps(
[p.model_dump() for p in pages_group],
indent=2, ensure_ascii=False,
)
merge_system_prompt = _load_prompt("stage5_merge.txt", video_id=video_id)
merge_user_prompt = f"<creator>{creator_name}</creator>\n<pages>\n{pages_json}\n</pages>"
max_tokens = estimate_max_tokens(
merge_system_prompt, merge_user_prompt,
stage="stage5_synthesis", hard_limit=hard_limit,
)
logger.info(
"Stage 5: Merge call for slug '%s'%d partial pages, max_tokens=%d",
slug, len(pages_group), max_tokens,
)
raw = llm.complete(
merge_system_prompt, merge_user_prompt,
response_model=SynthesisResult,
on_complete=_make_llm_callback(
video_id, "stage5_synthesis",
system_prompt=merge_system_prompt,
user_prompt=merge_user_prompt,
run_id=run_id, context_label=f"merge:{slug}",
request_params=_build_request_params(max_tokens, model_override, modality, "SynthesisResult", hard_limit),
),
modality=modality, model_override=model_override,
max_tokens=max_tokens,
)
merge_result = _safe_parse_llm_response(
raw, SynthesisResult, llm,
merge_system_prompt, merge_user_prompt,
modality=modality, model_override=model_override,
max_tokens=max_tokens,
)
if merge_result.pages:
final_pages.extend(merge_result.pages)
logger.info(
"Stage 5: Merge produced %d page(s) for slug '%s'",
len(merge_result.pages), slug,
)
else:
# Merge returned nothing — fall back to keeping the partials
logger.warning(
"Stage 5: Merge returned 0 pages for slug '%s', keeping %d partials",
slug, len(pages_group),
)
final_pages.extend(pages_group)
return final_pages
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
def stage5_synthesis(self, video_id: str, run_id: str | None = None) -> str:
"""Synthesize technique pages from classified key moments.
Groups moments by (creator, topic_category), calls the LLM to synthesize
each group into a TechniquePage, creates/updates page rows, and links
KeyMoments to their TechniquePage.
For large category groups (exceeding synthesis_chunk_size), moments are
split into chronological chunks, synthesized independently, then pages
with matching slugs are merged via a dedicated merge LLM call.
Sets processing_status to 'complete'.
Returns the video_id for chain compatibility.
"""
start = time.monotonic()
logger.info("Stage 5 (synthesis) starting for video_id=%s", video_id)
_emit_event(video_id, "stage5_synthesis", "start", run_id=run_id)
settings = get_settings()
chunk_size = settings.synthesis_chunk_size
session = _get_sync_session()
try:
# Load video and moments
video = session.execute(
select(SourceVideo).where(SourceVideo.id == video_id)
).scalar_one()
moments = (
session.execute(
select(KeyMoment)
.where(KeyMoment.source_video_id == video_id)
.order_by(KeyMoment.start_time)
)
.scalars()
.all()
)
# Resolve creator name for the LLM prompt
creator = session.execute(
select(Creator).where(Creator.id == video.creator_id)
).scalar_one_or_none()
creator_name = creator.name if creator else "Unknown"
if not moments:
logger.info("Stage 5: No moments found for video_id=%s, skipping.", video_id)
return video_id
# Load classification data from stage 4
classification_data = _load_classification_data(video_id)
cls_by_moment_id = {c["moment_id"]: c for c in classification_data}
# Group moments by topic_category (from classification)
# Normalize category casing to prevent near-duplicate groups
# (e.g., "Sound design" vs "Sound Design")
groups: dict[str, list[tuple[KeyMoment, dict]]] = defaultdict(list)
for moment in moments:
cls_info = cls_by_moment_id.get(str(moment.id), {})
category = cls_info.get("topic_category", "Uncategorized").strip().title()
groups[category].append((moment, cls_info))
system_prompt = _load_prompt("stage5_synthesis.txt", video_id=video_id)
llm = _get_llm_client()
model_override, modality = _get_stage_config(5)
hard_limit = settings.llm_max_tokens_hard_limit
logger.info("Stage 5 using model=%s, modality=%s", model_override or "default", modality)
pages_created = 0
for category, moment_group in groups.items():
# Collect all tags across the full group (used for DB writes later)
all_tags: set[str] = set()
for _, cls_info in moment_group:
all_tags.update(cls_info.get("topic_tags", []))
# ── Compose-or-create detection ────────────────────────
# Check if an existing technique page already covers this
# creator + category combination (from a prior video run).
compose_matches = session.execute(
select(TechniquePage).where(
TechniquePage.creator_id == video.creator_id,
func.lower(TechniquePage.topic_category) == func.lower(category),
)
).scalars().all()
if len(compose_matches) > 1:
logger.warning(
"Stage 5: Multiple existing pages (%d) match creator=%s category='%s'. "
"Using first match '%s'.",
len(compose_matches), video.creator_id, category,
compose_matches[0].slug,
)
compose_target = compose_matches[0] if compose_matches else None
if compose_target is not None:
# Load existing moments linked to this page
existing_moments = session.execute(
select(KeyMoment)
.where(KeyMoment.technique_page_id == compose_target.id)
.order_by(KeyMoment.start_time)
).scalars().all()
logger.info(
"Stage 5: Composing into existing page '%s' "
"(%d existing moments + %d new moments)",
compose_target.slug,
len(existing_moments),
len(moment_group),
)
compose_result = _compose_into_existing(
compose_target, existing_moments, moment_group,
category, creator_name, system_prompt,
llm, model_override, modality, hard_limit,
video_id, run_id,
)
synthesized_pages = list(compose_result.pages)
# ── Chunked synthesis with truncation recovery ─────────
elif len(moment_group) <= chunk_size:
# Small group — try single LLM call first
try:
result = _synthesize_chunk(
moment_group, category, creator_name,
system_prompt, llm, model_override, modality, hard_limit,
video_id, run_id, f"category:{category}",
)
synthesized_pages = list(result.pages)
logger.info(
"Stage 5: category '%s'%d moments, %d page(s) from single call",
category, len(moment_group), len(synthesized_pages),
)
except LLMTruncationError:
# Output too large for model context — split in half and retry
logger.warning(
"Stage 5: category '%s' truncated with %d moments. "
"Splitting into smaller chunks and retrying.",
category, len(moment_group),
)
half = max(1, len(moment_group) // 2)
chunk_pages = []
for sub_start in range(0, len(moment_group), half):
sub_chunk = moment_group[sub_start:sub_start + half]
sub_label = f"category:{category} recovery-chunk:{sub_start // half + 1}"
sub_result = _synthesize_chunk(
sub_chunk, category, creator_name,
system_prompt, llm, model_override, modality, hard_limit,
video_id, run_id, sub_label,
)
# Reindex moment_indices to global offsets
for p in sub_result.pages:
if p.moment_indices:
p.moment_indices = [idx + sub_start for idx in p.moment_indices]
chunk_pages.extend(sub_result.pages)
synthesized_pages = chunk_pages
logger.info(
"Stage 5: category '%s'%d page(s) from recovery split",
category, len(synthesized_pages),
)
else:
# Large group — split into chunks, synthesize each, then merge
num_chunks = (len(moment_group) + chunk_size - 1) // chunk_size
logger.info(
"Stage 5: category '%s' has %d moments — splitting into %d chunks of ≤%d",
category, len(moment_group), num_chunks, chunk_size,
)
chunk_pages = []
for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * chunk_size
chunk_end = min(chunk_start + chunk_size, len(moment_group))
chunk = moment_group[chunk_start:chunk_end]
chunk_label = f"category:{category} chunk:{chunk_idx + 1}/{num_chunks}"
result = _synthesize_chunk(
chunk, category, creator_name,
system_prompt, llm, model_override, modality, hard_limit,
video_id, run_id, chunk_label,
)
chunk_pages.extend(result.pages)
logger.info(
"Stage 5: %s produced %d page(s)",
chunk_label, len(result.pages),
)
# Merge pages with matching slugs across chunks
logger.info(
"Stage 5: category '%s'%d total pages from %d chunks, checking for merges",
category, len(chunk_pages), num_chunks,
)
synthesized_pages = _merge_pages_by_slug(
chunk_pages, creator_name,
llm, model_override, modality, hard_limit,
video_id, run_id,
)
logger.info(
"Stage 5: category '%s'%d final page(s) after merge",
category, len(synthesized_pages),
)
# ── Persist pages to DB ──────────────────────────────────
# Load prior pages from this video (snapshot taken before pipeline reset)
prior_page_ids = _load_prior_pages(video_id)
for page_data in synthesized_pages:
page_moment_indices = getattr(page_data, "moment_indices", None) or []
existing = None
# First: check by slug (most specific match)
if existing is None:
existing = session.execute(
select(TechniquePage).where(TechniquePage.slug == page_data.slug)
).scalar_one_or_none()
# Fallback: check prior pages from this video by creator + category
# Use .first() since multiple pages may share a category
if existing is None and prior_page_ids:
existing = session.execute(
select(TechniquePage).where(
TechniquePage.id.in_(prior_page_ids),
TechniquePage.creator_id == video.creator_id,
func.lower(TechniquePage.topic_category) == func.lower(page_data.topic_category or category),
)
).scalars().first()
if existing:
logger.info(
"Stage 5: Matched prior page '%s' (id=%s) by creator+category for video_id=%s",
existing.slug, existing.id, video_id,
)
if existing:
# Snapshot existing content before overwriting
try:
sq = existing.source_quality
sq_value = sq.value if hasattr(sq, 'value') else sq
snapshot = {
"title": existing.title,
"slug": existing.slug,
"topic_category": existing.topic_category,
"topic_tags": existing.topic_tags,
"summary": existing.summary,
"body_sections": existing.body_sections,
"signal_chains": existing.signal_chains,
"plugins": existing.plugins,
"source_quality": sq_value,
}
version_count = session.execute(
select(func.count()).where(
TechniquePageVersion.technique_page_id == existing.id
)
).scalar()
version_number = version_count + 1
version = TechniquePageVersion(
technique_page_id=existing.id,
version_number=version_number,
content_snapshot=snapshot,
pipeline_metadata=_capture_pipeline_metadata(),
)
session.add(version)
logger.info(
"Version snapshot v%d created for page slug=%s",
version_number, existing.slug,
)
except Exception as snap_exc:
logger.error(
"Failed to create version snapshot for page slug=%s: %s",
existing.slug, snap_exc,
)
# Best-effort versioning — continue with page update
# Update existing page
existing.title = page_data.title
existing.summary = page_data.summary
existing.body_sections = _serialize_body_sections(page_data.body_sections)
existing.signal_chains = page_data.signal_chains
existing.plugins = page_data.plugins if page_data.plugins else None
page_tags = _compute_page_tags(page_moment_indices, moment_group, all_tags)
existing.topic_tags = page_tags
existing.source_quality = page_data.source_quality
page = existing
else:
page = TechniquePage(
creator_id=video.creator_id,
title=page_data.title,
slug=page_data.slug,
topic_category=page_data.topic_category or category,
topic_tags=_compute_page_tags(page_moment_indices, moment_group, all_tags),
summary=page_data.summary,
body_sections=_serialize_body_sections(page_data.body_sections),
signal_chains=page_data.signal_chains,
plugins=page_data.plugins if page_data.plugins else None,
source_quality=page_data.source_quality,
)
session.add(page)
session.flush() # Get the page.id assigned
pages_created += 1
# Set body_sections_format on every page (new or updated)
page.body_sections_format = "v2"
# Track contributing video via TechniquePageVideo
stmt = pg_insert(TechniquePageVideo.__table__).values(
technique_page_id=page.id,
source_video_id=video.id,
).on_conflict_do_nothing()
session.execute(stmt)
# Link moments to the technique page using moment_indices
if page_moment_indices:
# LLM specified which moments belong to this page
for idx in page_moment_indices:
if 0 <= idx < len(moment_group):
moment_group[idx][0].technique_page_id = page.id
elif len(synthesized_pages) == 1:
# Single page — link all moments (safe fallback)
for m, _ in moment_group:
m.technique_page_id = page.id
else:
# Multiple pages but no moment_indices — log warning
logger.warning(
"Stage 5: page '%s' has no moment_indices and is one of %d pages "
"for category '%s'. Moments will not be linked to this page.",
page_data.slug, len(synthesized_pages), category,
)
# Update processing_status
video.processing_status = ProcessingStatus.complete
session.commit()
elapsed = time.monotonic() - start
_emit_event(video_id, "stage5_synthesis", "complete", run_id=run_id)
logger.info(
"Stage 5 (synthesis) completed for video_id=%s in %.1fs — %d pages created/updated",
video_id, elapsed, pages_created,
)
return video_id
except FileNotFoundError:
raise
except Exception as exc:
session.rollback()
_emit_event(video_id, "stage5_synthesis", "error", run_id=run_id, payload={"error": str(exc)})
logger.error("Stage 5 failed for video_id=%s: %s", video_id, exc)
raise self.retry(exc=exc)
finally:
session.close()
# ── Heading slug helper (matches frontend TableOfContents.tsx slugify) ────────
def _slugify_heading(text: str) -> str:
"""Convert a heading string to a URL-friendly anchor slug.
Must produce identical output to the frontend's slugify in
``frontend/src/components/TableOfContents.tsx``.
"""
return re.sub(r"[^a-z0-9]+", "-", text.lower()).strip("-")
# ── Stage 6: Embed & Index ───────────────────────────────────────────────────
@celery_app.task(bind=True, max_retries=0)
def stage6_embed_and_index(self, video_id: str, run_id: str | None = None) -> str:
"""Generate embeddings for technique pages and key moments, then upsert to Qdrant.
This is a non-blocking side-effect stage — failures are logged but do not
fail the pipeline. Embeddings can be regenerated later. Does NOT update
processing_status.
Returns the video_id for chain compatibility.
"""
start = time.monotonic()
logger.info("Stage 6 (embed & index) starting for video_id=%s", video_id)
settings = get_settings()
session = _get_sync_session()
try:
# Load technique pages created for this video's moments
moments = (
session.execute(
select(KeyMoment)
.where(KeyMoment.source_video_id == video_id)
.order_by(KeyMoment.start_time)
)
.scalars()
.all()
)
# Get unique technique page IDs from moments
page_ids = {m.technique_page_id for m in moments if m.technique_page_id is not None}
pages = []
if page_ids:
pages = (
session.execute(
select(TechniquePage).where(TechniquePage.id.in_(page_ids))
)
.scalars()
.all()
)
if not moments and not pages:
logger.info("Stage 6: No moments or pages for video_id=%s, skipping.", video_id)
if run_id:
_finish_run(run_id, "complete")
return video_id
# Resolve creator names for enriched embedding text
creator_ids = {p.creator_id for p in pages}
creator_map: dict[str, str] = {}
if creator_ids:
creators = (
session.execute(
select(Creator).where(Creator.id.in_(creator_ids))
)
.scalars()
.all()
)
creator_map = {str(c.id): c.name for c in creators}
# Resolve creator name for key moments via source_video → creator
video_ids = {m.source_video_id for m in moments}
video_creator_map: dict[str, str] = {}
if video_ids:
rows = session.execute(
select(SourceVideo.id, Creator.name, Creator.id.label("creator_id"))
.join(Creator, SourceVideo.creator_id == Creator.id)
.where(SourceVideo.id.in_(video_ids))
).all()
video_creator_map = {str(r[0]): r[1] for r in rows}
video_creator_id_map = {str(r[0]): str(r[2]) for r in rows}
embed_client = EmbeddingClient(settings)
qdrant = QdrantManager(settings)
# Ensure collection exists before upserting
qdrant.ensure_collection()
# ── Embed & upsert technique pages ───────────────────────────────
if pages:
page_texts = []
page_dicts = []
for p in pages:
creator_name = creator_map.get(str(p.creator_id), "")
tags_joined = " ".join(p.topic_tags) if p.topic_tags else ""
text = f"{creator_name} {p.title} {p.topic_category or ''} {tags_joined} {p.summary or ''}"
page_texts.append(text.strip())
page_dicts.append({
"page_id": str(p.id),
"creator_id": str(p.creator_id),
"creator_name": creator_name,
"title": p.title,
"slug": p.slug,
"topic_category": p.topic_category or "",
"topic_tags": p.topic_tags or [],
"summary": p.summary or "",
})
page_vectors = embed_client.embed(page_texts)
if page_vectors:
qdrant.upsert_technique_pages(page_dicts, page_vectors)
logger.info(
"Stage 6: Upserted %d technique page vectors for video_id=%s",
len(page_vectors), video_id,
)
else:
logger.warning(
"Stage 6: Embedding returned empty for %d technique pages (video_id=%s). "
"Skipping page upsert.",
len(page_texts), video_id,
)
# ── Embed & upsert key moments ───────────────────────────────────
if moments:
# Build page_id → slug mapping for linking moments to technique pages
page_id_to_slug: dict[str, str] = {}
if pages:
for p in pages:
page_id_to_slug[str(p.id)] = p.slug
moment_texts = []
moment_dicts = []
for m in moments:
creator_name = video_creator_map.get(str(m.source_video_id), "")
text = f"{creator_name} {m.title} {m.summary or ''}"
moment_texts.append(text.strip())
tp_id = str(m.technique_page_id) if m.technique_page_id else ""
moment_dicts.append({
"moment_id": str(m.id),
"source_video_id": str(m.source_video_id),
"creator_id": video_creator_id_map.get(str(m.source_video_id), ""),
"technique_page_id": tp_id,
"technique_page_slug": page_id_to_slug.get(tp_id, ""),
"title": m.title,
"creator_name": creator_name,
"start_time": m.start_time,
"end_time": m.end_time,
"content_type": m.content_type.value,
})
moment_vectors = embed_client.embed(moment_texts)
if moment_vectors:
qdrant.upsert_key_moments(moment_dicts, moment_vectors)
logger.info(
"Stage 6: Upserted %d key moment vectors for video_id=%s",
len(moment_vectors), video_id,
)
else:
logger.warning(
"Stage 6: Embedding returned empty for %d key moments (video_id=%s). "
"Skipping moment upsert.",
len(moment_texts), video_id,
)
# ── Embed & upsert technique page sections (v2 only) ────────────
section_count = 0
v2_pages = [p for p in pages if getattr(p, "body_sections_format", "v1") == "v2"]
for p in v2_pages:
body_sections = p.body_sections
if not isinstance(body_sections, list):
continue
creator_name = creator_map.get(str(p.creator_id), "")
page_id_str = str(p.id)
# Delete stale section points before re-upserting
try:
qdrant.delete_sections_by_page_id(page_id_str)
except Exception as exc:
logger.warning(
"Stage 6: Failed to delete stale sections for page_id=%s: %s",
page_id_str, exc,
)
section_texts: list[str] = []
section_dicts: list[dict] = []
for section in body_sections:
if not isinstance(section, dict):
logger.warning(
"Stage 6: Malformed section (not a dict) in page_id=%s. Skipping.",
page_id_str,
)
continue
heading = section.get("heading", "")
if not heading or not heading.strip():
continue
section_anchor = _slugify_heading(heading)
section_content = section.get("content", "")
# Include subsection content for richer embedding
subsection_parts: list[str] = []
for sub in section.get("subsections", []):
if isinstance(sub, dict):
sub_heading = sub.get("heading", "")
sub_content = sub.get("content", "")
if sub_heading:
subsection_parts.append(f"{sub_heading}: {sub_content}")
elif sub_content:
subsection_parts.append(sub_content)
embed_text = (
f"{creator_name} {p.title}{heading}: "
f"{section_content} {' '.join(subsection_parts)}"
).strip()
section_texts.append(embed_text)
section_dicts.append({
"page_id": page_id_str,
"creator_id": str(p.creator_id),
"creator_name": creator_name,
"title": p.title,
"slug": p.slug,
"section_heading": heading,
"section_anchor": section_anchor,
"topic_category": p.topic_category or "",
"topic_tags": p.topic_tags or [],
"summary": (section_content or "")[:200],
})
if section_texts:
try:
section_vectors = embed_client.embed(section_texts)
if section_vectors:
qdrant.upsert_technique_sections(section_dicts, section_vectors)
section_count += len(section_vectors)
else:
logger.warning(
"Stage 6: Embedding returned empty for %d sections of page_id=%s. Skipping.",
len(section_texts), page_id_str,
)
except Exception as exc:
logger.warning(
"Stage 6: Section embedding failed for page_id=%s: %s. Skipping.",
page_id_str, exc,
)
if section_count:
logger.info(
"Stage 6: Upserted %d technique section vectors for video_id=%s",
section_count, video_id,
)
elapsed = time.monotonic() - start
logger.info(
"Stage 6 (embed & index) completed for video_id=%s in %.1fs — "
"%d pages, %d moments processed",
video_id, elapsed, len(pages), len(moments),
)
if run_id:
_finish_run(run_id, "complete")
return video_id
except Exception as exc:
# Non-blocking: log error but don't fail the pipeline
logger.error(
"Stage 6 failed for video_id=%s: %s. "
"Pipeline continues — embeddings can be regenerated later.",
video_id, exc,
)
if run_id:
_finish_run(run_id, "complete") # Run is still "complete" — stage6 is best-effort
return video_id
finally:
session.close()
def _snapshot_prior_pages(video_id: str) -> None:
"""Save existing technique_page_ids linked to this video before pipeline resets them.
When a video is reprocessed, stage 3 deletes and recreates key_moments,
breaking the link to technique pages. This snapshots the page IDs to Redis
so stage 5 can find and update prior pages instead of creating duplicates.
"""
import redis
session = _get_sync_session()
try:
# Find technique pages linked via this video's key moments
rows = session.execute(
select(KeyMoment.technique_page_id)
.where(
KeyMoment.source_video_id == video_id,
KeyMoment.technique_page_id.isnot(None),
)
.distinct()
).scalars().all()
page_ids = [str(pid) for pid in rows]
if page_ids:
settings = get_settings()
r = redis.Redis.from_url(settings.redis_url)
key = f"chrysopedia:prior_pages:{video_id}"
r.set(key, json.dumps(page_ids), ex=86400)
logger.info(
"Snapshot %d prior technique pages for video_id=%s: %s",
len(page_ids), video_id, page_ids,
)
else:
logger.info("No prior technique pages for video_id=%s", video_id)
finally:
session.close()
def _load_prior_pages(video_id: str) -> list[str]:
"""Load prior technique page IDs from Redis."""
import redis
settings = get_settings()
r = redis.Redis.from_url(settings.redis_url)
key = f"chrysopedia:prior_pages:{video_id}"
raw = r.get(key)
if raw is None:
return []
return json.loads(raw)
# ── Stage completion detection for auto-resume ───────────────────────────────
# Ordered list of pipeline stages for resumability logic
_PIPELINE_STAGES = [
"stage2_segmentation",
"stage3_extraction",
"stage4_classification",
"stage5_synthesis",
"stage6_embed_and_index",
]
_STAGE_TASKS = {
"stage2_segmentation": stage2_segmentation,
"stage3_extraction": stage3_extraction,
"stage4_classification": stage4_classification,
"stage5_synthesis": stage5_synthesis,
"stage6_embed_and_index": stage6_embed_and_index,
}
def _get_last_completed_stage(video_id: str) -> str | None:
"""Find the last stage that completed successfully for this video.
Queries pipeline_events for the most recent run, looking for 'complete'
events. Returns the stage name (e.g. 'stage3_extraction') or None if
no stages have completed.
"""
session = _get_sync_session()
try:
# Find the most recent run for this video
from models import PipelineRun
latest_run = session.execute(
select(PipelineRun)
.where(PipelineRun.video_id == video_id)
.order_by(PipelineRun.started_at.desc())
.limit(1)
).scalar_one_or_none()
if latest_run is None:
return None
# Get all 'complete' events from that run
completed_events = session.execute(
select(PipelineEvent.stage)
.where(
PipelineEvent.run_id == str(latest_run.id),
PipelineEvent.event_type == "complete",
)
).scalars().all()
completed_set = set(completed_events)
# Walk backwards through the ordered stages to find the last completed one
last_completed = None
for stage_name in _PIPELINE_STAGES:
if stage_name in completed_set:
last_completed = stage_name
else:
break # Stop at first gap — stages must be sequential
if last_completed:
logger.info(
"Auto-resume: video_id=%s last completed stage=%s (run_id=%s)",
video_id, last_completed, latest_run.id,
)
return last_completed
finally:
session.close()
# ── Orchestrator ─────────────────────────────────────────────────────────────
@celery_app.task
def mark_pipeline_error(request, exc, traceback, video_id: str, run_id: str | None = None) -> None:
"""Error callback — marks video as errored when a pipeline stage fails."""
logger.error("Pipeline failed for video_id=%s: %s", video_id, exc)
_set_error_status(video_id, "pipeline", exc)
if run_id:
_finish_run(run_id, "error", error_stage="pipeline")
def _create_run(video_id: str, trigger: str) -> str:
"""Create a PipelineRun and return its id."""
from models import PipelineRun, PipelineRunTrigger
session = _get_sync_session()
try:
# Compute run_number: max existing + 1
from sqlalchemy import func as sa_func
max_num = session.execute(
select(sa_func.coalesce(sa_func.max(PipelineRun.run_number), 0))
.where(PipelineRun.video_id == video_id)
).scalar() or 0
run = PipelineRun(
video_id=video_id,
run_number=max_num + 1,
trigger=PipelineRunTrigger(trigger),
)
session.add(run)
session.commit()
run_id = str(run.id)
return run_id
finally:
session.close()
def _finish_run(run_id: str, status: str, error_stage: str | None = None) -> None:
"""Update a PipelineRun's status and finished_at."""
from models import PipelineRun, PipelineRunStatus, _now
session = _get_sync_session()
try:
run = session.execute(
select(PipelineRun).where(PipelineRun.id == run_id)
).scalar_one_or_none()
if run:
run.status = PipelineRunStatus(status)
run.finished_at = _now()
if error_stage:
run.error_stage = error_stage
# Aggregate total tokens from events
total = session.execute(
select(func.coalesce(func.sum(PipelineEvent.total_tokens), 0))
.where(PipelineEvent.run_id == run_id)
).scalar() or 0
run.total_tokens = total
session.commit()
except Exception as exc:
logger.warning("Failed to finish run %s: %s", run_id, exc)
finally:
session.close()
@celery_app.task
def run_pipeline(video_id: str, trigger: str = "manual") -> str:
"""Orchestrate the full pipeline (stages 2-6) with auto-resume.
For error/processing status, queries pipeline_events to find the last
stage that completed successfully and resumes from the next stage.
This avoids re-running expensive LLM stages that already succeeded.
For clean_reprocess trigger, always starts from stage 2.
Returns the video_id.
"""
logger.info("run_pipeline starting for video_id=%s", video_id)
session = _get_sync_session()
try:
video = session.execute(
select(SourceVideo).where(SourceVideo.id == video_id)
).scalar_one_or_none()
if video is None:
logger.error("run_pipeline: video_id=%s not found", video_id)
raise ValueError(f"Video not found: {video_id}")
status = video.processing_status
logger.info(
"run_pipeline: video_id=%s current status=%s", video_id, status.value
)
finally:
session.close()
if status == ProcessingStatus.complete:
logger.info(
"run_pipeline: video_id=%s already at status=%s, nothing to do.",
video_id, status.value,
)
return video_id
# Snapshot prior technique pages before pipeline resets key_moments
_snapshot_prior_pages(video_id)
# Create a pipeline run record
run_id = _create_run(video_id, trigger)
logger.info("run_pipeline: created run_id=%s for video_id=%s (trigger=%s)", run_id, video_id, trigger)
# Determine which stages to run
resume_from_idx = 0 # Default: start from stage 2
if trigger != "clean_reprocess" and status in (ProcessingStatus.processing, ProcessingStatus.error):
# Try to resume from where we left off
last_completed = _get_last_completed_stage(video_id)
if last_completed and last_completed in _PIPELINE_STAGES:
completed_idx = _PIPELINE_STAGES.index(last_completed)
resume_from_idx = completed_idx + 1
if resume_from_idx >= len(_PIPELINE_STAGES):
logger.info(
"run_pipeline: all stages already completed for video_id=%s",
video_id,
)
return video_id
stages_to_run = _PIPELINE_STAGES[resume_from_idx:]
logger.info(
"run_pipeline: video_id=%s will run stages: %s (resume_from_idx=%d)",
video_id, stages_to_run, resume_from_idx,
)
# Run stages inline (synchronously) so each video completes fully
# before the worker picks up the next queued video.
# This replaces the previous celery_chain dispatch which caused
# interleaved execution when multiple videos were queued.
if stages_to_run:
# Mark as processing before starting
session = _get_sync_session()
try:
video = session.execute(
select(SourceVideo).where(SourceVideo.id == video_id)
).scalar_one()
video.processing_status = ProcessingStatus.processing
session.commit()
finally:
session.close()
logger.info(
"run_pipeline: executing %d stages inline for video_id=%s (run_id=%s, starting at %s)",
len(stages_to_run), video_id, run_id, stages_to_run[0],
)
try:
for stage_name in stages_to_run:
task_func = _STAGE_TASKS[stage_name]
# Call the task directly — runs synchronously in this worker
# process. bind=True tasks receive the task instance as self
# automatically when called this way.
task_func(video_id, run_id=run_id)
except Exception as exc:
logger.error(
"run_pipeline: stage %s failed for video_id=%s: %s",
stage_name, video_id, exc,
)
_set_error_status(video_id, stage_name, exc)
if run_id:
_finish_run(run_id, "error", error_stage=stage_name)
raise
return video_id
# ── Single-Stage Re-Run ─────────────────────────────────────────────────────
@celery_app.task
def run_single_stage(
video_id: str,
stage_name: str,
trigger: str = "stage_rerun",
prompt_override: str | None = None,
) -> str:
"""Re-run a single pipeline stage without running predecessors.
Designed for fast prompt iteration — especially stage 5 synthesis.
Bypasses the processing_status==complete guard, creates a proper
PipelineRun record, and restores status on completion.
If ``prompt_override`` is provided, it is stored in Redis as a
per-video override that ``_load_prompt`` reads before falling back
to the on-disk template. The override is cleaned up after the stage runs.
Returns the video_id.
"""
import redis as redis_lib
logger.info(
"[RERUN] Starting single-stage re-run: video_id=%s, stage=%s, trigger=%s",
video_id, stage_name, trigger,
)
# Validate stage name
if stage_name not in _PIPELINE_STAGES:
raise ValueError(
f"[RERUN] Invalid stage '{stage_name}'. "
f"Valid stages: {_PIPELINE_STAGES}"
)
# Validate video exists
session = _get_sync_session()
try:
video = session.execute(
select(SourceVideo).where(SourceVideo.id == video_id)
).scalar_one_or_none()
if video is None:
raise ValueError(f"[RERUN] Video not found: {video_id}")
original_status = video.processing_status
finally:
session.close()
# Validate prerequisites for the requested stage
prereq_ok, prereq_msg = _check_stage_prerequisites(video_id, stage_name)
if not prereq_ok:
logger.error("[RERUN] Prerequisite check failed: %s", prereq_msg)
raise ValueError(f"[RERUN] Prerequisites not met: {prereq_msg}")
logger.info("[RERUN] Prerequisite check passed: %s", prereq_msg)
# Store prompt override in Redis if provided
override_key = None
if prompt_override:
settings = get_settings()
try:
r = redis_lib.Redis.from_url(settings.redis_url)
# Map stage name to its prompt template
stage_prompt_map = {
"stage2_segmentation": "stage2_segmentation.txt",
"stage3_extraction": "stage3_extraction.txt",
"stage4_classification": "stage4_classification.txt",
"stage5_synthesis": "stage5_synthesis.txt",
}
template = stage_prompt_map.get(stage_name)
if template:
override_key = f"chrysopedia:prompt_override:{video_id}:{template}"
r.set(override_key, prompt_override, ex=3600) # 1-hour TTL
logger.info(
"[RERUN] Prompt override stored: key=%s (%d chars, first 100: %s)",
override_key, len(prompt_override), prompt_override[:100],
)
except Exception as exc:
logger.warning("[RERUN] Failed to store prompt override: %s", exc)
# Snapshot prior pages (needed for stage 5 page matching)
if stage_name in ("stage5_synthesis",):
_snapshot_prior_pages(video_id)
# Create pipeline run record
run_id = _create_run(video_id, trigger)
logger.info("[RERUN] Created run_id=%s", run_id)
# Temporarily set status to processing
session = _get_sync_session()
try:
video = session.execute(
select(SourceVideo).where(SourceVideo.id == video_id)
).scalar_one()
video.processing_status = ProcessingStatus.processing
session.commit()
finally:
session.close()
# Run the single stage
start = time.monotonic()
try:
task_func = _STAGE_TASKS[stage_name]
task_func(video_id, run_id=run_id)
elapsed = time.monotonic() - start
logger.info(
"[RERUN] Stage %s completed: %.1fs, video_id=%s",
stage_name, elapsed, video_id,
)
_finish_run(run_id, "complete")
# Restore status to complete
session = _get_sync_session()
try:
video = session.execute(
select(SourceVideo).where(SourceVideo.id == video_id)
).scalar_one()
video.processing_status = ProcessingStatus.complete
session.commit()
logger.info("[RERUN] Status restored to complete")
finally:
session.close()
except Exception as exc:
elapsed = time.monotonic() - start
logger.error(
"[RERUN] Stage %s FAILED after %.1fs: %s",
stage_name, elapsed, exc,
)
_set_error_status(video_id, stage_name, exc)
_finish_run(run_id, "error", error_stage=stage_name)
raise
finally:
# Clean up prompt override from Redis
if override_key:
try:
settings = get_settings()
r = redis_lib.Redis.from_url(settings.redis_url)
r.delete(override_key)
logger.info("[RERUN] Prompt override cleaned up: %s", override_key)
except Exception as exc:
logger.warning("[RERUN] Failed to clean up prompt override: %s", exc)
return video_id
def _check_stage_prerequisites(video_id: str, stage_name: str) -> tuple[bool, str]:
"""Validate that prerequisite data exists for a stage re-run.
Returns (ok, message) where message describes what was found or missing.
"""
session = _get_sync_session()
try:
if stage_name == "stage2_segmentation":
# Needs transcript segments
count = session.execute(
select(func.count(TranscriptSegment.id))
.where(TranscriptSegment.source_video_id == video_id)
).scalar() or 0
if count == 0:
return False, "No transcript segments found"
return True, f"transcript_segments={count}"
if stage_name == "stage3_extraction":
# Needs transcript segments with topic_labels
count = session.execute(
select(func.count(TranscriptSegment.id))
.where(
TranscriptSegment.source_video_id == video_id,
TranscriptSegment.topic_label.isnot(None),
)
).scalar() or 0
if count == 0:
return False, "No labeled transcript segments (stage 2 must complete first)"
return True, f"labeled_segments={count}"
if stage_name == "stage4_classification":
# Needs key moments
count = session.execute(
select(func.count(KeyMoment.id))
.where(KeyMoment.source_video_id == video_id)
).scalar() or 0
if count == 0:
return False, "No key moments found (stage 3 must complete first)"
return True, f"key_moments={count}"
if stage_name == "stage5_synthesis":
# Needs key moments + classification data
km_count = session.execute(
select(func.count(KeyMoment.id))
.where(KeyMoment.source_video_id == video_id)
).scalar() or 0
if km_count == 0:
return False, "No key moments found (stages 2-3 must complete first)"
cls_data = _load_classification_data(video_id)
cls_source = "redis+pg"
if not cls_data:
return False, f"No classification data found (stage 4 must complete first), key_moments={km_count}"
return True, f"key_moments={km_count}, classification_entries={len(cls_data)}"
if stage_name == "stage6_embed_and_index":
return True, "stage 6 is non-blocking and always runs"
return False, f"Unknown stage: {stage_name}"
finally:
session.close()
# ── Avatar Fetching ─────────────────────────────────────────────────────────
@celery_app.task
def fetch_creator_avatar(creator_id: str) -> dict:
"""Fetch avatar for a single creator from TheAudioDB.
Looks up the creator by ID, calls TheAudioDB, and updates the
avatar_url/avatar_source/avatar_fetched_at columns if a confident
match is found. Returns a status dict.
"""
import sys
from datetime import datetime, timezone
# Ensure /app is on sys.path for forked Celery workers
if "/app" not in sys.path:
sys.path.insert(0, "/app")
from services.avatar import lookup_avatar
session = _get_sync_session()
try:
creator = session.execute(
select(Creator).where(Creator.id == creator_id)
).scalar_one_or_none()
if not creator:
return {"status": "error", "detail": f"Creator {creator_id} not found"}
result = lookup_avatar(creator.name, creator.genres)
if result:
creator.avatar_url = result.url
creator.avatar_source = result.source
creator.avatar_fetched_at = datetime.now(timezone.utc)
session.commit()
return {
"status": "found",
"creator": creator.name,
"avatar_url": result.url,
"confidence": result.confidence,
"matched_artist": result.artist_name,
}
else:
creator.avatar_source = "generated"
creator.avatar_fetched_at = datetime.now(timezone.utc)
session.commit()
return {
"status": "not_found",
"creator": creator.name,
"detail": "No confident match from TheAudioDB",
}
except Exception as exc:
session.rollback()
logger.error("Avatar fetch failed for creator %s: %s", creator_id, exc)
return {"status": "error", "detail": str(exc)}
finally:
session.close()
# ── Highlight Detection ──────────────────────────────────────────────────────
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
def stage_highlight_detection(self, video_id: str, run_id: str | None = None) -> str:
"""Score all KeyMoments for a video and upsert HighlightCandidates.
For each KeyMoment belonging to the video, runs the heuristic scorer and
bulk-upserts results into highlight_candidates (INSERT ON CONFLICT UPDATE).
Returns the video_id for chain compatibility.
"""
from pipeline.highlight_scorer import extract_word_timings, score_moment
start = time.monotonic()
logger.info("Highlight detection starting for video_id=%s", video_id)
_emit_event(video_id, "highlight_detection", "start", run_id=run_id)
session = _get_sync_session()
try:
# ------------------------------------------------------------------
# Load transcript data once for the entire video (word-level timing)
# ------------------------------------------------------------------
transcript_data: list | None = None
source_video = session.execute(
select(SourceVideo).where(SourceVideo.id == video_id)
).scalar_one_or_none()
if source_video and source_video.transcript_path:
transcript_file = source_video.transcript_path
try:
with open(transcript_file, "r") as fh:
raw = json.load(fh)
# Accept both {"segments": [...]} and bare [...]
if isinstance(raw, dict):
transcript_data = raw.get("segments", raw.get("results", []))
elif isinstance(raw, list):
transcript_data = raw
else:
transcript_data = None
if transcript_data:
logger.info(
"Loaded transcript for video_id=%s (%d segments)",
video_id, len(transcript_data),
)
except FileNotFoundError:
logger.warning(
"Transcript file not found for video_id=%s: %s",
video_id, transcript_file,
)
except (json.JSONDecodeError, OSError) as io_exc:
logger.warning(
"Failed to load transcript for video_id=%s: %s",
video_id, io_exc,
)
else:
logger.info(
"No transcript_path for video_id=%s — audio proxy signals will be neutral",
video_id,
)
moments = (
session.execute(
select(KeyMoment)
.where(KeyMoment.source_video_id == video_id)
.order_by(KeyMoment.start_time)
)
.scalars()
.all()
)
if not moments:
logger.info(
"Highlight detection: No key moments for video_id=%s, skipping.", video_id,
)
_emit_event(
video_id, "highlight_detection", "complete",
run_id=run_id, payload={"candidates": 0},
)
return video_id
candidate_count = 0
for moment in moments:
try:
# Extract word-level timings for this moment's window
word_timings = None
if transcript_data:
word_timings = extract_word_timings(
transcript_data, moment.start_time, moment.end_time,
) or None # empty list → None for neutral fallback
result = score_moment(
start_time=moment.start_time,
end_time=moment.end_time,
content_type=moment.content_type.value if moment.content_type else None,
summary=moment.summary,
plugins=moment.plugins,
raw_transcript=moment.raw_transcript,
source_quality=None, # filled below if technique_page loaded
video_content_type=None, # filled below if source_video loaded
word_timings=word_timings,
)
except Exception as score_exc:
logger.warning(
"Highlight detection: score_moment failed for moment %s: %s",
moment.id, score_exc,
)
result = {
"score": 0.0,
"score_breakdown": {},
"duration_secs": max(0.0, moment.end_time - moment.start_time),
}
stmt = pg_insert(HighlightCandidate).values(
key_moment_id=moment.id,
source_video_id=moment.source_video_id,
score=result["score"],
score_breakdown=result["score_breakdown"],
duration_secs=result["duration_secs"],
)
stmt = stmt.on_conflict_do_update(
constraint="highlight_candidates_key_moment_id_key",
set_={
"score": stmt.excluded.score,
"score_breakdown": stmt.excluded.score_breakdown,
"duration_secs": stmt.excluded.duration_secs,
"updated_at": func.now(),
},
)
session.execute(stmt)
candidate_count += 1
session.commit()
elapsed = time.monotonic() - start
_emit_event(
video_id, "highlight_detection", "complete",
run_id=run_id, payload={"candidates": candidate_count},
)
logger.info(
"Highlight detection completed for video_id=%s in %.1fs — %d candidates upserted",
video_id, elapsed, candidate_count,
)
return video_id
except Exception as exc:
session.rollback()
_emit_event(
video_id, "highlight_detection", "error",
run_id=run_id, payload={"error": str(exc)},
)
logger.error("Highlight detection failed for video_id=%s: %s", video_id, exc)
raise self.retry(exc=exc)
finally:
session.close()
# ── Personality profile extraction ───────────────────────────────────────────
def _sample_creator_transcripts(
moments: list,
creator_id: str,
max_chars: int = 40000,
) -> tuple[str, int]:
"""Sample transcripts from a creator's key moments, respecting size tiers.
- Small (<20K chars total): use all text.
- Medium (20K-60K): first 300 chars from each moment, up to budget.
- Large (>60K): random sample seeded by creator_id, attempts topic diversity
via Redis classification data.
Returns (sampled_text, total_char_count).
"""
import random
transcripts = [
(m.source_video_id, m.raw_transcript)
for m in moments
if m.raw_transcript and m.raw_transcript.strip()
]
if not transcripts:
return ("", 0)
total_chars = sum(len(t) for _, t in transcripts)
# Small: use everything
if total_chars <= 20_000:
text = "\n\n---\n\n".join(t for _, t in transcripts)
return (text, total_chars)
# Medium: first 300 chars from each moment
if total_chars <= 60_000:
excerpts = []
budget = max_chars
for _, t in transcripts:
chunk = t[:300]
if budget - len(chunk) < 0:
break
excerpts.append(chunk)
budget -= len(chunk)
text = "\n\n---\n\n".join(excerpts)
return (text, total_chars)
# Large: random sample with optional topic diversity from Redis
topic_map: dict[str, list[tuple[str, str]]] = {}
try:
import redis as _redis
settings = get_settings()
r = _redis.from_url(settings.redis_url)
video_ids = {str(vid) for vid, _ in transcripts}
for vid in video_ids:
raw = r.get(f"chrysopedia:classification:{vid}")
if raw:
classification = json.loads(raw)
if isinstance(classification, list):
for item in classification:
cat = item.get("topic_category", "unknown")
moment_id = item.get("moment_id")
if moment_id:
topic_map.setdefault(cat, []).append(moment_id)
r.close()
except Exception:
# Fall back to random sampling without topic diversity
pass
rng = random.Random(creator_id)
if topic_map:
# Interleave from different categories for diversity
ordered = []
cat_lists = list(topic_map.values())
rng.shuffle(cat_lists)
idx = 0
while any(cat_lists):
for cat in cat_lists:
if cat:
ordered.append(cat.pop(0))
cat_lists = [c for c in cat_lists if c]
# Map moment IDs back to transcripts
moment_lookup = {str(m.id): m.raw_transcript for m in moments if m.raw_transcript}
diverse_transcripts = [
moment_lookup[mid] for mid in ordered if mid in moment_lookup
]
if diverse_transcripts:
transcripts_list = diverse_transcripts
else:
transcripts_list = [t for _, t in transcripts]
else:
transcripts_list = [t for _, t in transcripts]
rng.shuffle(transcripts_list)
excerpts = []
budget = max_chars
for t in transcripts_list:
chunk = t[:600]
if budget - len(chunk) < 0:
break
excerpts.append(chunk)
budget -= len(chunk)
text = "\n\n---\n\n".join(excerpts)
return (text, total_chars)
@celery_app.task(bind=True, max_retries=2, default_retry_delay=60)
def extract_personality_profile(self, creator_id: str) -> str:
"""Extract a personality profile from a creator's transcripts via LLM.
Aggregates and samples transcripts from all of the creator's key moments,
sends them to the LLM with the personality_extraction prompt, validates
the response, and stores the profile as JSONB on Creator.personality_profile.
Returns the creator_id for chain compatibility.
"""
from datetime import datetime, timezone
start = time.monotonic()
logger.info("Personality extraction starting for creator_id=%s", creator_id)
_emit_event(creator_id, "personality_extraction", "start")
session = _get_sync_session()
try:
# Load creator
creator = session.execute(
select(Creator).where(Creator.id == creator_id)
).scalar_one_or_none()
if not creator:
logger.error("Creator not found: %s", creator_id)
_emit_event(
creator_id, "personality_extraction", "error",
payload={"error": "creator_not_found"},
)
return creator_id
# Load all key moments with transcripts for this creator
moments = (
session.execute(
select(KeyMoment)
.join(SourceVideo, KeyMoment.source_video_id == SourceVideo.id)
.where(SourceVideo.creator_id == creator.id)
.where(KeyMoment.raw_transcript.isnot(None))
)
.scalars()
.all()
)
if not moments:
logger.warning(
"No transcripts found for creator_id=%s (%s), skipping extraction",
creator_id, creator.name,
)
_emit_event(
creator_id, "personality_extraction", "complete",
payload={"skipped": True, "reason": "no_transcripts"},
)
return creator_id
# Sample transcripts
sampled_text, total_chars = _sample_creator_transcripts(
moments, creator_id,
)
if not sampled_text.strip():
logger.warning(
"Empty transcript sample for creator_id=%s, skipping", creator_id,
)
_emit_event(
creator_id, "personality_extraction", "complete",
payload={"skipped": True, "reason": "empty_sample"},
)
return creator_id
# Load prompt and call LLM
system_prompt = _load_prompt("personality_extraction.txt")
user_prompt = (
f"Creator: {creator.name}\n\n"
f"Transcript excerpts ({len(moments)} moments, {total_chars} total chars, "
f"sample below):\n\n{sampled_text}"
)
llm = _get_llm_client()
callback = _make_llm_callback(
creator_id, "personality_extraction",
system_prompt=system_prompt,
user_prompt=user_prompt,
)
response = llm.complete(
system_prompt=system_prompt,
user_prompt=user_prompt,
response_model=object, # triggers JSON mode
on_complete=callback,
)
# Parse and validate
from schemas import PersonalityProfile as ProfileValidator
try:
raw_profile = json.loads(str(response))
except json.JSONDecodeError as jde:
logger.warning(
"LLM returned invalid JSON for creator_id=%s, retrying: %s",
creator_id, jde,
)
raise self.retry(exc=jde)
try:
validated = ProfileValidator.model_validate(raw_profile)
except ValidationError as ve:
logger.warning(
"LLM profile failed validation for creator_id=%s, retrying: %s",
creator_id, ve,
)
raise self.retry(exc=ve)
# Build final profile dict with metadata
profile_dict = validated.model_dump()
profile_dict["_metadata"] = {
"extracted_at": datetime.now(timezone.utc).replace(tzinfo=None).isoformat(),
"transcript_sample_size": total_chars,
"moments_count": len(moments),
"model_used": getattr(response, "finish_reason", None) or "unknown",
}
# Low sample size note
if total_chars < 500:
profile_dict["_metadata"]["low_sample_size"] = True
# Store on creator
creator.personality_profile = profile_dict
session.commit()
elapsed = time.monotonic() - start
_emit_event(
creator_id, "personality_extraction", "complete",
duration_ms=int(elapsed * 1000),
payload={
"moments_count": len(moments),
"transcript_chars": total_chars,
"sample_chars": len(sampled_text),
},
)
logger.info(
"Personality extraction completed for creator_id=%s (%s) in %.1fs — "
"%d moments, %d chars sampled",
creator_id, creator.name, elapsed, len(moments), len(sampled_text),
)
return creator_id
except Exception as exc:
if isinstance(exc, (self.MaxRetriesExceededError,)):
raise
session.rollback()
_emit_event(
creator_id, "personality_extraction", "error",
payload={"error": str(exc)[:500]},
)
logger.error(
"Personality extraction failed for creator_id=%s: %s", creator_id, exc,
)
raise self.retry(exc=exc)
finally:
session.close()
# ── Stage: Shorts Generation ─────────────────────────────────────────────────
@celery_app.task(bind=True, max_retries=1, default_retry_delay=60)
def stage_generate_shorts(self, highlight_candidate_id: str, captions: bool = True) -> str:
"""Generate video shorts for an approved highlight candidate.
Creates one GeneratedShort row per FormatPreset, extracts the clip via
ffmpeg, uploads to MinIO, and updates status. Each preset is independent —
a failure on one does not block the others.
Args:
highlight_candidate_id: UUID string of the approved highlight.
captions: Whether to generate and burn in ASS subtitles (default True).
Returns the highlight_candidate_id on completion.
"""
from pipeline.shorts_generator import PRESETS, extract_clip_with_template, resolve_video_path
from pipeline.caption_generator import generate_ass_captions, write_ass_file
from pipeline.card_renderer import parse_template_config, render_card_to_file
from models import FormatPreset, GeneratedShort, ShortStatus, SourceVideo
start = time.monotonic()
session = _get_sync_session()
settings = get_settings()
try:
# ── Load highlight with joined relations ────────────────────────
highlight = session.execute(
select(HighlightCandidate)
.where(HighlightCandidate.id == highlight_candidate_id)
).scalar_one_or_none()
if highlight is None:
logger.error(
"Highlight candidate not found: %s", highlight_candidate_id,
)
return highlight_candidate_id
if highlight.status.value != "approved":
logger.warning(
"Highlight %s status is %s, expected approved — skipping",
highlight_candidate_id, highlight.status.value,
)
return highlight_candidate_id
# Check for already-processing shorts (reject duplicate runs)
existing_processing = session.execute(
select(func.count())
.where(GeneratedShort.highlight_candidate_id == highlight_candidate_id)
.where(GeneratedShort.status == ShortStatus.processing)
).scalar()
if existing_processing and existing_processing > 0:
logger.warning(
"Highlight %s already has %d processing shorts — rejecting duplicate",
highlight_candidate_id, existing_processing,
)
return highlight_candidate_id
# Eager-load relations
key_moment = highlight.key_moment
source_video = highlight.source_video
# ── Resolve video file path ─────────────────────────────────────
try:
video_path = resolve_video_path(
settings.video_source_path, source_video.file_path,
)
except FileNotFoundError as fnf:
logger.error(
"Video file missing for highlight %s: %s",
highlight_candidate_id, fnf,
)
# Mark all presets as failed
for preset in FormatPreset:
spec = PRESETS[preset]
short = GeneratedShort(
highlight_candidate_id=highlight_candidate_id,
format_preset=preset,
width=spec.width,
height=spec.height,
status=ShortStatus.failed,
error_message=str(fnf),
)
session.add(short)
session.commit()
return highlight_candidate_id
# ── Compute effective start/end (trim overrides) ────────────────
clip_start = highlight.trim_start if highlight.trim_start is not None else key_moment.start_time
clip_end = highlight.trim_end if highlight.trim_end is not None else key_moment.end_time
logger.info(
"Generating shorts for highlight=%s video=%s [%.1f%.1f]s",
highlight_candidate_id, source_video.file_path,
clip_start, clip_end,
)
# ── Generate captions from transcript (if available and requested) ─
ass_path: Path | None = None
captions_ok = False
if not captions:
logger.info(
"Captions disabled for highlight=%s — skipping caption generation",
highlight_candidate_id,
)
else:
try:
transcript_data: list | None = None
if source_video.transcript_path:
try:
with open(source_video.transcript_path, "r") as fh:
raw = json.load(fh)
if isinstance(raw, dict):
transcript_data = raw.get("segments", raw.get("results", []))
elif isinstance(raw, list):
transcript_data = raw
except (FileNotFoundError, json.JSONDecodeError, OSError) as io_exc:
logger.warning(
"Failed to load transcript for captions highlight=%s: %s",
highlight_candidate_id, io_exc,
)
if transcript_data:
from pipeline.highlight_scorer import extract_word_timings
word_timings = extract_word_timings(transcript_data, clip_start, clip_end)
if word_timings:
ass_content = generate_ass_captions(word_timings, clip_start)
ass_path = write_ass_file(
ass_content,
Path(f"/tmp/captions_{highlight_candidate_id}.ass"),
)
captions_ok = True
logger.info(
"Generated captions for highlight=%s (%d words)",
highlight_candidate_id, len(word_timings),
)
else:
logger.warning(
"No word timings in transcript window [%.1f%.1f]s for highlight=%s — proceeding without captions",
clip_start, clip_end, highlight_candidate_id,
)
else:
logger.info(
"No transcript available for highlight=%s — proceeding without captions",
highlight_candidate_id,
)
except Exception as cap_exc:
logger.warning(
"Caption generation failed for highlight=%s: %s — proceeding without captions",
highlight_candidate_id, cap_exc,
)
# ── Load creator template config (if available) ─────────────────
intro_path: Path | None = None
outro_path: Path | None = None
try:
creator = source_video.creator
template_cfg = parse_template_config(
creator.shorts_template if creator else None,
)
except Exception as tmpl_exc:
logger.warning(
"Template config load failed for highlight=%s: %s — proceeding without cards",
highlight_candidate_id, tmpl_exc,
)
template_cfg = parse_template_config(None)
# ── Process each preset independently ───────────────────────────
for preset in FormatPreset:
spec = PRESETS[preset]
preset_start = time.monotonic()
# Create DB row (status=processing)
short = GeneratedShort(
highlight_candidate_id=highlight_candidate_id,
format_preset=preset,
width=spec.width,
height=spec.height,
status=ShortStatus.processing,
duration_secs=clip_end - clip_start,
)
session.add(short)
session.commit()
session.refresh(short)
tmp_path = Path(f"/tmp/short_{short.id}_{preset.value}.mp4")
minio_key = f"shorts/{highlight_candidate_id}/{preset.value}.mp4"
try:
# Render intro/outro cards for this preset's resolution
preset_intro: Path | None = None
preset_outro: Path | None = None
if template_cfg["show_intro"] and template_cfg["intro_text"]:
preset_intro = Path(
f"/tmp/intro_{short.id}_{preset.value}.mp4"
)
try:
render_card_to_file(
text=template_cfg["intro_text"],
duration_secs=template_cfg["intro_duration"],
width=spec.width,
height=spec.height,
output_path=preset_intro,
accent_color=template_cfg["accent_color"],
font_family=template_cfg["font_family"],
)
except Exception as intro_exc:
logger.warning(
"Intro card render failed for highlight=%s preset=%s: %s — skipping intro",
highlight_candidate_id, preset.value, intro_exc,
)
preset_intro = None
if template_cfg["show_outro"] and template_cfg["outro_text"]:
preset_outro = Path(
f"/tmp/outro_{short.id}_{preset.value}.mp4"
)
try:
render_card_to_file(
text=template_cfg["outro_text"],
duration_secs=template_cfg["outro_duration"],
width=spec.width,
height=spec.height,
output_path=preset_outro,
accent_color=template_cfg["accent_color"],
font_family=template_cfg["font_family"],
)
except Exception as outro_exc:
logger.warning(
"Outro card render failed for highlight=%s preset=%s: %s — skipping outro",
highlight_candidate_id, preset.value, outro_exc,
)
preset_outro = None
# Extract clip (with optional template cards)
extract_clip_with_template(
input_path=video_path,
output_path=tmp_path,
start_secs=clip_start,
end_secs=clip_end,
vf_filter=spec.vf_filter,
ass_path=ass_path,
intro_path=preset_intro,
outro_path=preset_outro,
)
# Upload to MinIO
file_size = tmp_path.stat().st_size
with open(tmp_path, "rb") as f:
from minio_client import upload_file
upload_file(
object_key=minio_key,
data=f,
length=file_size,
content_type="video/mp4",
)
# Update DB row — complete
short.status = ShortStatus.complete
short.file_size_bytes = file_size
short.minio_object_key = minio_key
short.captions_enabled = captions_ok
short.share_token = secrets.token_urlsafe(8)
session.commit()
elapsed_preset = time.monotonic() - preset_start
logger.info(
"Short generated: highlight=%s preset=%s "
"size=%d bytes duration=%.1fs elapsed=%.1fs",
highlight_candidate_id, preset.value,
file_size, clip_end - clip_start, elapsed_preset,
)
except Exception as exc:
session.rollback()
# Re-fetch the short row after rollback
session.refresh(short)
short.status = ShortStatus.failed
short.error_message = str(exc)[:2000]
session.commit()
elapsed_preset = time.monotonic() - preset_start
logger.error(
"Short failed: highlight=%s preset=%s "
"error=%s elapsed=%.1fs",
highlight_candidate_id, preset.value,
str(exc)[:500], elapsed_preset,
)
finally:
# Clean up temp files (main clip + intro/outro cards)
for tmp in [tmp_path, preset_intro, preset_outro]:
if tmp is not None and tmp.exists():
try:
tmp.unlink()
except OSError:
pass
# Clean up temp ASS caption file
if ass_path is not None and ass_path.exists():
try:
ass_path.unlink()
except OSError:
pass
elapsed = time.monotonic() - start
logger.info(
"Shorts generation complete for highlight=%s in %.1fs",
highlight_candidate_id, elapsed,
)
return highlight_candidate_id
except Exception as exc:
session.rollback()
logger.error(
"Shorts generation failed for highlight=%s: %s",
highlight_candidate_id, exc,
)
raise self.retry(exc=exc)
finally:
session.close()