"""Pipeline stage tasks (stages 2-5) and run_pipeline orchestrator. Each stage reads from PostgreSQL via sync SQLAlchemy, loads its prompt template from disk, calls the LLM client, parses the response, writes results back, and updates processing_status on SourceVideo. Celery tasks are synchronous — all DB access uses ``sqlalchemy.orm.Session``. """ from __future__ import annotations import hashlib import json import logging import os import re import secrets import subprocess import time from collections import defaultdict from pathlib import Path import yaml from pydantic import ValidationError from sqlalchemy import create_engine, func, select from sqlalchemy.orm import Session, sessionmaker from config import get_settings from sqlalchemy.dialects.postgresql import insert as pg_insert from models import ( Creator, HighlightCandidate, KeyMoment, KeyMomentContentType, PipelineEvent, ProcessingStatus, SourceVideo, TechniquePage, TechniquePageVersion, TechniquePageVideo, TranscriptSegment, ) from pipeline.embedding_client import EmbeddingClient from pipeline.llm_client import LLMClient, LLMResponse, estimate_max_tokens from pipeline.qdrant_client import QdrantManager from pipeline.schemas import ( ClassificationResult, ExtractionResult, SegmentationResult, SynthesisResult, ) from worker import celery_app logger = logging.getLogger(__name__) class LLMTruncationError(RuntimeError): """Raised when the LLM response was truncated (finish_reason=length).""" pass # ── Error status helper ────────────────────────────────────────────────────── def _set_error_status(video_id: str, stage_name: str, error: Exception) -> None: """Mark a video as errored when a pipeline stage fails permanently.""" try: session = _get_sync_session() video = session.execute( select(SourceVideo).where(SourceVideo.id == video_id) ).scalar_one_or_none() if video: video.processing_status = ProcessingStatus.error session.commit() session.close() except Exception as mark_exc: logger.error( "Failed to mark video_id=%s as error after %s failure: %s", video_id, stage_name, mark_exc, ) # ── Pipeline event persistence ─────────────────────────────────────────────── def _emit_event( video_id: str, stage: str, event_type: str, *, run_id: str | None = None, prompt_tokens: int | None = None, completion_tokens: int | None = None, total_tokens: int | None = None, model: str | None = None, duration_ms: int | None = None, payload: dict | None = None, system_prompt_text: str | None = None, user_prompt_text: str | None = None, response_text: str | None = None, ) -> None: """Persist a pipeline event to the DB. Best-effort -- failures logged, not raised.""" try: session = _get_sync_session() try: event = PipelineEvent( video_id=video_id, run_id=run_id, stage=stage, event_type=event_type, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, model=model, duration_ms=duration_ms, payload=payload, system_prompt_text=system_prompt_text, user_prompt_text=user_prompt_text, response_text=response_text, ) session.add(event) session.commit() finally: session.close() except Exception as exc: logger.warning("Failed to emit pipeline event: %s", exc) def _is_debug_mode() -> bool: """Check if debug mode is enabled via Redis. Falls back to config setting.""" try: import redis settings = get_settings() r = redis.from_url(settings.redis_url) val = r.get("chrysopedia:debug_mode") r.close() if val is not None: return val.decode().lower() == "true" except Exception: pass return getattr(get_settings(), "debug_mode", False) def _make_llm_callback( video_id: str, stage: str, system_prompt: str | None = None, user_prompt: str | None = None, run_id: str | None = None, context_label: str | None = None, request_params: dict | None = None, ): """Create an on_complete callback for LLMClient that emits llm_call events. When debug mode is enabled, captures full system prompt, user prompt, and response text on each llm_call event. Parameters ---------- request_params: Dict of LLM request parameters (max_tokens, model_override, modality, response_model, temperature, etc.) to store in the event payload for debugging which parameters were actually sent to the API. """ debug = _is_debug_mode() def callback(*, model=None, prompt_tokens=None, completion_tokens=None, total_tokens=None, content=None, finish_reason=None, is_fallback=False, **_kwargs): # Truncate content for storage — keep first 2000 chars for debugging truncated = content[:2000] if content and len(content) > 2000 else content _emit_event( video_id=video_id, stage=stage, event_type="llm_call", run_id=run_id, model=model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, payload={ "content_preview": truncated, "content_length": len(content) if content else 0, "finish_reason": finish_reason, "is_fallback": is_fallback, **({"context": context_label} if context_label else {}), **({"request_params": request_params} if request_params else {}), }, system_prompt_text=system_prompt if debug else None, user_prompt_text=user_prompt if debug else None, response_text=content if debug else None, ) return callback def _build_request_params( max_tokens: int, model_override: str | None, modality: str, response_model: str, hard_limit: int, ) -> dict: """Build the request_params dict for pipeline event logging. Separates actual API params (sent to the LLM) from internal config (used by our estimator only) so the debug JSON is unambiguous. """ settings = get_settings() return { "api_params": { "max_tokens": max_tokens, "model": model_override or settings.llm_model, "temperature": settings.llm_temperature, "response_format": "json_object" if modality == "chat" else "none (thinking mode)", }, "pipeline_config": { "modality": modality, "response_model": response_model, "estimator_hard_limit": hard_limit, "fallback_max_tokens": settings.llm_max_tokens, }, } # ── Helpers ────────────────────────────────────────────────────────────────── _engine = None _SessionLocal = None def _get_sync_engine(): """Create a sync SQLAlchemy engine, converting the async URL if needed.""" global _engine if _engine is None: settings = get_settings() url = settings.database_url # Convert async driver to sync driver url = url.replace("postgresql+asyncpg://", "postgresql+psycopg2://") _engine = create_engine(url, pool_pre_ping=True, pool_size=5, max_overflow=10) return _engine def _get_sync_session() -> Session: """Create a sync SQLAlchemy session for Celery tasks.""" global _SessionLocal if _SessionLocal is None: _SessionLocal = sessionmaker(bind=_get_sync_engine()) return _SessionLocal() def _load_prompt(template_name: str, video_id: str | None = None) -> str: """Read a prompt template from the prompts directory. If ``video_id`` is provided, checks Redis for a per-video prompt override (key: ``chrysopedia:prompt_override:{video_id}:{template_name}``) before falling back to the on-disk template. Overrides are set by ``run_single_stage`` for single-stage re-runs with custom prompts. Raises FileNotFoundError if no override exists and the template is missing. """ # Check for per-video prompt override in Redis if video_id: try: import redis settings = get_settings() r = redis.Redis.from_url(settings.redis_url) override_key = f"chrysopedia:prompt_override:{video_id}:{template_name}" override = r.get(override_key) if override: prompt_text = override.decode("utf-8") logger.info( "[PROMPT] Using override from Redis: video_id=%s, template=%s (%d chars)", video_id, template_name, len(prompt_text), ) return prompt_text except Exception as exc: logger.warning("[PROMPT] Redis override check failed: %s", exc) settings = get_settings() path = Path(settings.prompts_path) / template_name if not path.exists(): logger.error("Prompt template not found: %s", path) raise FileNotFoundError(f"Prompt template not found: {path}") return path.read_text(encoding="utf-8") def _get_llm_client() -> LLMClient: """Return an LLMClient configured from settings.""" return LLMClient(get_settings()) def _get_stage_config(stage_num: int) -> tuple[str | None, str]: """Return (model_override, modality) for a pipeline stage. Reads stage-specific config from Settings. If the stage-specific model is None/empty, returns None (LLMClient will use its default). If the stage-specific modality is unset, defaults to "chat". """ settings = get_settings() model = getattr(settings, f"llm_stage{stage_num}_model", None) or None modality = getattr(settings, f"llm_stage{stage_num}_modality", None) or "chat" return model, modality def _load_canonical_tags() -> dict: """Load canonical tag taxonomy from config/canonical_tags.yaml.""" # Walk up from backend/ to find config/ candidates = [ Path("config/canonical_tags.yaml"), Path("../config/canonical_tags.yaml"), ] for candidate in candidates: if candidate.exists(): with open(candidate, encoding="utf-8") as f: return yaml.safe_load(f) raise FileNotFoundError( "canonical_tags.yaml not found. Searched: " + ", ".join(str(c) for c in candidates) ) def _format_taxonomy_for_prompt(tags_data: dict) -> str: """Format the canonical tags taxonomy as readable text for the LLM prompt.""" lines = [] for cat in tags_data.get("categories", []): lines.append(f"Category: {cat['name']}") lines.append(f" Description: {cat['description']}") lines.append(f" Sub-topics: {', '.join(cat.get('sub_topics', []))}") lines.append("") return "\n".join(lines) def _safe_parse_llm_response( raw, model_cls, llm: LLMClient, system_prompt: str, user_prompt: str, modality: str = "chat", model_override: str | None = None, max_tokens: int | None = None, ): """Parse LLM response with truncation detection and one retry on failure. If the response was truncated (finish_reason=length), raises LLMTruncationError immediately — retrying with a JSON nudge would only make things worse by adding tokens to an already-too-large prompt. For non-truncation parse failures: retry once with a JSON nudge, then raise on second failure. """ # Check for truncation before attempting parse is_truncated = isinstance(raw, LLMResponse) and raw.truncated if is_truncated: logger.warning( "LLM response truncated (finish=length) for %s. " "prompt_tokens=%s, completion_tokens=%s. Will not retry with nudge.", model_cls.__name__, getattr(raw, "prompt_tokens", "?"), getattr(raw, "completion_tokens", "?"), ) try: return llm.parse_response(raw, model_cls) except (ValidationError, ValueError, json.JSONDecodeError) as exc: if is_truncated: raise LLMTruncationError( f"LLM output truncated for {model_cls.__name__}: " f"prompt_tokens={getattr(raw, 'prompt_tokens', '?')}, " f"completion_tokens={getattr(raw, 'completion_tokens', '?')}. " f"Response too large for model context window." ) from exc logger.warning( "First parse attempt failed for %s (%s). Retrying with JSON nudge. " "Raw response (first 500 chars): %.500s", model_cls.__name__, type(exc).__name__, raw, ) # Retry with explicit JSON instruction nudge_prompt = user_prompt + "\n\nIMPORTANT: Output ONLY valid JSON. No markdown, no explanation." retry_raw = llm.complete( system_prompt, nudge_prompt, response_model=model_cls, modality=modality, model_override=model_override, max_tokens=max_tokens, ) return llm.parse_response(retry_raw, model_cls) # ── Stage 2: Segmentation ─────────────────────────────────────────────────── @celery_app.task(bind=True, max_retries=3, default_retry_delay=30) def stage2_segmentation(self, video_id: str, run_id: str | None = None) -> str: """Analyze transcript segments and identify topic boundaries. Loads all TranscriptSegment rows for the video, sends them to the LLM for topic boundary detection, and updates topic_label on each segment. Returns the video_id for chain compatibility. """ start = time.monotonic() logger.info("Stage 2 (segmentation) starting for video_id=%s", video_id) _emit_event(video_id, "stage2_segmentation", "start", run_id=run_id) session = _get_sync_session() try: # Load segments ordered by index segments = ( session.execute( select(TranscriptSegment) .where(TranscriptSegment.source_video_id == video_id) .order_by(TranscriptSegment.segment_index) ) .scalars() .all() ) if not segments: logger.info("Stage 2: No segments found for video_id=%s, skipping.", video_id) return video_id # Build transcript text with indices for the LLM transcript_lines = [] for seg in segments: transcript_lines.append( f"[{seg.segment_index}] ({seg.start_time:.1f}s - {seg.end_time:.1f}s) {seg.text}" ) transcript_text = "\n".join(transcript_lines) # Load prompt and call LLM system_prompt = _load_prompt("stage2_segmentation.txt", video_id=video_id) user_prompt = f"\n{transcript_text}\n" llm = _get_llm_client() model_override, modality = _get_stage_config(2) hard_limit = get_settings().llm_max_tokens_hard_limit max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage2_segmentation", hard_limit=hard_limit) logger.info("Stage 2 using model=%s, modality=%s, max_tokens=%d", model_override or "default", modality, max_tokens) _s2_request_params = _build_request_params(max_tokens, model_override, modality, "SegmentationResult", hard_limit) raw = llm.complete(system_prompt, user_prompt, response_model=SegmentationResult, on_complete=_make_llm_callback(video_id, "stage2_segmentation", system_prompt=system_prompt, user_prompt=user_prompt, run_id=run_id, request_params=_s2_request_params), modality=modality, model_override=model_override, max_tokens=max_tokens) result = _safe_parse_llm_response(raw, SegmentationResult, llm, system_prompt, user_prompt, modality=modality, model_override=model_override, max_tokens=max_tokens) # Update topic_label on each segment row seg_by_index = {s.segment_index: s for s in segments} for topic_seg in result.segments: for idx in range(topic_seg.start_index, topic_seg.end_index + 1): if idx in seg_by_index: seg_by_index[idx].topic_label = topic_seg.topic_label session.commit() elapsed = time.monotonic() - start _emit_event(video_id, "stage2_segmentation", "complete", run_id=run_id) logger.info( "Stage 2 (segmentation) completed for video_id=%s in %.1fs — %d topic groups found", video_id, elapsed, len(result.segments), ) return video_id except FileNotFoundError: raise # Don't retry missing prompt files except Exception as exc: session.rollback() _emit_event(video_id, "stage2_segmentation", "error", run_id=run_id, payload={"error": str(exc)}) logger.error("Stage 2 failed for video_id=%s: %s", video_id, exc) raise self.retry(exc=exc) finally: session.close() # ── Stage 3: Extraction ───────────────────────────────────────────────────── @celery_app.task(bind=True, max_retries=3, default_retry_delay=30) def stage3_extraction(self, video_id: str, run_id: str | None = None) -> str: """Extract key moments from each topic segment group. Groups segments by topic_label, calls the LLM for each group to extract moments, creates KeyMoment rows, and sets processing_status=extracted. Returns the video_id for chain compatibility. """ start = time.monotonic() logger.info("Stage 3 (extraction) starting for video_id=%s", video_id) _emit_event(video_id, "stage3_extraction", "start", run_id=run_id) session = _get_sync_session() try: # Load segments with topic labels segments = ( session.execute( select(TranscriptSegment) .where(TranscriptSegment.source_video_id == video_id) .order_by(TranscriptSegment.segment_index) ) .scalars() .all() ) if not segments: logger.info("Stage 3: No segments found for video_id=%s, skipping.", video_id) return video_id # Group segments by topic_label groups: dict[str, list[TranscriptSegment]] = defaultdict(list) for seg in segments: label = seg.topic_label or "unlabeled" groups[label].append(seg) system_prompt = _load_prompt("stage3_extraction.txt", video_id=video_id) llm = _get_llm_client() model_override, modality = _get_stage_config(3) hard_limit = get_settings().llm_max_tokens_hard_limit logger.info("Stage 3 using model=%s, modality=%s", model_override or "default", modality) total_moments = 0 for topic_label, group_segs in groups.items(): # Build segment text for this group seg_lines = [] for seg in group_segs: seg_lines.append( f"({seg.start_time:.1f}s - {seg.end_time:.1f}s) {seg.text}" ) segment_text = "\n".join(seg_lines) user_prompt = ( f"Topic: {topic_label}\n\n" f"\n{segment_text}\n" ) max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage3_extraction", hard_limit=hard_limit) _s3_request_params = _build_request_params(max_tokens, model_override, modality, "ExtractionResult", hard_limit) raw = llm.complete(system_prompt, user_prompt, response_model=ExtractionResult, on_complete=_make_llm_callback(video_id, "stage3_extraction", system_prompt=system_prompt, user_prompt=user_prompt, run_id=run_id, context_label=topic_label, request_params=_s3_request_params), modality=modality, model_override=model_override, max_tokens=max_tokens) result = _safe_parse_llm_response(raw, ExtractionResult, llm, system_prompt, user_prompt, modality=modality, model_override=model_override, max_tokens=max_tokens) # Create KeyMoment rows for moment in result.moments: # Validate content_type against enum try: ct = KeyMomentContentType(moment.content_type) except ValueError: ct = KeyMomentContentType.technique km = KeyMoment( source_video_id=video_id, title=moment.title, summary=moment.summary, start_time=moment.start_time, end_time=moment.end_time, content_type=ct, plugins=moment.plugins if moment.plugins else None, raw_transcript=moment.raw_transcript or None, ) session.add(km) total_moments += 1 session.commit() elapsed = time.monotonic() - start _emit_event(video_id, "stage3_extraction", "complete", run_id=run_id) logger.info( "Stage 3 (extraction) completed for video_id=%s in %.1fs — %d moments created", video_id, elapsed, total_moments, ) return video_id except FileNotFoundError: raise except Exception as exc: session.rollback() _emit_event(video_id, "stage3_extraction", "error", run_id=run_id, payload={"error": str(exc)}) logger.error("Stage 3 failed for video_id=%s: %s", video_id, exc) raise self.retry(exc=exc) finally: session.close() # ── Stage 4: Classification ───────────────────────────────────────────────── # Maximum moments per classification batch. Keeps each LLM call well within # context window limits. Batches are classified independently and merged. _STAGE4_BATCH_SIZE = 20 def _classify_moment_batch( moments_batch: list, batch_offset: int, taxonomy_text: str, system_prompt: str, llm: LLMClient, model_override: str | None, modality: str, hard_limit: int, video_id: str, run_id: str | None, ) -> ClassificationResult: """Classify a single batch of moments. Raises on failure.""" moments_lines = [] for i, m in enumerate(moments_batch): moments_lines.append( f"[{i}] Title: {m.title}\n" f" Summary: {m.summary}\n" f" Content type: {m.content_type.value}\n" f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}" ) moments_text = "\n\n".join(moments_lines) user_prompt = ( f"\n{taxonomy_text}\n\n\n" f"\n{moments_text}\n" ) max_tokens = estimate_max_tokens( system_prompt, user_prompt, stage="stage4_classification", hard_limit=hard_limit, ) batch_label = f"batch {batch_offset // _STAGE4_BATCH_SIZE + 1} (moments {batch_offset}-{batch_offset + len(moments_batch) - 1})" logger.info( "Stage 4 classifying %s, max_tokens=%d", batch_label, max_tokens, ) raw = llm.complete( system_prompt, user_prompt, response_model=ClassificationResult, on_complete=_make_llm_callback( video_id, "stage4_classification", system_prompt=system_prompt, user_prompt=user_prompt, run_id=run_id, context_label=batch_label, request_params=_build_request_params(max_tokens, model_override, modality, "ClassificationResult", hard_limit), ), modality=modality, model_override=model_override, max_tokens=max_tokens, ) return _safe_parse_llm_response( raw, ClassificationResult, llm, system_prompt, user_prompt, modality=modality, model_override=model_override, max_tokens=max_tokens, ) @celery_app.task(bind=True, max_retries=3, default_retry_delay=30) def stage4_classification(self, video_id: str, run_id: str | None = None) -> str: """Classify key moments against the canonical tag taxonomy. Loads all KeyMoment rows for the video, sends them to the LLM with the canonical taxonomy, and stores classification results in Redis for stage 5 consumption. Updates content_type if the classifier overrides it. For large moment sets, automatically batches into groups of _STAGE4_BATCH_SIZE to stay within model context window limits. Stage 4 does NOT change processing_status. Returns the video_id for chain compatibility. """ start = time.monotonic() logger.info("Stage 4 (classification) starting for video_id=%s", video_id) _emit_event(video_id, "stage4_classification", "start", run_id=run_id) session = _get_sync_session() try: # Load key moments moments = ( session.execute( select(KeyMoment) .where(KeyMoment.source_video_id == video_id) .order_by(KeyMoment.start_time) ) .scalars() .all() ) if not moments: logger.info("Stage 4: No moments found for video_id=%s, skipping.", video_id) _store_classification_data(video_id, []) return video_id # Load canonical tags tags_data = _load_canonical_tags() taxonomy_text = _format_taxonomy_for_prompt(tags_data) system_prompt = _load_prompt("stage4_classification.txt", video_id=video_id) llm = _get_llm_client() model_override, modality = _get_stage_config(4) hard_limit = get_settings().llm_max_tokens_hard_limit # Batch moments for classification all_classifications = [] for batch_start in range(0, len(moments), _STAGE4_BATCH_SIZE): batch = moments[batch_start:batch_start + _STAGE4_BATCH_SIZE] result = _classify_moment_batch( batch, batch_start, taxonomy_text, system_prompt, llm, model_override, modality, hard_limit, video_id, run_id, ) # Reindex: batch uses 0-based indices, remap to global indices for cls in result.classifications: cls.moment_index += batch_start all_classifications.extend(result.classifications) # Apply content_type overrides and prepare classification data for stage 5 classification_data = [] for cls in all_classifications: if 0 <= cls.moment_index < len(moments): moment = moments[cls.moment_index] # Apply content_type override if provided if cls.content_type_override: try: moment.content_type = KeyMomentContentType(cls.content_type_override) except ValueError: pass classification_data.append({ "moment_id": str(moment.id), "topic_category": cls.topic_category.strip().title(), "topic_tags": cls.topic_tags, }) session.commit() # Store classification data in Redis for stage 5 _store_classification_data(video_id, classification_data) elapsed = time.monotonic() - start num_batches = (len(moments) + _STAGE4_BATCH_SIZE - 1) // _STAGE4_BATCH_SIZE _emit_event(video_id, "stage4_classification", "complete", run_id=run_id) logger.info( "Stage 4 (classification) completed for video_id=%s in %.1fs — " "%d moments classified in %d batch(es)", video_id, elapsed, len(classification_data), num_batches, ) return video_id except FileNotFoundError: raise except Exception as exc: session.rollback() _emit_event(video_id, "stage4_classification", "error", run_id=run_id, payload={"error": str(exc)}) logger.error("Stage 4 failed for video_id=%s: %s", video_id, exc) raise self.retry(exc=exc) finally: session.close() def _store_classification_data(video_id: str, data: list[dict]) -> None: """Store classification data in Redis (cache) and PostgreSQL (durable). Dual-write ensures classification data survives Redis TTL expiry or flush. Redis serves as the fast-path cache; PostgreSQL is the durable fallback. """ import redis settings = get_settings() # Redis: fast cache with 7-day TTL try: r = redis.Redis.from_url(settings.redis_url) key = f"chrysopedia:classification:{video_id}" r.set(key, json.dumps(data), ex=604800) # 7 days logger.info( "[CLASSIFY-STORE] Redis write: video_id=%s, %d entries, ttl=7d", video_id, len(data), ) except Exception as exc: logger.warning( "[CLASSIFY-STORE] Redis write failed for video_id=%s: %s", video_id, exc, ) # PostgreSQL: durable storage on SourceVideo.classification_data session = _get_sync_session() try: video = session.execute( select(SourceVideo).where(SourceVideo.id == video_id) ).scalar_one_or_none() if video: video.classification_data = data session.commit() logger.info( "[CLASSIFY-STORE] PostgreSQL write: video_id=%s, %d entries", video_id, len(data), ) else: logger.warning( "[CLASSIFY-STORE] Video not found for PostgreSQL write: %s", video_id, ) except Exception as exc: session.rollback() logger.warning( "[CLASSIFY-STORE] PostgreSQL write failed for video_id=%s: %s", video_id, exc, ) finally: session.close() def _load_classification_data(video_id: str) -> list[dict]: """Load classification data from Redis (fast path) or PostgreSQL (fallback). Tries Redis first. If the key has expired or Redis is unavailable, falls back to the durable SourceVideo.classification_data column. """ import redis settings = get_settings() # Try Redis first (fast path) try: r = redis.Redis.from_url(settings.redis_url) key = f"chrysopedia:classification:{video_id}" raw = r.get(key) if raw is not None: data = json.loads(raw) logger.info( "[CLASSIFY-LOAD] Source: redis, video_id=%s, %d entries", video_id, len(data), ) return data except Exception as exc: logger.warning( "[CLASSIFY-LOAD] Redis unavailable for video_id=%s: %s", video_id, exc, ) # Fallback to PostgreSQL logger.info("[CLASSIFY-LOAD] Redis miss, falling back to PostgreSQL for video_id=%s", video_id) session = _get_sync_session() try: video = session.execute( select(SourceVideo).where(SourceVideo.id == video_id) ).scalar_one_or_none() if video and video.classification_data: data = video.classification_data logger.info( "[CLASSIFY-LOAD] Source: postgresql, video_id=%s, %d entries", video_id, len(data), ) return data except Exception as exc: logger.warning( "[CLASSIFY-LOAD] PostgreSQL fallback failed for video_id=%s: %s", video_id, exc, ) finally: session.close() logger.warning( "[CLASSIFY-LOAD] No classification data found in Redis or PostgreSQL for video_id=%s", video_id, ) return [] def _get_git_commit_sha() -> str: """Resolve the git commit SHA used to build this image. Resolution order: 1. /app/.git-commit file (written during Docker build) 2. git rev-parse --short HEAD (local dev) 3. GIT_COMMIT_SHA env var / config setting 4. "unknown" """ # Docker build artifact git_commit_file = Path("/app/.git-commit") if git_commit_file.exists(): sha = git_commit_file.read_text(encoding="utf-8").strip() if sha and sha != "unknown": return sha # Local dev — run git try: result = subprocess.run( ["git", "rev-parse", "--short", "HEAD"], capture_output=True, text=True, timeout=5, ) if result.returncode == 0 and result.stdout.strip(): return result.stdout.strip() except (FileNotFoundError, subprocess.TimeoutExpired): pass # Config / env var fallback try: sha = get_settings().git_commit_sha if sha and sha != "unknown": return sha except Exception: pass return "unknown" def _capture_pipeline_metadata() -> dict: """Capture current pipeline configuration for version metadata. Returns a dict with model names, prompt file SHA-256 hashes, and stage modality settings. Handles missing prompt files gracefully. """ settings = get_settings() prompts_path = Path(settings.prompts_path) # Hash each prompt template file prompt_hashes: dict[str, str] = {} prompt_files = [ "stage2_segmentation.txt", "stage3_extraction.txt", "stage4_classification.txt", "stage5_synthesis.txt", ] for filename in prompt_files: filepath = prompts_path / filename try: content = filepath.read_bytes() prompt_hashes[filename] = hashlib.sha256(content).hexdigest() except FileNotFoundError: logger.warning("Prompt file not found for metadata capture: %s", filepath) prompt_hashes[filename] = "" except OSError as exc: logger.warning("Could not read prompt file %s: %s", filepath, exc) prompt_hashes[filename] = "" return { "git_commit_sha": _get_git_commit_sha(), "models": { "stage2": settings.llm_stage2_model, "stage3": settings.llm_stage3_model, "stage4": settings.llm_stage4_model, "stage5": settings.llm_stage5_model, "embedding": settings.embedding_model, }, "modalities": { "stage2": settings.llm_stage2_modality, "stage3": settings.llm_stage3_modality, "stage4": settings.llm_stage4_modality, "stage5": settings.llm_stage5_modality, }, "prompt_hashes": prompt_hashes, } # ── Stage 5: Synthesis ─────────────────────────────────────────────────────── def _serialize_body_sections(sections) -> list | dict | None: """Convert body_sections to JSON-serializable form for DB storage.""" if isinstance(sections, list): return [s.model_dump() if hasattr(s, 'model_dump') else s for s in sections] return sections def _compute_page_tags( moment_indices: list[int], moment_group: list[tuple], all_tags: set[str], ) -> list[str] | None: """Compute tags for a specific page from its linked moment indices. If moment_indices are available, collects tags only from those moments. Falls back to all_tags for the category group if no indices provided. """ if not moment_indices: return list(all_tags) if all_tags else None page_tags: set[str] = set() for idx in moment_indices: if 0 <= idx < len(moment_group): _, cls_info = moment_group[idx] page_tags.update(cls_info.get("topic_tags", [])) return list(page_tags) if page_tags else None def _build_moments_text( moment_group: list[tuple[KeyMoment, dict]], category: str, ) -> tuple[str, set[str]]: """Build the moments prompt text and collect all tags for a group of moments. Returns (moments_text, all_tags). """ moments_lines = [] all_tags: set[str] = set() for i, (m, cls_info) in enumerate(moment_group): tags = cls_info.get("topic_tags", []) all_tags.update(tags) moments_lines.append( f"[{i}] Title: {m.title}\n" f" Summary: {m.summary}\n" f" Content type: {m.content_type.value}\n" f" Time: {m.start_time:.1f}s - {m.end_time:.1f}s\n" f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}\n" f" Category: {category}\n" f" Tags: {', '.join(tags) if tags else 'none'}\n" f" Transcript excerpt: {(m.raw_transcript or '')[:300]}" ) return "\n\n".join(moments_lines), all_tags def _build_compose_user_prompt( existing_page: TechniquePage, existing_moments: list[KeyMoment], new_moments: list[tuple[KeyMoment, dict]], creator_name: str, ) -> str: """Build the user prompt for composing new moments into an existing page. Existing moments keep indices [0]-[N-1]. New moments get indices [N]-[N+M-1]. XML-tagged prompt structure matches test_harness.py build_compose_prompt(). """ category = existing_page.topic_category or "Uncategorized" # Serialize existing page to dict matching SynthesizedPage shape sq = existing_page.source_quality sq_value = sq.value if hasattr(sq, "value") else sq page_dict = { "title": existing_page.title, "slug": existing_page.slug, "topic_category": existing_page.topic_category, "summary": existing_page.summary, "body_sections": existing_page.body_sections, "signal_chains": existing_page.signal_chains, "plugins": existing_page.plugins, "source_quality": sq_value, } # Format existing moments [0]-[N-1] using _build_moments_text pattern # Existing moments don't have classification data — use empty dict existing_as_tuples = [(m, {}) for m in existing_moments] existing_text, _ = _build_moments_text(existing_as_tuples, category) # Format new moments [N]-[N+M-1] with offset indices n = len(existing_moments) new_lines = [] for i, (m, cls_info) in enumerate(new_moments): tags = cls_info.get("topic_tags", []) new_lines.append( f"[{n + i}] Title: {m.title}\n" f" Summary: {m.summary}\n" f" Content type: {m.content_type.value}\n" f" Time: {m.start_time:.1f}s - {m.end_time:.1f}s\n" f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}\n" f" Category: {category}\n" f" Tags: {', '.join(tags) if tags else 'none'}\n" f" Transcript excerpt: {(m.raw_transcript or '')[:300]}" ) new_text = "\n\n".join(new_lines) page_json = json.dumps(page_dict, indent=2, ensure_ascii=False, default=str) return ( f"\n{page_json}\n\n" f"\n{existing_text}\n\n" f"\n{new_text}\n\n" f"{creator_name}" ) def _compose_into_existing( existing_page: TechniquePage, existing_moments: list[KeyMoment], new_moment_group: list[tuple[KeyMoment, dict]], category: str, creator_name: str, system_prompt: str, llm: LLMClient, model_override: str | None, modality: str, hard_limit: int, video_id: str, run_id: str | None, ) -> SynthesisResult: """Compose new moments into an existing technique page via LLM. Loads the compose system prompt, builds the compose user prompt, and calls the LLM with the same retry/parse pattern as _synthesize_chunk(). """ compose_prompt = _load_prompt("stage5_compose.txt", video_id=video_id) user_prompt = _build_compose_user_prompt( existing_page, existing_moments, new_moment_group, creator_name, ) estimated_input = estimate_max_tokens( compose_prompt, user_prompt, stage="stage5_synthesis", hard_limit=hard_limit, ) logger.info( "Stage 5: Composing into '%s' — %d existing + %d new moments, max_tokens=%d", existing_page.slug, len(existing_moments), len(new_moment_group), estimated_input, ) raw = llm.complete( compose_prompt, user_prompt, response_model=SynthesisResult, on_complete=_make_llm_callback( video_id, "stage5_synthesis", system_prompt=compose_prompt, user_prompt=user_prompt, run_id=run_id, context_label=f"compose:{category}", request_params=_build_request_params( estimated_input, model_override, modality, "SynthesisResult", hard_limit, ), ), modality=modality, model_override=model_override, max_tokens=estimated_input, ) return _safe_parse_llm_response( raw, SynthesisResult, llm, compose_prompt, user_prompt, modality=modality, model_override=model_override, max_tokens=estimated_input, ) def _synthesize_chunk( chunk: list[tuple[KeyMoment, dict]], category: str, creator_name: str, system_prompt: str, llm: LLMClient, model_override: str | None, modality: str, hard_limit: int, video_id: str, run_id: str | None, chunk_label: str, ) -> SynthesisResult: """Run a single synthesis LLM call for a chunk of moments. Returns the parsed SynthesisResult. """ moments_text, _ = _build_moments_text(chunk, category) user_prompt = f"{creator_name}\n\n{moments_text}\n" estimated_input = estimate_max_tokens(system_prompt, user_prompt, stage="stage5_synthesis", hard_limit=hard_limit) logger.info( "Stage 5: Synthesizing %s — %d moments, max_tokens=%d", chunk_label, len(chunk), estimated_input, ) raw = llm.complete( system_prompt, user_prompt, response_model=SynthesisResult, on_complete=_make_llm_callback( video_id, "stage5_synthesis", system_prompt=system_prompt, user_prompt=user_prompt, run_id=run_id, context_label=chunk_label, request_params=_build_request_params(estimated_input, model_override, modality, "SynthesisResult", hard_limit), ), modality=modality, model_override=model_override, max_tokens=estimated_input, ) return _safe_parse_llm_response( raw, SynthesisResult, llm, system_prompt, user_prompt, modality=modality, model_override=model_override, max_tokens=estimated_input, ) def _slug_base(slug: str) -> str: """Extract the slug prefix before the creator name suffix for merge grouping. E.g. 'wavetable-sound-design-copycatt' → 'wavetable-sound-design' Also normalizes casing. """ return slug.lower().strip() def _merge_pages_by_slug( all_pages: list, creator_name: str, llm: LLMClient, model_override: str | None, modality: str, hard_limit: int, video_id: str, run_id: str | None, ) -> list: """Detect pages with the same slug across chunks and merge them via LLM. Pages with unique slugs pass through unchanged. Pages sharing a slug get sent to a merge prompt that combines them into one cohesive page. Returns the final list of SynthesizedPage objects. """ from pipeline.schemas import SynthesizedPage # Group pages by slug by_slug: dict[str, list] = defaultdict(list) for page in all_pages: by_slug[_slug_base(page.slug)].append(page) final_pages = [] for slug, pages_group in by_slug.items(): if len(pages_group) == 1: # Unique slug — no merge needed final_pages.append(pages_group[0]) continue # Multiple pages share this slug — merge via LLM logger.info( "Stage 5: Merging %d partial pages with slug '%s' for video_id=%s", len(pages_group), slug, video_id, ) # Serialize partial pages to JSON for the merge prompt pages_json = json.dumps( [p.model_dump() for p in pages_group], indent=2, ensure_ascii=False, ) merge_system_prompt = _load_prompt("stage5_merge.txt", video_id=video_id) merge_user_prompt = f"{creator_name}\n\n{pages_json}\n" max_tokens = estimate_max_tokens( merge_system_prompt, merge_user_prompt, stage="stage5_synthesis", hard_limit=hard_limit, ) logger.info( "Stage 5: Merge call for slug '%s' — %d partial pages, max_tokens=%d", slug, len(pages_group), max_tokens, ) raw = llm.complete( merge_system_prompt, merge_user_prompt, response_model=SynthesisResult, on_complete=_make_llm_callback( video_id, "stage5_synthesis", system_prompt=merge_system_prompt, user_prompt=merge_user_prompt, run_id=run_id, context_label=f"merge:{slug}", request_params=_build_request_params(max_tokens, model_override, modality, "SynthesisResult", hard_limit), ), modality=modality, model_override=model_override, max_tokens=max_tokens, ) merge_result = _safe_parse_llm_response( raw, SynthesisResult, llm, merge_system_prompt, merge_user_prompt, modality=modality, model_override=model_override, max_tokens=max_tokens, ) if merge_result.pages: final_pages.extend(merge_result.pages) logger.info( "Stage 5: Merge produced %d page(s) for slug '%s'", len(merge_result.pages), slug, ) else: # Merge returned nothing — fall back to keeping the partials logger.warning( "Stage 5: Merge returned 0 pages for slug '%s', keeping %d partials", slug, len(pages_group), ) final_pages.extend(pages_group) return final_pages @celery_app.task(bind=True, max_retries=3, default_retry_delay=30) def stage5_synthesis(self, video_id: str, run_id: str | None = None) -> str: """Synthesize technique pages from classified key moments. Groups moments by (creator, topic_category), calls the LLM to synthesize each group into a TechniquePage, creates/updates page rows, and links KeyMoments to their TechniquePage. For large category groups (exceeding synthesis_chunk_size), moments are split into chronological chunks, synthesized independently, then pages with matching slugs are merged via a dedicated merge LLM call. Sets processing_status to 'complete'. Returns the video_id for chain compatibility. """ start = time.monotonic() logger.info("Stage 5 (synthesis) starting for video_id=%s", video_id) _emit_event(video_id, "stage5_synthesis", "start", run_id=run_id) settings = get_settings() chunk_size = settings.synthesis_chunk_size session = _get_sync_session() try: # Load video and moments video = session.execute( select(SourceVideo).where(SourceVideo.id == video_id) ).scalar_one() moments = ( session.execute( select(KeyMoment) .where(KeyMoment.source_video_id == video_id) .order_by(KeyMoment.start_time) ) .scalars() .all() ) # Resolve creator name for the LLM prompt creator = session.execute( select(Creator).where(Creator.id == video.creator_id) ).scalar_one_or_none() creator_name = creator.name if creator else "Unknown" if not moments: logger.info("Stage 5: No moments found for video_id=%s, skipping.", video_id) return video_id # Load classification data from stage 4 classification_data = _load_classification_data(video_id) cls_by_moment_id = {c["moment_id"]: c for c in classification_data} # Group moments by topic_category (from classification) # Normalize category casing to prevent near-duplicate groups # (e.g., "Sound design" vs "Sound Design") groups: dict[str, list[tuple[KeyMoment, dict]]] = defaultdict(list) for moment in moments: cls_info = cls_by_moment_id.get(str(moment.id), {}) category = cls_info.get("topic_category", "Uncategorized").strip().title() groups[category].append((moment, cls_info)) system_prompt = _load_prompt("stage5_synthesis.txt", video_id=video_id) llm = _get_llm_client() model_override, modality = _get_stage_config(5) hard_limit = settings.llm_max_tokens_hard_limit logger.info("Stage 5 using model=%s, modality=%s", model_override or "default", modality) pages_created = 0 for category, moment_group in groups.items(): # Collect all tags across the full group (used for DB writes later) all_tags: set[str] = set() for _, cls_info in moment_group: all_tags.update(cls_info.get("topic_tags", [])) # ── Compose-or-create detection ──────────────────────── # Check if an existing technique page already covers this # creator + category combination (from a prior video run). compose_matches = session.execute( select(TechniquePage).where( TechniquePage.creator_id == video.creator_id, func.lower(TechniquePage.topic_category) == func.lower(category), ) ).scalars().all() if len(compose_matches) > 1: logger.warning( "Stage 5: Multiple existing pages (%d) match creator=%s category='%s'. " "Using first match '%s'.", len(compose_matches), video.creator_id, category, compose_matches[0].slug, ) compose_target = compose_matches[0] if compose_matches else None if compose_target is not None: # Load existing moments linked to this page existing_moments = session.execute( select(KeyMoment) .where(KeyMoment.technique_page_id == compose_target.id) .order_by(KeyMoment.start_time) ).scalars().all() logger.info( "Stage 5: Composing into existing page '%s' " "(%d existing moments + %d new moments)", compose_target.slug, len(existing_moments), len(moment_group), ) compose_result = _compose_into_existing( compose_target, existing_moments, moment_group, category, creator_name, system_prompt, llm, model_override, modality, hard_limit, video_id, run_id, ) synthesized_pages = list(compose_result.pages) # ── Chunked synthesis with truncation recovery ───────── elif len(moment_group) <= chunk_size: # Small group — try single LLM call first try: result = _synthesize_chunk( moment_group, category, creator_name, system_prompt, llm, model_override, modality, hard_limit, video_id, run_id, f"category:{category}", ) synthesized_pages = list(result.pages) logger.info( "Stage 5: category '%s' — %d moments, %d page(s) from single call", category, len(moment_group), len(synthesized_pages), ) except LLMTruncationError: # Output too large for model context — split in half and retry logger.warning( "Stage 5: category '%s' truncated with %d moments. " "Splitting into smaller chunks and retrying.", category, len(moment_group), ) half = max(1, len(moment_group) // 2) chunk_pages = [] for sub_start in range(0, len(moment_group), half): sub_chunk = moment_group[sub_start:sub_start + half] sub_label = f"category:{category} recovery-chunk:{sub_start // half + 1}" sub_result = _synthesize_chunk( sub_chunk, category, creator_name, system_prompt, llm, model_override, modality, hard_limit, video_id, run_id, sub_label, ) # Reindex moment_indices to global offsets for p in sub_result.pages: if p.moment_indices: p.moment_indices = [idx + sub_start for idx in p.moment_indices] chunk_pages.extend(sub_result.pages) synthesized_pages = chunk_pages logger.info( "Stage 5: category '%s' — %d page(s) from recovery split", category, len(synthesized_pages), ) else: # Large group — split into chunks, synthesize each, then merge num_chunks = (len(moment_group) + chunk_size - 1) // chunk_size logger.info( "Stage 5: category '%s' has %d moments — splitting into %d chunks of ≤%d", category, len(moment_group), num_chunks, chunk_size, ) chunk_pages = [] for chunk_idx in range(num_chunks): chunk_start = chunk_idx * chunk_size chunk_end = min(chunk_start + chunk_size, len(moment_group)) chunk = moment_group[chunk_start:chunk_end] chunk_label = f"category:{category} chunk:{chunk_idx + 1}/{num_chunks}" result = _synthesize_chunk( chunk, category, creator_name, system_prompt, llm, model_override, modality, hard_limit, video_id, run_id, chunk_label, ) chunk_pages.extend(result.pages) logger.info( "Stage 5: %s produced %d page(s)", chunk_label, len(result.pages), ) # Merge pages with matching slugs across chunks logger.info( "Stage 5: category '%s' — %d total pages from %d chunks, checking for merges", category, len(chunk_pages), num_chunks, ) synthesized_pages = _merge_pages_by_slug( chunk_pages, creator_name, llm, model_override, modality, hard_limit, video_id, run_id, ) logger.info( "Stage 5: category '%s' — %d final page(s) after merge", category, len(synthesized_pages), ) # ── Persist pages to DB ────────────────────────────────── # Load prior pages from this video (snapshot taken before pipeline reset) prior_page_ids = _load_prior_pages(video_id) for page_data in synthesized_pages: page_moment_indices = getattr(page_data, "moment_indices", None) or [] existing = None # First: check by slug (most specific match) if existing is None: existing = session.execute( select(TechniquePage).where(TechniquePage.slug == page_data.slug) ).scalar_one_or_none() # Fallback: check prior pages from this video by creator + category # Use .first() since multiple pages may share a category if existing is None and prior_page_ids: existing = session.execute( select(TechniquePage).where( TechniquePage.id.in_(prior_page_ids), TechniquePage.creator_id == video.creator_id, func.lower(TechniquePage.topic_category) == func.lower(page_data.topic_category or category), ) ).scalars().first() if existing: logger.info( "Stage 5: Matched prior page '%s' (id=%s) by creator+category for video_id=%s", existing.slug, existing.id, video_id, ) if existing: # Snapshot existing content before overwriting try: sq = existing.source_quality sq_value = sq.value if hasattr(sq, 'value') else sq snapshot = { "title": existing.title, "slug": existing.slug, "topic_category": existing.topic_category, "topic_tags": existing.topic_tags, "summary": existing.summary, "body_sections": existing.body_sections, "signal_chains": existing.signal_chains, "plugins": existing.plugins, "source_quality": sq_value, } version_count = session.execute( select(func.count()).where( TechniquePageVersion.technique_page_id == existing.id ) ).scalar() version_number = version_count + 1 version = TechniquePageVersion( technique_page_id=existing.id, version_number=version_number, content_snapshot=snapshot, pipeline_metadata=_capture_pipeline_metadata(), ) session.add(version) logger.info( "Version snapshot v%d created for page slug=%s", version_number, existing.slug, ) except Exception as snap_exc: logger.error( "Failed to create version snapshot for page slug=%s: %s", existing.slug, snap_exc, ) # Best-effort versioning — continue with page update # Update existing page existing.title = page_data.title existing.summary = page_data.summary existing.body_sections = _serialize_body_sections(page_data.body_sections) existing.signal_chains = page_data.signal_chains existing.plugins = page_data.plugins if page_data.plugins else None page_tags = _compute_page_tags(page_moment_indices, moment_group, all_tags) existing.topic_tags = page_tags existing.source_quality = page_data.source_quality page = existing else: page = TechniquePage( creator_id=video.creator_id, title=page_data.title, slug=page_data.slug, topic_category=page_data.topic_category or category, topic_tags=_compute_page_tags(page_moment_indices, moment_group, all_tags), summary=page_data.summary, body_sections=_serialize_body_sections(page_data.body_sections), signal_chains=page_data.signal_chains, plugins=page_data.plugins if page_data.plugins else None, source_quality=page_data.source_quality, ) session.add(page) session.flush() # Get the page.id assigned pages_created += 1 # Set body_sections_format on every page (new or updated) page.body_sections_format = "v2" # Track contributing video via TechniquePageVideo stmt = pg_insert(TechniquePageVideo.__table__).values( technique_page_id=page.id, source_video_id=video.id, ).on_conflict_do_nothing() session.execute(stmt) # Link moments to the technique page using moment_indices if page_moment_indices: # LLM specified which moments belong to this page for idx in page_moment_indices: if 0 <= idx < len(moment_group): moment_group[idx][0].technique_page_id = page.id elif len(synthesized_pages) == 1: # Single page — link all moments (safe fallback) for m, _ in moment_group: m.technique_page_id = page.id else: # Multiple pages but no moment_indices — log warning logger.warning( "Stage 5: page '%s' has no moment_indices and is one of %d pages " "for category '%s'. Moments will not be linked to this page.", page_data.slug, len(synthesized_pages), category, ) # Update processing_status video.processing_status = ProcessingStatus.complete session.commit() elapsed = time.monotonic() - start _emit_event(video_id, "stage5_synthesis", "complete", run_id=run_id) logger.info( "Stage 5 (synthesis) completed for video_id=%s in %.1fs — %d pages created/updated", video_id, elapsed, pages_created, ) return video_id except FileNotFoundError: raise except Exception as exc: session.rollback() _emit_event(video_id, "stage5_synthesis", "error", run_id=run_id, payload={"error": str(exc)}) logger.error("Stage 5 failed for video_id=%s: %s", video_id, exc) raise self.retry(exc=exc) finally: session.close() # ── Heading slug helper (matches frontend TableOfContents.tsx slugify) ──────── def _slugify_heading(text: str) -> str: """Convert a heading string to a URL-friendly anchor slug. Must produce identical output to the frontend's slugify in ``frontend/src/components/TableOfContents.tsx``. """ return re.sub(r"[^a-z0-9]+", "-", text.lower()).strip("-") # ── Stage 6: Embed & Index ─────────────────────────────────────────────────── @celery_app.task(bind=True, max_retries=0) def stage6_embed_and_index(self, video_id: str, run_id: str | None = None) -> str: """Generate embeddings for technique pages and key moments, then upsert to Qdrant. This is a non-blocking side-effect stage — failures are logged but do not fail the pipeline. Embeddings can be regenerated later. Does NOT update processing_status. Returns the video_id for chain compatibility. """ start = time.monotonic() logger.info("Stage 6 (embed & index) starting for video_id=%s", video_id) settings = get_settings() session = _get_sync_session() try: # Load technique pages created for this video's moments moments = ( session.execute( select(KeyMoment) .where(KeyMoment.source_video_id == video_id) .order_by(KeyMoment.start_time) ) .scalars() .all() ) # Get unique technique page IDs from moments page_ids = {m.technique_page_id for m in moments if m.technique_page_id is not None} pages = [] if page_ids: pages = ( session.execute( select(TechniquePage).where(TechniquePage.id.in_(page_ids)) ) .scalars() .all() ) if not moments and not pages: logger.info("Stage 6: No moments or pages for video_id=%s, skipping.", video_id) if run_id: _finish_run(run_id, "complete") return video_id # Resolve creator names for enriched embedding text creator_ids = {p.creator_id for p in pages} creator_map: dict[str, str] = {} if creator_ids: creators = ( session.execute( select(Creator).where(Creator.id.in_(creator_ids)) ) .scalars() .all() ) creator_map = {str(c.id): c.name for c in creators} # Resolve creator name for key moments via source_video → creator video_ids = {m.source_video_id for m in moments} video_creator_map: dict[str, str] = {} if video_ids: rows = session.execute( select(SourceVideo.id, Creator.name, Creator.id.label("creator_id")) .join(Creator, SourceVideo.creator_id == Creator.id) .where(SourceVideo.id.in_(video_ids)) ).all() video_creator_map = {str(r[0]): r[1] for r in rows} video_creator_id_map = {str(r[0]): str(r[2]) for r in rows} embed_client = EmbeddingClient(settings) qdrant = QdrantManager(settings) # Ensure collection exists before upserting qdrant.ensure_collection() # ── Embed & upsert technique pages ─────────────────────────────── if pages: page_texts = [] page_dicts = [] for p in pages: creator_name = creator_map.get(str(p.creator_id), "") tags_joined = " ".join(p.topic_tags) if p.topic_tags else "" text = f"{creator_name} {p.title} {p.topic_category or ''} {tags_joined} {p.summary or ''}" page_texts.append(text.strip()) page_dicts.append({ "page_id": str(p.id), "creator_id": str(p.creator_id), "creator_name": creator_name, "title": p.title, "slug": p.slug, "topic_category": p.topic_category or "", "topic_tags": p.topic_tags or [], "summary": p.summary or "", }) page_vectors = embed_client.embed(page_texts) if page_vectors: qdrant.upsert_technique_pages(page_dicts, page_vectors) logger.info( "Stage 6: Upserted %d technique page vectors for video_id=%s", len(page_vectors), video_id, ) else: logger.warning( "Stage 6: Embedding returned empty for %d technique pages (video_id=%s). " "Skipping page upsert.", len(page_texts), video_id, ) # ── Embed & upsert key moments ─────────────────────────────────── if moments: # Build page_id → slug mapping for linking moments to technique pages page_id_to_slug: dict[str, str] = {} if pages: for p in pages: page_id_to_slug[str(p.id)] = p.slug moment_texts = [] moment_dicts = [] for m in moments: creator_name = video_creator_map.get(str(m.source_video_id), "") text = f"{creator_name} {m.title} {m.summary or ''}" moment_texts.append(text.strip()) tp_id = str(m.technique_page_id) if m.technique_page_id else "" moment_dicts.append({ "moment_id": str(m.id), "source_video_id": str(m.source_video_id), "creator_id": video_creator_id_map.get(str(m.source_video_id), ""), "technique_page_id": tp_id, "technique_page_slug": page_id_to_slug.get(tp_id, ""), "title": m.title, "creator_name": creator_name, "start_time": m.start_time, "end_time": m.end_time, "content_type": m.content_type.value, }) moment_vectors = embed_client.embed(moment_texts) if moment_vectors: qdrant.upsert_key_moments(moment_dicts, moment_vectors) logger.info( "Stage 6: Upserted %d key moment vectors for video_id=%s", len(moment_vectors), video_id, ) else: logger.warning( "Stage 6: Embedding returned empty for %d key moments (video_id=%s). " "Skipping moment upsert.", len(moment_texts), video_id, ) # ── Embed & upsert technique page sections (v2 only) ──────────── section_count = 0 v2_pages = [p for p in pages if getattr(p, "body_sections_format", "v1") == "v2"] for p in v2_pages: body_sections = p.body_sections if not isinstance(body_sections, list): continue creator_name = creator_map.get(str(p.creator_id), "") page_id_str = str(p.id) # Delete stale section points before re-upserting try: qdrant.delete_sections_by_page_id(page_id_str) except Exception as exc: logger.warning( "Stage 6: Failed to delete stale sections for page_id=%s: %s", page_id_str, exc, ) section_texts: list[str] = [] section_dicts: list[dict] = [] for section in body_sections: if not isinstance(section, dict): logger.warning( "Stage 6: Malformed section (not a dict) in page_id=%s. Skipping.", page_id_str, ) continue heading = section.get("heading", "") if not heading or not heading.strip(): continue section_anchor = _slugify_heading(heading) section_content = section.get("content", "") # Include subsection content for richer embedding subsection_parts: list[str] = [] for sub in section.get("subsections", []): if isinstance(sub, dict): sub_heading = sub.get("heading", "") sub_content = sub.get("content", "") if sub_heading: subsection_parts.append(f"{sub_heading}: {sub_content}") elif sub_content: subsection_parts.append(sub_content) embed_text = ( f"{creator_name} {p.title} — {heading}: " f"{section_content} {' '.join(subsection_parts)}" ).strip() section_texts.append(embed_text) section_dicts.append({ "page_id": page_id_str, "creator_id": str(p.creator_id), "creator_name": creator_name, "title": p.title, "slug": p.slug, "section_heading": heading, "section_anchor": section_anchor, "topic_category": p.topic_category or "", "topic_tags": p.topic_tags or [], "summary": (section_content or "")[:200], }) if section_texts: try: section_vectors = embed_client.embed(section_texts) if section_vectors: qdrant.upsert_technique_sections(section_dicts, section_vectors) section_count += len(section_vectors) else: logger.warning( "Stage 6: Embedding returned empty for %d sections of page_id=%s. Skipping.", len(section_texts), page_id_str, ) except Exception as exc: logger.warning( "Stage 6: Section embedding failed for page_id=%s: %s. Skipping.", page_id_str, exc, ) if section_count: logger.info( "Stage 6: Upserted %d technique section vectors for video_id=%s", section_count, video_id, ) elapsed = time.monotonic() - start logger.info( "Stage 6 (embed & index) completed for video_id=%s in %.1fs — " "%d pages, %d moments processed", video_id, elapsed, len(pages), len(moments), ) if run_id: _finish_run(run_id, "complete") return video_id except Exception as exc: # Non-blocking: log error but don't fail the pipeline logger.error( "Stage 6 failed for video_id=%s: %s. " "Pipeline continues — embeddings can be regenerated later.", video_id, exc, ) if run_id: _finish_run(run_id, "complete") # Run is still "complete" — stage6 is best-effort return video_id finally: session.close() def _snapshot_prior_pages(video_id: str) -> None: """Save existing technique_page_ids linked to this video before pipeline resets them. When a video is reprocessed, stage 3 deletes and recreates key_moments, breaking the link to technique pages. This snapshots the page IDs to Redis so stage 5 can find and update prior pages instead of creating duplicates. """ import redis session = _get_sync_session() try: # Find technique pages linked via this video's key moments rows = session.execute( select(KeyMoment.technique_page_id) .where( KeyMoment.source_video_id == video_id, KeyMoment.technique_page_id.isnot(None), ) .distinct() ).scalars().all() page_ids = [str(pid) for pid in rows] if page_ids: settings = get_settings() r = redis.Redis.from_url(settings.redis_url) key = f"chrysopedia:prior_pages:{video_id}" r.set(key, json.dumps(page_ids), ex=86400) logger.info( "Snapshot %d prior technique pages for video_id=%s: %s", len(page_ids), video_id, page_ids, ) else: logger.info("No prior technique pages for video_id=%s", video_id) finally: session.close() def _load_prior_pages(video_id: str) -> list[str]: """Load prior technique page IDs from Redis.""" import redis settings = get_settings() r = redis.Redis.from_url(settings.redis_url) key = f"chrysopedia:prior_pages:{video_id}" raw = r.get(key) if raw is None: return [] return json.loads(raw) # ── Stage completion detection for auto-resume ─────────────────────────────── # Ordered list of pipeline stages for resumability logic _PIPELINE_STAGES = [ "stage2_segmentation", "stage3_extraction", "stage4_classification", "stage5_synthesis", "stage6_embed_and_index", ] _STAGE_TASKS = { "stage2_segmentation": stage2_segmentation, "stage3_extraction": stage3_extraction, "stage4_classification": stage4_classification, "stage5_synthesis": stage5_synthesis, "stage6_embed_and_index": stage6_embed_and_index, } def _get_last_completed_stage(video_id: str) -> str | None: """Find the last stage that completed successfully for this video. Queries pipeline_events for the most recent run, looking for 'complete' events. Returns the stage name (e.g. 'stage3_extraction') or None if no stages have completed. """ session = _get_sync_session() try: # Find the most recent run for this video from models import PipelineRun latest_run = session.execute( select(PipelineRun) .where(PipelineRun.video_id == video_id) .order_by(PipelineRun.started_at.desc()) .limit(1) ).scalar_one_or_none() if latest_run is None: return None # Get all 'complete' events from that run completed_events = session.execute( select(PipelineEvent.stage) .where( PipelineEvent.run_id == str(latest_run.id), PipelineEvent.event_type == "complete", ) ).scalars().all() completed_set = set(completed_events) # Walk backwards through the ordered stages to find the last completed one last_completed = None for stage_name in _PIPELINE_STAGES: if stage_name in completed_set: last_completed = stage_name else: break # Stop at first gap — stages must be sequential if last_completed: logger.info( "Auto-resume: video_id=%s last completed stage=%s (run_id=%s)", video_id, last_completed, latest_run.id, ) return last_completed finally: session.close() # ── Orchestrator ───────────────────────────────────────────────────────────── @celery_app.task def mark_pipeline_error(request, exc, traceback, video_id: str, run_id: str | None = None) -> None: """Error callback — marks video as errored when a pipeline stage fails.""" logger.error("Pipeline failed for video_id=%s: %s", video_id, exc) _set_error_status(video_id, "pipeline", exc) if run_id: _finish_run(run_id, "error", error_stage="pipeline") def _create_run(video_id: str, trigger: str) -> str: """Create a PipelineRun and return its id.""" from models import PipelineRun, PipelineRunTrigger session = _get_sync_session() try: # Compute run_number: max existing + 1 from sqlalchemy import func as sa_func max_num = session.execute( select(sa_func.coalesce(sa_func.max(PipelineRun.run_number), 0)) .where(PipelineRun.video_id == video_id) ).scalar() or 0 run = PipelineRun( video_id=video_id, run_number=max_num + 1, trigger=PipelineRunTrigger(trigger), ) session.add(run) session.commit() run_id = str(run.id) return run_id finally: session.close() def _finish_run(run_id: str, status: str, error_stage: str | None = None) -> None: """Update a PipelineRun's status and finished_at.""" from models import PipelineRun, PipelineRunStatus, _now session = _get_sync_session() try: run = session.execute( select(PipelineRun).where(PipelineRun.id == run_id) ).scalar_one_or_none() if run: run.status = PipelineRunStatus(status) run.finished_at = _now() if error_stage: run.error_stage = error_stage # Aggregate total tokens from events total = session.execute( select(func.coalesce(func.sum(PipelineEvent.total_tokens), 0)) .where(PipelineEvent.run_id == run_id) ).scalar() or 0 run.total_tokens = total session.commit() except Exception as exc: logger.warning("Failed to finish run %s: %s", run_id, exc) finally: session.close() @celery_app.task def run_pipeline(video_id: str, trigger: str = "manual") -> str: """Orchestrate the full pipeline (stages 2-6) with auto-resume. For error/processing status, queries pipeline_events to find the last stage that completed successfully and resumes from the next stage. This avoids re-running expensive LLM stages that already succeeded. For clean_reprocess trigger, always starts from stage 2. Returns the video_id. """ logger.info("run_pipeline starting for video_id=%s", video_id) session = _get_sync_session() try: video = session.execute( select(SourceVideo).where(SourceVideo.id == video_id) ).scalar_one_or_none() if video is None: logger.error("run_pipeline: video_id=%s not found", video_id) raise ValueError(f"Video not found: {video_id}") status = video.processing_status logger.info( "run_pipeline: video_id=%s current status=%s", video_id, status.value ) finally: session.close() if status == ProcessingStatus.complete: logger.info( "run_pipeline: video_id=%s already at status=%s, nothing to do.", video_id, status.value, ) return video_id # Snapshot prior technique pages before pipeline resets key_moments _snapshot_prior_pages(video_id) # Create a pipeline run record run_id = _create_run(video_id, trigger) logger.info("run_pipeline: created run_id=%s for video_id=%s (trigger=%s)", run_id, video_id, trigger) # Determine which stages to run resume_from_idx = 0 # Default: start from stage 2 if trigger != "clean_reprocess" and status in (ProcessingStatus.processing, ProcessingStatus.error): # Try to resume from where we left off last_completed = _get_last_completed_stage(video_id) if last_completed and last_completed in _PIPELINE_STAGES: completed_idx = _PIPELINE_STAGES.index(last_completed) resume_from_idx = completed_idx + 1 if resume_from_idx >= len(_PIPELINE_STAGES): logger.info( "run_pipeline: all stages already completed for video_id=%s", video_id, ) return video_id stages_to_run = _PIPELINE_STAGES[resume_from_idx:] logger.info( "run_pipeline: video_id=%s will run stages: %s (resume_from_idx=%d)", video_id, stages_to_run, resume_from_idx, ) # Run stages inline (synchronously) so each video completes fully # before the worker picks up the next queued video. # This replaces the previous celery_chain dispatch which caused # interleaved execution when multiple videos were queued. if stages_to_run: # Mark as processing before starting session = _get_sync_session() try: video = session.execute( select(SourceVideo).where(SourceVideo.id == video_id) ).scalar_one() video.processing_status = ProcessingStatus.processing session.commit() finally: session.close() logger.info( "run_pipeline: executing %d stages inline for video_id=%s (run_id=%s, starting at %s)", len(stages_to_run), video_id, run_id, stages_to_run[0], ) try: for stage_name in stages_to_run: task_func = _STAGE_TASKS[stage_name] # Call the task directly — runs synchronously in this worker # process. bind=True tasks receive the task instance as self # automatically when called this way. task_func(video_id, run_id=run_id) except Exception as exc: logger.error( "run_pipeline: stage %s failed for video_id=%s: %s", stage_name, video_id, exc, ) _set_error_status(video_id, stage_name, exc) if run_id: _finish_run(run_id, "error", error_stage=stage_name) raise return video_id # ── Single-Stage Re-Run ───────────────────────────────────────────────────── @celery_app.task def run_single_stage( video_id: str, stage_name: str, trigger: str = "stage_rerun", prompt_override: str | None = None, ) -> str: """Re-run a single pipeline stage without running predecessors. Designed for fast prompt iteration — especially stage 5 synthesis. Bypasses the processing_status==complete guard, creates a proper PipelineRun record, and restores status on completion. If ``prompt_override`` is provided, it is stored in Redis as a per-video override that ``_load_prompt`` reads before falling back to the on-disk template. The override is cleaned up after the stage runs. Returns the video_id. """ import redis as redis_lib logger.info( "[RERUN] Starting single-stage re-run: video_id=%s, stage=%s, trigger=%s", video_id, stage_name, trigger, ) # Validate stage name if stage_name not in _PIPELINE_STAGES: raise ValueError( f"[RERUN] Invalid stage '{stage_name}'. " f"Valid stages: {_PIPELINE_STAGES}" ) # Validate video exists session = _get_sync_session() try: video = session.execute( select(SourceVideo).where(SourceVideo.id == video_id) ).scalar_one_or_none() if video is None: raise ValueError(f"[RERUN] Video not found: {video_id}") original_status = video.processing_status finally: session.close() # Validate prerequisites for the requested stage prereq_ok, prereq_msg = _check_stage_prerequisites(video_id, stage_name) if not prereq_ok: logger.error("[RERUN] Prerequisite check failed: %s", prereq_msg) raise ValueError(f"[RERUN] Prerequisites not met: {prereq_msg}") logger.info("[RERUN] Prerequisite check passed: %s", prereq_msg) # Store prompt override in Redis if provided override_key = None if prompt_override: settings = get_settings() try: r = redis_lib.Redis.from_url(settings.redis_url) # Map stage name to its prompt template stage_prompt_map = { "stage2_segmentation": "stage2_segmentation.txt", "stage3_extraction": "stage3_extraction.txt", "stage4_classification": "stage4_classification.txt", "stage5_synthesis": "stage5_synthesis.txt", } template = stage_prompt_map.get(stage_name) if template: override_key = f"chrysopedia:prompt_override:{video_id}:{template}" r.set(override_key, prompt_override, ex=3600) # 1-hour TTL logger.info( "[RERUN] Prompt override stored: key=%s (%d chars, first 100: %s)", override_key, len(prompt_override), prompt_override[:100], ) except Exception as exc: logger.warning("[RERUN] Failed to store prompt override: %s", exc) # Snapshot prior pages (needed for stage 5 page matching) if stage_name in ("stage5_synthesis",): _snapshot_prior_pages(video_id) # Create pipeline run record run_id = _create_run(video_id, trigger) logger.info("[RERUN] Created run_id=%s", run_id) # Temporarily set status to processing session = _get_sync_session() try: video = session.execute( select(SourceVideo).where(SourceVideo.id == video_id) ).scalar_one() video.processing_status = ProcessingStatus.processing session.commit() finally: session.close() # Run the single stage start = time.monotonic() try: task_func = _STAGE_TASKS[stage_name] task_func(video_id, run_id=run_id) elapsed = time.monotonic() - start logger.info( "[RERUN] Stage %s completed: %.1fs, video_id=%s", stage_name, elapsed, video_id, ) _finish_run(run_id, "complete") # Restore status to complete session = _get_sync_session() try: video = session.execute( select(SourceVideo).where(SourceVideo.id == video_id) ).scalar_one() video.processing_status = ProcessingStatus.complete session.commit() logger.info("[RERUN] Status restored to complete") finally: session.close() except Exception as exc: elapsed = time.monotonic() - start logger.error( "[RERUN] Stage %s FAILED after %.1fs: %s", stage_name, elapsed, exc, ) _set_error_status(video_id, stage_name, exc) _finish_run(run_id, "error", error_stage=stage_name) raise finally: # Clean up prompt override from Redis if override_key: try: settings = get_settings() r = redis_lib.Redis.from_url(settings.redis_url) r.delete(override_key) logger.info("[RERUN] Prompt override cleaned up: %s", override_key) except Exception as exc: logger.warning("[RERUN] Failed to clean up prompt override: %s", exc) return video_id def _check_stage_prerequisites(video_id: str, stage_name: str) -> tuple[bool, str]: """Validate that prerequisite data exists for a stage re-run. Returns (ok, message) where message describes what was found or missing. """ session = _get_sync_session() try: if stage_name == "stage2_segmentation": # Needs transcript segments count = session.execute( select(func.count(TranscriptSegment.id)) .where(TranscriptSegment.source_video_id == video_id) ).scalar() or 0 if count == 0: return False, "No transcript segments found" return True, f"transcript_segments={count}" if stage_name == "stage3_extraction": # Needs transcript segments with topic_labels count = session.execute( select(func.count(TranscriptSegment.id)) .where( TranscriptSegment.source_video_id == video_id, TranscriptSegment.topic_label.isnot(None), ) ).scalar() or 0 if count == 0: return False, "No labeled transcript segments (stage 2 must complete first)" return True, f"labeled_segments={count}" if stage_name == "stage4_classification": # Needs key moments count = session.execute( select(func.count(KeyMoment.id)) .where(KeyMoment.source_video_id == video_id) ).scalar() or 0 if count == 0: return False, "No key moments found (stage 3 must complete first)" return True, f"key_moments={count}" if stage_name == "stage5_synthesis": # Needs key moments + classification data km_count = session.execute( select(func.count(KeyMoment.id)) .where(KeyMoment.source_video_id == video_id) ).scalar() or 0 if km_count == 0: return False, "No key moments found (stages 2-3 must complete first)" cls_data = _load_classification_data(video_id) cls_source = "redis+pg" if not cls_data: return False, f"No classification data found (stage 4 must complete first), key_moments={km_count}" return True, f"key_moments={km_count}, classification_entries={len(cls_data)}" if stage_name == "stage6_embed_and_index": return True, "stage 6 is non-blocking and always runs" return False, f"Unknown stage: {stage_name}" finally: session.close() # ── Avatar Fetching ───────────────────────────────────────────────────────── @celery_app.task def fetch_creator_avatar(creator_id: str) -> dict: """Fetch avatar for a single creator from TheAudioDB. Looks up the creator by ID, calls TheAudioDB, and updates the avatar_url/avatar_source/avatar_fetched_at columns if a confident match is found. Returns a status dict. """ import sys from datetime import datetime, timezone # Ensure /app is on sys.path for forked Celery workers if "/app" not in sys.path: sys.path.insert(0, "/app") from services.avatar import lookup_avatar session = _get_sync_session() try: creator = session.execute( select(Creator).where(Creator.id == creator_id) ).scalar_one_or_none() if not creator: return {"status": "error", "detail": f"Creator {creator_id} not found"} result = lookup_avatar(creator.name, creator.genres) if result: creator.avatar_url = result.url creator.avatar_source = result.source creator.avatar_fetched_at = datetime.now(timezone.utc) session.commit() return { "status": "found", "creator": creator.name, "avatar_url": result.url, "confidence": result.confidence, "matched_artist": result.artist_name, } else: creator.avatar_source = "generated" creator.avatar_fetched_at = datetime.now(timezone.utc) session.commit() return { "status": "not_found", "creator": creator.name, "detail": "No confident match from TheAudioDB", } except Exception as exc: session.rollback() logger.error("Avatar fetch failed for creator %s: %s", creator_id, exc) return {"status": "error", "detail": str(exc)} finally: session.close() # ── Highlight Detection ────────────────────────────────────────────────────── @celery_app.task(bind=True, max_retries=3, default_retry_delay=30) def stage_highlight_detection(self, video_id: str, run_id: str | None = None) -> str: """Score all KeyMoments for a video and upsert HighlightCandidates. For each KeyMoment belonging to the video, runs the heuristic scorer and bulk-upserts results into highlight_candidates (INSERT ON CONFLICT UPDATE). Returns the video_id for chain compatibility. """ from pipeline.highlight_scorer import extract_word_timings, score_moment start = time.monotonic() logger.info("Highlight detection starting for video_id=%s", video_id) _emit_event(video_id, "highlight_detection", "start", run_id=run_id) session = _get_sync_session() try: # ------------------------------------------------------------------ # Load transcript data once for the entire video (word-level timing) # ------------------------------------------------------------------ transcript_data: list | None = None source_video = session.execute( select(SourceVideo).where(SourceVideo.id == video_id) ).scalar_one_or_none() if source_video and source_video.transcript_path: transcript_file = source_video.transcript_path try: with open(transcript_file, "r") as fh: raw = json.load(fh) # Accept both {"segments": [...]} and bare [...] if isinstance(raw, dict): transcript_data = raw.get("segments", raw.get("results", [])) elif isinstance(raw, list): transcript_data = raw else: transcript_data = None if transcript_data: logger.info( "Loaded transcript for video_id=%s (%d segments)", video_id, len(transcript_data), ) except FileNotFoundError: logger.warning( "Transcript file not found for video_id=%s: %s", video_id, transcript_file, ) except (json.JSONDecodeError, OSError) as io_exc: logger.warning( "Failed to load transcript for video_id=%s: %s", video_id, io_exc, ) else: logger.info( "No transcript_path for video_id=%s — audio proxy signals will be neutral", video_id, ) moments = ( session.execute( select(KeyMoment) .where(KeyMoment.source_video_id == video_id) .order_by(KeyMoment.start_time) ) .scalars() .all() ) if not moments: logger.info( "Highlight detection: No key moments for video_id=%s, skipping.", video_id, ) _emit_event( video_id, "highlight_detection", "complete", run_id=run_id, payload={"candidates": 0}, ) return video_id candidate_count = 0 for moment in moments: try: # Extract word-level timings for this moment's window word_timings = None if transcript_data: word_timings = extract_word_timings( transcript_data, moment.start_time, moment.end_time, ) or None # empty list → None for neutral fallback result = score_moment( start_time=moment.start_time, end_time=moment.end_time, content_type=moment.content_type.value if moment.content_type else None, summary=moment.summary, plugins=moment.plugins, raw_transcript=moment.raw_transcript, source_quality=None, # filled below if technique_page loaded video_content_type=None, # filled below if source_video loaded word_timings=word_timings, ) except Exception as score_exc: logger.warning( "Highlight detection: score_moment failed for moment %s: %s", moment.id, score_exc, ) result = { "score": 0.0, "score_breakdown": {}, "duration_secs": max(0.0, moment.end_time - moment.start_time), } stmt = pg_insert(HighlightCandidate).values( key_moment_id=moment.id, source_video_id=moment.source_video_id, score=result["score"], score_breakdown=result["score_breakdown"], duration_secs=result["duration_secs"], ) stmt = stmt.on_conflict_do_update( constraint="highlight_candidates_key_moment_id_key", set_={ "score": stmt.excluded.score, "score_breakdown": stmt.excluded.score_breakdown, "duration_secs": stmt.excluded.duration_secs, "updated_at": func.now(), }, ) session.execute(stmt) candidate_count += 1 session.commit() elapsed = time.monotonic() - start _emit_event( video_id, "highlight_detection", "complete", run_id=run_id, payload={"candidates": candidate_count}, ) logger.info( "Highlight detection completed for video_id=%s in %.1fs — %d candidates upserted", video_id, elapsed, candidate_count, ) return video_id except Exception as exc: session.rollback() _emit_event( video_id, "highlight_detection", "error", run_id=run_id, payload={"error": str(exc)}, ) logger.error("Highlight detection failed for video_id=%s: %s", video_id, exc) raise self.retry(exc=exc) finally: session.close() # ── Personality profile extraction ─────────────────────────────────────────── def _sample_creator_transcripts( moments: list, creator_id: str, max_chars: int = 40000, ) -> tuple[str, int]: """Sample transcripts from a creator's key moments, respecting size tiers. - Small (<20K chars total): use all text. - Medium (20K-60K): first 300 chars from each moment, up to budget. - Large (>60K): random sample seeded by creator_id, attempts topic diversity via Redis classification data. Returns (sampled_text, total_char_count). """ import random transcripts = [ (m.source_video_id, m.raw_transcript) for m in moments if m.raw_transcript and m.raw_transcript.strip() ] if not transcripts: return ("", 0) total_chars = sum(len(t) for _, t in transcripts) # Small: use everything if total_chars <= 20_000: text = "\n\n---\n\n".join(t for _, t in transcripts) return (text, total_chars) # Medium: first 300 chars from each moment if total_chars <= 60_000: excerpts = [] budget = max_chars for _, t in transcripts: chunk = t[:300] if budget - len(chunk) < 0: break excerpts.append(chunk) budget -= len(chunk) text = "\n\n---\n\n".join(excerpts) return (text, total_chars) # Large: random sample with optional topic diversity from Redis topic_map: dict[str, list[tuple[str, str]]] = {} try: import redis as _redis settings = get_settings() r = _redis.from_url(settings.redis_url) video_ids = {str(vid) for vid, _ in transcripts} for vid in video_ids: raw = r.get(f"chrysopedia:classification:{vid}") if raw: classification = json.loads(raw) if isinstance(classification, list): for item in classification: cat = item.get("topic_category", "unknown") moment_id = item.get("moment_id") if moment_id: topic_map.setdefault(cat, []).append(moment_id) r.close() except Exception: # Fall back to random sampling without topic diversity pass rng = random.Random(creator_id) if topic_map: # Interleave from different categories for diversity ordered = [] cat_lists = list(topic_map.values()) rng.shuffle(cat_lists) idx = 0 while any(cat_lists): for cat in cat_lists: if cat: ordered.append(cat.pop(0)) cat_lists = [c for c in cat_lists if c] # Map moment IDs back to transcripts moment_lookup = {str(m.id): m.raw_transcript for m in moments if m.raw_transcript} diverse_transcripts = [ moment_lookup[mid] for mid in ordered if mid in moment_lookup ] if diverse_transcripts: transcripts_list = diverse_transcripts else: transcripts_list = [t for _, t in transcripts] else: transcripts_list = [t for _, t in transcripts] rng.shuffle(transcripts_list) excerpts = [] budget = max_chars for t in transcripts_list: chunk = t[:600] if budget - len(chunk) < 0: break excerpts.append(chunk) budget -= len(chunk) text = "\n\n---\n\n".join(excerpts) return (text, total_chars) @celery_app.task(bind=True, max_retries=2, default_retry_delay=60) def extract_personality_profile(self, creator_id: str) -> str: """Extract a personality profile from a creator's transcripts via LLM. Aggregates and samples transcripts from all of the creator's key moments, sends them to the LLM with the personality_extraction prompt, validates the response, and stores the profile as JSONB on Creator.personality_profile. Returns the creator_id for chain compatibility. """ from datetime import datetime, timezone start = time.monotonic() logger.info("Personality extraction starting for creator_id=%s", creator_id) _emit_event(creator_id, "personality_extraction", "start") session = _get_sync_session() try: # Load creator creator = session.execute( select(Creator).where(Creator.id == creator_id) ).scalar_one_or_none() if not creator: logger.error("Creator not found: %s", creator_id) _emit_event( creator_id, "personality_extraction", "error", payload={"error": "creator_not_found"}, ) return creator_id # Load all key moments with transcripts for this creator moments = ( session.execute( select(KeyMoment) .join(SourceVideo, KeyMoment.source_video_id == SourceVideo.id) .where(SourceVideo.creator_id == creator.id) .where(KeyMoment.raw_transcript.isnot(None)) ) .scalars() .all() ) if not moments: logger.warning( "No transcripts found for creator_id=%s (%s), skipping extraction", creator_id, creator.name, ) _emit_event( creator_id, "personality_extraction", "complete", payload={"skipped": True, "reason": "no_transcripts"}, ) return creator_id # Sample transcripts sampled_text, total_chars = _sample_creator_transcripts( moments, creator_id, ) if not sampled_text.strip(): logger.warning( "Empty transcript sample for creator_id=%s, skipping", creator_id, ) _emit_event( creator_id, "personality_extraction", "complete", payload={"skipped": True, "reason": "empty_sample"}, ) return creator_id # Load prompt and call LLM system_prompt = _load_prompt("personality_extraction.txt") user_prompt = ( f"Creator: {creator.name}\n\n" f"Transcript excerpts ({len(moments)} moments, {total_chars} total chars, " f"sample below):\n\n{sampled_text}" ) llm = _get_llm_client() callback = _make_llm_callback( creator_id, "personality_extraction", system_prompt=system_prompt, user_prompt=user_prompt, ) response = llm.complete( system_prompt=system_prompt, user_prompt=user_prompt, response_model=object, # triggers JSON mode on_complete=callback, ) # Parse and validate from schemas import PersonalityProfile as ProfileValidator try: raw_profile = json.loads(str(response)) except json.JSONDecodeError as jde: logger.warning( "LLM returned invalid JSON for creator_id=%s, retrying: %s", creator_id, jde, ) raise self.retry(exc=jde) try: validated = ProfileValidator.model_validate(raw_profile) except ValidationError as ve: logger.warning( "LLM profile failed validation for creator_id=%s, retrying: %s", creator_id, ve, ) raise self.retry(exc=ve) # Build final profile dict with metadata profile_dict = validated.model_dump() profile_dict["_metadata"] = { "extracted_at": datetime.now(timezone.utc).replace(tzinfo=None).isoformat(), "transcript_sample_size": total_chars, "moments_count": len(moments), "model_used": getattr(response, "finish_reason", None) or "unknown", } # Low sample size note if total_chars < 500: profile_dict["_metadata"]["low_sample_size"] = True # Store on creator creator.personality_profile = profile_dict session.commit() elapsed = time.monotonic() - start _emit_event( creator_id, "personality_extraction", "complete", duration_ms=int(elapsed * 1000), payload={ "moments_count": len(moments), "transcript_chars": total_chars, "sample_chars": len(sampled_text), }, ) logger.info( "Personality extraction completed for creator_id=%s (%s) in %.1fs — " "%d moments, %d chars sampled", creator_id, creator.name, elapsed, len(moments), len(sampled_text), ) return creator_id except Exception as exc: if isinstance(exc, (self.MaxRetriesExceededError,)): raise session.rollback() _emit_event( creator_id, "personality_extraction", "error", payload={"error": str(exc)[:500]}, ) logger.error( "Personality extraction failed for creator_id=%s: %s", creator_id, exc, ) raise self.retry(exc=exc) finally: session.close() # ── Stage: Shorts Generation ───────────────────────────────────────────────── @celery_app.task(bind=True, max_retries=1, default_retry_delay=60) def stage_generate_shorts(self, highlight_candidate_id: str, captions: bool = True) -> str: """Generate video shorts for an approved highlight candidate. Creates one GeneratedShort row per FormatPreset, extracts the clip via ffmpeg, uploads to MinIO, and updates status. Each preset is independent — a failure on one does not block the others. Args: highlight_candidate_id: UUID string of the approved highlight. captions: Whether to generate and burn in ASS subtitles (default True). Returns the highlight_candidate_id on completion. """ from pipeline.shorts_generator import PRESETS, extract_clip_with_template, resolve_video_path from pipeline.caption_generator import generate_ass_captions, write_ass_file from pipeline.card_renderer import parse_template_config, render_card_to_file from models import FormatPreset, GeneratedShort, ShortStatus, SourceVideo start = time.monotonic() session = _get_sync_session() settings = get_settings() try: # ── Load highlight with joined relations ──────────────────────── highlight = session.execute( select(HighlightCandidate) .where(HighlightCandidate.id == highlight_candidate_id) ).scalar_one_or_none() if highlight is None: logger.error( "Highlight candidate not found: %s", highlight_candidate_id, ) return highlight_candidate_id if highlight.status.value != "approved": logger.warning( "Highlight %s status is %s, expected approved — skipping", highlight_candidate_id, highlight.status.value, ) return highlight_candidate_id # Check for already-processing shorts (reject duplicate runs) existing_processing = session.execute( select(func.count()) .where(GeneratedShort.highlight_candidate_id == highlight_candidate_id) .where(GeneratedShort.status == ShortStatus.processing) ).scalar() if existing_processing and existing_processing > 0: logger.warning( "Highlight %s already has %d processing shorts — rejecting duplicate", highlight_candidate_id, existing_processing, ) return highlight_candidate_id # Eager-load relations key_moment = highlight.key_moment source_video = highlight.source_video # ── Resolve video file path ───────────────────────────────────── try: video_path = resolve_video_path( settings.video_source_path, source_video.file_path, ) except FileNotFoundError as fnf: logger.error( "Video file missing for highlight %s: %s", highlight_candidate_id, fnf, ) # Mark all presets as failed for preset in FormatPreset: spec = PRESETS[preset] short = GeneratedShort( highlight_candidate_id=highlight_candidate_id, format_preset=preset, width=spec.width, height=spec.height, status=ShortStatus.failed, error_message=str(fnf), ) session.add(short) session.commit() return highlight_candidate_id # ── Compute effective start/end (trim overrides) ──────────────── clip_start = highlight.trim_start if highlight.trim_start is not None else key_moment.start_time clip_end = highlight.trim_end if highlight.trim_end is not None else key_moment.end_time logger.info( "Generating shorts for highlight=%s video=%s [%.1f–%.1f]s", highlight_candidate_id, source_video.file_path, clip_start, clip_end, ) # ── Generate captions from transcript (if available and requested) ─ ass_path: Path | None = None captions_ok = False if not captions: logger.info( "Captions disabled for highlight=%s — skipping caption generation", highlight_candidate_id, ) else: try: transcript_data: list | None = None if source_video.transcript_path: try: with open(source_video.transcript_path, "r") as fh: raw = json.load(fh) if isinstance(raw, dict): transcript_data = raw.get("segments", raw.get("results", [])) elif isinstance(raw, list): transcript_data = raw except (FileNotFoundError, json.JSONDecodeError, OSError) as io_exc: logger.warning( "Failed to load transcript for captions highlight=%s: %s", highlight_candidate_id, io_exc, ) if transcript_data: from pipeline.highlight_scorer import extract_word_timings word_timings = extract_word_timings(transcript_data, clip_start, clip_end) if word_timings: ass_content = generate_ass_captions(word_timings, clip_start) ass_path = write_ass_file( ass_content, Path(f"/tmp/captions_{highlight_candidate_id}.ass"), ) captions_ok = True logger.info( "Generated captions for highlight=%s (%d words)", highlight_candidate_id, len(word_timings), ) else: logger.warning( "No word timings in transcript window [%.1f–%.1f]s for highlight=%s — proceeding without captions", clip_start, clip_end, highlight_candidate_id, ) else: logger.info( "No transcript available for highlight=%s — proceeding without captions", highlight_candidate_id, ) except Exception as cap_exc: logger.warning( "Caption generation failed for highlight=%s: %s — proceeding without captions", highlight_candidate_id, cap_exc, ) # ── Load creator template config (if available) ───────────────── intro_path: Path | None = None outro_path: Path | None = None try: creator = source_video.creator template_cfg = parse_template_config( creator.shorts_template if creator else None, ) except Exception as tmpl_exc: logger.warning( "Template config load failed for highlight=%s: %s — proceeding without cards", highlight_candidate_id, tmpl_exc, ) template_cfg = parse_template_config(None) # ── Process each preset independently ─────────────────────────── for preset in FormatPreset: spec = PRESETS[preset] preset_start = time.monotonic() # Create DB row (status=processing) short = GeneratedShort( highlight_candidate_id=highlight_candidate_id, format_preset=preset, width=spec.width, height=spec.height, status=ShortStatus.processing, duration_secs=clip_end - clip_start, ) session.add(short) session.commit() session.refresh(short) tmp_path = Path(f"/tmp/short_{short.id}_{preset.value}.mp4") minio_key = f"shorts/{highlight_candidate_id}/{preset.value}.mp4" try: # Render intro/outro cards for this preset's resolution preset_intro: Path | None = None preset_outro: Path | None = None if template_cfg["show_intro"] and template_cfg["intro_text"]: preset_intro = Path( f"/tmp/intro_{short.id}_{preset.value}.mp4" ) try: render_card_to_file( text=template_cfg["intro_text"], duration_secs=template_cfg["intro_duration"], width=spec.width, height=spec.height, output_path=preset_intro, accent_color=template_cfg["accent_color"], font_family=template_cfg["font_family"], ) except Exception as intro_exc: logger.warning( "Intro card render failed for highlight=%s preset=%s: %s — skipping intro", highlight_candidate_id, preset.value, intro_exc, ) preset_intro = None if template_cfg["show_outro"] and template_cfg["outro_text"]: preset_outro = Path( f"/tmp/outro_{short.id}_{preset.value}.mp4" ) try: render_card_to_file( text=template_cfg["outro_text"], duration_secs=template_cfg["outro_duration"], width=spec.width, height=spec.height, output_path=preset_outro, accent_color=template_cfg["accent_color"], font_family=template_cfg["font_family"], ) except Exception as outro_exc: logger.warning( "Outro card render failed for highlight=%s preset=%s: %s — skipping outro", highlight_candidate_id, preset.value, outro_exc, ) preset_outro = None # Extract clip (with optional template cards) extract_clip_with_template( input_path=video_path, output_path=tmp_path, start_secs=clip_start, end_secs=clip_end, vf_filter=spec.vf_filter, ass_path=ass_path, intro_path=preset_intro, outro_path=preset_outro, ) # Upload to MinIO file_size = tmp_path.stat().st_size with open(tmp_path, "rb") as f: from minio_client import upload_file upload_file( object_key=minio_key, data=f, length=file_size, content_type="video/mp4", ) # Update DB row — complete short.status = ShortStatus.complete short.file_size_bytes = file_size short.minio_object_key = minio_key short.captions_enabled = captions_ok short.share_token = secrets.token_urlsafe(8) session.commit() elapsed_preset = time.monotonic() - preset_start logger.info( "Short generated: highlight=%s preset=%s " "size=%d bytes duration=%.1fs elapsed=%.1fs", highlight_candidate_id, preset.value, file_size, clip_end - clip_start, elapsed_preset, ) except Exception as exc: session.rollback() # Re-fetch the short row after rollback session.refresh(short) short.status = ShortStatus.failed short.error_message = str(exc)[:2000] session.commit() elapsed_preset = time.monotonic() - preset_start logger.error( "Short failed: highlight=%s preset=%s " "error=%s elapsed=%.1fs", highlight_candidate_id, preset.value, str(exc)[:500], elapsed_preset, ) finally: # Clean up temp files (main clip + intro/outro cards) for tmp in [tmp_path, preset_intro, preset_outro]: if tmp is not None and tmp.exists(): try: tmp.unlink() except OSError: pass # Clean up temp ASS caption file if ass_path is not None and ass_path.exists(): try: ass_path.unlink() except OSError: pass elapsed = time.monotonic() - start logger.info( "Shorts generation complete for highlight=%s in %.1fs", highlight_candidate_id, elapsed, ) return highlight_candidate_id except Exception as exc: session.rollback() logger.error( "Shorts generation failed for highlight=%s: %s", highlight_candidate_id, exc, ) raise self.retry(exc=exc) finally: session.close()