chrysopedia/backend/routers/pipeline.py
jlightner fa82f1079a feat: Enriched Qdrant embedding text with creator_name/tags and added r…
- "backend/pipeline/stages.py"
- "backend/pipeline/qdrant_client.py"
- "backend/routers/pipeline.py"

GSD-Task: S01/T02
2026-04-01 06:41:52 +00:00

679 lines
26 KiB
Python

"""Pipeline management endpoints — public trigger + admin dashboard.
Public:
POST /pipeline/trigger/{video_id} Trigger pipeline for a video
Admin:
GET /admin/pipeline/videos Video list with status + event counts
POST /admin/pipeline/trigger/{video_id} Retrigger (same as public but under admin prefix)
POST /admin/pipeline/revoke/{video_id} Revoke/cancel active tasks for a video
GET /admin/pipeline/events/{video_id} Event log for a video (paginated)
GET /admin/pipeline/worker-status Active/reserved tasks from Celery inspect
"""
import logging
import uuid
from datetime import datetime, timezone
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import func, select, case
from sqlalchemy.ext.asyncio import AsyncSession
from config import get_settings
from database import get_session
from models import PipelineEvent, PipelineRun, PipelineRunStatus, SourceVideo, Creator, KeyMoment, TranscriptSegment, ProcessingStatus
from redis_client import get_redis
from schemas import DebugModeResponse, DebugModeUpdate, TokenStageSummary, TokenSummaryResponse
logger = logging.getLogger("chrysopedia.pipeline")
router = APIRouter(tags=["pipeline"])
REDIS_DEBUG_MODE_KEY = "chrysopedia:debug_mode"
# ── Public trigger ───────────────────────────────────────────────────────────
@router.post("/pipeline/trigger/{video_id}")
async def trigger_pipeline(
video_id: str,
db: AsyncSession = Depends(get_session),
):
"""Manually trigger (or re-trigger) the LLM extraction pipeline for a video."""
stmt = select(SourceVideo).where(SourceVideo.id == video_id)
result = await db.execute(stmt)
video = result.scalar_one_or_none()
if video is None:
raise HTTPException(status_code=404, detail=f"Video not found: {video_id}")
from pipeline.stages import run_pipeline
try:
run_pipeline.delay(str(video.id), trigger="manual")
logger.info("Pipeline manually triggered for video_id=%s", video_id)
except Exception as exc:
logger.warning("Failed to dispatch pipeline for video_id=%s: %s", video_id, exc)
raise HTTPException(
status_code=503,
detail="Pipeline dispatch failed — Celery/Redis may be unavailable",
) from exc
return {
"status": "triggered",
"video_id": str(video.id),
"current_processing_status": video.processing_status.value,
}
# ── Admin: Video list ────────────────────────────────────────────────────────
@router.get("/admin/pipeline/videos")
async def list_pipeline_videos(
db: AsyncSession = Depends(get_session),
):
"""List all videos with processing status and pipeline event counts."""
# Subquery for event counts per video
event_counts = (
select(
PipelineEvent.video_id,
func.count().label("event_count"),
func.sum(case(
(PipelineEvent.event_type == "llm_call", PipelineEvent.total_tokens),
else_=0
)).label("total_tokens_used"),
func.max(PipelineEvent.created_at).label("last_event_at"),
)
.group_by(PipelineEvent.video_id)
.subquery()
)
# Subquery for the most recent pipeline run per video
latest_run = (
select(
PipelineRun.video_id,
PipelineRun.id.label("run_id"),
PipelineRun.run_number,
PipelineRun.trigger.label("run_trigger"),
PipelineRun.status.label("run_status"),
PipelineRun.started_at.label("run_started_at"),
PipelineRun.finished_at.label("run_finished_at"),
PipelineRun.error_stage.label("run_error_stage"),
PipelineRun.total_tokens.label("run_total_tokens"),
)
.order_by(PipelineRun.video_id, PipelineRun.started_at.desc())
.distinct(PipelineRun.video_id)
.subquery()
)
# Subquery for the most recent stage start event per video (active stage indicator)
latest_stage = (
select(
PipelineEvent.video_id,
PipelineEvent.stage.label("active_stage"),
PipelineEvent.event_type.label("active_stage_status"),
PipelineEvent.created_at.label("stage_started_at"),
)
.where(PipelineEvent.event_type.in_(["start", "complete", "error"]))
.order_by(PipelineEvent.video_id, PipelineEvent.created_at.desc())
.distinct(PipelineEvent.video_id)
.subquery()
)
stmt = (
select(
SourceVideo.id,
SourceVideo.filename,
SourceVideo.processing_status,
SourceVideo.content_hash,
SourceVideo.created_at,
SourceVideo.updated_at,
Creator.name.label("creator_name"),
event_counts.c.event_count,
event_counts.c.total_tokens_used,
event_counts.c.last_event_at,
latest_stage.c.active_stage,
latest_stage.c.active_stage_status,
latest_stage.c.stage_started_at,
latest_run.c.run_id,
latest_run.c.run_number,
latest_run.c.run_trigger,
latest_run.c.run_status,
latest_run.c.run_started_at,
latest_run.c.run_finished_at,
latest_run.c.run_error_stage,
latest_run.c.run_total_tokens,
)
.join(Creator, SourceVideo.creator_id == Creator.id)
.outerjoin(event_counts, SourceVideo.id == event_counts.c.video_id)
.outerjoin(latest_stage, SourceVideo.id == latest_stage.c.video_id)
.outerjoin(latest_run, SourceVideo.id == latest_run.c.video_id)
.order_by(SourceVideo.updated_at.desc())
)
result = await db.execute(stmt)
rows = result.all()
return {
"items": [
{
"id": str(r.id),
"filename": r.filename,
"processing_status": r.processing_status.value if hasattr(r.processing_status, 'value') else str(r.processing_status),
"content_hash": r.content_hash,
"creator_name": r.creator_name,
"created_at": r.created_at.isoformat() if r.created_at else None,
"updated_at": r.updated_at.isoformat() if r.updated_at else None,
"event_count": r.event_count or 0,
"total_tokens_used": r.total_tokens_used or 0,
"last_event_at": r.last_event_at.isoformat() if r.last_event_at else None,
"active_stage": r.active_stage,
"active_stage_status": r.active_stage_status,
"stage_started_at": r.stage_started_at.isoformat() if r.stage_started_at else None,
"latest_run": {
"id": str(r.run_id),
"run_number": r.run_number,
"trigger": r.run_trigger.value if hasattr(r.run_trigger, 'value') else r.run_trigger,
"status": r.run_status.value if hasattr(r.run_status, 'value') else r.run_status,
"started_at": r.run_started_at.isoformat() if r.run_started_at else None,
"finished_at": r.run_finished_at.isoformat() if r.run_finished_at else None,
"error_stage": r.run_error_stage,
"total_tokens": r.run_total_tokens or 0,
} if r.run_id else None,
}
for r in rows
],
"total": len(rows),
}
# ── Admin: Retrigger ─────────────────────────────────────────────────────────
@router.post("/admin/pipeline/trigger/{video_id}")
async def admin_trigger_pipeline(
video_id: str,
db: AsyncSession = Depends(get_session),
):
"""Admin retrigger — same as public trigger."""
return await trigger_pipeline(video_id, db)
# ── Admin: Clean Retrigger ───────────────────────────────────────────────────
@router.post("/admin/pipeline/clean-retrigger/{video_id}")
async def clean_retrigger_pipeline(
video_id: str,
db: AsyncSession = Depends(get_session),
):
"""Wipe prior pipeline output for a video, then retrigger.
Deletes: pipeline_events, key_moments, transcript_segments,
and associated Qdrant vectors. Resets processing_status to 'not_started'.
Does NOT delete technique_pages — the pipeline re-synthesizes via upsert.
"""
stmt = select(SourceVideo).where(SourceVideo.id == video_id)
result = await db.execute(stmt)
video = result.scalar_one_or_none()
if video is None:
raise HTTPException(status_code=404, detail=f"Video not found: {video_id}")
# Delete pipeline events
await db.execute(
PipelineEvent.__table__.delete().where(PipelineEvent.video_id == video_id)
)
# Delete key moments
await db.execute(
KeyMoment.__table__.delete().where(KeyMoment.source_video_id == video_id)
)
# Note: transcript_segments are NOT deleted — they are the pipeline's input
# data created during ingest, not pipeline output. Deleting them would leave
# the pipeline with nothing to process.
# Reset status
video.processing_status = ProcessingStatus.not_started
await db.commit()
deleted_counts = {
"pipeline_events": "cleared",
"key_moments": "cleared",
}
# Best-effort Qdrant cleanup (non-blocking)
try:
settings = get_settings()
from pipeline.qdrant_client import QdrantManager
qdrant = QdrantManager(settings)
qdrant.delete_by_video_id(str(video_id))
deleted_counts["qdrant_vectors"] = "cleared"
except Exception as exc:
logger.warning("Qdrant cleanup failed for video_id=%s: %s", video_id, exc)
deleted_counts["qdrant_vectors"] = f"skipped: {exc}"
# Now trigger the pipeline
from pipeline.stages import run_pipeline
try:
run_pipeline.delay(str(video.id), trigger="clean_reprocess")
logger.info("Clean retrigger dispatched for video_id=%s", video_id)
except Exception as exc:
logger.warning("Failed to dispatch pipeline after cleanup for video_id=%s: %s", video_id, exc)
raise HTTPException(
status_code=503,
detail="Cleanup succeeded but pipeline dispatch failed — Celery/Redis may be unavailable",
) from exc
return {
"status": "clean_retriggered",
"video_id": str(video.id),
"cleaned": deleted_counts,
}
# ── Admin: Revoke ────────────────────────────────────────────────────────────
@router.post("/admin/pipeline/revoke/{video_id}")
async def revoke_pipeline(
video_id: str,
db: AsyncSession = Depends(get_session),
):
"""Revoke/cancel active Celery tasks for a video.
Uses Celery's revoke with terminate=True to kill running tasks.
Also marks the latest running pipeline_run as cancelled.
This is best-effort — the task may have already completed.
"""
from worker import celery_app
try:
# Get active tasks and revoke any matching this video_id
inspector = celery_app.control.inspect()
active = inspector.active() or {}
revoked_count = 0
for _worker, tasks in active.items():
for task in tasks:
task_args = task.get("args", [])
if task_args and str(task_args[0]) == video_id:
celery_app.control.revoke(task["id"], terminate=True)
revoked_count += 1
logger.info("Revoked task %s for video_id=%s", task["id"], video_id)
# Mark any running pipeline_runs as cancelled
running_runs = await db.execute(
select(PipelineRun).where(
PipelineRun.video_id == video_id,
PipelineRun.status == PipelineRunStatus.running,
)
)
for run in running_runs.scalars().all():
run.status = PipelineRunStatus.cancelled
run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
await db.commit()
return {
"status": "revoked" if revoked_count > 0 else "no_active_tasks",
"video_id": video_id,
"tasks_revoked": revoked_count,
}
except Exception as exc:
logger.warning("Failed to revoke tasks for video_id=%s: %s", video_id, exc)
raise HTTPException(
status_code=503,
detail="Failed to communicate with Celery worker",
) from exc
# ── Admin: Recent activity feed ──────────────────────────────────────────────
@router.get("/admin/pipeline/recent-activity")
async def recent_pipeline_activity(
limit: Annotated[int, Query(ge=1, le=20)] = 10,
db: AsyncSession = Depends(get_session),
):
"""Get the most recent pipeline stage completions and errors with video context."""
stmt = (
select(
PipelineEvent.id,
PipelineEvent.video_id,
PipelineEvent.stage,
PipelineEvent.event_type,
PipelineEvent.total_tokens,
PipelineEvent.duration_ms,
PipelineEvent.created_at,
SourceVideo.filename,
Creator.name.label("creator_name"),
)
.join(SourceVideo, PipelineEvent.video_id == SourceVideo.id)
.join(Creator, SourceVideo.creator_id == Creator.id)
.where(PipelineEvent.event_type.in_(["complete", "error"]))
.order_by(PipelineEvent.created_at.desc())
.limit(limit)
)
result = await db.execute(stmt)
rows = result.all()
return {
"items": [
{
"id": str(r.id),
"video_id": str(r.video_id),
"filename": r.filename,
"creator_name": r.creator_name,
"stage": r.stage,
"event_type": r.event_type,
"total_tokens": r.total_tokens,
"duration_ms": r.duration_ms,
"created_at": r.created_at.isoformat() if r.created_at else None,
}
for r in rows
],
}
# ── Admin: Pipeline runs ─────────────────────────────────────────────────────
@router.get("/admin/pipeline/runs/{video_id}")
async def list_pipeline_runs(
video_id: str,
db: AsyncSession = Depends(get_session),
):
"""List all pipeline runs for a video, newest first."""
# Count events per run
event_counts = (
select(
PipelineEvent.run_id,
func.count().label("event_count"),
)
.where(PipelineEvent.run_id.isnot(None))
.group_by(PipelineEvent.run_id)
.subquery()
)
stmt = (
select(
PipelineRun,
event_counts.c.event_count,
)
.outerjoin(event_counts, PipelineRun.id == event_counts.c.run_id)
.where(PipelineRun.video_id == video_id)
.order_by(PipelineRun.started_at.desc())
)
result = await db.execute(stmt)
rows = result.all()
# Also count legacy events (run_id IS NULL) for this video
legacy_count_result = await db.execute(
select(func.count())
.select_from(PipelineEvent)
.where(PipelineEvent.video_id == video_id, PipelineEvent.run_id.is_(None))
)
legacy_count = legacy_count_result.scalar() or 0
items = []
for run, evt_count in rows:
items.append({
"id": str(run.id),
"run_number": run.run_number,
"trigger": run.trigger.value if hasattr(run.trigger, 'value') else str(run.trigger),
"status": run.status.value if hasattr(run.status, 'value') else str(run.status),
"started_at": run.started_at.isoformat() if run.started_at else None,
"finished_at": run.finished_at.isoformat() if run.finished_at else None,
"error_stage": run.error_stage,
"total_tokens": run.total_tokens or 0,
"event_count": evt_count or 0,
})
return {
"items": items,
"legacy_event_count": legacy_count,
}
# ── Admin: Event log ─────────────────────────────────────────────────────────
@router.get("/admin/pipeline/events/{video_id}")
async def list_pipeline_events(
video_id: str,
offset: Annotated[int, Query(ge=0)] = 0,
limit: Annotated[int, Query(ge=1, le=200)] = 100,
stage: Annotated[str | None, Query(description="Filter by stage name")] = None,
event_type: Annotated[str | None, Query(description="Filter by event type")] = None,
run_id: Annotated[str | None, Query(description="Filter by pipeline run ID")] = None,
order: Annotated[str, Query(description="Sort order: asc or desc")] = "desc",
db: AsyncSession = Depends(get_session),
):
"""Get pipeline events for a video. Default: newest first (desc)."""
stmt = select(PipelineEvent).where(PipelineEvent.video_id == video_id)
if run_id:
stmt = stmt.where(PipelineEvent.run_id == run_id)
if stage:
stmt = stmt.where(PipelineEvent.stage == stage)
if event_type:
stmt = stmt.where(PipelineEvent.event_type == event_type)
# Validate order param
if order not in ("asc", "desc"):
raise HTTPException(status_code=400, detail="order must be 'asc' or 'desc'")
# Count
count_stmt = select(func.count()).select_from(stmt.subquery())
total = (await db.execute(count_stmt)).scalar() or 0
# Fetch
order_clause = PipelineEvent.created_at.asc() if order == "asc" else PipelineEvent.created_at.desc()
stmt = stmt.order_by(order_clause)
stmt = stmt.offset(offset).limit(limit)
result = await db.execute(stmt)
events = result.scalars().all()
return {
"items": [
{
"id": str(e.id),
"video_id": str(e.video_id),
"stage": e.stage,
"event_type": e.event_type,
"prompt_tokens": e.prompt_tokens,
"completion_tokens": e.completion_tokens,
"total_tokens": e.total_tokens,
"model": e.model,
"duration_ms": e.duration_ms,
"payload": e.payload,
"created_at": e.created_at.isoformat() if e.created_at else None,
"system_prompt_text": e.system_prompt_text,
"user_prompt_text": e.user_prompt_text,
"response_text": e.response_text,
}
for e in events
],
"total": total,
"offset": offset,
"limit": limit,
}
# ── Admin: Debug mode ─────────────────────────────────────────────────────────
@router.get("/admin/pipeline/debug-mode", response_model=DebugModeResponse)
async def get_debug_mode() -> DebugModeResponse:
"""Get the current pipeline debug mode (on/off)."""
settings = get_settings()
try:
redis = await get_redis()
try:
value = await redis.get(REDIS_DEBUG_MODE_KEY)
if value is not None:
return DebugModeResponse(debug_mode=value.lower() == "true")
finally:
await redis.aclose()
except Exception as exc:
logger.warning("Redis unavailable for debug mode read, using config default: %s", exc)
return DebugModeResponse(debug_mode=settings.debug_mode)
@router.put("/admin/pipeline/debug-mode", response_model=DebugModeResponse)
async def set_debug_mode(body: DebugModeUpdate) -> DebugModeResponse:
"""Set the pipeline debug mode (on/off)."""
try:
redis = await get_redis()
try:
await redis.set(REDIS_DEBUG_MODE_KEY, str(body.debug_mode))
finally:
await redis.aclose()
except Exception as exc:
logger.error("Failed to set debug mode in Redis: %s", exc)
raise HTTPException(
status_code=503,
detail=f"Redis unavailable: {exc}",
)
logger.info("Pipeline debug mode set to %s", body.debug_mode)
return DebugModeResponse(debug_mode=body.debug_mode)
# ── Admin: Token summary ─────────────────────────────────────────────────────
@router.get("/admin/pipeline/token-summary/{video_id}", response_model=TokenSummaryResponse)
async def get_token_summary(
video_id: str,
db: AsyncSession = Depends(get_session),
) -> TokenSummaryResponse:
"""Get per-stage token usage summary for a video."""
stmt = (
select(
PipelineEvent.stage,
func.count().label("call_count"),
func.coalesce(func.sum(PipelineEvent.prompt_tokens), 0).label("total_prompt_tokens"),
func.coalesce(func.sum(PipelineEvent.completion_tokens), 0).label("total_completion_tokens"),
func.coalesce(func.sum(PipelineEvent.total_tokens), 0).label("total_tokens"),
)
.where(PipelineEvent.video_id == video_id)
.where(PipelineEvent.event_type == "llm_call")
.group_by(PipelineEvent.stage)
.order_by(PipelineEvent.stage)
)
result = await db.execute(stmt)
rows = result.all()
stages = [
TokenStageSummary(
stage=r.stage,
call_count=r.call_count,
total_prompt_tokens=r.total_prompt_tokens,
total_completion_tokens=r.total_completion_tokens,
total_tokens=r.total_tokens,
)
for r in rows
]
grand_total = sum(s.total_tokens for s in stages)
return TokenSummaryResponse(
video_id=video_id,
stages=stages,
grand_total_tokens=grand_total,
)
# ── Admin: Worker status ─────────────────────────────────────────────────────
@router.post("/admin/pipeline/reindex-all")
async def reindex_all(
db: AsyncSession = Depends(get_session),
):
"""Re-run stage 6 (embed & index) for all videos with processing_status='complete'.
Use after changing embedding text composition or Qdrant payload fields
to regenerate all vectors and payloads without re-running the full pipeline.
"""
stmt = select(SourceVideo.id).where(
SourceVideo.processing_status == ProcessingStatus.complete
)
result = await db.execute(stmt)
video_ids = [str(row[0]) for row in result.all()]
if not video_ids:
return {
"status": "no_videos",
"message": "No videos with processing_status='complete' found.",
"dispatched": 0,
}
from pipeline.stages import stage6_embed_and_index
dispatched = 0
errors = []
for vid in video_ids:
try:
stage6_embed_and_index.delay(vid)
dispatched += 1
except Exception as exc:
logger.warning("Failed to dispatch reindex for video_id=%s: %s", vid, exc)
errors.append({"video_id": vid, "error": str(exc)})
logger.info(
"Reindex-all dispatched %d/%d stage6 tasks.",
dispatched, len(video_ids),
)
return {
"status": "dispatched",
"dispatched": dispatched,
"total_complete_videos": len(video_ids),
"errors": errors if errors else None,
}
# ── Admin: Worker status ─────────────────────────────────────────────────────
@router.get("/admin/pipeline/worker-status")
async def worker_status():
"""Get current Celery worker status — active, reserved, and stats."""
from worker import celery_app
try:
inspector = celery_app.control.inspect()
active = inspector.active() or {}
reserved = inspector.reserved() or {}
stats = inspector.stats() or {}
workers = []
for worker_name in set(list(active.keys()) + list(reserved.keys()) + list(stats.keys())):
worker_active = active.get(worker_name, [])
worker_reserved = reserved.get(worker_name, [])
worker_stats = stats.get(worker_name, {})
workers.append({
"name": worker_name,
"active_tasks": [
{
"id": t.get("id"),
"name": t.get("name"),
"args": t.get("args", []),
"time_start": t.get("time_start"),
}
for t in worker_active
],
"reserved_tasks": len(worker_reserved),
"total_completed": worker_stats.get("total", {}).get("tasks.pipeline.stages.stage2_segmentation", 0)
+ worker_stats.get("total", {}).get("tasks.pipeline.stages.stage3_extraction", 0)
+ worker_stats.get("total", {}).get("tasks.pipeline.stages.stage4_classification", 0)
+ worker_stats.get("total", {}).get("tasks.pipeline.stages.stage5_synthesis", 0),
"uptime": worker_stats.get("clock", None),
"pool_size": worker_stats.get("pool", {}).get("max-concurrency") if isinstance(worker_stats.get("pool"), dict) else None,
})
return {
"online": len(workers) > 0,
"workers": workers,
}
except Exception as exc:
logger.warning("Failed to inspect Celery workers: %s", exc)
return {
"online": False,
"workers": [],
"error": str(exc),
}