- "backend/pipeline/shorts_generator.py" - "backend/pipeline/stages.py" GSD-Task: S03/T02
3051 lines
117 KiB
Python
3051 lines
117 KiB
Python
"""Pipeline stage tasks (stages 2-5) and run_pipeline orchestrator.
|
||
|
||
Each stage reads from PostgreSQL via sync SQLAlchemy, loads its prompt
|
||
template from disk, calls the LLM client, parses the response, writes
|
||
results back, and updates processing_status on SourceVideo.
|
||
|
||
Celery tasks are synchronous — all DB access uses ``sqlalchemy.orm.Session``.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import hashlib
|
||
import json
|
||
import logging
|
||
import os
|
||
import re
|
||
import subprocess
|
||
import time
|
||
from collections import defaultdict
|
||
from pathlib import Path
|
||
|
||
import yaml
|
||
|
||
from pydantic import ValidationError
|
||
from sqlalchemy import create_engine, func, select
|
||
from sqlalchemy.orm import Session, sessionmaker
|
||
|
||
from config import get_settings
|
||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||
|
||
from models import (
|
||
Creator,
|
||
HighlightCandidate,
|
||
KeyMoment,
|
||
KeyMomentContentType,
|
||
PipelineEvent,
|
||
ProcessingStatus,
|
||
SourceVideo,
|
||
TechniquePage,
|
||
TechniquePageVersion,
|
||
TechniquePageVideo,
|
||
TranscriptSegment,
|
||
)
|
||
from pipeline.embedding_client import EmbeddingClient
|
||
from pipeline.llm_client import LLMClient, LLMResponse, estimate_max_tokens
|
||
from pipeline.qdrant_client import QdrantManager
|
||
from pipeline.schemas import (
|
||
ClassificationResult,
|
||
ExtractionResult,
|
||
SegmentationResult,
|
||
SynthesisResult,
|
||
)
|
||
from worker import celery_app
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class LLMTruncationError(RuntimeError):
|
||
"""Raised when the LLM response was truncated (finish_reason=length)."""
|
||
pass
|
||
|
||
|
||
# ── Error status helper ──────────────────────────────────────────────────────
|
||
|
||
def _set_error_status(video_id: str, stage_name: str, error: Exception) -> None:
|
||
"""Mark a video as errored when a pipeline stage fails permanently."""
|
||
try:
|
||
session = _get_sync_session()
|
||
video = session.execute(
|
||
select(SourceVideo).where(SourceVideo.id == video_id)
|
||
).scalar_one_or_none()
|
||
if video:
|
||
video.processing_status = ProcessingStatus.error
|
||
session.commit()
|
||
session.close()
|
||
except Exception as mark_exc:
|
||
logger.error(
|
||
"Failed to mark video_id=%s as error after %s failure: %s",
|
||
video_id, stage_name, mark_exc,
|
||
)
|
||
|
||
|
||
# ── Pipeline event persistence ───────────────────────────────────────────────
|
||
|
||
def _emit_event(
|
||
video_id: str,
|
||
stage: str,
|
||
event_type: str,
|
||
*,
|
||
run_id: str | None = None,
|
||
prompt_tokens: int | None = None,
|
||
completion_tokens: int | None = None,
|
||
total_tokens: int | None = None,
|
||
model: str | None = None,
|
||
duration_ms: int | None = None,
|
||
payload: dict | None = None,
|
||
system_prompt_text: str | None = None,
|
||
user_prompt_text: str | None = None,
|
||
response_text: str | None = None,
|
||
) -> None:
|
||
"""Persist a pipeline event to the DB. Best-effort -- failures logged, not raised."""
|
||
try:
|
||
session = _get_sync_session()
|
||
try:
|
||
event = PipelineEvent(
|
||
video_id=video_id,
|
||
run_id=run_id,
|
||
stage=stage,
|
||
event_type=event_type,
|
||
prompt_tokens=prompt_tokens,
|
||
completion_tokens=completion_tokens,
|
||
total_tokens=total_tokens,
|
||
model=model,
|
||
duration_ms=duration_ms,
|
||
payload=payload,
|
||
system_prompt_text=system_prompt_text,
|
||
user_prompt_text=user_prompt_text,
|
||
response_text=response_text,
|
||
)
|
||
session.add(event)
|
||
session.commit()
|
||
finally:
|
||
session.close()
|
||
except Exception as exc:
|
||
logger.warning("Failed to emit pipeline event: %s", exc)
|
||
|
||
|
||
def _is_debug_mode() -> bool:
|
||
"""Check if debug mode is enabled via Redis. Falls back to config setting."""
|
||
try:
|
||
import redis
|
||
settings = get_settings()
|
||
r = redis.from_url(settings.redis_url)
|
||
val = r.get("chrysopedia:debug_mode")
|
||
r.close()
|
||
if val is not None:
|
||
return val.decode().lower() == "true"
|
||
except Exception:
|
||
pass
|
||
return getattr(get_settings(), "debug_mode", False)
|
||
|
||
|
||
def _make_llm_callback(
|
||
video_id: str,
|
||
stage: str,
|
||
system_prompt: str | None = None,
|
||
user_prompt: str | None = None,
|
||
run_id: str | None = None,
|
||
context_label: str | None = None,
|
||
request_params: dict | None = None,
|
||
):
|
||
"""Create an on_complete callback for LLMClient that emits llm_call events.
|
||
|
||
When debug mode is enabled, captures full system prompt, user prompt,
|
||
and response text on each llm_call event.
|
||
|
||
Parameters
|
||
----------
|
||
request_params:
|
||
Dict of LLM request parameters (max_tokens, model_override, modality,
|
||
response_model, temperature, etc.) to store in the event payload for
|
||
debugging which parameters were actually sent to the API.
|
||
"""
|
||
debug = _is_debug_mode()
|
||
|
||
def callback(*, model=None, prompt_tokens=None, completion_tokens=None,
|
||
total_tokens=None, content=None, finish_reason=None,
|
||
is_fallback=False, **_kwargs):
|
||
# Truncate content for storage — keep first 2000 chars for debugging
|
||
truncated = content[:2000] if content and len(content) > 2000 else content
|
||
_emit_event(
|
||
video_id=video_id,
|
||
stage=stage,
|
||
event_type="llm_call",
|
||
run_id=run_id,
|
||
model=model,
|
||
prompt_tokens=prompt_tokens,
|
||
completion_tokens=completion_tokens,
|
||
total_tokens=total_tokens,
|
||
payload={
|
||
"content_preview": truncated,
|
||
"content_length": len(content) if content else 0,
|
||
"finish_reason": finish_reason,
|
||
"is_fallback": is_fallback,
|
||
**({"context": context_label} if context_label else {}),
|
||
**({"request_params": request_params} if request_params else {}),
|
||
},
|
||
system_prompt_text=system_prompt if debug else None,
|
||
user_prompt_text=user_prompt if debug else None,
|
||
response_text=content if debug else None,
|
||
)
|
||
return callback
|
||
|
||
|
||
def _build_request_params(
|
||
max_tokens: int,
|
||
model_override: str | None,
|
||
modality: str,
|
||
response_model: str,
|
||
hard_limit: int,
|
||
) -> dict:
|
||
"""Build the request_params dict for pipeline event logging.
|
||
|
||
Separates actual API params (sent to the LLM) from internal config
|
||
(used by our estimator only) so the debug JSON is unambiguous.
|
||
"""
|
||
settings = get_settings()
|
||
return {
|
||
"api_params": {
|
||
"max_tokens": max_tokens,
|
||
"model": model_override or settings.llm_model,
|
||
"temperature": settings.llm_temperature,
|
||
"response_format": "json_object" if modality == "chat" else "none (thinking mode)",
|
||
},
|
||
"pipeline_config": {
|
||
"modality": modality,
|
||
"response_model": response_model,
|
||
"estimator_hard_limit": hard_limit,
|
||
"fallback_max_tokens": settings.llm_max_tokens,
|
||
},
|
||
}
|
||
|
||
|
||
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||
|
||
_engine = None
|
||
_SessionLocal = None
|
||
|
||
|
||
def _get_sync_engine():
|
||
"""Create a sync SQLAlchemy engine, converting the async URL if needed."""
|
||
global _engine
|
||
if _engine is None:
|
||
settings = get_settings()
|
||
url = settings.database_url
|
||
# Convert async driver to sync driver
|
||
url = url.replace("postgresql+asyncpg://", "postgresql+psycopg2://")
|
||
_engine = create_engine(url, pool_pre_ping=True, pool_size=5, max_overflow=10)
|
||
return _engine
|
||
|
||
|
||
def _get_sync_session() -> Session:
|
||
"""Create a sync SQLAlchemy session for Celery tasks."""
|
||
global _SessionLocal
|
||
if _SessionLocal is None:
|
||
_SessionLocal = sessionmaker(bind=_get_sync_engine())
|
||
return _SessionLocal()
|
||
|
||
|
||
def _load_prompt(template_name: str, video_id: str | None = None) -> str:
|
||
"""Read a prompt template from the prompts directory.
|
||
|
||
If ``video_id`` is provided, checks Redis for a per-video prompt override
|
||
(key: ``chrysopedia:prompt_override:{video_id}:{template_name}``) before
|
||
falling back to the on-disk template. Overrides are set by
|
||
``run_single_stage`` for single-stage re-runs with custom prompts.
|
||
|
||
Raises FileNotFoundError if no override exists and the template is missing.
|
||
"""
|
||
# Check for per-video prompt override in Redis
|
||
if video_id:
|
||
try:
|
||
import redis
|
||
settings = get_settings()
|
||
r = redis.Redis.from_url(settings.redis_url)
|
||
override_key = f"chrysopedia:prompt_override:{video_id}:{template_name}"
|
||
override = r.get(override_key)
|
||
if override:
|
||
prompt_text = override.decode("utf-8")
|
||
logger.info(
|
||
"[PROMPT] Using override from Redis: video_id=%s, template=%s (%d chars)",
|
||
video_id, template_name, len(prompt_text),
|
||
)
|
||
return prompt_text
|
||
except Exception as exc:
|
||
logger.warning("[PROMPT] Redis override check failed: %s", exc)
|
||
|
||
settings = get_settings()
|
||
path = Path(settings.prompts_path) / template_name
|
||
if not path.exists():
|
||
logger.error("Prompt template not found: %s", path)
|
||
raise FileNotFoundError(f"Prompt template not found: {path}")
|
||
return path.read_text(encoding="utf-8")
|
||
|
||
|
||
def _get_llm_client() -> LLMClient:
|
||
"""Return an LLMClient configured from settings."""
|
||
return LLMClient(get_settings())
|
||
|
||
|
||
def _get_stage_config(stage_num: int) -> tuple[str | None, str]:
|
||
"""Return (model_override, modality) for a pipeline stage.
|
||
|
||
Reads stage-specific config from Settings. If the stage-specific model
|
||
is None/empty, returns None (LLMClient will use its default). If the
|
||
stage-specific modality is unset, defaults to "chat".
|
||
"""
|
||
settings = get_settings()
|
||
model = getattr(settings, f"llm_stage{stage_num}_model", None) or None
|
||
modality = getattr(settings, f"llm_stage{stage_num}_modality", None) or "chat"
|
||
return model, modality
|
||
|
||
|
||
def _load_canonical_tags() -> dict:
|
||
"""Load canonical tag taxonomy from config/canonical_tags.yaml."""
|
||
# Walk up from backend/ to find config/
|
||
candidates = [
|
||
Path("config/canonical_tags.yaml"),
|
||
Path("../config/canonical_tags.yaml"),
|
||
]
|
||
for candidate in candidates:
|
||
if candidate.exists():
|
||
with open(candidate, encoding="utf-8") as f:
|
||
return yaml.safe_load(f)
|
||
raise FileNotFoundError(
|
||
"canonical_tags.yaml not found. Searched: " + ", ".join(str(c) for c in candidates)
|
||
)
|
||
|
||
|
||
def _format_taxonomy_for_prompt(tags_data: dict) -> str:
|
||
"""Format the canonical tags taxonomy as readable text for the LLM prompt."""
|
||
lines = []
|
||
for cat in tags_data.get("categories", []):
|
||
lines.append(f"Category: {cat['name']}")
|
||
lines.append(f" Description: {cat['description']}")
|
||
lines.append(f" Sub-topics: {', '.join(cat.get('sub_topics', []))}")
|
||
lines.append("")
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _safe_parse_llm_response(
|
||
raw,
|
||
model_cls,
|
||
llm: LLMClient,
|
||
system_prompt: str,
|
||
user_prompt: str,
|
||
modality: str = "chat",
|
||
model_override: str | None = None,
|
||
max_tokens: int | None = None,
|
||
):
|
||
"""Parse LLM response with truncation detection and one retry on failure.
|
||
|
||
If the response was truncated (finish_reason=length), raises
|
||
LLMTruncationError immediately — retrying with a JSON nudge would only
|
||
make things worse by adding tokens to an already-too-large prompt.
|
||
|
||
For non-truncation parse failures: retry once with a JSON nudge, then
|
||
raise on second failure.
|
||
"""
|
||
# Check for truncation before attempting parse
|
||
is_truncated = isinstance(raw, LLMResponse) and raw.truncated
|
||
if is_truncated:
|
||
logger.warning(
|
||
"LLM response truncated (finish=length) for %s. "
|
||
"prompt_tokens=%s, completion_tokens=%s. Will not retry with nudge.",
|
||
model_cls.__name__,
|
||
getattr(raw, "prompt_tokens", "?"),
|
||
getattr(raw, "completion_tokens", "?"),
|
||
)
|
||
|
||
try:
|
||
return llm.parse_response(raw, model_cls)
|
||
except (ValidationError, ValueError, json.JSONDecodeError) as exc:
|
||
if is_truncated:
|
||
raise LLMTruncationError(
|
||
f"LLM output truncated for {model_cls.__name__}: "
|
||
f"prompt_tokens={getattr(raw, 'prompt_tokens', '?')}, "
|
||
f"completion_tokens={getattr(raw, 'completion_tokens', '?')}. "
|
||
f"Response too large for model context window."
|
||
) from exc
|
||
|
||
logger.warning(
|
||
"First parse attempt failed for %s (%s). Retrying with JSON nudge. "
|
||
"Raw response (first 500 chars): %.500s",
|
||
model_cls.__name__,
|
||
type(exc).__name__,
|
||
raw,
|
||
)
|
||
# Retry with explicit JSON instruction
|
||
nudge_prompt = user_prompt + "\n\nIMPORTANT: Output ONLY valid JSON. No markdown, no explanation."
|
||
retry_raw = llm.complete(
|
||
system_prompt, nudge_prompt, response_model=model_cls,
|
||
modality=modality, model_override=model_override,
|
||
max_tokens=max_tokens,
|
||
)
|
||
return llm.parse_response(retry_raw, model_cls)
|
||
|
||
|
||
# ── Stage 2: Segmentation ───────────────────────────────────────────────────
|
||
|
||
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
|
||
def stage2_segmentation(self, video_id: str, run_id: str | None = None) -> str:
|
||
"""Analyze transcript segments and identify topic boundaries.
|
||
|
||
Loads all TranscriptSegment rows for the video, sends them to the LLM
|
||
for topic boundary detection, and updates topic_label on each segment.
|
||
|
||
Returns the video_id for chain compatibility.
|
||
"""
|
||
start = time.monotonic()
|
||
logger.info("Stage 2 (segmentation) starting for video_id=%s", video_id)
|
||
_emit_event(video_id, "stage2_segmentation", "start", run_id=run_id)
|
||
|
||
session = _get_sync_session()
|
||
try:
|
||
# Load segments ordered by index
|
||
segments = (
|
||
session.execute(
|
||
select(TranscriptSegment)
|
||
.where(TranscriptSegment.source_video_id == video_id)
|
||
.order_by(TranscriptSegment.segment_index)
|
||
)
|
||
.scalars()
|
||
.all()
|
||
)
|
||
|
||
if not segments:
|
||
logger.info("Stage 2: No segments found for video_id=%s, skipping.", video_id)
|
||
return video_id
|
||
|
||
# Build transcript text with indices for the LLM
|
||
transcript_lines = []
|
||
for seg in segments:
|
||
transcript_lines.append(
|
||
f"[{seg.segment_index}] ({seg.start_time:.1f}s - {seg.end_time:.1f}s) {seg.text}"
|
||
)
|
||
transcript_text = "\n".join(transcript_lines)
|
||
|
||
# Load prompt and call LLM
|
||
system_prompt = _load_prompt("stage2_segmentation.txt", video_id=video_id)
|
||
user_prompt = f"<transcript>\n{transcript_text}\n</transcript>"
|
||
|
||
llm = _get_llm_client()
|
||
model_override, modality = _get_stage_config(2)
|
||
hard_limit = get_settings().llm_max_tokens_hard_limit
|
||
max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage2_segmentation", hard_limit=hard_limit)
|
||
logger.info("Stage 2 using model=%s, modality=%s, max_tokens=%d", model_override or "default", modality, max_tokens)
|
||
_s2_request_params = _build_request_params(max_tokens, model_override, modality, "SegmentationResult", hard_limit)
|
||
raw = llm.complete(system_prompt, user_prompt, response_model=SegmentationResult, on_complete=_make_llm_callback(video_id, "stage2_segmentation", system_prompt=system_prompt, user_prompt=user_prompt, run_id=run_id, request_params=_s2_request_params),
|
||
modality=modality, model_override=model_override, max_tokens=max_tokens)
|
||
result = _safe_parse_llm_response(raw, SegmentationResult, llm, system_prompt, user_prompt,
|
||
modality=modality, model_override=model_override, max_tokens=max_tokens)
|
||
|
||
# Update topic_label on each segment row
|
||
seg_by_index = {s.segment_index: s for s in segments}
|
||
for topic_seg in result.segments:
|
||
for idx in range(topic_seg.start_index, topic_seg.end_index + 1):
|
||
if idx in seg_by_index:
|
||
seg_by_index[idx].topic_label = topic_seg.topic_label
|
||
|
||
session.commit()
|
||
elapsed = time.monotonic() - start
|
||
_emit_event(video_id, "stage2_segmentation", "complete", run_id=run_id)
|
||
logger.info(
|
||
"Stage 2 (segmentation) completed for video_id=%s in %.1fs — %d topic groups found",
|
||
video_id, elapsed, len(result.segments),
|
||
)
|
||
return video_id
|
||
|
||
except FileNotFoundError:
|
||
raise # Don't retry missing prompt files
|
||
except Exception as exc:
|
||
session.rollback()
|
||
_emit_event(video_id, "stage2_segmentation", "error", run_id=run_id, payload={"error": str(exc)})
|
||
logger.error("Stage 2 failed for video_id=%s: %s", video_id, exc)
|
||
raise self.retry(exc=exc)
|
||
finally:
|
||
session.close()
|
||
|
||
|
||
# ── Stage 3: Extraction ─────────────────────────────────────────────────────
|
||
|
||
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
|
||
def stage3_extraction(self, video_id: str, run_id: str | None = None) -> str:
|
||
"""Extract key moments from each topic segment group.
|
||
|
||
Groups segments by topic_label, calls the LLM for each group to extract
|
||
moments, creates KeyMoment rows, and sets processing_status=extracted.
|
||
|
||
Returns the video_id for chain compatibility.
|
||
"""
|
||
start = time.monotonic()
|
||
logger.info("Stage 3 (extraction) starting for video_id=%s", video_id)
|
||
_emit_event(video_id, "stage3_extraction", "start", run_id=run_id)
|
||
|
||
session = _get_sync_session()
|
||
try:
|
||
# Load segments with topic labels
|
||
segments = (
|
||
session.execute(
|
||
select(TranscriptSegment)
|
||
.where(TranscriptSegment.source_video_id == video_id)
|
||
.order_by(TranscriptSegment.segment_index)
|
||
)
|
||
.scalars()
|
||
.all()
|
||
)
|
||
|
||
if not segments:
|
||
logger.info("Stage 3: No segments found for video_id=%s, skipping.", video_id)
|
||
return video_id
|
||
|
||
# Group segments by topic_label
|
||
groups: dict[str, list[TranscriptSegment]] = defaultdict(list)
|
||
for seg in segments:
|
||
label = seg.topic_label or "unlabeled"
|
||
groups[label].append(seg)
|
||
|
||
system_prompt = _load_prompt("stage3_extraction.txt", video_id=video_id)
|
||
llm = _get_llm_client()
|
||
model_override, modality = _get_stage_config(3)
|
||
hard_limit = get_settings().llm_max_tokens_hard_limit
|
||
logger.info("Stage 3 using model=%s, modality=%s", model_override or "default", modality)
|
||
total_moments = 0
|
||
|
||
for topic_label, group_segs in groups.items():
|
||
# Build segment text for this group
|
||
seg_lines = []
|
||
for seg in group_segs:
|
||
seg_lines.append(
|
||
f"({seg.start_time:.1f}s - {seg.end_time:.1f}s) {seg.text}"
|
||
)
|
||
segment_text = "\n".join(seg_lines)
|
||
|
||
user_prompt = (
|
||
f"Topic: {topic_label}\n\n"
|
||
f"<segment>\n{segment_text}\n</segment>"
|
||
)
|
||
|
||
max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage3_extraction", hard_limit=hard_limit)
|
||
_s3_request_params = _build_request_params(max_tokens, model_override, modality, "ExtractionResult", hard_limit)
|
||
raw = llm.complete(system_prompt, user_prompt, response_model=ExtractionResult, on_complete=_make_llm_callback(video_id, "stage3_extraction", system_prompt=system_prompt, user_prompt=user_prompt, run_id=run_id, context_label=topic_label, request_params=_s3_request_params),
|
||
modality=modality, model_override=model_override, max_tokens=max_tokens)
|
||
result = _safe_parse_llm_response(raw, ExtractionResult, llm, system_prompt, user_prompt,
|
||
modality=modality, model_override=model_override, max_tokens=max_tokens)
|
||
|
||
# Create KeyMoment rows
|
||
for moment in result.moments:
|
||
# Validate content_type against enum
|
||
try:
|
||
ct = KeyMomentContentType(moment.content_type)
|
||
except ValueError:
|
||
ct = KeyMomentContentType.technique
|
||
|
||
km = KeyMoment(
|
||
source_video_id=video_id,
|
||
title=moment.title,
|
||
summary=moment.summary,
|
||
start_time=moment.start_time,
|
||
end_time=moment.end_time,
|
||
content_type=ct,
|
||
plugins=moment.plugins if moment.plugins else None,
|
||
raw_transcript=moment.raw_transcript or None,
|
||
)
|
||
session.add(km)
|
||
total_moments += 1
|
||
|
||
session.commit()
|
||
elapsed = time.monotonic() - start
|
||
_emit_event(video_id, "stage3_extraction", "complete", run_id=run_id)
|
||
logger.info(
|
||
"Stage 3 (extraction) completed for video_id=%s in %.1fs — %d moments created",
|
||
video_id, elapsed, total_moments,
|
||
)
|
||
return video_id
|
||
|
||
except FileNotFoundError:
|
||
raise
|
||
except Exception as exc:
|
||
session.rollback()
|
||
_emit_event(video_id, "stage3_extraction", "error", run_id=run_id, payload={"error": str(exc)})
|
||
logger.error("Stage 3 failed for video_id=%s: %s", video_id, exc)
|
||
raise self.retry(exc=exc)
|
||
finally:
|
||
session.close()
|
||
|
||
|
||
# ── Stage 4: Classification ─────────────────────────────────────────────────
|
||
|
||
# Maximum moments per classification batch. Keeps each LLM call well within
|
||
# context window limits. Batches are classified independently and merged.
|
||
_STAGE4_BATCH_SIZE = 20
|
||
|
||
|
||
def _classify_moment_batch(
|
||
moments_batch: list,
|
||
batch_offset: int,
|
||
taxonomy_text: str,
|
||
system_prompt: str,
|
||
llm: LLMClient,
|
||
model_override: str | None,
|
||
modality: str,
|
||
hard_limit: int,
|
||
video_id: str,
|
||
run_id: str | None,
|
||
) -> ClassificationResult:
|
||
"""Classify a single batch of moments. Raises on failure."""
|
||
moments_lines = []
|
||
for i, m in enumerate(moments_batch):
|
||
moments_lines.append(
|
||
f"[{i}] Title: {m.title}\n"
|
||
f" Summary: {m.summary}\n"
|
||
f" Content type: {m.content_type.value}\n"
|
||
f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}"
|
||
)
|
||
moments_text = "\n\n".join(moments_lines)
|
||
|
||
user_prompt = (
|
||
f"<taxonomy>\n{taxonomy_text}\n</taxonomy>\n\n"
|
||
f"<moments>\n{moments_text}\n</moments>"
|
||
)
|
||
|
||
max_tokens = estimate_max_tokens(
|
||
system_prompt, user_prompt,
|
||
stage="stage4_classification", hard_limit=hard_limit,
|
||
)
|
||
batch_label = f"batch {batch_offset // _STAGE4_BATCH_SIZE + 1} (moments {batch_offset}-{batch_offset + len(moments_batch) - 1})"
|
||
logger.info(
|
||
"Stage 4 classifying %s, max_tokens=%d",
|
||
batch_label, max_tokens,
|
||
)
|
||
raw = llm.complete(
|
||
system_prompt, user_prompt,
|
||
response_model=ClassificationResult,
|
||
on_complete=_make_llm_callback(
|
||
video_id, "stage4_classification",
|
||
system_prompt=system_prompt, user_prompt=user_prompt,
|
||
run_id=run_id, context_label=batch_label,
|
||
request_params=_build_request_params(max_tokens, model_override, modality, "ClassificationResult", hard_limit),
|
||
),
|
||
modality=modality, model_override=model_override,
|
||
max_tokens=max_tokens,
|
||
)
|
||
return _safe_parse_llm_response(
|
||
raw, ClassificationResult, llm, system_prompt, user_prompt,
|
||
modality=modality, model_override=model_override,
|
||
max_tokens=max_tokens,
|
||
)
|
||
|
||
|
||
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
|
||
def stage4_classification(self, video_id: str, run_id: str | None = None) -> str:
|
||
"""Classify key moments against the canonical tag taxonomy.
|
||
|
||
Loads all KeyMoment rows for the video, sends them to the LLM with the
|
||
canonical taxonomy, and stores classification results in Redis for
|
||
stage 5 consumption. Updates content_type if the classifier overrides it.
|
||
|
||
For large moment sets, automatically batches into groups of
|
||
_STAGE4_BATCH_SIZE to stay within model context window limits.
|
||
|
||
Stage 4 does NOT change processing_status.
|
||
|
||
Returns the video_id for chain compatibility.
|
||
"""
|
||
start = time.monotonic()
|
||
logger.info("Stage 4 (classification) starting for video_id=%s", video_id)
|
||
_emit_event(video_id, "stage4_classification", "start", run_id=run_id)
|
||
|
||
session = _get_sync_session()
|
||
try:
|
||
# Load key moments
|
||
moments = (
|
||
session.execute(
|
||
select(KeyMoment)
|
||
.where(KeyMoment.source_video_id == video_id)
|
||
.order_by(KeyMoment.start_time)
|
||
)
|
||
.scalars()
|
||
.all()
|
||
)
|
||
|
||
if not moments:
|
||
logger.info("Stage 4: No moments found for video_id=%s, skipping.", video_id)
|
||
_store_classification_data(video_id, [])
|
||
return video_id
|
||
|
||
# Load canonical tags
|
||
tags_data = _load_canonical_tags()
|
||
taxonomy_text = _format_taxonomy_for_prompt(tags_data)
|
||
|
||
system_prompt = _load_prompt("stage4_classification.txt", video_id=video_id)
|
||
llm = _get_llm_client()
|
||
model_override, modality = _get_stage_config(4)
|
||
hard_limit = get_settings().llm_max_tokens_hard_limit
|
||
|
||
# Batch moments for classification
|
||
all_classifications = []
|
||
for batch_start in range(0, len(moments), _STAGE4_BATCH_SIZE):
|
||
batch = moments[batch_start:batch_start + _STAGE4_BATCH_SIZE]
|
||
result = _classify_moment_batch(
|
||
batch, batch_start, taxonomy_text, system_prompt,
|
||
llm, model_override, modality, hard_limit,
|
||
video_id, run_id,
|
||
)
|
||
# Reindex: batch uses 0-based indices, remap to global indices
|
||
for cls in result.classifications:
|
||
cls.moment_index += batch_start
|
||
all_classifications.extend(result.classifications)
|
||
|
||
# Apply content_type overrides and prepare classification data for stage 5
|
||
classification_data = []
|
||
|
||
for cls in all_classifications:
|
||
if 0 <= cls.moment_index < len(moments):
|
||
moment = moments[cls.moment_index]
|
||
|
||
# Apply content_type override if provided
|
||
if cls.content_type_override:
|
||
try:
|
||
moment.content_type = KeyMomentContentType(cls.content_type_override)
|
||
except ValueError:
|
||
pass
|
||
|
||
classification_data.append({
|
||
"moment_id": str(moment.id),
|
||
"topic_category": cls.topic_category.strip().title(),
|
||
"topic_tags": cls.topic_tags,
|
||
})
|
||
|
||
session.commit()
|
||
|
||
# Store classification data in Redis for stage 5
|
||
_store_classification_data(video_id, classification_data)
|
||
|
||
elapsed = time.monotonic() - start
|
||
num_batches = (len(moments) + _STAGE4_BATCH_SIZE - 1) // _STAGE4_BATCH_SIZE
|
||
_emit_event(video_id, "stage4_classification", "complete", run_id=run_id)
|
||
logger.info(
|
||
"Stage 4 (classification) completed for video_id=%s in %.1fs — "
|
||
"%d moments classified in %d batch(es)",
|
||
video_id, elapsed, len(classification_data), num_batches,
|
||
)
|
||
return video_id
|
||
|
||
except FileNotFoundError:
|
||
raise
|
||
except Exception as exc:
|
||
session.rollback()
|
||
_emit_event(video_id, "stage4_classification", "error", run_id=run_id, payload={"error": str(exc)})
|
||
logger.error("Stage 4 failed for video_id=%s: %s", video_id, exc)
|
||
raise self.retry(exc=exc)
|
||
finally:
|
||
session.close()
|
||
|
||
|
||
def _store_classification_data(video_id: str, data: list[dict]) -> None:
|
||
"""Store classification data in Redis (cache) and PostgreSQL (durable).
|
||
|
||
Dual-write ensures classification data survives Redis TTL expiry or flush.
|
||
Redis serves as the fast-path cache; PostgreSQL is the durable fallback.
|
||
"""
|
||
import redis
|
||
|
||
settings = get_settings()
|
||
|
||
# Redis: fast cache with 7-day TTL
|
||
try:
|
||
r = redis.Redis.from_url(settings.redis_url)
|
||
key = f"chrysopedia:classification:{video_id}"
|
||
r.set(key, json.dumps(data), ex=604800) # 7 days
|
||
logger.info(
|
||
"[CLASSIFY-STORE] Redis write: video_id=%s, %d entries, ttl=7d",
|
||
video_id, len(data),
|
||
)
|
||
except Exception as exc:
|
||
logger.warning(
|
||
"[CLASSIFY-STORE] Redis write failed for video_id=%s: %s", video_id, exc,
|
||
)
|
||
|
||
# PostgreSQL: durable storage on SourceVideo.classification_data
|
||
session = _get_sync_session()
|
||
try:
|
||
video = session.execute(
|
||
select(SourceVideo).where(SourceVideo.id == video_id)
|
||
).scalar_one_or_none()
|
||
if video:
|
||
video.classification_data = data
|
||
session.commit()
|
||
logger.info(
|
||
"[CLASSIFY-STORE] PostgreSQL write: video_id=%s, %d entries",
|
||
video_id, len(data),
|
||
)
|
||
else:
|
||
logger.warning(
|
||
"[CLASSIFY-STORE] Video not found for PostgreSQL write: %s", video_id,
|
||
)
|
||
except Exception as exc:
|
||
session.rollback()
|
||
logger.warning(
|
||
"[CLASSIFY-STORE] PostgreSQL write failed for video_id=%s: %s", video_id, exc,
|
||
)
|
||
finally:
|
||
session.close()
|
||
|
||
|
||
def _load_classification_data(video_id: str) -> list[dict]:
|
||
"""Load classification data from Redis (fast path) or PostgreSQL (fallback).
|
||
|
||
Tries Redis first. If the key has expired or Redis is unavailable, falls
|
||
back to the durable SourceVideo.classification_data column.
|
||
"""
|
||
import redis
|
||
|
||
settings = get_settings()
|
||
|
||
# Try Redis first (fast path)
|
||
try:
|
||
r = redis.Redis.from_url(settings.redis_url)
|
||
key = f"chrysopedia:classification:{video_id}"
|
||
raw = r.get(key)
|
||
if raw is not None:
|
||
data = json.loads(raw)
|
||
logger.info(
|
||
"[CLASSIFY-LOAD] Source: redis, video_id=%s, %d entries",
|
||
video_id, len(data),
|
||
)
|
||
return data
|
||
except Exception as exc:
|
||
logger.warning(
|
||
"[CLASSIFY-LOAD] Redis unavailable for video_id=%s: %s", video_id, exc,
|
||
)
|
||
|
||
# Fallback to PostgreSQL
|
||
logger.info("[CLASSIFY-LOAD] Redis miss, falling back to PostgreSQL for video_id=%s", video_id)
|
||
session = _get_sync_session()
|
||
try:
|
||
video = session.execute(
|
||
select(SourceVideo).where(SourceVideo.id == video_id)
|
||
).scalar_one_or_none()
|
||
if video and video.classification_data:
|
||
data = video.classification_data
|
||
logger.info(
|
||
"[CLASSIFY-LOAD] Source: postgresql, video_id=%s, %d entries",
|
||
video_id, len(data),
|
||
)
|
||
return data
|
||
except Exception as exc:
|
||
logger.warning(
|
||
"[CLASSIFY-LOAD] PostgreSQL fallback failed for video_id=%s: %s", video_id, exc,
|
||
)
|
||
finally:
|
||
session.close()
|
||
|
||
logger.warning(
|
||
"[CLASSIFY-LOAD] No classification data found in Redis or PostgreSQL for video_id=%s",
|
||
video_id,
|
||
)
|
||
return []
|
||
|
||
|
||
|
||
def _get_git_commit_sha() -> str:
|
||
"""Resolve the git commit SHA used to build this image.
|
||
|
||
Resolution order:
|
||
1. /app/.git-commit file (written during Docker build)
|
||
2. git rev-parse --short HEAD (local dev)
|
||
3. GIT_COMMIT_SHA env var / config setting
|
||
4. "unknown"
|
||
"""
|
||
# Docker build artifact
|
||
git_commit_file = Path("/app/.git-commit")
|
||
if git_commit_file.exists():
|
||
sha = git_commit_file.read_text(encoding="utf-8").strip()
|
||
if sha and sha != "unknown":
|
||
return sha
|
||
|
||
# Local dev — run git
|
||
try:
|
||
result = subprocess.run(
|
||
["git", "rev-parse", "--short", "HEAD"],
|
||
capture_output=True, text=True, timeout=5,
|
||
)
|
||
if result.returncode == 0 and result.stdout.strip():
|
||
return result.stdout.strip()
|
||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||
pass
|
||
|
||
# Config / env var fallback
|
||
try:
|
||
sha = get_settings().git_commit_sha
|
||
if sha and sha != "unknown":
|
||
return sha
|
||
except Exception:
|
||
pass
|
||
|
||
return "unknown"
|
||
|
||
def _capture_pipeline_metadata() -> dict:
|
||
"""Capture current pipeline configuration for version metadata.
|
||
|
||
Returns a dict with model names, prompt file SHA-256 hashes, and stage
|
||
modality settings. Handles missing prompt files gracefully.
|
||
"""
|
||
settings = get_settings()
|
||
prompts_path = Path(settings.prompts_path)
|
||
|
||
# Hash each prompt template file
|
||
prompt_hashes: dict[str, str] = {}
|
||
prompt_files = [
|
||
"stage2_segmentation.txt",
|
||
"stage3_extraction.txt",
|
||
"stage4_classification.txt",
|
||
"stage5_synthesis.txt",
|
||
]
|
||
for filename in prompt_files:
|
||
filepath = prompts_path / filename
|
||
try:
|
||
content = filepath.read_bytes()
|
||
prompt_hashes[filename] = hashlib.sha256(content).hexdigest()
|
||
except FileNotFoundError:
|
||
logger.warning("Prompt file not found for metadata capture: %s", filepath)
|
||
prompt_hashes[filename] = ""
|
||
except OSError as exc:
|
||
logger.warning("Could not read prompt file %s: %s", filepath, exc)
|
||
prompt_hashes[filename] = ""
|
||
|
||
return {
|
||
"git_commit_sha": _get_git_commit_sha(),
|
||
"models": {
|
||
"stage2": settings.llm_stage2_model,
|
||
"stage3": settings.llm_stage3_model,
|
||
"stage4": settings.llm_stage4_model,
|
||
"stage5": settings.llm_stage5_model,
|
||
"embedding": settings.embedding_model,
|
||
},
|
||
"modalities": {
|
||
"stage2": settings.llm_stage2_modality,
|
||
"stage3": settings.llm_stage3_modality,
|
||
"stage4": settings.llm_stage4_modality,
|
||
"stage5": settings.llm_stage5_modality,
|
||
},
|
||
"prompt_hashes": prompt_hashes,
|
||
}
|
||
|
||
|
||
# ── Stage 5: Synthesis ───────────────────────────────────────────────────────
|
||
|
||
|
||
def _serialize_body_sections(sections) -> list | dict | None:
|
||
"""Convert body_sections to JSON-serializable form for DB storage."""
|
||
if isinstance(sections, list):
|
||
return [s.model_dump() if hasattr(s, 'model_dump') else s for s in sections]
|
||
return sections
|
||
|
||
|
||
def _compute_page_tags(
|
||
moment_indices: list[int],
|
||
moment_group: list[tuple],
|
||
all_tags: set[str],
|
||
) -> list[str] | None:
|
||
"""Compute tags for a specific page from its linked moment indices.
|
||
|
||
If moment_indices are available, collects tags only from those moments.
|
||
Falls back to all_tags for the category group if no indices provided.
|
||
"""
|
||
if not moment_indices:
|
||
return list(all_tags) if all_tags else None
|
||
|
||
page_tags: set[str] = set()
|
||
for idx in moment_indices:
|
||
if 0 <= idx < len(moment_group):
|
||
_, cls_info = moment_group[idx]
|
||
page_tags.update(cls_info.get("topic_tags", []))
|
||
|
||
return list(page_tags) if page_tags else None
|
||
|
||
|
||
def _build_moments_text(
|
||
moment_group: list[tuple[KeyMoment, dict]],
|
||
category: str,
|
||
) -> tuple[str, set[str]]:
|
||
"""Build the moments prompt text and collect all tags for a group of moments.
|
||
|
||
Returns (moments_text, all_tags).
|
||
"""
|
||
moments_lines = []
|
||
all_tags: set[str] = set()
|
||
for i, (m, cls_info) in enumerate(moment_group):
|
||
tags = cls_info.get("topic_tags", [])
|
||
all_tags.update(tags)
|
||
moments_lines.append(
|
||
f"[{i}] Title: {m.title}\n"
|
||
f" Summary: {m.summary}\n"
|
||
f" Content type: {m.content_type.value}\n"
|
||
f" Time: {m.start_time:.1f}s - {m.end_time:.1f}s\n"
|
||
f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}\n"
|
||
f" Category: {category}\n"
|
||
f" Tags: {', '.join(tags) if tags else 'none'}\n"
|
||
f" Transcript excerpt: {(m.raw_transcript or '')[:300]}"
|
||
)
|
||
return "\n\n".join(moments_lines), all_tags
|
||
|
||
|
||
def _build_compose_user_prompt(
|
||
existing_page: TechniquePage,
|
||
existing_moments: list[KeyMoment],
|
||
new_moments: list[tuple[KeyMoment, dict]],
|
||
creator_name: str,
|
||
) -> str:
|
||
"""Build the user prompt for composing new moments into an existing page.
|
||
|
||
Existing moments keep indices [0]-[N-1].
|
||
New moments get indices [N]-[N+M-1].
|
||
XML-tagged prompt structure matches test_harness.py build_compose_prompt().
|
||
"""
|
||
category = existing_page.topic_category or "Uncategorized"
|
||
|
||
# Serialize existing page to dict matching SynthesizedPage shape
|
||
sq = existing_page.source_quality
|
||
sq_value = sq.value if hasattr(sq, "value") else sq
|
||
page_dict = {
|
||
"title": existing_page.title,
|
||
"slug": existing_page.slug,
|
||
"topic_category": existing_page.topic_category,
|
||
"summary": existing_page.summary,
|
||
"body_sections": existing_page.body_sections,
|
||
"signal_chains": existing_page.signal_chains,
|
||
"plugins": existing_page.plugins,
|
||
"source_quality": sq_value,
|
||
}
|
||
|
||
# Format existing moments [0]-[N-1] using _build_moments_text pattern
|
||
# Existing moments don't have classification data — use empty dict
|
||
existing_as_tuples = [(m, {}) for m in existing_moments]
|
||
existing_text, _ = _build_moments_text(existing_as_tuples, category)
|
||
|
||
# Format new moments [N]-[N+M-1] with offset indices
|
||
n = len(existing_moments)
|
||
new_lines = []
|
||
for i, (m, cls_info) in enumerate(new_moments):
|
||
tags = cls_info.get("topic_tags", [])
|
||
new_lines.append(
|
||
f"[{n + i}] Title: {m.title}\n"
|
||
f" Summary: {m.summary}\n"
|
||
f" Content type: {m.content_type.value}\n"
|
||
f" Time: {m.start_time:.1f}s - {m.end_time:.1f}s\n"
|
||
f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}\n"
|
||
f" Category: {category}\n"
|
||
f" Tags: {', '.join(tags) if tags else 'none'}\n"
|
||
f" Transcript excerpt: {(m.raw_transcript or '')[:300]}"
|
||
)
|
||
new_text = "\n\n".join(new_lines)
|
||
|
||
page_json = json.dumps(page_dict, indent=2, ensure_ascii=False, default=str)
|
||
|
||
return (
|
||
f"<existing_page>\n{page_json}\n</existing_page>\n"
|
||
f"<existing_moments>\n{existing_text}\n</existing_moments>\n"
|
||
f"<new_moments>\n{new_text}\n</new_moments>\n"
|
||
f"<creator>{creator_name}</creator>"
|
||
)
|
||
|
||
|
||
def _compose_into_existing(
|
||
existing_page: TechniquePage,
|
||
existing_moments: list[KeyMoment],
|
||
new_moment_group: list[tuple[KeyMoment, dict]],
|
||
category: str,
|
||
creator_name: str,
|
||
system_prompt: str,
|
||
llm: LLMClient,
|
||
model_override: str | None,
|
||
modality: str,
|
||
hard_limit: int,
|
||
video_id: str,
|
||
run_id: str | None,
|
||
) -> SynthesisResult:
|
||
"""Compose new moments into an existing technique page via LLM.
|
||
|
||
Loads the compose system prompt, builds the compose user prompt, and
|
||
calls the LLM with the same retry/parse pattern as _synthesize_chunk().
|
||
"""
|
||
compose_prompt = _load_prompt("stage5_compose.txt", video_id=video_id)
|
||
user_prompt = _build_compose_user_prompt(
|
||
existing_page, existing_moments, new_moment_group, creator_name,
|
||
)
|
||
|
||
estimated_input = estimate_max_tokens(
|
||
compose_prompt, user_prompt,
|
||
stage="stage5_synthesis", hard_limit=hard_limit,
|
||
)
|
||
logger.info(
|
||
"Stage 5: Composing into '%s' — %d existing + %d new moments, max_tokens=%d",
|
||
existing_page.slug, len(existing_moments), len(new_moment_group), estimated_input,
|
||
)
|
||
|
||
raw = llm.complete(
|
||
compose_prompt, user_prompt, response_model=SynthesisResult,
|
||
on_complete=_make_llm_callback(
|
||
video_id, "stage5_synthesis",
|
||
system_prompt=compose_prompt, user_prompt=user_prompt,
|
||
run_id=run_id, context_label=f"compose:{category}",
|
||
request_params=_build_request_params(
|
||
estimated_input, model_override, modality, "SynthesisResult", hard_limit,
|
||
),
|
||
),
|
||
modality=modality, model_override=model_override, max_tokens=estimated_input,
|
||
)
|
||
return _safe_parse_llm_response(
|
||
raw, SynthesisResult, llm, compose_prompt, user_prompt,
|
||
modality=modality, model_override=model_override, max_tokens=estimated_input,
|
||
)
|
||
|
||
|
||
def _synthesize_chunk(
|
||
chunk: list[tuple[KeyMoment, dict]],
|
||
category: str,
|
||
creator_name: str,
|
||
system_prompt: str,
|
||
llm: LLMClient,
|
||
model_override: str | None,
|
||
modality: str,
|
||
hard_limit: int,
|
||
video_id: str,
|
||
run_id: str | None,
|
||
chunk_label: str,
|
||
) -> SynthesisResult:
|
||
"""Run a single synthesis LLM call for a chunk of moments.
|
||
|
||
Returns the parsed SynthesisResult.
|
||
"""
|
||
moments_text, _ = _build_moments_text(chunk, category)
|
||
user_prompt = f"<creator>{creator_name}</creator>\n<moments>\n{moments_text}\n</moments>"
|
||
|
||
estimated_input = estimate_max_tokens(system_prompt, user_prompt, stage="stage5_synthesis", hard_limit=hard_limit)
|
||
logger.info(
|
||
"Stage 5: Synthesizing %s — %d moments, max_tokens=%d",
|
||
chunk_label, len(chunk), estimated_input,
|
||
)
|
||
|
||
raw = llm.complete(
|
||
system_prompt, user_prompt, response_model=SynthesisResult,
|
||
on_complete=_make_llm_callback(
|
||
video_id, "stage5_synthesis",
|
||
system_prompt=system_prompt, user_prompt=user_prompt,
|
||
run_id=run_id, context_label=chunk_label,
|
||
request_params=_build_request_params(estimated_input, model_override, modality, "SynthesisResult", hard_limit),
|
||
),
|
||
modality=modality, model_override=model_override, max_tokens=estimated_input,
|
||
)
|
||
return _safe_parse_llm_response(
|
||
raw, SynthesisResult, llm, system_prompt, user_prompt,
|
||
modality=modality, model_override=model_override, max_tokens=estimated_input,
|
||
)
|
||
|
||
|
||
def _slug_base(slug: str) -> str:
|
||
"""Extract the slug prefix before the creator name suffix for merge grouping.
|
||
|
||
E.g. 'wavetable-sound-design-copycatt' → 'wavetable-sound-design'
|
||
Also normalizes casing.
|
||
"""
|
||
return slug.lower().strip()
|
||
|
||
|
||
def _merge_pages_by_slug(
|
||
all_pages: list,
|
||
creator_name: str,
|
||
llm: LLMClient,
|
||
model_override: str | None,
|
||
modality: str,
|
||
hard_limit: int,
|
||
video_id: str,
|
||
run_id: str | None,
|
||
) -> list:
|
||
"""Detect pages with the same slug across chunks and merge them via LLM.
|
||
|
||
Pages with unique slugs pass through unchanged. Pages sharing a slug
|
||
get sent to a merge prompt that combines them into one cohesive page.
|
||
|
||
Returns the final list of SynthesizedPage objects.
|
||
"""
|
||
from pipeline.schemas import SynthesizedPage
|
||
|
||
# Group pages by slug
|
||
by_slug: dict[str, list] = defaultdict(list)
|
||
for page in all_pages:
|
||
by_slug[_slug_base(page.slug)].append(page)
|
||
|
||
final_pages = []
|
||
for slug, pages_group in by_slug.items():
|
||
if len(pages_group) == 1:
|
||
# Unique slug — no merge needed
|
||
final_pages.append(pages_group[0])
|
||
continue
|
||
|
||
# Multiple pages share this slug — merge via LLM
|
||
logger.info(
|
||
"Stage 5: Merging %d partial pages with slug '%s' for video_id=%s",
|
||
len(pages_group), slug, video_id,
|
||
)
|
||
|
||
# Serialize partial pages to JSON for the merge prompt
|
||
pages_json = json.dumps(
|
||
[p.model_dump() for p in pages_group],
|
||
indent=2, ensure_ascii=False,
|
||
)
|
||
|
||
merge_system_prompt = _load_prompt("stage5_merge.txt", video_id=video_id)
|
||
merge_user_prompt = f"<creator>{creator_name}</creator>\n<pages>\n{pages_json}\n</pages>"
|
||
|
||
max_tokens = estimate_max_tokens(
|
||
merge_system_prompt, merge_user_prompt,
|
||
stage="stage5_synthesis", hard_limit=hard_limit,
|
||
)
|
||
logger.info(
|
||
"Stage 5: Merge call for slug '%s' — %d partial pages, max_tokens=%d",
|
||
slug, len(pages_group), max_tokens,
|
||
)
|
||
|
||
raw = llm.complete(
|
||
merge_system_prompt, merge_user_prompt,
|
||
response_model=SynthesisResult,
|
||
on_complete=_make_llm_callback(
|
||
video_id, "stage5_synthesis",
|
||
system_prompt=merge_system_prompt,
|
||
user_prompt=merge_user_prompt,
|
||
run_id=run_id, context_label=f"merge:{slug}",
|
||
request_params=_build_request_params(max_tokens, model_override, modality, "SynthesisResult", hard_limit),
|
||
),
|
||
modality=modality, model_override=model_override,
|
||
max_tokens=max_tokens,
|
||
)
|
||
merge_result = _safe_parse_llm_response(
|
||
raw, SynthesisResult, llm,
|
||
merge_system_prompt, merge_user_prompt,
|
||
modality=modality, model_override=model_override,
|
||
max_tokens=max_tokens,
|
||
)
|
||
|
||
if merge_result.pages:
|
||
final_pages.extend(merge_result.pages)
|
||
logger.info(
|
||
"Stage 5: Merge produced %d page(s) for slug '%s'",
|
||
len(merge_result.pages), slug,
|
||
)
|
||
else:
|
||
# Merge returned nothing — fall back to keeping the partials
|
||
logger.warning(
|
||
"Stage 5: Merge returned 0 pages for slug '%s', keeping %d partials",
|
||
slug, len(pages_group),
|
||
)
|
||
final_pages.extend(pages_group)
|
||
|
||
return final_pages
|
||
|
||
|
||
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
|
||
def stage5_synthesis(self, video_id: str, run_id: str | None = None) -> str:
|
||
"""Synthesize technique pages from classified key moments.
|
||
|
||
Groups moments by (creator, topic_category), calls the LLM to synthesize
|
||
each group into a TechniquePage, creates/updates page rows, and links
|
||
KeyMoments to their TechniquePage.
|
||
|
||
For large category groups (exceeding synthesis_chunk_size), moments are
|
||
split into chronological chunks, synthesized independently, then pages
|
||
with matching slugs are merged via a dedicated merge LLM call.
|
||
|
||
Sets processing_status to 'complete'.
|
||
|
||
Returns the video_id for chain compatibility.
|
||
"""
|
||
start = time.monotonic()
|
||
logger.info("Stage 5 (synthesis) starting for video_id=%s", video_id)
|
||
_emit_event(video_id, "stage5_synthesis", "start", run_id=run_id)
|
||
|
||
settings = get_settings()
|
||
chunk_size = settings.synthesis_chunk_size
|
||
session = _get_sync_session()
|
||
try:
|
||
# Load video and moments
|
||
video = session.execute(
|
||
select(SourceVideo).where(SourceVideo.id == video_id)
|
||
).scalar_one()
|
||
|
||
moments = (
|
||
session.execute(
|
||
select(KeyMoment)
|
||
.where(KeyMoment.source_video_id == video_id)
|
||
.order_by(KeyMoment.start_time)
|
||
)
|
||
.scalars()
|
||
.all()
|
||
)
|
||
|
||
# Resolve creator name for the LLM prompt
|
||
creator = session.execute(
|
||
select(Creator).where(Creator.id == video.creator_id)
|
||
).scalar_one_or_none()
|
||
creator_name = creator.name if creator else "Unknown"
|
||
|
||
if not moments:
|
||
logger.info("Stage 5: No moments found for video_id=%s, skipping.", video_id)
|
||
return video_id
|
||
|
||
# Load classification data from stage 4
|
||
classification_data = _load_classification_data(video_id)
|
||
cls_by_moment_id = {c["moment_id"]: c for c in classification_data}
|
||
|
||
# Group moments by topic_category (from classification)
|
||
# Normalize category casing to prevent near-duplicate groups
|
||
# (e.g., "Sound design" vs "Sound Design")
|
||
groups: dict[str, list[tuple[KeyMoment, dict]]] = defaultdict(list)
|
||
for moment in moments:
|
||
cls_info = cls_by_moment_id.get(str(moment.id), {})
|
||
category = cls_info.get("topic_category", "Uncategorized").strip().title()
|
||
groups[category].append((moment, cls_info))
|
||
|
||
system_prompt = _load_prompt("stage5_synthesis.txt", video_id=video_id)
|
||
llm = _get_llm_client()
|
||
model_override, modality = _get_stage_config(5)
|
||
hard_limit = settings.llm_max_tokens_hard_limit
|
||
logger.info("Stage 5 using model=%s, modality=%s", model_override or "default", modality)
|
||
pages_created = 0
|
||
|
||
for category, moment_group in groups.items():
|
||
# Collect all tags across the full group (used for DB writes later)
|
||
all_tags: set[str] = set()
|
||
for _, cls_info in moment_group:
|
||
all_tags.update(cls_info.get("topic_tags", []))
|
||
|
||
# ── Compose-or-create detection ────────────────────────
|
||
# Check if an existing technique page already covers this
|
||
# creator + category combination (from a prior video run).
|
||
compose_matches = session.execute(
|
||
select(TechniquePage).where(
|
||
TechniquePage.creator_id == video.creator_id,
|
||
func.lower(TechniquePage.topic_category) == func.lower(category),
|
||
)
|
||
).scalars().all()
|
||
|
||
if len(compose_matches) > 1:
|
||
logger.warning(
|
||
"Stage 5: Multiple existing pages (%d) match creator=%s category='%s'. "
|
||
"Using first match '%s'.",
|
||
len(compose_matches), video.creator_id, category,
|
||
compose_matches[0].slug,
|
||
)
|
||
|
||
compose_target = compose_matches[0] if compose_matches else None
|
||
|
||
if compose_target is not None:
|
||
# Load existing moments linked to this page
|
||
existing_moments = session.execute(
|
||
select(KeyMoment)
|
||
.where(KeyMoment.technique_page_id == compose_target.id)
|
||
.order_by(KeyMoment.start_time)
|
||
).scalars().all()
|
||
|
||
logger.info(
|
||
"Stage 5: Composing into existing page '%s' "
|
||
"(%d existing moments + %d new moments)",
|
||
compose_target.slug,
|
||
len(existing_moments),
|
||
len(moment_group),
|
||
)
|
||
|
||
compose_result = _compose_into_existing(
|
||
compose_target, existing_moments, moment_group,
|
||
category, creator_name, system_prompt,
|
||
llm, model_override, modality, hard_limit,
|
||
video_id, run_id,
|
||
)
|
||
synthesized_pages = list(compose_result.pages)
|
||
|
||
# ── Chunked synthesis with truncation recovery ─────────
|
||
elif len(moment_group) <= chunk_size:
|
||
# Small group — try single LLM call first
|
||
try:
|
||
result = _synthesize_chunk(
|
||
moment_group, category, creator_name,
|
||
system_prompt, llm, model_override, modality, hard_limit,
|
||
video_id, run_id, f"category:{category}",
|
||
)
|
||
synthesized_pages = list(result.pages)
|
||
logger.info(
|
||
"Stage 5: category '%s' — %d moments, %d page(s) from single call",
|
||
category, len(moment_group), len(synthesized_pages),
|
||
)
|
||
except LLMTruncationError:
|
||
# Output too large for model context — split in half and retry
|
||
logger.warning(
|
||
"Stage 5: category '%s' truncated with %d moments. "
|
||
"Splitting into smaller chunks and retrying.",
|
||
category, len(moment_group),
|
||
)
|
||
half = max(1, len(moment_group) // 2)
|
||
chunk_pages = []
|
||
for sub_start in range(0, len(moment_group), half):
|
||
sub_chunk = moment_group[sub_start:sub_start + half]
|
||
sub_label = f"category:{category} recovery-chunk:{sub_start // half + 1}"
|
||
sub_result = _synthesize_chunk(
|
||
sub_chunk, category, creator_name,
|
||
system_prompt, llm, model_override, modality, hard_limit,
|
||
video_id, run_id, sub_label,
|
||
)
|
||
# Reindex moment_indices to global offsets
|
||
for p in sub_result.pages:
|
||
if p.moment_indices:
|
||
p.moment_indices = [idx + sub_start for idx in p.moment_indices]
|
||
chunk_pages.extend(sub_result.pages)
|
||
synthesized_pages = chunk_pages
|
||
logger.info(
|
||
"Stage 5: category '%s' — %d page(s) from recovery split",
|
||
category, len(synthesized_pages),
|
||
)
|
||
else:
|
||
# Large group — split into chunks, synthesize each, then merge
|
||
num_chunks = (len(moment_group) + chunk_size - 1) // chunk_size
|
||
logger.info(
|
||
"Stage 5: category '%s' has %d moments — splitting into %d chunks of ≤%d",
|
||
category, len(moment_group), num_chunks, chunk_size,
|
||
)
|
||
|
||
chunk_pages = []
|
||
for chunk_idx in range(num_chunks):
|
||
chunk_start = chunk_idx * chunk_size
|
||
chunk_end = min(chunk_start + chunk_size, len(moment_group))
|
||
chunk = moment_group[chunk_start:chunk_end]
|
||
chunk_label = f"category:{category} chunk:{chunk_idx + 1}/{num_chunks}"
|
||
|
||
result = _synthesize_chunk(
|
||
chunk, category, creator_name,
|
||
system_prompt, llm, model_override, modality, hard_limit,
|
||
video_id, run_id, chunk_label,
|
||
)
|
||
chunk_pages.extend(result.pages)
|
||
logger.info(
|
||
"Stage 5: %s produced %d page(s)",
|
||
chunk_label, len(result.pages),
|
||
)
|
||
|
||
# Merge pages with matching slugs across chunks
|
||
logger.info(
|
||
"Stage 5: category '%s' — %d total pages from %d chunks, checking for merges",
|
||
category, len(chunk_pages), num_chunks,
|
||
)
|
||
synthesized_pages = _merge_pages_by_slug(
|
||
chunk_pages, creator_name,
|
||
llm, model_override, modality, hard_limit,
|
||
video_id, run_id,
|
||
)
|
||
logger.info(
|
||
"Stage 5: category '%s' — %d final page(s) after merge",
|
||
category, len(synthesized_pages),
|
||
)
|
||
|
||
# ── Persist pages to DB ──────────────────────────────────
|
||
# Load prior pages from this video (snapshot taken before pipeline reset)
|
||
prior_page_ids = _load_prior_pages(video_id)
|
||
|
||
for page_data in synthesized_pages:
|
||
page_moment_indices = getattr(page_data, "moment_indices", None) or []
|
||
existing = None
|
||
|
||
# First: check by slug (most specific match)
|
||
if existing is None:
|
||
existing = session.execute(
|
||
select(TechniquePage).where(TechniquePage.slug == page_data.slug)
|
||
).scalar_one_or_none()
|
||
|
||
# Fallback: check prior pages from this video by creator + category
|
||
# Use .first() since multiple pages may share a category
|
||
if existing is None and prior_page_ids:
|
||
existing = session.execute(
|
||
select(TechniquePage).where(
|
||
TechniquePage.id.in_(prior_page_ids),
|
||
TechniquePage.creator_id == video.creator_id,
|
||
func.lower(TechniquePage.topic_category) == func.lower(page_data.topic_category or category),
|
||
)
|
||
).scalars().first()
|
||
if existing:
|
||
logger.info(
|
||
"Stage 5: Matched prior page '%s' (id=%s) by creator+category for video_id=%s",
|
||
existing.slug, existing.id, video_id,
|
||
)
|
||
|
||
if existing:
|
||
# Snapshot existing content before overwriting
|
||
try:
|
||
sq = existing.source_quality
|
||
sq_value = sq.value if hasattr(sq, 'value') else sq
|
||
snapshot = {
|
||
"title": existing.title,
|
||
"slug": existing.slug,
|
||
"topic_category": existing.topic_category,
|
||
"topic_tags": existing.topic_tags,
|
||
"summary": existing.summary,
|
||
"body_sections": existing.body_sections,
|
||
"signal_chains": existing.signal_chains,
|
||
"plugins": existing.plugins,
|
||
"source_quality": sq_value,
|
||
}
|
||
version_count = session.execute(
|
||
select(func.count()).where(
|
||
TechniquePageVersion.technique_page_id == existing.id
|
||
)
|
||
).scalar()
|
||
version_number = version_count + 1
|
||
|
||
version = TechniquePageVersion(
|
||
technique_page_id=existing.id,
|
||
version_number=version_number,
|
||
content_snapshot=snapshot,
|
||
pipeline_metadata=_capture_pipeline_metadata(),
|
||
)
|
||
session.add(version)
|
||
logger.info(
|
||
"Version snapshot v%d created for page slug=%s",
|
||
version_number, existing.slug,
|
||
)
|
||
except Exception as snap_exc:
|
||
logger.error(
|
||
"Failed to create version snapshot for page slug=%s: %s",
|
||
existing.slug, snap_exc,
|
||
)
|
||
# Best-effort versioning — continue with page update
|
||
|
||
# Update existing page
|
||
existing.title = page_data.title
|
||
existing.summary = page_data.summary
|
||
existing.body_sections = _serialize_body_sections(page_data.body_sections)
|
||
existing.signal_chains = page_data.signal_chains
|
||
existing.plugins = page_data.plugins if page_data.plugins else None
|
||
page_tags = _compute_page_tags(page_moment_indices, moment_group, all_tags)
|
||
existing.topic_tags = page_tags
|
||
existing.source_quality = page_data.source_quality
|
||
page = existing
|
||
else:
|
||
page = TechniquePage(
|
||
creator_id=video.creator_id,
|
||
title=page_data.title,
|
||
slug=page_data.slug,
|
||
topic_category=page_data.topic_category or category,
|
||
topic_tags=_compute_page_tags(page_moment_indices, moment_group, all_tags),
|
||
summary=page_data.summary,
|
||
body_sections=_serialize_body_sections(page_data.body_sections),
|
||
signal_chains=page_data.signal_chains,
|
||
plugins=page_data.plugins if page_data.plugins else None,
|
||
source_quality=page_data.source_quality,
|
||
)
|
||
session.add(page)
|
||
session.flush() # Get the page.id assigned
|
||
|
||
pages_created += 1
|
||
|
||
# Set body_sections_format on every page (new or updated)
|
||
page.body_sections_format = "v2"
|
||
|
||
# Track contributing video via TechniquePageVideo
|
||
stmt = pg_insert(TechniquePageVideo.__table__).values(
|
||
technique_page_id=page.id,
|
||
source_video_id=video.id,
|
||
).on_conflict_do_nothing()
|
||
session.execute(stmt)
|
||
|
||
# Link moments to the technique page using moment_indices
|
||
|
||
if page_moment_indices:
|
||
# LLM specified which moments belong to this page
|
||
for idx in page_moment_indices:
|
||
if 0 <= idx < len(moment_group):
|
||
moment_group[idx][0].technique_page_id = page.id
|
||
elif len(synthesized_pages) == 1:
|
||
# Single page — link all moments (safe fallback)
|
||
for m, _ in moment_group:
|
||
m.technique_page_id = page.id
|
||
else:
|
||
# Multiple pages but no moment_indices — log warning
|
||
logger.warning(
|
||
"Stage 5: page '%s' has no moment_indices and is one of %d pages "
|
||
"for category '%s'. Moments will not be linked to this page.",
|
||
page_data.slug, len(synthesized_pages), category,
|
||
)
|
||
|
||
# Update processing_status
|
||
video.processing_status = ProcessingStatus.complete
|
||
|
||
session.commit()
|
||
elapsed = time.monotonic() - start
|
||
_emit_event(video_id, "stage5_synthesis", "complete", run_id=run_id)
|
||
logger.info(
|
||
"Stage 5 (synthesis) completed for video_id=%s in %.1fs — %d pages created/updated",
|
||
video_id, elapsed, pages_created,
|
||
)
|
||
return video_id
|
||
|
||
except FileNotFoundError:
|
||
raise
|
||
except Exception as exc:
|
||
session.rollback()
|
||
_emit_event(video_id, "stage5_synthesis", "error", run_id=run_id, payload={"error": str(exc)})
|
||
logger.error("Stage 5 failed for video_id=%s: %s", video_id, exc)
|
||
raise self.retry(exc=exc)
|
||
finally:
|
||
session.close()
|
||
|
||
|
||
# ── Heading slug helper (matches frontend TableOfContents.tsx slugify) ────────
|
||
|
||
def _slugify_heading(text: str) -> str:
|
||
"""Convert a heading string to a URL-friendly anchor slug.
|
||
|
||
Must produce identical output to the frontend's slugify in
|
||
``frontend/src/components/TableOfContents.tsx``.
|
||
"""
|
||
return re.sub(r"[^a-z0-9]+", "-", text.lower()).strip("-")
|
||
|
||
|
||
# ── Stage 6: Embed & Index ───────────────────────────────────────────────────
|
||
|
||
@celery_app.task(bind=True, max_retries=0)
|
||
def stage6_embed_and_index(self, video_id: str, run_id: str | None = None) -> str:
|
||
"""Generate embeddings for technique pages and key moments, then upsert to Qdrant.
|
||
|
||
This is a non-blocking side-effect stage — failures are logged but do not
|
||
fail the pipeline. Embeddings can be regenerated later. Does NOT update
|
||
processing_status.
|
||
|
||
Returns the video_id for chain compatibility.
|
||
"""
|
||
start = time.monotonic()
|
||
logger.info("Stage 6 (embed & index) starting for video_id=%s", video_id)
|
||
|
||
settings = get_settings()
|
||
session = _get_sync_session()
|
||
try:
|
||
# Load technique pages created for this video's moments
|
||
moments = (
|
||
session.execute(
|
||
select(KeyMoment)
|
||
.where(KeyMoment.source_video_id == video_id)
|
||
.order_by(KeyMoment.start_time)
|
||
)
|
||
.scalars()
|
||
.all()
|
||
)
|
||
|
||
# Get unique technique page IDs from moments
|
||
page_ids = {m.technique_page_id for m in moments if m.technique_page_id is not None}
|
||
pages = []
|
||
if page_ids:
|
||
pages = (
|
||
session.execute(
|
||
select(TechniquePage).where(TechniquePage.id.in_(page_ids))
|
||
)
|
||
.scalars()
|
||
.all()
|
||
)
|
||
|
||
if not moments and not pages:
|
||
logger.info("Stage 6: No moments or pages for video_id=%s, skipping.", video_id)
|
||
if run_id:
|
||
_finish_run(run_id, "complete")
|
||
return video_id
|
||
|
||
# Resolve creator names for enriched embedding text
|
||
creator_ids = {p.creator_id for p in pages}
|
||
creator_map: dict[str, str] = {}
|
||
if creator_ids:
|
||
creators = (
|
||
session.execute(
|
||
select(Creator).where(Creator.id.in_(creator_ids))
|
||
)
|
||
.scalars()
|
||
.all()
|
||
)
|
||
creator_map = {str(c.id): c.name for c in creators}
|
||
|
||
# Resolve creator name for key moments via source_video → creator
|
||
video_ids = {m.source_video_id for m in moments}
|
||
video_creator_map: dict[str, str] = {}
|
||
if video_ids:
|
||
rows = session.execute(
|
||
select(SourceVideo.id, Creator.name, Creator.id.label("creator_id"))
|
||
.join(Creator, SourceVideo.creator_id == Creator.id)
|
||
.where(SourceVideo.id.in_(video_ids))
|
||
).all()
|
||
video_creator_map = {str(r[0]): r[1] for r in rows}
|
||
video_creator_id_map = {str(r[0]): str(r[2]) for r in rows}
|
||
|
||
embed_client = EmbeddingClient(settings)
|
||
qdrant = QdrantManager(settings)
|
||
|
||
# Ensure collection exists before upserting
|
||
qdrant.ensure_collection()
|
||
|
||
# ── Embed & upsert technique pages ───────────────────────────────
|
||
if pages:
|
||
page_texts = []
|
||
page_dicts = []
|
||
for p in pages:
|
||
creator_name = creator_map.get(str(p.creator_id), "")
|
||
tags_joined = " ".join(p.topic_tags) if p.topic_tags else ""
|
||
text = f"{creator_name} {p.title} {p.topic_category or ''} {tags_joined} {p.summary or ''}"
|
||
page_texts.append(text.strip())
|
||
page_dicts.append({
|
||
"page_id": str(p.id),
|
||
"creator_id": str(p.creator_id),
|
||
"creator_name": creator_name,
|
||
"title": p.title,
|
||
"slug": p.slug,
|
||
"topic_category": p.topic_category or "",
|
||
"topic_tags": p.topic_tags or [],
|
||
"summary": p.summary or "",
|
||
})
|
||
|
||
page_vectors = embed_client.embed(page_texts)
|
||
if page_vectors:
|
||
qdrant.upsert_technique_pages(page_dicts, page_vectors)
|
||
logger.info(
|
||
"Stage 6: Upserted %d technique page vectors for video_id=%s",
|
||
len(page_vectors), video_id,
|
||
)
|
||
else:
|
||
logger.warning(
|
||
"Stage 6: Embedding returned empty for %d technique pages (video_id=%s). "
|
||
"Skipping page upsert.",
|
||
len(page_texts), video_id,
|
||
)
|
||
|
||
# ── Embed & upsert key moments ───────────────────────────────────
|
||
if moments:
|
||
# Build page_id → slug mapping for linking moments to technique pages
|
||
page_id_to_slug: dict[str, str] = {}
|
||
if pages:
|
||
for p in pages:
|
||
page_id_to_slug[str(p.id)] = p.slug
|
||
|
||
moment_texts = []
|
||
moment_dicts = []
|
||
for m in moments:
|
||
creator_name = video_creator_map.get(str(m.source_video_id), "")
|
||
text = f"{creator_name} {m.title} {m.summary or ''}"
|
||
moment_texts.append(text.strip())
|
||
tp_id = str(m.technique_page_id) if m.technique_page_id else ""
|
||
moment_dicts.append({
|
||
"moment_id": str(m.id),
|
||
"source_video_id": str(m.source_video_id),
|
||
"creator_id": video_creator_id_map.get(str(m.source_video_id), ""),
|
||
"technique_page_id": tp_id,
|
||
"technique_page_slug": page_id_to_slug.get(tp_id, ""),
|
||
"title": m.title,
|
||
"creator_name": creator_name,
|
||
"start_time": m.start_time,
|
||
"end_time": m.end_time,
|
||
"content_type": m.content_type.value,
|
||
})
|
||
|
||
moment_vectors = embed_client.embed(moment_texts)
|
||
if moment_vectors:
|
||
qdrant.upsert_key_moments(moment_dicts, moment_vectors)
|
||
logger.info(
|
||
"Stage 6: Upserted %d key moment vectors for video_id=%s",
|
||
len(moment_vectors), video_id,
|
||
)
|
||
else:
|
||
logger.warning(
|
||
"Stage 6: Embedding returned empty for %d key moments (video_id=%s). "
|
||
"Skipping moment upsert.",
|
||
len(moment_texts), video_id,
|
||
)
|
||
|
||
# ── Embed & upsert technique page sections (v2 only) ────────────
|
||
section_count = 0
|
||
v2_pages = [p for p in pages if getattr(p, "body_sections_format", "v1") == "v2"]
|
||
for p in v2_pages:
|
||
body_sections = p.body_sections
|
||
if not isinstance(body_sections, list):
|
||
continue
|
||
|
||
creator_name = creator_map.get(str(p.creator_id), "")
|
||
page_id_str = str(p.id)
|
||
|
||
# Delete stale section points before re-upserting
|
||
try:
|
||
qdrant.delete_sections_by_page_id(page_id_str)
|
||
except Exception as exc:
|
||
logger.warning(
|
||
"Stage 6: Failed to delete stale sections for page_id=%s: %s",
|
||
page_id_str, exc,
|
||
)
|
||
|
||
section_texts: list[str] = []
|
||
section_dicts: list[dict] = []
|
||
|
||
for section in body_sections:
|
||
if not isinstance(section, dict):
|
||
logger.warning(
|
||
"Stage 6: Malformed section (not a dict) in page_id=%s. Skipping.",
|
||
page_id_str,
|
||
)
|
||
continue
|
||
heading = section.get("heading", "")
|
||
if not heading or not heading.strip():
|
||
continue
|
||
|
||
section_anchor = _slugify_heading(heading)
|
||
section_content = section.get("content", "")
|
||
# Include subsection content for richer embedding
|
||
subsection_parts: list[str] = []
|
||
for sub in section.get("subsections", []):
|
||
if isinstance(sub, dict):
|
||
sub_heading = sub.get("heading", "")
|
||
sub_content = sub.get("content", "")
|
||
if sub_heading:
|
||
subsection_parts.append(f"{sub_heading}: {sub_content}")
|
||
elif sub_content:
|
||
subsection_parts.append(sub_content)
|
||
|
||
embed_text = (
|
||
f"{creator_name} {p.title} — {heading}: "
|
||
f"{section_content} {' '.join(subsection_parts)}"
|
||
).strip()
|
||
section_texts.append(embed_text)
|
||
|
||
section_dicts.append({
|
||
"page_id": page_id_str,
|
||
"creator_id": str(p.creator_id),
|
||
"creator_name": creator_name,
|
||
"title": p.title,
|
||
"slug": p.slug,
|
||
"section_heading": heading,
|
||
"section_anchor": section_anchor,
|
||
"topic_category": p.topic_category or "",
|
||
"topic_tags": p.topic_tags or [],
|
||
"summary": (section_content or "")[:200],
|
||
})
|
||
|
||
if section_texts:
|
||
try:
|
||
section_vectors = embed_client.embed(section_texts)
|
||
if section_vectors:
|
||
qdrant.upsert_technique_sections(section_dicts, section_vectors)
|
||
section_count += len(section_vectors)
|
||
else:
|
||
logger.warning(
|
||
"Stage 6: Embedding returned empty for %d sections of page_id=%s. Skipping.",
|
||
len(section_texts), page_id_str,
|
||
)
|
||
except Exception as exc:
|
||
logger.warning(
|
||
"Stage 6: Section embedding failed for page_id=%s: %s. Skipping.",
|
||
page_id_str, exc,
|
||
)
|
||
|
||
if section_count:
|
||
logger.info(
|
||
"Stage 6: Upserted %d technique section vectors for video_id=%s",
|
||
section_count, video_id,
|
||
)
|
||
|
||
elapsed = time.monotonic() - start
|
||
logger.info(
|
||
"Stage 6 (embed & index) completed for video_id=%s in %.1fs — "
|
||
"%d pages, %d moments processed",
|
||
video_id, elapsed, len(pages), len(moments),
|
||
)
|
||
if run_id:
|
||
_finish_run(run_id, "complete")
|
||
return video_id
|
||
|
||
except Exception as exc:
|
||
# Non-blocking: log error but don't fail the pipeline
|
||
logger.error(
|
||
"Stage 6 failed for video_id=%s: %s. "
|
||
"Pipeline continues — embeddings can be regenerated later.",
|
||
video_id, exc,
|
||
)
|
||
if run_id:
|
||
_finish_run(run_id, "complete") # Run is still "complete" — stage6 is best-effort
|
||
return video_id
|
||
finally:
|
||
session.close()
|
||
|
||
|
||
|
||
|
||
def _snapshot_prior_pages(video_id: str) -> None:
|
||
"""Save existing technique_page_ids linked to this video before pipeline resets them.
|
||
|
||
When a video is reprocessed, stage 3 deletes and recreates key_moments,
|
||
breaking the link to technique pages. This snapshots the page IDs to Redis
|
||
so stage 5 can find and update prior pages instead of creating duplicates.
|
||
"""
|
||
import redis
|
||
|
||
session = _get_sync_session()
|
||
try:
|
||
# Find technique pages linked via this video's key moments
|
||
rows = session.execute(
|
||
select(KeyMoment.technique_page_id)
|
||
.where(
|
||
KeyMoment.source_video_id == video_id,
|
||
KeyMoment.technique_page_id.isnot(None),
|
||
)
|
||
.distinct()
|
||
).scalars().all()
|
||
|
||
page_ids = [str(pid) for pid in rows]
|
||
|
||
if page_ids:
|
||
settings = get_settings()
|
||
r = redis.Redis.from_url(settings.redis_url)
|
||
key = f"chrysopedia:prior_pages:{video_id}"
|
||
r.set(key, json.dumps(page_ids), ex=86400)
|
||
logger.info(
|
||
"Snapshot %d prior technique pages for video_id=%s: %s",
|
||
len(page_ids), video_id, page_ids,
|
||
)
|
||
else:
|
||
logger.info("No prior technique pages for video_id=%s", video_id)
|
||
finally:
|
||
session.close()
|
||
|
||
|
||
def _load_prior_pages(video_id: str) -> list[str]:
|
||
"""Load prior technique page IDs from Redis."""
|
||
import redis
|
||
|
||
settings = get_settings()
|
||
r = redis.Redis.from_url(settings.redis_url)
|
||
key = f"chrysopedia:prior_pages:{video_id}"
|
||
raw = r.get(key)
|
||
if raw is None:
|
||
return []
|
||
return json.loads(raw)
|
||
|
||
|
||
# ── Stage completion detection for auto-resume ───────────────────────────────
|
||
|
||
# Ordered list of pipeline stages for resumability logic
|
||
_PIPELINE_STAGES = [
|
||
"stage2_segmentation",
|
||
"stage3_extraction",
|
||
"stage4_classification",
|
||
"stage5_synthesis",
|
||
"stage6_embed_and_index",
|
||
]
|
||
|
||
_STAGE_TASKS = {
|
||
"stage2_segmentation": stage2_segmentation,
|
||
"stage3_extraction": stage3_extraction,
|
||
"stage4_classification": stage4_classification,
|
||
"stage5_synthesis": stage5_synthesis,
|
||
"stage6_embed_and_index": stage6_embed_and_index,
|
||
}
|
||
|
||
|
||
def _get_last_completed_stage(video_id: str) -> str | None:
|
||
"""Find the last stage that completed successfully for this video.
|
||
|
||
Queries pipeline_events for the most recent run, looking for 'complete'
|
||
events. Returns the stage name (e.g. 'stage3_extraction') or None if
|
||
no stages have completed.
|
||
"""
|
||
session = _get_sync_session()
|
||
try:
|
||
# Find the most recent run for this video
|
||
from models import PipelineRun
|
||
latest_run = session.execute(
|
||
select(PipelineRun)
|
||
.where(PipelineRun.video_id == video_id)
|
||
.order_by(PipelineRun.started_at.desc())
|
||
.limit(1)
|
||
).scalar_one_or_none()
|
||
|
||
if latest_run is None:
|
||
return None
|
||
|
||
# Get all 'complete' events from that run
|
||
completed_events = session.execute(
|
||
select(PipelineEvent.stage)
|
||
.where(
|
||
PipelineEvent.run_id == str(latest_run.id),
|
||
PipelineEvent.event_type == "complete",
|
||
)
|
||
).scalars().all()
|
||
|
||
completed_set = set(completed_events)
|
||
|
||
# Walk backwards through the ordered stages to find the last completed one
|
||
last_completed = None
|
||
for stage_name in _PIPELINE_STAGES:
|
||
if stage_name in completed_set:
|
||
last_completed = stage_name
|
||
else:
|
||
break # Stop at first gap — stages must be sequential
|
||
|
||
if last_completed:
|
||
logger.info(
|
||
"Auto-resume: video_id=%s last completed stage=%s (run_id=%s)",
|
||
video_id, last_completed, latest_run.id,
|
||
)
|
||
return last_completed
|
||
finally:
|
||
session.close()
|
||
|
||
|
||
# ── Orchestrator ─────────────────────────────────────────────────────────────
|
||
|
||
@celery_app.task
|
||
def mark_pipeline_error(request, exc, traceback, video_id: str, run_id: str | None = None) -> None:
|
||
"""Error callback — marks video as errored when a pipeline stage fails."""
|
||
logger.error("Pipeline failed for video_id=%s: %s", video_id, exc)
|
||
_set_error_status(video_id, "pipeline", exc)
|
||
if run_id:
|
||
_finish_run(run_id, "error", error_stage="pipeline")
|
||
|
||
|
||
def _create_run(video_id: str, trigger: str) -> str:
|
||
"""Create a PipelineRun and return its id."""
|
||
from models import PipelineRun, PipelineRunTrigger
|
||
|
||
session = _get_sync_session()
|
||
try:
|
||
# Compute run_number: max existing + 1
|
||
from sqlalchemy import func as sa_func
|
||
max_num = session.execute(
|
||
select(sa_func.coalesce(sa_func.max(PipelineRun.run_number), 0))
|
||
.where(PipelineRun.video_id == video_id)
|
||
).scalar() or 0
|
||
run = PipelineRun(
|
||
video_id=video_id,
|
||
run_number=max_num + 1,
|
||
trigger=PipelineRunTrigger(trigger),
|
||
)
|
||
session.add(run)
|
||
session.commit()
|
||
run_id = str(run.id)
|
||
return run_id
|
||
finally:
|
||
session.close()
|
||
|
||
|
||
def _finish_run(run_id: str, status: str, error_stage: str | None = None) -> None:
|
||
"""Update a PipelineRun's status and finished_at."""
|
||
from models import PipelineRun, PipelineRunStatus, _now
|
||
|
||
session = _get_sync_session()
|
||
try:
|
||
run = session.execute(
|
||
select(PipelineRun).where(PipelineRun.id == run_id)
|
||
).scalar_one_or_none()
|
||
if run:
|
||
run.status = PipelineRunStatus(status)
|
||
run.finished_at = _now()
|
||
if error_stage:
|
||
run.error_stage = error_stage
|
||
# Aggregate total tokens from events
|
||
total = session.execute(
|
||
select(func.coalesce(func.sum(PipelineEvent.total_tokens), 0))
|
||
.where(PipelineEvent.run_id == run_id)
|
||
).scalar() or 0
|
||
run.total_tokens = total
|
||
session.commit()
|
||
except Exception as exc:
|
||
logger.warning("Failed to finish run %s: %s", run_id, exc)
|
||
finally:
|
||
session.close()
|
||
|
||
|
||
@celery_app.task
|
||
def run_pipeline(video_id: str, trigger: str = "manual") -> str:
|
||
"""Orchestrate the full pipeline (stages 2-6) with auto-resume.
|
||
|
||
For error/processing status, queries pipeline_events to find the last
|
||
stage that completed successfully and resumes from the next stage.
|
||
This avoids re-running expensive LLM stages that already succeeded.
|
||
|
||
For clean_reprocess trigger, always starts from stage 2.
|
||
|
||
Returns the video_id.
|
||
"""
|
||
logger.info("run_pipeline starting for video_id=%s", video_id)
|
||
|
||
session = _get_sync_session()
|
||
try:
|
||
video = session.execute(
|
||
select(SourceVideo).where(SourceVideo.id == video_id)
|
||
).scalar_one_or_none()
|
||
|
||
if video is None:
|
||
logger.error("run_pipeline: video_id=%s not found", video_id)
|
||
raise ValueError(f"Video not found: {video_id}")
|
||
|
||
status = video.processing_status
|
||
logger.info(
|
||
"run_pipeline: video_id=%s current status=%s", video_id, status.value
|
||
)
|
||
finally:
|
||
session.close()
|
||
|
||
if status == ProcessingStatus.complete:
|
||
logger.info(
|
||
"run_pipeline: video_id=%s already at status=%s, nothing to do.",
|
||
video_id, status.value,
|
||
)
|
||
return video_id
|
||
|
||
# Snapshot prior technique pages before pipeline resets key_moments
|
||
_snapshot_prior_pages(video_id)
|
||
|
||
# Create a pipeline run record
|
||
run_id = _create_run(video_id, trigger)
|
||
logger.info("run_pipeline: created run_id=%s for video_id=%s (trigger=%s)", run_id, video_id, trigger)
|
||
|
||
# Determine which stages to run
|
||
resume_from_idx = 0 # Default: start from stage 2
|
||
|
||
if trigger != "clean_reprocess" and status in (ProcessingStatus.processing, ProcessingStatus.error):
|
||
# Try to resume from where we left off
|
||
last_completed = _get_last_completed_stage(video_id)
|
||
if last_completed and last_completed in _PIPELINE_STAGES:
|
||
completed_idx = _PIPELINE_STAGES.index(last_completed)
|
||
resume_from_idx = completed_idx + 1
|
||
if resume_from_idx >= len(_PIPELINE_STAGES):
|
||
logger.info(
|
||
"run_pipeline: all stages already completed for video_id=%s",
|
||
video_id,
|
||
)
|
||
return video_id
|
||
|
||
stages_to_run = _PIPELINE_STAGES[resume_from_idx:]
|
||
logger.info(
|
||
"run_pipeline: video_id=%s will run stages: %s (resume_from_idx=%d)",
|
||
video_id, stages_to_run, resume_from_idx,
|
||
)
|
||
|
||
# Run stages inline (synchronously) so each video completes fully
|
||
# before the worker picks up the next queued video.
|
||
# This replaces the previous celery_chain dispatch which caused
|
||
# interleaved execution when multiple videos were queued.
|
||
if stages_to_run:
|
||
# Mark as processing before starting
|
||
session = _get_sync_session()
|
||
try:
|
||
video = session.execute(
|
||
select(SourceVideo).where(SourceVideo.id == video_id)
|
||
).scalar_one()
|
||
video.processing_status = ProcessingStatus.processing
|
||
session.commit()
|
||
finally:
|
||
session.close()
|
||
|
||
logger.info(
|
||
"run_pipeline: executing %d stages inline for video_id=%s (run_id=%s, starting at %s)",
|
||
len(stages_to_run), video_id, run_id, stages_to_run[0],
|
||
)
|
||
|
||
try:
|
||
for stage_name in stages_to_run:
|
||
task_func = _STAGE_TASKS[stage_name]
|
||
# Call the task directly — runs synchronously in this worker
|
||
# process. bind=True tasks receive the task instance as self
|
||
# automatically when called this way.
|
||
task_func(video_id, run_id=run_id)
|
||
except Exception as exc:
|
||
logger.error(
|
||
"run_pipeline: stage %s failed for video_id=%s: %s",
|
||
stage_name, video_id, exc,
|
||
)
|
||
_set_error_status(video_id, stage_name, exc)
|
||
if run_id:
|
||
_finish_run(run_id, "error", error_stage=stage_name)
|
||
raise
|
||
|
||
return video_id
|
||
|
||
|
||
# ── Single-Stage Re-Run ─────────────────────────────────────────────────────
|
||
|
||
@celery_app.task
|
||
def run_single_stage(
|
||
video_id: str,
|
||
stage_name: str,
|
||
trigger: str = "stage_rerun",
|
||
prompt_override: str | None = None,
|
||
) -> str:
|
||
"""Re-run a single pipeline stage without running predecessors.
|
||
|
||
Designed for fast prompt iteration — especially stage 5 synthesis.
|
||
Bypasses the processing_status==complete guard, creates a proper
|
||
PipelineRun record, and restores status on completion.
|
||
|
||
If ``prompt_override`` is provided, it is stored in Redis as a
|
||
per-video override that ``_load_prompt`` reads before falling back
|
||
to the on-disk template. The override is cleaned up after the stage runs.
|
||
|
||
Returns the video_id.
|
||
"""
|
||
import redis as redis_lib
|
||
|
||
logger.info(
|
||
"[RERUN] Starting single-stage re-run: video_id=%s, stage=%s, trigger=%s",
|
||
video_id, stage_name, trigger,
|
||
)
|
||
|
||
# Validate stage name
|
||
if stage_name not in _PIPELINE_STAGES:
|
||
raise ValueError(
|
||
f"[RERUN] Invalid stage '{stage_name}'. "
|
||
f"Valid stages: {_PIPELINE_STAGES}"
|
||
)
|
||
|
||
# Validate video exists
|
||
session = _get_sync_session()
|
||
try:
|
||
video = session.execute(
|
||
select(SourceVideo).where(SourceVideo.id == video_id)
|
||
).scalar_one_or_none()
|
||
if video is None:
|
||
raise ValueError(f"[RERUN] Video not found: {video_id}")
|
||
original_status = video.processing_status
|
||
finally:
|
||
session.close()
|
||
|
||
# Validate prerequisites for the requested stage
|
||
prereq_ok, prereq_msg = _check_stage_prerequisites(video_id, stage_name)
|
||
if not prereq_ok:
|
||
logger.error("[RERUN] Prerequisite check failed: %s", prereq_msg)
|
||
raise ValueError(f"[RERUN] Prerequisites not met: {prereq_msg}")
|
||
|
||
logger.info("[RERUN] Prerequisite check passed: %s", prereq_msg)
|
||
|
||
# Store prompt override in Redis if provided
|
||
override_key = None
|
||
if prompt_override:
|
||
settings = get_settings()
|
||
try:
|
||
r = redis_lib.Redis.from_url(settings.redis_url)
|
||
# Map stage name to its prompt template
|
||
stage_prompt_map = {
|
||
"stage2_segmentation": "stage2_segmentation.txt",
|
||
"stage3_extraction": "stage3_extraction.txt",
|
||
"stage4_classification": "stage4_classification.txt",
|
||
"stage5_synthesis": "stage5_synthesis.txt",
|
||
}
|
||
template = stage_prompt_map.get(stage_name)
|
||
if template:
|
||
override_key = f"chrysopedia:prompt_override:{video_id}:{template}"
|
||
r.set(override_key, prompt_override, ex=3600) # 1-hour TTL
|
||
logger.info(
|
||
"[RERUN] Prompt override stored: key=%s (%d chars, first 100: %s)",
|
||
override_key, len(prompt_override), prompt_override[:100],
|
||
)
|
||
except Exception as exc:
|
||
logger.warning("[RERUN] Failed to store prompt override: %s", exc)
|
||
|
||
# Snapshot prior pages (needed for stage 5 page matching)
|
||
if stage_name in ("stage5_synthesis",):
|
||
_snapshot_prior_pages(video_id)
|
||
|
||
# Create pipeline run record
|
||
run_id = _create_run(video_id, trigger)
|
||
logger.info("[RERUN] Created run_id=%s", run_id)
|
||
|
||
# Temporarily set status to processing
|
||
session = _get_sync_session()
|
||
try:
|
||
video = session.execute(
|
||
select(SourceVideo).where(SourceVideo.id == video_id)
|
||
).scalar_one()
|
||
video.processing_status = ProcessingStatus.processing
|
||
session.commit()
|
||
finally:
|
||
session.close()
|
||
|
||
# Run the single stage
|
||
start = time.monotonic()
|
||
try:
|
||
task_func = _STAGE_TASKS[stage_name]
|
||
task_func(video_id, run_id=run_id)
|
||
|
||
elapsed = time.monotonic() - start
|
||
logger.info(
|
||
"[RERUN] Stage %s completed: %.1fs, video_id=%s",
|
||
stage_name, elapsed, video_id,
|
||
)
|
||
_finish_run(run_id, "complete")
|
||
|
||
# Restore status to complete
|
||
session = _get_sync_session()
|
||
try:
|
||
video = session.execute(
|
||
select(SourceVideo).where(SourceVideo.id == video_id)
|
||
).scalar_one()
|
||
video.processing_status = ProcessingStatus.complete
|
||
session.commit()
|
||
logger.info("[RERUN] Status restored to complete")
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as exc:
|
||
elapsed = time.monotonic() - start
|
||
logger.error(
|
||
"[RERUN] Stage %s FAILED after %.1fs: %s",
|
||
stage_name, elapsed, exc,
|
||
)
|
||
_set_error_status(video_id, stage_name, exc)
|
||
_finish_run(run_id, "error", error_stage=stage_name)
|
||
raise
|
||
|
||
finally:
|
||
# Clean up prompt override from Redis
|
||
if override_key:
|
||
try:
|
||
settings = get_settings()
|
||
r = redis_lib.Redis.from_url(settings.redis_url)
|
||
r.delete(override_key)
|
||
logger.info("[RERUN] Prompt override cleaned up: %s", override_key)
|
||
except Exception as exc:
|
||
logger.warning("[RERUN] Failed to clean up prompt override: %s", exc)
|
||
|
||
return video_id
|
||
|
||
|
||
def _check_stage_prerequisites(video_id: str, stage_name: str) -> tuple[bool, str]:
|
||
"""Validate that prerequisite data exists for a stage re-run.
|
||
|
||
Returns (ok, message) where message describes what was found or missing.
|
||
"""
|
||
session = _get_sync_session()
|
||
try:
|
||
if stage_name == "stage2_segmentation":
|
||
# Needs transcript segments
|
||
count = session.execute(
|
||
select(func.count(TranscriptSegment.id))
|
||
.where(TranscriptSegment.source_video_id == video_id)
|
||
).scalar() or 0
|
||
if count == 0:
|
||
return False, "No transcript segments found"
|
||
return True, f"transcript_segments={count}"
|
||
|
||
if stage_name == "stage3_extraction":
|
||
# Needs transcript segments with topic_labels
|
||
count = session.execute(
|
||
select(func.count(TranscriptSegment.id))
|
||
.where(
|
||
TranscriptSegment.source_video_id == video_id,
|
||
TranscriptSegment.topic_label.isnot(None),
|
||
)
|
||
).scalar() or 0
|
||
if count == 0:
|
||
return False, "No labeled transcript segments (stage 2 must complete first)"
|
||
return True, f"labeled_segments={count}"
|
||
|
||
if stage_name == "stage4_classification":
|
||
# Needs key moments
|
||
count = session.execute(
|
||
select(func.count(KeyMoment.id))
|
||
.where(KeyMoment.source_video_id == video_id)
|
||
).scalar() or 0
|
||
if count == 0:
|
||
return False, "No key moments found (stage 3 must complete first)"
|
||
return True, f"key_moments={count}"
|
||
|
||
if stage_name == "stage5_synthesis":
|
||
# Needs key moments + classification data
|
||
km_count = session.execute(
|
||
select(func.count(KeyMoment.id))
|
||
.where(KeyMoment.source_video_id == video_id)
|
||
).scalar() or 0
|
||
if km_count == 0:
|
||
return False, "No key moments found (stages 2-3 must complete first)"
|
||
|
||
cls_data = _load_classification_data(video_id)
|
||
cls_source = "redis+pg"
|
||
if not cls_data:
|
||
return False, f"No classification data found (stage 4 must complete first), key_moments={km_count}"
|
||
|
||
return True, f"key_moments={km_count}, classification_entries={len(cls_data)}"
|
||
|
||
if stage_name == "stage6_embed_and_index":
|
||
return True, "stage 6 is non-blocking and always runs"
|
||
|
||
return False, f"Unknown stage: {stage_name}"
|
||
finally:
|
||
session.close()
|
||
|
||
|
||
# ── Avatar Fetching ─────────────────────────────────────────────────────────
|
||
|
||
@celery_app.task
|
||
def fetch_creator_avatar(creator_id: str) -> dict:
|
||
"""Fetch avatar for a single creator from TheAudioDB.
|
||
|
||
Looks up the creator by ID, calls TheAudioDB, and updates the
|
||
avatar_url/avatar_source/avatar_fetched_at columns if a confident
|
||
match is found. Returns a status dict.
|
||
"""
|
||
import sys
|
||
from datetime import datetime, timezone
|
||
# Ensure /app is on sys.path for forked Celery workers
|
||
if "/app" not in sys.path:
|
||
sys.path.insert(0, "/app")
|
||
from services.avatar import lookup_avatar
|
||
|
||
session = _get_sync_session()
|
||
try:
|
||
creator = session.execute(
|
||
select(Creator).where(Creator.id == creator_id)
|
||
).scalar_one_or_none()
|
||
|
||
if not creator:
|
||
return {"status": "error", "detail": f"Creator {creator_id} not found"}
|
||
|
||
result = lookup_avatar(creator.name, creator.genres)
|
||
|
||
if result:
|
||
creator.avatar_url = result.url
|
||
creator.avatar_source = result.source
|
||
creator.avatar_fetched_at = datetime.now(timezone.utc)
|
||
session.commit()
|
||
return {
|
||
"status": "found",
|
||
"creator": creator.name,
|
||
"avatar_url": result.url,
|
||
"confidence": result.confidence,
|
||
"matched_artist": result.artist_name,
|
||
}
|
||
else:
|
||
creator.avatar_source = "generated"
|
||
creator.avatar_fetched_at = datetime.now(timezone.utc)
|
||
session.commit()
|
||
return {
|
||
"status": "not_found",
|
||
"creator": creator.name,
|
||
"detail": "No confident match from TheAudioDB",
|
||
}
|
||
except Exception as exc:
|
||
session.rollback()
|
||
logger.error("Avatar fetch failed for creator %s: %s", creator_id, exc)
|
||
return {"status": "error", "detail": str(exc)}
|
||
finally:
|
||
session.close()
|
||
|
||
|
||
# ── Highlight Detection ──────────────────────────────────────────────────────
|
||
|
||
@celery_app.task(bind=True, max_retries=3, default_retry_delay=30)
|
||
def stage_highlight_detection(self, video_id: str, run_id: str | None = None) -> str:
|
||
"""Score all KeyMoments for a video and upsert HighlightCandidates.
|
||
|
||
For each KeyMoment belonging to the video, runs the heuristic scorer and
|
||
bulk-upserts results into highlight_candidates (INSERT ON CONFLICT UPDATE).
|
||
|
||
Returns the video_id for chain compatibility.
|
||
"""
|
||
from pipeline.highlight_scorer import extract_word_timings, score_moment
|
||
|
||
start = time.monotonic()
|
||
logger.info("Highlight detection starting for video_id=%s", video_id)
|
||
_emit_event(video_id, "highlight_detection", "start", run_id=run_id)
|
||
|
||
session = _get_sync_session()
|
||
try:
|
||
# ------------------------------------------------------------------
|
||
# Load transcript data once for the entire video (word-level timing)
|
||
# ------------------------------------------------------------------
|
||
transcript_data: list | None = None
|
||
source_video = session.execute(
|
||
select(SourceVideo).where(SourceVideo.id == video_id)
|
||
).scalar_one_or_none()
|
||
|
||
if source_video and source_video.transcript_path:
|
||
transcript_file = source_video.transcript_path
|
||
try:
|
||
with open(transcript_file, "r") as fh:
|
||
raw = json.load(fh)
|
||
# Accept both {"segments": [...]} and bare [...]
|
||
if isinstance(raw, dict):
|
||
transcript_data = raw.get("segments", raw.get("results", []))
|
||
elif isinstance(raw, list):
|
||
transcript_data = raw
|
||
else:
|
||
transcript_data = None
|
||
if transcript_data:
|
||
logger.info(
|
||
"Loaded transcript for video_id=%s (%d segments)",
|
||
video_id, len(transcript_data),
|
||
)
|
||
except FileNotFoundError:
|
||
logger.warning(
|
||
"Transcript file not found for video_id=%s: %s",
|
||
video_id, transcript_file,
|
||
)
|
||
except (json.JSONDecodeError, OSError) as io_exc:
|
||
logger.warning(
|
||
"Failed to load transcript for video_id=%s: %s",
|
||
video_id, io_exc,
|
||
)
|
||
else:
|
||
logger.info(
|
||
"No transcript_path for video_id=%s — audio proxy signals will be neutral",
|
||
video_id,
|
||
)
|
||
|
||
moments = (
|
||
session.execute(
|
||
select(KeyMoment)
|
||
.where(KeyMoment.source_video_id == video_id)
|
||
.order_by(KeyMoment.start_time)
|
||
)
|
||
.scalars()
|
||
.all()
|
||
)
|
||
|
||
if not moments:
|
||
logger.info(
|
||
"Highlight detection: No key moments for video_id=%s, skipping.", video_id,
|
||
)
|
||
_emit_event(
|
||
video_id, "highlight_detection", "complete",
|
||
run_id=run_id, payload={"candidates": 0},
|
||
)
|
||
return video_id
|
||
|
||
candidate_count = 0
|
||
for moment in moments:
|
||
try:
|
||
# Extract word-level timings for this moment's window
|
||
word_timings = None
|
||
if transcript_data:
|
||
word_timings = extract_word_timings(
|
||
transcript_data, moment.start_time, moment.end_time,
|
||
) or None # empty list → None for neutral fallback
|
||
|
||
result = score_moment(
|
||
start_time=moment.start_time,
|
||
end_time=moment.end_time,
|
||
content_type=moment.content_type.value if moment.content_type else None,
|
||
summary=moment.summary,
|
||
plugins=moment.plugins,
|
||
raw_transcript=moment.raw_transcript,
|
||
source_quality=None, # filled below if technique_page loaded
|
||
video_content_type=None, # filled below if source_video loaded
|
||
word_timings=word_timings,
|
||
)
|
||
except Exception as score_exc:
|
||
logger.warning(
|
||
"Highlight detection: score_moment failed for moment %s: %s",
|
||
moment.id, score_exc,
|
||
)
|
||
result = {
|
||
"score": 0.0,
|
||
"score_breakdown": {},
|
||
"duration_secs": max(0.0, moment.end_time - moment.start_time),
|
||
}
|
||
|
||
stmt = pg_insert(HighlightCandidate).values(
|
||
key_moment_id=moment.id,
|
||
source_video_id=moment.source_video_id,
|
||
score=result["score"],
|
||
score_breakdown=result["score_breakdown"],
|
||
duration_secs=result["duration_secs"],
|
||
)
|
||
stmt = stmt.on_conflict_do_update(
|
||
constraint="highlight_candidates_key_moment_id_key",
|
||
set_={
|
||
"score": stmt.excluded.score,
|
||
"score_breakdown": stmt.excluded.score_breakdown,
|
||
"duration_secs": stmt.excluded.duration_secs,
|
||
"updated_at": func.now(),
|
||
},
|
||
)
|
||
session.execute(stmt)
|
||
candidate_count += 1
|
||
|
||
session.commit()
|
||
elapsed = time.monotonic() - start
|
||
_emit_event(
|
||
video_id, "highlight_detection", "complete",
|
||
run_id=run_id, payload={"candidates": candidate_count},
|
||
)
|
||
logger.info(
|
||
"Highlight detection completed for video_id=%s in %.1fs — %d candidates upserted",
|
||
video_id, elapsed, candidate_count,
|
||
)
|
||
return video_id
|
||
|
||
except Exception as exc:
|
||
session.rollback()
|
||
_emit_event(
|
||
video_id, "highlight_detection", "error",
|
||
run_id=run_id, payload={"error": str(exc)},
|
||
)
|
||
logger.error("Highlight detection failed for video_id=%s: %s", video_id, exc)
|
||
raise self.retry(exc=exc)
|
||
finally:
|
||
session.close()
|
||
|
||
|
||
# ── Personality profile extraction ───────────────────────────────────────────
|
||
|
||
|
||
def _sample_creator_transcripts(
|
||
moments: list,
|
||
creator_id: str,
|
||
max_chars: int = 40000,
|
||
) -> tuple[str, int]:
|
||
"""Sample transcripts from a creator's key moments, respecting size tiers.
|
||
|
||
- Small (<20K chars total): use all text.
|
||
- Medium (20K-60K): first 300 chars from each moment, up to budget.
|
||
- Large (>60K): random sample seeded by creator_id, attempts topic diversity
|
||
via Redis classification data.
|
||
|
||
Returns (sampled_text, total_char_count).
|
||
"""
|
||
import random
|
||
|
||
transcripts = [
|
||
(m.source_video_id, m.raw_transcript)
|
||
for m in moments
|
||
if m.raw_transcript and m.raw_transcript.strip()
|
||
]
|
||
if not transcripts:
|
||
return ("", 0)
|
||
|
||
total_chars = sum(len(t) for _, t in transcripts)
|
||
|
||
# Small: use everything
|
||
if total_chars <= 20_000:
|
||
text = "\n\n---\n\n".join(t for _, t in transcripts)
|
||
return (text, total_chars)
|
||
|
||
# Medium: first 300 chars from each moment
|
||
if total_chars <= 60_000:
|
||
excerpts = []
|
||
budget = max_chars
|
||
for _, t in transcripts:
|
||
chunk = t[:300]
|
||
if budget - len(chunk) < 0:
|
||
break
|
||
excerpts.append(chunk)
|
||
budget -= len(chunk)
|
||
text = "\n\n---\n\n".join(excerpts)
|
||
return (text, total_chars)
|
||
|
||
# Large: random sample with optional topic diversity from Redis
|
||
topic_map: dict[str, list[tuple[str, str]]] = {}
|
||
try:
|
||
import redis as _redis
|
||
settings = get_settings()
|
||
r = _redis.from_url(settings.redis_url)
|
||
video_ids = {str(vid) for vid, _ in transcripts}
|
||
for vid in video_ids:
|
||
raw = r.get(f"chrysopedia:classification:{vid}")
|
||
if raw:
|
||
classification = json.loads(raw)
|
||
if isinstance(classification, list):
|
||
for item in classification:
|
||
cat = item.get("topic_category", "unknown")
|
||
moment_id = item.get("moment_id")
|
||
if moment_id:
|
||
topic_map.setdefault(cat, []).append(moment_id)
|
||
r.close()
|
||
except Exception:
|
||
# Fall back to random sampling without topic diversity
|
||
pass
|
||
|
||
rng = random.Random(creator_id)
|
||
|
||
if topic_map:
|
||
# Interleave from different categories for diversity
|
||
ordered = []
|
||
cat_lists = list(topic_map.values())
|
||
rng.shuffle(cat_lists)
|
||
idx = 0
|
||
while any(cat_lists):
|
||
for cat in cat_lists:
|
||
if cat:
|
||
ordered.append(cat.pop(0))
|
||
cat_lists = [c for c in cat_lists if c]
|
||
# Map moment IDs back to transcripts
|
||
moment_lookup = {str(m.id): m.raw_transcript for m in moments if m.raw_transcript}
|
||
diverse_transcripts = [
|
||
moment_lookup[mid] for mid in ordered if mid in moment_lookup
|
||
]
|
||
if diverse_transcripts:
|
||
transcripts_list = diverse_transcripts
|
||
else:
|
||
transcripts_list = [t for _, t in transcripts]
|
||
else:
|
||
transcripts_list = [t for _, t in transcripts]
|
||
rng.shuffle(transcripts_list)
|
||
|
||
excerpts = []
|
||
budget = max_chars
|
||
for t in transcripts_list:
|
||
chunk = t[:600]
|
||
if budget - len(chunk) < 0:
|
||
break
|
||
excerpts.append(chunk)
|
||
budget -= len(chunk)
|
||
|
||
text = "\n\n---\n\n".join(excerpts)
|
||
return (text, total_chars)
|
||
|
||
|
||
@celery_app.task(bind=True, max_retries=2, default_retry_delay=60)
|
||
def extract_personality_profile(self, creator_id: str) -> str:
|
||
"""Extract a personality profile from a creator's transcripts via LLM.
|
||
|
||
Aggregates and samples transcripts from all of the creator's key moments,
|
||
sends them to the LLM with the personality_extraction prompt, validates
|
||
the response, and stores the profile as JSONB on Creator.personality_profile.
|
||
|
||
Returns the creator_id for chain compatibility.
|
||
"""
|
||
from datetime import datetime, timezone
|
||
|
||
start = time.monotonic()
|
||
logger.info("Personality extraction starting for creator_id=%s", creator_id)
|
||
_emit_event(creator_id, "personality_extraction", "start")
|
||
|
||
session = _get_sync_session()
|
||
try:
|
||
# Load creator
|
||
creator = session.execute(
|
||
select(Creator).where(Creator.id == creator_id)
|
||
).scalar_one_or_none()
|
||
if not creator:
|
||
logger.error("Creator not found: %s", creator_id)
|
||
_emit_event(
|
||
creator_id, "personality_extraction", "error",
|
||
payload={"error": "creator_not_found"},
|
||
)
|
||
return creator_id
|
||
|
||
# Load all key moments with transcripts for this creator
|
||
moments = (
|
||
session.execute(
|
||
select(KeyMoment)
|
||
.join(SourceVideo, KeyMoment.source_video_id == SourceVideo.id)
|
||
.where(SourceVideo.creator_id == creator.id)
|
||
.where(KeyMoment.raw_transcript.isnot(None))
|
||
)
|
||
.scalars()
|
||
.all()
|
||
)
|
||
|
||
if not moments:
|
||
logger.warning(
|
||
"No transcripts found for creator_id=%s (%s), skipping extraction",
|
||
creator_id, creator.name,
|
||
)
|
||
_emit_event(
|
||
creator_id, "personality_extraction", "complete",
|
||
payload={"skipped": True, "reason": "no_transcripts"},
|
||
)
|
||
return creator_id
|
||
|
||
# Sample transcripts
|
||
sampled_text, total_chars = _sample_creator_transcripts(
|
||
moments, creator_id,
|
||
)
|
||
|
||
if not sampled_text.strip():
|
||
logger.warning(
|
||
"Empty transcript sample for creator_id=%s, skipping", creator_id,
|
||
)
|
||
_emit_event(
|
||
creator_id, "personality_extraction", "complete",
|
||
payload={"skipped": True, "reason": "empty_sample"},
|
||
)
|
||
return creator_id
|
||
|
||
# Load prompt and call LLM
|
||
system_prompt = _load_prompt("personality_extraction.txt")
|
||
user_prompt = (
|
||
f"Creator: {creator.name}\n\n"
|
||
f"Transcript excerpts ({len(moments)} moments, {total_chars} total chars, "
|
||
f"sample below):\n\n{sampled_text}"
|
||
)
|
||
|
||
llm = _get_llm_client()
|
||
callback = _make_llm_callback(
|
||
creator_id, "personality_extraction",
|
||
system_prompt=system_prompt,
|
||
user_prompt=user_prompt,
|
||
)
|
||
|
||
response = llm.complete(
|
||
system_prompt=system_prompt,
|
||
user_prompt=user_prompt,
|
||
response_model=object, # triggers JSON mode
|
||
on_complete=callback,
|
||
)
|
||
|
||
# Parse and validate
|
||
from schemas import PersonalityProfile as ProfileValidator
|
||
try:
|
||
raw_profile = json.loads(str(response))
|
||
except json.JSONDecodeError as jde:
|
||
logger.warning(
|
||
"LLM returned invalid JSON for creator_id=%s, retrying: %s",
|
||
creator_id, jde,
|
||
)
|
||
raise self.retry(exc=jde)
|
||
|
||
try:
|
||
validated = ProfileValidator.model_validate(raw_profile)
|
||
except ValidationError as ve:
|
||
logger.warning(
|
||
"LLM profile failed validation for creator_id=%s, retrying: %s",
|
||
creator_id, ve,
|
||
)
|
||
raise self.retry(exc=ve)
|
||
|
||
# Build final profile dict with metadata
|
||
profile_dict = validated.model_dump()
|
||
profile_dict["_metadata"] = {
|
||
"extracted_at": datetime.now(timezone.utc).replace(tzinfo=None).isoformat(),
|
||
"transcript_sample_size": total_chars,
|
||
"moments_count": len(moments),
|
||
"model_used": getattr(response, "finish_reason", None) or "unknown",
|
||
}
|
||
|
||
# Low sample size note
|
||
if total_chars < 500:
|
||
profile_dict["_metadata"]["low_sample_size"] = True
|
||
|
||
# Store on creator
|
||
creator.personality_profile = profile_dict
|
||
session.commit()
|
||
|
||
elapsed = time.monotonic() - start
|
||
_emit_event(
|
||
creator_id, "personality_extraction", "complete",
|
||
duration_ms=int(elapsed * 1000),
|
||
payload={
|
||
"moments_count": len(moments),
|
||
"transcript_chars": total_chars,
|
||
"sample_chars": len(sampled_text),
|
||
},
|
||
)
|
||
logger.info(
|
||
"Personality extraction completed for creator_id=%s (%s) in %.1fs — "
|
||
"%d moments, %d chars sampled",
|
||
creator_id, creator.name, elapsed, len(moments), len(sampled_text),
|
||
)
|
||
return creator_id
|
||
|
||
except Exception as exc:
|
||
if isinstance(exc, (self.MaxRetriesExceededError,)):
|
||
raise
|
||
session.rollback()
|
||
_emit_event(
|
||
creator_id, "personality_extraction", "error",
|
||
payload={"error": str(exc)[:500]},
|
||
)
|
||
logger.error(
|
||
"Personality extraction failed for creator_id=%s: %s", creator_id, exc,
|
||
)
|
||
raise self.retry(exc=exc)
|
||
finally:
|
||
session.close()
|
||
|
||
|
||
# ── Stage: Shorts Generation ─────────────────────────────────────────────────
|
||
|
||
@celery_app.task(bind=True, max_retries=1, default_retry_delay=60)
|
||
def stage_generate_shorts(self, highlight_candidate_id: str) -> str:
|
||
"""Generate video shorts for an approved highlight candidate.
|
||
|
||
Creates one GeneratedShort row per FormatPreset, extracts the clip via
|
||
ffmpeg, uploads to MinIO, and updates status. Each preset is independent —
|
||
a failure on one does not block the others.
|
||
|
||
Returns the highlight_candidate_id on completion.
|
||
"""
|
||
from pipeline.shorts_generator import PRESETS, extract_clip, resolve_video_path
|
||
from models import FormatPreset, GeneratedShort, ShortStatus
|
||
|
||
start = time.monotonic()
|
||
session = _get_sync_session()
|
||
settings = get_settings()
|
||
|
||
try:
|
||
# ── Load highlight with joined relations ────────────────────────
|
||
highlight = session.execute(
|
||
select(HighlightCandidate)
|
||
.where(HighlightCandidate.id == highlight_candidate_id)
|
||
).scalar_one_or_none()
|
||
|
||
if highlight is None:
|
||
logger.error(
|
||
"Highlight candidate not found: %s", highlight_candidate_id,
|
||
)
|
||
return highlight_candidate_id
|
||
|
||
if highlight.status.value != "approved":
|
||
logger.warning(
|
||
"Highlight %s status is %s, expected approved — skipping",
|
||
highlight_candidate_id, highlight.status.value,
|
||
)
|
||
return highlight_candidate_id
|
||
|
||
# Check for already-processing shorts (reject duplicate runs)
|
||
existing_processing = session.execute(
|
||
select(func.count())
|
||
.where(GeneratedShort.highlight_candidate_id == highlight_candidate_id)
|
||
.where(GeneratedShort.status == ShortStatus.processing)
|
||
).scalar()
|
||
if existing_processing and existing_processing > 0:
|
||
logger.warning(
|
||
"Highlight %s already has %d processing shorts — rejecting duplicate",
|
||
highlight_candidate_id, existing_processing,
|
||
)
|
||
return highlight_candidate_id
|
||
|
||
# Eager-load relations
|
||
key_moment = highlight.key_moment
|
||
source_video = highlight.source_video
|
||
|
||
# ── Resolve video file path ─────────────────────────────────────
|
||
try:
|
||
video_path = resolve_video_path(
|
||
settings.video_source_path, source_video.file_path,
|
||
)
|
||
except FileNotFoundError as fnf:
|
||
logger.error(
|
||
"Video file missing for highlight %s: %s",
|
||
highlight_candidate_id, fnf,
|
||
)
|
||
# Mark all presets as failed
|
||
for preset in FormatPreset:
|
||
spec = PRESETS[preset]
|
||
short = GeneratedShort(
|
||
highlight_candidate_id=highlight_candidate_id,
|
||
format_preset=preset,
|
||
width=spec.width,
|
||
height=spec.height,
|
||
status=ShortStatus.failed,
|
||
error_message=str(fnf),
|
||
)
|
||
session.add(short)
|
||
session.commit()
|
||
return highlight_candidate_id
|
||
|
||
# ── Compute effective start/end (trim overrides) ────────────────
|
||
clip_start = highlight.trim_start if highlight.trim_start is not None else key_moment.start_time
|
||
clip_end = highlight.trim_end if highlight.trim_end is not None else key_moment.end_time
|
||
|
||
logger.info(
|
||
"Generating shorts for highlight=%s video=%s [%.1f–%.1f]s",
|
||
highlight_candidate_id, source_video.file_path,
|
||
clip_start, clip_end,
|
||
)
|
||
|
||
# ── Process each preset independently ───────────────────────────
|
||
for preset in FormatPreset:
|
||
spec = PRESETS[preset]
|
||
preset_start = time.monotonic()
|
||
|
||
# Create DB row (status=processing)
|
||
short = GeneratedShort(
|
||
highlight_candidate_id=highlight_candidate_id,
|
||
format_preset=preset,
|
||
width=spec.width,
|
||
height=spec.height,
|
||
status=ShortStatus.processing,
|
||
duration_secs=clip_end - clip_start,
|
||
)
|
||
session.add(short)
|
||
session.commit()
|
||
session.refresh(short)
|
||
|
||
tmp_path = Path(f"/tmp/short_{short.id}_{preset.value}.mp4")
|
||
minio_key = f"shorts/{highlight_candidate_id}/{preset.value}.mp4"
|
||
|
||
try:
|
||
# Extract clip
|
||
extract_clip(
|
||
input_path=video_path,
|
||
output_path=tmp_path,
|
||
start_secs=clip_start,
|
||
end_secs=clip_end,
|
||
vf_filter=spec.vf_filter,
|
||
)
|
||
|
||
# Upload to MinIO
|
||
file_size = tmp_path.stat().st_size
|
||
with open(tmp_path, "rb") as f:
|
||
from minio_client import upload_file
|
||
upload_file(
|
||
object_key=minio_key,
|
||
data=f,
|
||
length=file_size,
|
||
content_type="video/mp4",
|
||
)
|
||
|
||
# Update DB row — complete
|
||
short.status = ShortStatus.complete
|
||
short.file_size_bytes = file_size
|
||
short.minio_object_key = minio_key
|
||
session.commit()
|
||
|
||
elapsed_preset = time.monotonic() - preset_start
|
||
logger.info(
|
||
"Short generated: highlight=%s preset=%s "
|
||
"size=%d bytes duration=%.1fs elapsed=%.1fs",
|
||
highlight_candidate_id, preset.value,
|
||
file_size, clip_end - clip_start, elapsed_preset,
|
||
)
|
||
|
||
except Exception as exc:
|
||
session.rollback()
|
||
# Re-fetch the short row after rollback
|
||
session.refresh(short)
|
||
short.status = ShortStatus.failed
|
||
short.error_message = str(exc)[:2000]
|
||
session.commit()
|
||
|
||
elapsed_preset = time.monotonic() - preset_start
|
||
logger.error(
|
||
"Short failed: highlight=%s preset=%s "
|
||
"error=%s elapsed=%.1fs",
|
||
highlight_candidate_id, preset.value,
|
||
str(exc)[:500], elapsed_preset,
|
||
)
|
||
|
||
finally:
|
||
# Clean up temp file
|
||
if tmp_path.exists():
|
||
try:
|
||
tmp_path.unlink()
|
||
except OSError:
|
||
pass
|
||
|
||
elapsed = time.monotonic() - start
|
||
logger.info(
|
||
"Shorts generation complete for highlight=%s in %.1fs",
|
||
highlight_candidate_id, elapsed,
|
||
)
|
||
return highlight_candidate_id
|
||
|
||
except Exception as exc:
|
||
session.rollback()
|
||
logger.error(
|
||
"Shorts generation failed for highlight=%s: %s",
|
||
highlight_candidate_id, exc,
|
||
)
|
||
raise self.retry(exc=exc)
|
||
finally:
|
||
session.close()
|