"""Pipeline stage tasks (stages 2-5) and run_pipeline orchestrator. Each stage reads from PostgreSQL via sync SQLAlchemy, loads its prompt template from disk, calls the LLM client, parses the response, writes results back, and updates processing_status on SourceVideo. Celery tasks are synchronous — all DB access uses ``sqlalchemy.orm.Session``. """ from __future__ import annotations import json import logging import time from collections import defaultdict from pathlib import Path import yaml from celery import chain as celery_chain from pydantic import ValidationError from sqlalchemy import create_engine, select from sqlalchemy.orm import Session, sessionmaker from config import get_settings from models import ( KeyMoment, KeyMomentContentType, ProcessingStatus, SourceVideo, TechniquePage, TranscriptSegment, ) from pipeline.embedding_client import EmbeddingClient from pipeline.llm_client import LLMClient from pipeline.qdrant_client import QdrantManager from pipeline.schemas import ( ClassificationResult, ExtractionResult, SegmentationResult, SynthesisResult, ) from worker import celery_app logger = logging.getLogger(__name__) # ── Helpers ────────────────────────────────────────────────────────────────── _engine = None _SessionLocal = None def _get_sync_engine(): """Create a sync SQLAlchemy engine, converting the async URL if needed.""" global _engine if _engine is None: settings = get_settings() url = settings.database_url # Convert async driver to sync driver url = url.replace("postgresql+asyncpg://", "postgresql+psycopg2://") _engine = create_engine(url, pool_pre_ping=True, pool_size=5, max_overflow=10) return _engine def _get_sync_session() -> Session: """Create a sync SQLAlchemy session for Celery tasks.""" global _SessionLocal if _SessionLocal is None: _SessionLocal = sessionmaker(bind=_get_sync_engine()) return _SessionLocal() def _load_prompt(template_name: str) -> str: """Read a prompt template from the prompts directory. Raises FileNotFoundError if the template does not exist. """ settings = get_settings() path = Path(settings.prompts_path) / template_name if not path.exists(): logger.error("Prompt template not found: %s", path) raise FileNotFoundError(f"Prompt template not found: {path}") return path.read_text(encoding="utf-8") def _get_llm_client() -> LLMClient: """Return an LLMClient configured from settings.""" return LLMClient(get_settings()) def _load_canonical_tags() -> dict: """Load canonical tag taxonomy from config/canonical_tags.yaml.""" # Walk up from backend/ to find config/ candidates = [ Path("config/canonical_tags.yaml"), Path("../config/canonical_tags.yaml"), ] for candidate in candidates: if candidate.exists(): with open(candidate, encoding="utf-8") as f: return yaml.safe_load(f) raise FileNotFoundError( "canonical_tags.yaml not found. Searched: " + ", ".join(str(c) for c in candidates) ) def _format_taxonomy_for_prompt(tags_data: dict) -> str: """Format the canonical tags taxonomy as readable text for the LLM prompt.""" lines = [] for cat in tags_data.get("categories", []): lines.append(f"Category: {cat['name']}") lines.append(f" Description: {cat['description']}") lines.append(f" Sub-topics: {', '.join(cat.get('sub_topics', []))}") lines.append("") return "\n".join(lines) def _safe_parse_llm_response(raw: str, model_cls, llm: LLMClient, system_prompt: str, user_prompt: str): """Parse LLM response with one retry on failure. On malformed response: log the raw text, retry once with a JSON nudge, then raise on second failure. """ try: return llm.parse_response(raw, model_cls) except (ValidationError, ValueError, json.JSONDecodeError) as exc: logger.warning( "First parse attempt failed for %s (%s). Retrying with JSON nudge. " "Raw response (first 500 chars): %.500s", model_cls.__name__, type(exc).__name__, raw, ) # Retry with explicit JSON instruction nudge_prompt = user_prompt + "\n\nIMPORTANT: Output ONLY valid JSON. No markdown, no explanation." retry_raw = llm.complete(system_prompt, nudge_prompt, response_model=model_cls) return llm.parse_response(retry_raw, model_cls) # ── Stage 2: Segmentation ─────────────────────────────────────────────────── @celery_app.task(bind=True, max_retries=3, default_retry_delay=30) def stage2_segmentation(self, video_id: str) -> str: """Analyze transcript segments and identify topic boundaries. Loads all TranscriptSegment rows for the video, sends them to the LLM for topic boundary detection, and updates topic_label on each segment. Returns the video_id for chain compatibility. """ start = time.monotonic() logger.info("Stage 2 (segmentation) starting for video_id=%s", video_id) session = _get_sync_session() try: # Load segments ordered by index segments = ( session.execute( select(TranscriptSegment) .where(TranscriptSegment.source_video_id == video_id) .order_by(TranscriptSegment.segment_index) ) .scalars() .all() ) if not segments: logger.info("Stage 2: No segments found for video_id=%s, skipping.", video_id) return video_id # Build transcript text with indices for the LLM transcript_lines = [] for seg in segments: transcript_lines.append( f"[{seg.segment_index}] ({seg.start_time:.1f}s - {seg.end_time:.1f}s) {seg.text}" ) transcript_text = "\n".join(transcript_lines) # Load prompt and call LLM system_prompt = _load_prompt("stage2_segmentation.txt") user_prompt = f"\n{transcript_text}\n" llm = _get_llm_client() raw = llm.complete(system_prompt, user_prompt, response_model=SegmentationResult) result = _safe_parse_llm_response(raw, SegmentationResult, llm, system_prompt, user_prompt) # Update topic_label on each segment row seg_by_index = {s.segment_index: s for s in segments} for topic_seg in result.segments: for idx in range(topic_seg.start_index, topic_seg.end_index + 1): if idx in seg_by_index: seg_by_index[idx].topic_label = topic_seg.topic_label session.commit() elapsed = time.monotonic() - start logger.info( "Stage 2 (segmentation) completed for video_id=%s in %.1fs — %d topic groups found", video_id, elapsed, len(result.segments), ) return video_id except FileNotFoundError: raise # Don't retry missing prompt files except Exception as exc: session.rollback() logger.error("Stage 2 failed for video_id=%s: %s", video_id, exc) raise self.retry(exc=exc) finally: session.close() # ── Stage 3: Extraction ───────────────────────────────────────────────────── @celery_app.task(bind=True, max_retries=3, default_retry_delay=30) def stage3_extraction(self, video_id: str) -> str: """Extract key moments from each topic segment group. Groups segments by topic_label, calls the LLM for each group to extract moments, creates KeyMoment rows, and sets processing_status=extracted. Returns the video_id for chain compatibility. """ start = time.monotonic() logger.info("Stage 3 (extraction) starting for video_id=%s", video_id) session = _get_sync_session() try: # Load segments with topic labels segments = ( session.execute( select(TranscriptSegment) .where(TranscriptSegment.source_video_id == video_id) .order_by(TranscriptSegment.segment_index) ) .scalars() .all() ) if not segments: logger.info("Stage 3: No segments found for video_id=%s, skipping.", video_id) return video_id # Group segments by topic_label groups: dict[str, list[TranscriptSegment]] = defaultdict(list) for seg in segments: label = seg.topic_label or "unlabeled" groups[label].append(seg) system_prompt = _load_prompt("stage3_extraction.txt") llm = _get_llm_client() total_moments = 0 for topic_label, group_segs in groups.items(): # Build segment text for this group seg_lines = [] for seg in group_segs: seg_lines.append( f"({seg.start_time:.1f}s - {seg.end_time:.1f}s) {seg.text}" ) segment_text = "\n".join(seg_lines) user_prompt = ( f"Topic: {topic_label}\n\n" f"\n{segment_text}\n" ) raw = llm.complete(system_prompt, user_prompt, response_model=ExtractionResult) result = _safe_parse_llm_response(raw, ExtractionResult, llm, system_prompt, user_prompt) # Create KeyMoment rows for moment in result.moments: # Validate content_type against enum try: ct = KeyMomentContentType(moment.content_type) except ValueError: ct = KeyMomentContentType.technique km = KeyMoment( source_video_id=video_id, title=moment.title, summary=moment.summary, start_time=moment.start_time, end_time=moment.end_time, content_type=ct, plugins=moment.plugins if moment.plugins else None, raw_transcript=moment.raw_transcript or None, ) session.add(km) total_moments += 1 # Update processing_status to extracted video = session.execute( select(SourceVideo).where(SourceVideo.id == video_id) ).scalar_one() video.processing_status = ProcessingStatus.extracted session.commit() elapsed = time.monotonic() - start logger.info( "Stage 3 (extraction) completed for video_id=%s in %.1fs — %d moments created", video_id, elapsed, total_moments, ) return video_id except FileNotFoundError: raise except Exception as exc: session.rollback() logger.error("Stage 3 failed for video_id=%s: %s", video_id, exc) raise self.retry(exc=exc) finally: session.close() # ── Stage 4: Classification ───────────────────────────────────────────────── @celery_app.task(bind=True, max_retries=3, default_retry_delay=30) def stage4_classification(self, video_id: str) -> str: """Classify key moments against the canonical tag taxonomy. Loads all KeyMoment rows for the video, sends them to the LLM with the canonical taxonomy, and stores classification results in Redis for stage 5 consumption. Updates content_type if the classifier overrides it. Stage 4 does NOT change processing_status. Returns the video_id for chain compatibility. """ start = time.monotonic() logger.info("Stage 4 (classification) starting for video_id=%s", video_id) session = _get_sync_session() try: # Load key moments moments = ( session.execute( select(KeyMoment) .where(KeyMoment.source_video_id == video_id) .order_by(KeyMoment.start_time) ) .scalars() .all() ) if not moments: logger.info("Stage 4: No moments found for video_id=%s, skipping.", video_id) # Store empty classification data _store_classification_data(video_id, []) return video_id # Load canonical tags tags_data = _load_canonical_tags() taxonomy_text = _format_taxonomy_for_prompt(tags_data) # Build moments text for the LLM moments_lines = [] for i, m in enumerate(moments): moments_lines.append( f"[{i}] Title: {m.title}\n" f" Summary: {m.summary}\n" f" Content type: {m.content_type.value}\n" f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}" ) moments_text = "\n\n".join(moments_lines) system_prompt = _load_prompt("stage4_classification.txt") user_prompt = ( f"\n{taxonomy_text}\n\n\n" f"\n{moments_text}\n" ) llm = _get_llm_client() raw = llm.complete(system_prompt, user_prompt, response_model=ClassificationResult) result = _safe_parse_llm_response(raw, ClassificationResult, llm, system_prompt, user_prompt) # Apply content_type overrides and prepare classification data for stage 5 classification_data = [] moment_ids = [str(m.id) for m in moments] for cls in result.classifications: if 0 <= cls.moment_index < len(moments): moment = moments[cls.moment_index] # Apply content_type override if provided if cls.content_type_override: try: moment.content_type = KeyMomentContentType(cls.content_type_override) except ValueError: pass classification_data.append({ "moment_id": str(moment.id), "topic_category": cls.topic_category, "topic_tags": cls.topic_tags, }) session.commit() # Store classification data in Redis for stage 5 _store_classification_data(video_id, classification_data) elapsed = time.monotonic() - start logger.info( "Stage 4 (classification) completed for video_id=%s in %.1fs — %d moments classified", video_id, elapsed, len(classification_data), ) return video_id except FileNotFoundError: raise except Exception as exc: session.rollback() logger.error("Stage 4 failed for video_id=%s: %s", video_id, exc) raise self.retry(exc=exc) finally: session.close() def _store_classification_data(video_id: str, data: list[dict]) -> None: """Store classification data in Redis for cross-stage communication.""" import redis settings = get_settings() r = redis.Redis.from_url(settings.redis_url) key = f"chrysopedia:classification:{video_id}" r.set(key, json.dumps(data), ex=86400) # Expire after 24 hours def _load_classification_data(video_id: str) -> list[dict]: """Load classification data from Redis.""" import redis settings = get_settings() r = redis.Redis.from_url(settings.redis_url) key = f"chrysopedia:classification:{video_id}" raw = r.get(key) if raw is None: return [] return json.loads(raw) # ── Stage 5: Synthesis ─────────────────────────────────────────────────────── @celery_app.task(bind=True, max_retries=3, default_retry_delay=30) def stage5_synthesis(self, video_id: str) -> str: """Synthesize technique pages from classified key moments. Groups moments by (creator, topic_category), calls the LLM to synthesize each group into a TechniquePage, creates/updates page rows, and links KeyMoments to their TechniquePage. Sets processing_status to 'reviewed' (or 'published' if review_mode is False). Returns the video_id for chain compatibility. """ start = time.monotonic() logger.info("Stage 5 (synthesis) starting for video_id=%s", video_id) settings = get_settings() session = _get_sync_session() try: # Load video and moments video = session.execute( select(SourceVideo).where(SourceVideo.id == video_id) ).scalar_one() moments = ( session.execute( select(KeyMoment) .where(KeyMoment.source_video_id == video_id) .order_by(KeyMoment.start_time) ) .scalars() .all() ) if not moments: logger.info("Stage 5: No moments found for video_id=%s, skipping.", video_id) return video_id # Load classification data from stage 4 classification_data = _load_classification_data(video_id) cls_by_moment_id = {c["moment_id"]: c for c in classification_data} # Group moments by topic_category (from classification) groups: dict[str, list[tuple[KeyMoment, dict]]] = defaultdict(list) for moment in moments: cls_info = cls_by_moment_id.get(str(moment.id), {}) category = cls_info.get("topic_category", "Uncategorized") groups[category].append((moment, cls_info)) system_prompt = _load_prompt("stage5_synthesis.txt") llm = _get_llm_client() pages_created = 0 for category, moment_group in groups.items(): # Build moments text for the LLM moments_lines = [] all_tags: set[str] = set() for i, (m, cls_info) in enumerate(moment_group): tags = cls_info.get("topic_tags", []) all_tags.update(tags) moments_lines.append( f"[{i}] Title: {m.title}\n" f" Summary: {m.summary}\n" f" Content type: {m.content_type.value}\n" f" Time: {m.start_time:.1f}s - {m.end_time:.1f}s\n" f" Plugins: {', '.join(m.plugins) if m.plugins else 'none'}\n" f" Category: {category}\n" f" Tags: {', '.join(tags) if tags else 'none'}\n" f" Transcript excerpt: {(m.raw_transcript or '')[:300]}" ) moments_text = "\n\n".join(moments_lines) user_prompt = f"\n{moments_text}\n" raw = llm.complete(system_prompt, user_prompt, response_model=SynthesisResult) result = _safe_parse_llm_response(raw, SynthesisResult, llm, system_prompt, user_prompt) # Create/update TechniquePage rows for page_data in result.pages: # Check if page with this slug already exists existing = session.execute( select(TechniquePage).where(TechniquePage.slug == page_data.slug) ).scalar_one_or_none() if existing: # Update existing page existing.title = page_data.title existing.summary = page_data.summary existing.body_sections = page_data.body_sections existing.signal_chains = page_data.signal_chains existing.plugins = page_data.plugins if page_data.plugins else None existing.topic_tags = list(all_tags) if all_tags else None existing.source_quality = page_data.source_quality page = existing else: page = TechniquePage( creator_id=video.creator_id, title=page_data.title, slug=page_data.slug, topic_category=page_data.topic_category or category, topic_tags=list(all_tags) if all_tags else None, summary=page_data.summary, body_sections=page_data.body_sections, signal_chains=page_data.signal_chains, plugins=page_data.plugins if page_data.plugins else None, source_quality=page_data.source_quality, ) session.add(page) session.flush() # Get the page.id assigned pages_created += 1 # Link moments to the technique page for m, _ in moment_group: m.technique_page_id = page.id # Update processing_status if settings.review_mode: video.processing_status = ProcessingStatus.reviewed else: video.processing_status = ProcessingStatus.published session.commit() elapsed = time.monotonic() - start logger.info( "Stage 5 (synthesis) completed for video_id=%s in %.1fs — %d pages created/updated", video_id, elapsed, pages_created, ) return video_id except FileNotFoundError: raise except Exception as exc: session.rollback() logger.error("Stage 5 failed for video_id=%s: %s", video_id, exc) raise self.retry(exc=exc) finally: session.close() # ── Stage 6: Embed & Index ─────────────────────────────────────────────────── @celery_app.task(bind=True, max_retries=0) def stage6_embed_and_index(self, video_id: str) -> str: """Generate embeddings for technique pages and key moments, then upsert to Qdrant. This is a non-blocking side-effect stage — failures are logged but do not fail the pipeline. Embeddings can be regenerated later. Does NOT update processing_status. Returns the video_id for chain compatibility. """ start = time.monotonic() logger.info("Stage 6 (embed & index) starting for video_id=%s", video_id) settings = get_settings() session = _get_sync_session() try: # Load technique pages created for this video's moments moments = ( session.execute( select(KeyMoment) .where(KeyMoment.source_video_id == video_id) .order_by(KeyMoment.start_time) ) .scalars() .all() ) # Get unique technique page IDs from moments page_ids = {m.technique_page_id for m in moments if m.technique_page_id is not None} pages = [] if page_ids: pages = ( session.execute( select(TechniquePage).where(TechniquePage.id.in_(page_ids)) ) .scalars() .all() ) if not moments and not pages: logger.info("Stage 6: No moments or pages for video_id=%s, skipping.", video_id) return video_id embed_client = EmbeddingClient(settings) qdrant = QdrantManager(settings) # Ensure collection exists before upserting qdrant.ensure_collection() # ── Embed & upsert technique pages ─────────────────────────────── if pages: page_texts = [] page_dicts = [] for p in pages: text = f"{p.title} {p.summary or ''} {p.topic_category or ''}" page_texts.append(text.strip()) page_dicts.append({ "page_id": str(p.id), "creator_id": str(p.creator_id), "title": p.title, "topic_category": p.topic_category or "", "topic_tags": p.topic_tags or [], "summary": p.summary or "", }) page_vectors = embed_client.embed(page_texts) if page_vectors: qdrant.upsert_technique_pages(page_dicts, page_vectors) logger.info( "Stage 6: Upserted %d technique page vectors for video_id=%s", len(page_vectors), video_id, ) else: logger.warning( "Stage 6: Embedding returned empty for %d technique pages (video_id=%s). " "Skipping page upsert.", len(page_texts), video_id, ) # ── Embed & upsert key moments ─────────────────────────────────── if moments: moment_texts = [] moment_dicts = [] for m in moments: text = f"{m.title} {m.summary or ''}" moment_texts.append(text.strip()) moment_dicts.append({ "moment_id": str(m.id), "source_video_id": str(m.source_video_id), "title": m.title, "start_time": m.start_time, "end_time": m.end_time, "content_type": m.content_type.value, }) moment_vectors = embed_client.embed(moment_texts) if moment_vectors: qdrant.upsert_key_moments(moment_dicts, moment_vectors) logger.info( "Stage 6: Upserted %d key moment vectors for video_id=%s", len(moment_vectors), video_id, ) else: logger.warning( "Stage 6: Embedding returned empty for %d key moments (video_id=%s). " "Skipping moment upsert.", len(moment_texts), video_id, ) elapsed = time.monotonic() - start logger.info( "Stage 6 (embed & index) completed for video_id=%s in %.1fs — " "%d pages, %d moments processed", video_id, elapsed, len(pages), len(moments), ) return video_id except Exception as exc: # Non-blocking: log error but don't fail the pipeline logger.error( "Stage 6 failed for video_id=%s: %s. " "Pipeline continues — embeddings can be regenerated later.", video_id, exc, ) return video_id finally: session.close() # ── Orchestrator ───────────────────────────────────────────────────────────── @celery_app.task def run_pipeline(video_id: str) -> str: """Orchestrate the full pipeline (stages 2-5) with resumability. Checks the current processing_status of the video and chains only the stages that still need to run. For example: - pending/transcribed → stages 2, 3, 4, 5 - extracted → stages 4, 5 - reviewed/published → no-op Returns the video_id. """ logger.info("run_pipeline starting for video_id=%s", video_id) session = _get_sync_session() try: video = session.execute( select(SourceVideo).where(SourceVideo.id == video_id) ).scalar_one_or_none() if video is None: logger.error("run_pipeline: video_id=%s not found", video_id) raise ValueError(f"Video not found: {video_id}") status = video.processing_status logger.info( "run_pipeline: video_id=%s current status=%s", video_id, status.value ) finally: session.close() # Build the chain based on current status stages = [] if status in (ProcessingStatus.pending, ProcessingStatus.transcribed): stages = [ stage2_segmentation.s(video_id), stage3_extraction.s(), # receives video_id from previous stage4_classification.s(), stage5_synthesis.s(), stage6_embed_and_index.s(), ] elif status == ProcessingStatus.extracted: stages = [ stage4_classification.s(video_id), stage5_synthesis.s(), stage6_embed_and_index.s(), ] elif status in (ProcessingStatus.reviewed, ProcessingStatus.published): logger.info( "run_pipeline: video_id=%s already at status=%s, nothing to do.", video_id, status.value, ) return video_id if stages: pipeline = celery_chain(*stages) pipeline.apply_async() logger.info( "run_pipeline: dispatched %d stages for video_id=%s", len(stages), video_id, ) return video_id