diff --git a/backend/config.py b/backend/config.py
index d6e2b6f..482e601 100644
--- a/backend/config.py
+++ b/backend/config.py
@@ -43,8 +43,9 @@ class Settings(BaseSettings):
llm_stage5_model: str | None = "fyn-llm-agent-think" # synthesis — reasoning
llm_stage5_modality: str = "thinking"
- # Max tokens for LLM responses (OpenWebUI defaults to 1000 which truncates pipeline JSON)
- llm_max_tokens: int = 65536
+ # Dynamic token estimation — each stage calculates max_tokens from input size
+ llm_max_tokens_hard_limit: int = 32768 # Hard ceiling for dynamic estimator
+ llm_max_tokens: int = 65536 # Fallback when no estimate is provided
# Embedding endpoint
embedding_api_url: str = "http://localhost:11434/v1"
diff --git a/backend/pipeline/llm_client.py b/backend/pipeline/llm_client.py
index f7d23f0..1c26f95 100644
--- a/backend/pipeline/llm_client.py
+++ b/backend/pipeline/llm_client.py
@@ -14,7 +14,10 @@ from __future__ import annotations
import logging
import re
-from typing import TypeVar
+from typing import TYPE_CHECKING, TypeVar
+
+if TYPE_CHECKING:
+ from collections.abc import Callable
import openai
from pydantic import BaseModel
@@ -50,6 +53,77 @@ def strip_think_tags(text: str) -> str:
return cleaned.strip()
+
+# ── Token estimation ─────────────────────────────────────────────────────────
+
+# Stage-specific output multipliers: estimated output tokens as a ratio of input tokens.
+# These are empirically tuned based on observed pipeline behavior.
+_STAGE_OUTPUT_RATIOS: dict[str, float] = {
+ "stage2_segmentation": 0.3, # Compact topic groups — much smaller than input
+ "stage3_extraction": 1.2, # Detailed moments with summaries — can exceed input
+ "stage4_classification": 0.15, # Index + category + tags per moment — very compact
+ "stage5_synthesis": 1.5, # Full prose technique pages — heaviest output
+}
+
+# Minimum floor so we never send a trivially small max_tokens
+_MIN_MAX_TOKENS = 2048
+
+
+def estimate_tokens(text: str) -> int:
+ """Estimate token count from text using a chars-per-token heuristic.
+
+ Uses 3.5 chars/token which is conservative for English + JSON markup.
+ """
+ if not text:
+ return 0
+ return max(1, int(len(text) / 3.5))
+
+
+def estimate_max_tokens(
+ system_prompt: str,
+ user_prompt: str,
+ stage: str | None = None,
+ hard_limit: int = 32768,
+) -> int:
+ """Estimate the max_tokens parameter for an LLM call.
+
+ Calculates expected output size based on input size and stage-specific
+ multipliers. The result is clamped between _MIN_MAX_TOKENS and hard_limit.
+
+ Parameters
+ ----------
+ system_prompt:
+ The system prompt text.
+ user_prompt:
+ The user prompt text (transcript, moments, etc.).
+ stage:
+ Pipeline stage name (e.g. "stage3_extraction"). If None or unknown,
+ uses a default 1.0x multiplier.
+ hard_limit:
+ Absolute ceiling — never exceed this value.
+
+ Returns
+ -------
+ int
+ Estimated max_tokens value to pass to the LLM API.
+ """
+ input_tokens = estimate_tokens(system_prompt) + estimate_tokens(user_prompt)
+ ratio = _STAGE_OUTPUT_RATIOS.get(stage or "", 1.0)
+ estimated_output = int(input_tokens * ratio)
+
+ # Add a 20% buffer for JSON overhead and variability
+ estimated_output = int(estimated_output * 1.2)
+
+ # Clamp to [_MIN_MAX_TOKENS, hard_limit]
+ result = max(_MIN_MAX_TOKENS, min(estimated_output, hard_limit))
+
+ logger.info(
+ "Token estimate: input≈%d, stage=%s, ratio=%.2f, estimated_output=%d, max_tokens=%d (hard_limit=%d)",
+ input_tokens, stage or "default", ratio, estimated_output, result, hard_limit,
+ )
+ return result
+
+
class LLMClient:
"""Sync LLM client that tries a primary endpoint and falls back on failure."""
@@ -73,6 +147,8 @@ class LLMClient:
response_model: type[BaseModel] | None = None,
modality: str = "chat",
model_override: str | None = None,
+ on_complete: "Callable | None" = None,
+ max_tokens: int | None = None,
) -> str:
"""Send a chat completion request, falling back on connection/timeout errors.
@@ -92,6 +168,9 @@ class LLMClient:
model_override:
Model name to use instead of the default. If None, uses the
configured default for the endpoint.
+ max_tokens:
+ Override for max_tokens on this call. If None, falls back to
+ the configured ``llm_max_tokens`` from settings.
Returns
-------
@@ -123,12 +202,14 @@ class LLMClient:
primary_model = model_override or self.settings.llm_model
fallback_model = self.settings.llm_fallback_model
+ effective_max_tokens = max_tokens if max_tokens is not None else self.settings.llm_max_tokens
logger.info(
- "LLM request: model=%s, modality=%s, response_model=%s",
+ "LLM request: model=%s, modality=%s, response_model=%s, max_tokens=%d",
primary_model,
modality,
response_model.__name__ if response_model else None,
+ effective_max_tokens,
)
# --- Try primary endpoint ---
@@ -136,7 +217,7 @@ class LLMClient:
response = self._primary.chat.completions.create(
model=primary_model,
messages=messages,
- max_tokens=self.settings.llm_max_tokens,
+ max_tokens=effective_max_tokens,
**kwargs,
)
raw = response.choices[0].message.content or ""
@@ -149,6 +230,18 @@ class LLMClient:
)
if modality == "thinking":
raw = strip_think_tags(raw)
+ if on_complete is not None:
+ try:
+ on_complete(
+ model=primary_model,
+ prompt_tokens=usage.prompt_tokens if usage else None,
+ completion_tokens=usage.completion_tokens if usage else None,
+ total_tokens=usage.total_tokens if usage else None,
+ content=raw,
+ finish_reason=response.choices[0].finish_reason if response.choices else None,
+ )
+ except Exception as cb_exc:
+ logger.warning("on_complete callback failed: %s", cb_exc)
return raw
except (openai.APIConnectionError, openai.APITimeoutError) as exc:
@@ -164,7 +257,7 @@ class LLMClient:
response = self._fallback.chat.completions.create(
model=fallback_model,
messages=messages,
- max_tokens=self.settings.llm_max_tokens,
+ max_tokens=effective_max_tokens,
**kwargs,
)
raw = response.choices[0].message.content or ""
@@ -177,6 +270,19 @@ class LLMClient:
)
if modality == "thinking":
raw = strip_think_tags(raw)
+ if on_complete is not None:
+ try:
+ on_complete(
+ model=fallback_model,
+ prompt_tokens=usage.prompt_tokens if usage else None,
+ completion_tokens=usage.completion_tokens if usage else None,
+ total_tokens=usage.total_tokens if usage else None,
+ content=raw,
+ finish_reason=response.choices[0].finish_reason if response.choices else None,
+ is_fallback=True,
+ )
+ except Exception as cb_exc:
+ logger.warning("on_complete callback failed: %s", cb_exc)
return raw
except (openai.APIConnectionError, openai.APITimeoutError, openai.APIError) as exc:
diff --git a/backend/pipeline/stages.py b/backend/pipeline/stages.py
index 3a49f86..fdc2204 100644
--- a/backend/pipeline/stages.py
+++ b/backend/pipeline/stages.py
@@ -26,6 +26,7 @@ from config import get_settings
from models import (
KeyMoment,
KeyMomentContentType,
+ PipelineEvent,
ProcessingStatus,
SourceVideo,
TechniquePage,
@@ -33,7 +34,7 @@ from models import (
TranscriptSegment,
)
from pipeline.embedding_client import EmbeddingClient
-from pipeline.llm_client import LLMClient
+from pipeline.llm_client import LLMClient, estimate_max_tokens
from pipeline.qdrant_client import QdrantManager
from pipeline.schemas import (
ClassificationResult,
@@ -45,6 +46,68 @@ from worker import celery_app
logger = logging.getLogger(__name__)
+
+# ── Pipeline event persistence ───────────────────────────────────────────────
+
+def _emit_event(
+ video_id: str,
+ stage: str,
+ event_type: str,
+ *,
+ prompt_tokens: int | None = None,
+ completion_tokens: int | None = None,
+ total_tokens: int | None = None,
+ model: str | None = None,
+ duration_ms: int | None = None,
+ payload: dict | None = None,
+) -> None:
+ """Persist a pipeline event to the DB. Best-effort -- failures logged, not raised."""
+ try:
+ session = _get_sync_session()
+ try:
+ event = PipelineEvent(
+ video_id=video_id,
+ stage=stage,
+ event_type=event_type,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ model=model,
+ duration_ms=duration_ms,
+ payload=payload,
+ )
+ session.add(event)
+ session.commit()
+ finally:
+ session.close()
+ except Exception as exc:
+ logger.warning("Failed to emit pipeline event: %s", exc)
+
+
+def _make_llm_callback(video_id: str, stage: str):
+ """Create an on_complete callback for LLMClient that emits llm_call events."""
+ def callback(*, model=None, prompt_tokens=None, completion_tokens=None,
+ total_tokens=None, content=None, finish_reason=None,
+ is_fallback=False, **_kwargs):
+ # Truncate content for storage — keep first 2000 chars for debugging
+ truncated = content[:2000] if content and len(content) > 2000 else content
+ _emit_event(
+ video_id=video_id,
+ stage=stage,
+ event_type="llm_call",
+ model=model,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ payload={
+ "content_preview": truncated,
+ "content_length": len(content) if content else 0,
+ "finish_reason": finish_reason,
+ "is_fallback": is_fallback,
+ },
+ )
+ return callback
+
# ── Helpers ──────────────────────────────────────────────────────────────────
_engine = None
@@ -175,6 +238,7 @@ def stage2_segmentation(self, video_id: str) -> str:
"""
start = time.monotonic()
logger.info("Stage 2 (segmentation) starting for video_id=%s", video_id)
+ _emit_event(video_id, "stage2_segmentation", "start")
session = _get_sync_session()
try:
@@ -207,9 +271,11 @@ def stage2_segmentation(self, video_id: str) -> str:
llm = _get_llm_client()
model_override, modality = _get_stage_config(2)
- logger.info("Stage 2 using model=%s, modality=%s", model_override or "default", modality)
- raw = llm.complete(system_prompt, user_prompt, response_model=SegmentationResult,
- modality=modality, model_override=model_override)
+ hard_limit = get_settings().llm_max_tokens_hard_limit
+ max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage2_segmentation", hard_limit=hard_limit)
+ logger.info("Stage 2 using model=%s, modality=%s, max_tokens=%d", model_override or "default", modality, max_tokens)
+ raw = llm.complete(system_prompt, user_prompt, response_model=SegmentationResult, on_complete=_make_llm_callback(video_id, "stage2_segmentation"),
+ modality=modality, model_override=model_override, max_tokens=max_tokens)
result = _safe_parse_llm_response(raw, SegmentationResult, llm, system_prompt, user_prompt,
modality=modality, model_override=model_override)
@@ -222,6 +288,7 @@ def stage2_segmentation(self, video_id: str) -> str:
session.commit()
elapsed = time.monotonic() - start
+ _emit_event(video_id, "stage2_segmentation", "complete")
logger.info(
"Stage 2 (segmentation) completed for video_id=%s in %.1fs — %d topic groups found",
video_id, elapsed, len(result.segments),
@@ -232,6 +299,7 @@ def stage2_segmentation(self, video_id: str) -> str:
raise # Don't retry missing prompt files
except Exception as exc:
session.rollback()
+ _emit_event(video_id, "stage2_segmentation", "error", payload={"error": str(exc)})
logger.error("Stage 2 failed for video_id=%s: %s", video_id, exc)
raise self.retry(exc=exc)
finally:
@@ -251,6 +319,7 @@ def stage3_extraction(self, video_id: str) -> str:
"""
start = time.monotonic()
logger.info("Stage 3 (extraction) starting for video_id=%s", video_id)
+ _emit_event(video_id, "stage3_extraction", "start")
session = _get_sync_session()
try:
@@ -278,6 +347,7 @@ def stage3_extraction(self, video_id: str) -> str:
system_prompt = _load_prompt("stage3_extraction.txt")
llm = _get_llm_client()
model_override, modality = _get_stage_config(3)
+ hard_limit = get_settings().llm_max_tokens_hard_limit
logger.info("Stage 3 using model=%s, modality=%s", model_override or "default", modality)
total_moments = 0
@@ -295,8 +365,9 @@ def stage3_extraction(self, video_id: str) -> str:
f"\n{segment_text}\n"
)
- raw = llm.complete(system_prompt, user_prompt, response_model=ExtractionResult,
- modality=modality, model_override=model_override)
+ max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage3_extraction", hard_limit=hard_limit)
+ raw = llm.complete(system_prompt, user_prompt, response_model=ExtractionResult, on_complete=_make_llm_callback(video_id, "stage3_extraction"),
+ modality=modality, model_override=model_override, max_tokens=max_tokens)
result = _safe_parse_llm_response(raw, ExtractionResult, llm, system_prompt, user_prompt,
modality=modality, model_override=model_override)
@@ -329,6 +400,7 @@ def stage3_extraction(self, video_id: str) -> str:
session.commit()
elapsed = time.monotonic() - start
+ _emit_event(video_id, "stage3_extraction", "complete")
logger.info(
"Stage 3 (extraction) completed for video_id=%s in %.1fs — %d moments created",
video_id, elapsed, total_moments,
@@ -339,6 +411,7 @@ def stage3_extraction(self, video_id: str) -> str:
raise
except Exception as exc:
session.rollback()
+ _emit_event(video_id, "stage3_extraction", "error", payload={"error": str(exc)})
logger.error("Stage 3 failed for video_id=%s: %s", video_id, exc)
raise self.retry(exc=exc)
finally:
@@ -361,6 +434,7 @@ def stage4_classification(self, video_id: str) -> str:
"""
start = time.monotonic()
logger.info("Stage 4 (classification) starting for video_id=%s", video_id)
+ _emit_event(video_id, "stage4_classification", "start")
session = _get_sync_session()
try:
@@ -404,9 +478,11 @@ def stage4_classification(self, video_id: str) -> str:
llm = _get_llm_client()
model_override, modality = _get_stage_config(4)
- logger.info("Stage 4 using model=%s, modality=%s", model_override or "default", modality)
- raw = llm.complete(system_prompt, user_prompt, response_model=ClassificationResult,
- modality=modality, model_override=model_override)
+ hard_limit = get_settings().llm_max_tokens_hard_limit
+ max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage4_classification", hard_limit=hard_limit)
+ logger.info("Stage 4 using model=%s, modality=%s, max_tokens=%d", model_override or "default", modality, max_tokens)
+ raw = llm.complete(system_prompt, user_prompt, response_model=ClassificationResult, on_complete=_make_llm_callback(video_id, "stage4_classification"),
+ modality=modality, model_override=model_override, max_tokens=max_tokens)
result = _safe_parse_llm_response(raw, ClassificationResult, llm, system_prompt, user_prompt,
modality=modality, model_override=model_override)
@@ -437,6 +513,7 @@ def stage4_classification(self, video_id: str) -> str:
_store_classification_data(video_id, classification_data)
elapsed = time.monotonic() - start
+ _emit_event(video_id, "stage4_classification", "complete")
logger.info(
"Stage 4 (classification) completed for video_id=%s in %.1fs — %d moments classified",
video_id, elapsed, len(classification_data),
@@ -447,6 +524,7 @@ def stage4_classification(self, video_id: str) -> str:
raise
except Exception as exc:
session.rollback()
+ _emit_event(video_id, "stage4_classification", "error", payload={"error": str(exc)})
logger.error("Stage 4 failed for video_id=%s: %s", video_id, exc)
raise self.retry(exc=exc)
finally:
@@ -539,6 +617,7 @@ def stage5_synthesis(self, video_id: str) -> str:
"""
start = time.monotonic()
logger.info("Stage 5 (synthesis) starting for video_id=%s", video_id)
+ _emit_event(video_id, "stage5_synthesis", "start")
settings = get_settings()
session = _get_sync_session()
@@ -576,6 +655,7 @@ def stage5_synthesis(self, video_id: str) -> str:
system_prompt = _load_prompt("stage5_synthesis.txt")
llm = _get_llm_client()
model_override, modality = _get_stage_config(5)
+ hard_limit = get_settings().llm_max_tokens_hard_limit
logger.info("Stage 5 using model=%s, modality=%s", model_override or "default", modality)
pages_created = 0
@@ -600,17 +680,39 @@ def stage5_synthesis(self, video_id: str) -> str:
user_prompt = f"\n{moments_text}\n"
- raw = llm.complete(system_prompt, user_prompt, response_model=SynthesisResult,
- modality=modality, model_override=model_override)
+ max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage5_synthesis", hard_limit=hard_limit)
+ raw = llm.complete(system_prompt, user_prompt, response_model=SynthesisResult, on_complete=_make_llm_callback(video_id, "stage5_synthesis"),
+ modality=modality, model_override=model_override, max_tokens=max_tokens)
result = _safe_parse_llm_response(raw, SynthesisResult, llm, system_prompt, user_prompt,
modality=modality, model_override=model_override)
+ # Load prior pages from this video (snapshot taken before pipeline reset)
+ prior_page_ids = _load_prior_pages(video_id)
+
# Create/update TechniquePage rows
for page_data in result.pages:
- # Check if page with this slug already exists
- existing = session.execute(
- select(TechniquePage).where(TechniquePage.slug == page_data.slug)
- ).scalar_one_or_none()
+ existing = None
+
+ # First: check prior pages from this video by creator + category
+ if prior_page_ids:
+ existing = session.execute(
+ select(TechniquePage).where(
+ TechniquePage.id.in_(prior_page_ids),
+ TechniquePage.creator_id == video.creator_id,
+ TechniquePage.topic_category == (page_data.topic_category or category),
+ )
+ ).scalar_one_or_none()
+ if existing:
+ logger.info(
+ "Stage 5: Matched prior page '%s' (id=%s) by creator+category for video_id=%s",
+ existing.slug, existing.id, video_id,
+ )
+
+ # Fallback: check by slug (handles cross-video dedup)
+ if existing is None:
+ existing = session.execute(
+ select(TechniquePage).where(TechniquePage.slug == page_data.slug)
+ ).scalar_one_or_none()
if existing:
# Snapshot existing content before overwriting
@@ -690,6 +792,7 @@ def stage5_synthesis(self, video_id: str) -> str:
session.commit()
elapsed = time.monotonic() - start
+ _emit_event(video_id, "stage5_synthesis", "complete")
logger.info(
"Stage 5 (synthesis) completed for video_id=%s in %.1fs — %d pages created/updated",
video_id, elapsed, pages_created,
@@ -700,6 +803,7 @@ def stage5_synthesis(self, video_id: str) -> str:
raise
except Exception as exc:
session.rollback()
+ _emit_event(video_id, "stage5_synthesis", "error", payload={"error": str(exc)})
logger.error("Stage 5 failed for video_id=%s: %s", video_id, exc)
raise self.retry(exc=exc)
finally:
@@ -837,6 +941,58 @@ def stage6_embed_and_index(self, video_id: str) -> str:
session.close()
+
+
+def _snapshot_prior_pages(video_id: str) -> None:
+ """Save existing technique_page_ids linked to this video before pipeline resets them.
+
+ When a video is reprocessed, stage 3 deletes and recreates key_moments,
+ breaking the link to technique pages. This snapshots the page IDs to Redis
+ so stage 5 can find and update prior pages instead of creating duplicates.
+ """
+ import redis
+
+ session = _get_sync_session()
+ try:
+ # Find technique pages linked via this video's key moments
+ rows = session.execute(
+ select(KeyMoment.technique_page_id)
+ .where(
+ KeyMoment.source_video_id == video_id,
+ KeyMoment.technique_page_id.isnot(None),
+ )
+ .distinct()
+ ).scalars().all()
+
+ page_ids = [str(pid) for pid in rows]
+
+ if page_ids:
+ settings = get_settings()
+ r = redis.Redis.from_url(settings.redis_url)
+ key = f"chrysopedia:prior_pages:{video_id}"
+ r.set(key, json.dumps(page_ids), ex=86400)
+ logger.info(
+ "Snapshot %d prior technique pages for video_id=%s: %s",
+ len(page_ids), video_id, page_ids,
+ )
+ else:
+ logger.info("No prior technique pages for video_id=%s", video_id)
+ finally:
+ session.close()
+
+
+def _load_prior_pages(video_id: str) -> list[str]:
+ """Load prior technique page IDs from Redis."""
+ import redis
+
+ settings = get_settings()
+ r = redis.Redis.from_url(settings.redis_url)
+ key = f"chrysopedia:prior_pages:{video_id}"
+ raw = r.get(key)
+ if raw is None:
+ return []
+ return json.loads(raw)
+
# ── Orchestrator ─────────────────────────────────────────────────────────────
@celery_app.task
@@ -870,6 +1026,9 @@ def run_pipeline(video_id: str) -> str:
finally:
session.close()
+ # Snapshot prior technique pages before pipeline resets key_moments
+ _snapshot_prior_pages(video_id)
+
# Build the chain based on current status
stages = []
if status in (ProcessingStatus.pending, ProcessingStatus.transcribed):