From c6c15defee9f9f32b35a70bb94e8d6f7e165be66 Mon Sep 17 00:00:00 2001 From: jlightner Date: Mon, 30 Mar 2026 05:55:17 -0500 Subject: [PATCH] feat: Dynamic token estimation for per-stage max_tokens - Add estimate_tokens() and estimate_max_tokens() to llm_client with stage-specific output ratios (0.3x segmentation, 1.2x extraction, 0.15x classification, 1.5x synthesis) - Add max_tokens override parameter to LLMClient.complete() - Wire all 4 pipeline stages to estimate max_tokens from actual prompt content with 20% buffer and 2048 floor - Add LLM_MAX_TOKENS_HARD_LIMIT=32768 config (dynamic estimator ceiling) - Log token estimates alongside every LLM request Co-Authored-By: Claude Opus 4.6 (1M context) --- backend/config.py | 5 +- backend/pipeline/llm_client.py | 114 +++++++++++++++++++- backend/pipeline/stages.py | 189 ++++++++++++++++++++++++++++++--- 3 files changed, 287 insertions(+), 21 deletions(-) diff --git a/backend/config.py b/backend/config.py index d6e2b6f..482e601 100644 --- a/backend/config.py +++ b/backend/config.py @@ -43,8 +43,9 @@ class Settings(BaseSettings): llm_stage5_model: str | None = "fyn-llm-agent-think" # synthesis — reasoning llm_stage5_modality: str = "thinking" - # Max tokens for LLM responses (OpenWebUI defaults to 1000 which truncates pipeline JSON) - llm_max_tokens: int = 65536 + # Dynamic token estimation — each stage calculates max_tokens from input size + llm_max_tokens_hard_limit: int = 32768 # Hard ceiling for dynamic estimator + llm_max_tokens: int = 65536 # Fallback when no estimate is provided # Embedding endpoint embedding_api_url: str = "http://localhost:11434/v1" diff --git a/backend/pipeline/llm_client.py b/backend/pipeline/llm_client.py index f7d23f0..1c26f95 100644 --- a/backend/pipeline/llm_client.py +++ b/backend/pipeline/llm_client.py @@ -14,7 +14,10 @@ from __future__ import annotations import logging import re -from typing import TypeVar +from typing import TYPE_CHECKING, TypeVar + +if TYPE_CHECKING: + from collections.abc import Callable import openai from pydantic import BaseModel @@ -50,6 +53,77 @@ def strip_think_tags(text: str) -> str: return cleaned.strip() + +# ── Token estimation ───────────────────────────────────────────────────────── + +# Stage-specific output multipliers: estimated output tokens as a ratio of input tokens. +# These are empirically tuned based on observed pipeline behavior. +_STAGE_OUTPUT_RATIOS: dict[str, float] = { + "stage2_segmentation": 0.3, # Compact topic groups — much smaller than input + "stage3_extraction": 1.2, # Detailed moments with summaries — can exceed input + "stage4_classification": 0.15, # Index + category + tags per moment — very compact + "stage5_synthesis": 1.5, # Full prose technique pages — heaviest output +} + +# Minimum floor so we never send a trivially small max_tokens +_MIN_MAX_TOKENS = 2048 + + +def estimate_tokens(text: str) -> int: + """Estimate token count from text using a chars-per-token heuristic. + + Uses 3.5 chars/token which is conservative for English + JSON markup. + """ + if not text: + return 0 + return max(1, int(len(text) / 3.5)) + + +def estimate_max_tokens( + system_prompt: str, + user_prompt: str, + stage: str | None = None, + hard_limit: int = 32768, +) -> int: + """Estimate the max_tokens parameter for an LLM call. + + Calculates expected output size based on input size and stage-specific + multipliers. The result is clamped between _MIN_MAX_TOKENS and hard_limit. + + Parameters + ---------- + system_prompt: + The system prompt text. + user_prompt: + The user prompt text (transcript, moments, etc.). + stage: + Pipeline stage name (e.g. "stage3_extraction"). If None or unknown, + uses a default 1.0x multiplier. + hard_limit: + Absolute ceiling — never exceed this value. + + Returns + ------- + int + Estimated max_tokens value to pass to the LLM API. + """ + input_tokens = estimate_tokens(system_prompt) + estimate_tokens(user_prompt) + ratio = _STAGE_OUTPUT_RATIOS.get(stage or "", 1.0) + estimated_output = int(input_tokens * ratio) + + # Add a 20% buffer for JSON overhead and variability + estimated_output = int(estimated_output * 1.2) + + # Clamp to [_MIN_MAX_TOKENS, hard_limit] + result = max(_MIN_MAX_TOKENS, min(estimated_output, hard_limit)) + + logger.info( + "Token estimate: input≈%d, stage=%s, ratio=%.2f, estimated_output=%d, max_tokens=%d (hard_limit=%d)", + input_tokens, stage or "default", ratio, estimated_output, result, hard_limit, + ) + return result + + class LLMClient: """Sync LLM client that tries a primary endpoint and falls back on failure.""" @@ -73,6 +147,8 @@ class LLMClient: response_model: type[BaseModel] | None = None, modality: str = "chat", model_override: str | None = None, + on_complete: "Callable | None" = None, + max_tokens: int | None = None, ) -> str: """Send a chat completion request, falling back on connection/timeout errors. @@ -92,6 +168,9 @@ class LLMClient: model_override: Model name to use instead of the default. If None, uses the configured default for the endpoint. + max_tokens: + Override for max_tokens on this call. If None, falls back to + the configured ``llm_max_tokens`` from settings. Returns ------- @@ -123,12 +202,14 @@ class LLMClient: primary_model = model_override or self.settings.llm_model fallback_model = self.settings.llm_fallback_model + effective_max_tokens = max_tokens if max_tokens is not None else self.settings.llm_max_tokens logger.info( - "LLM request: model=%s, modality=%s, response_model=%s", + "LLM request: model=%s, modality=%s, response_model=%s, max_tokens=%d", primary_model, modality, response_model.__name__ if response_model else None, + effective_max_tokens, ) # --- Try primary endpoint --- @@ -136,7 +217,7 @@ class LLMClient: response = self._primary.chat.completions.create( model=primary_model, messages=messages, - max_tokens=self.settings.llm_max_tokens, + max_tokens=effective_max_tokens, **kwargs, ) raw = response.choices[0].message.content or "" @@ -149,6 +230,18 @@ class LLMClient: ) if modality == "thinking": raw = strip_think_tags(raw) + if on_complete is not None: + try: + on_complete( + model=primary_model, + prompt_tokens=usage.prompt_tokens if usage else None, + completion_tokens=usage.completion_tokens if usage else None, + total_tokens=usage.total_tokens if usage else None, + content=raw, + finish_reason=response.choices[0].finish_reason if response.choices else None, + ) + except Exception as cb_exc: + logger.warning("on_complete callback failed: %s", cb_exc) return raw except (openai.APIConnectionError, openai.APITimeoutError) as exc: @@ -164,7 +257,7 @@ class LLMClient: response = self._fallback.chat.completions.create( model=fallback_model, messages=messages, - max_tokens=self.settings.llm_max_tokens, + max_tokens=effective_max_tokens, **kwargs, ) raw = response.choices[0].message.content or "" @@ -177,6 +270,19 @@ class LLMClient: ) if modality == "thinking": raw = strip_think_tags(raw) + if on_complete is not None: + try: + on_complete( + model=fallback_model, + prompt_tokens=usage.prompt_tokens if usage else None, + completion_tokens=usage.completion_tokens if usage else None, + total_tokens=usage.total_tokens if usage else None, + content=raw, + finish_reason=response.choices[0].finish_reason if response.choices else None, + is_fallback=True, + ) + except Exception as cb_exc: + logger.warning("on_complete callback failed: %s", cb_exc) return raw except (openai.APIConnectionError, openai.APITimeoutError, openai.APIError) as exc: diff --git a/backend/pipeline/stages.py b/backend/pipeline/stages.py index 3a49f86..fdc2204 100644 --- a/backend/pipeline/stages.py +++ b/backend/pipeline/stages.py @@ -26,6 +26,7 @@ from config import get_settings from models import ( KeyMoment, KeyMomentContentType, + PipelineEvent, ProcessingStatus, SourceVideo, TechniquePage, @@ -33,7 +34,7 @@ from models import ( TranscriptSegment, ) from pipeline.embedding_client import EmbeddingClient -from pipeline.llm_client import LLMClient +from pipeline.llm_client import LLMClient, estimate_max_tokens from pipeline.qdrant_client import QdrantManager from pipeline.schemas import ( ClassificationResult, @@ -45,6 +46,68 @@ from worker import celery_app logger = logging.getLogger(__name__) + +# ── Pipeline event persistence ─────────────────────────────────────────────── + +def _emit_event( + video_id: str, + stage: str, + event_type: str, + *, + prompt_tokens: int | None = None, + completion_tokens: int | None = None, + total_tokens: int | None = None, + model: str | None = None, + duration_ms: int | None = None, + payload: dict | None = None, +) -> None: + """Persist a pipeline event to the DB. Best-effort -- failures logged, not raised.""" + try: + session = _get_sync_session() + try: + event = PipelineEvent( + video_id=video_id, + stage=stage, + event_type=event_type, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + model=model, + duration_ms=duration_ms, + payload=payload, + ) + session.add(event) + session.commit() + finally: + session.close() + except Exception as exc: + logger.warning("Failed to emit pipeline event: %s", exc) + + +def _make_llm_callback(video_id: str, stage: str): + """Create an on_complete callback for LLMClient that emits llm_call events.""" + def callback(*, model=None, prompt_tokens=None, completion_tokens=None, + total_tokens=None, content=None, finish_reason=None, + is_fallback=False, **_kwargs): + # Truncate content for storage — keep first 2000 chars for debugging + truncated = content[:2000] if content and len(content) > 2000 else content + _emit_event( + video_id=video_id, + stage=stage, + event_type="llm_call", + model=model, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + payload={ + "content_preview": truncated, + "content_length": len(content) if content else 0, + "finish_reason": finish_reason, + "is_fallback": is_fallback, + }, + ) + return callback + # ── Helpers ────────────────────────────────────────────────────────────────── _engine = None @@ -175,6 +238,7 @@ def stage2_segmentation(self, video_id: str) -> str: """ start = time.monotonic() logger.info("Stage 2 (segmentation) starting for video_id=%s", video_id) + _emit_event(video_id, "stage2_segmentation", "start") session = _get_sync_session() try: @@ -207,9 +271,11 @@ def stage2_segmentation(self, video_id: str) -> str: llm = _get_llm_client() model_override, modality = _get_stage_config(2) - logger.info("Stage 2 using model=%s, modality=%s", model_override or "default", modality) - raw = llm.complete(system_prompt, user_prompt, response_model=SegmentationResult, - modality=modality, model_override=model_override) + hard_limit = get_settings().llm_max_tokens_hard_limit + max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage2_segmentation", hard_limit=hard_limit) + logger.info("Stage 2 using model=%s, modality=%s, max_tokens=%d", model_override or "default", modality, max_tokens) + raw = llm.complete(system_prompt, user_prompt, response_model=SegmentationResult, on_complete=_make_llm_callback(video_id, "stage2_segmentation"), + modality=modality, model_override=model_override, max_tokens=max_tokens) result = _safe_parse_llm_response(raw, SegmentationResult, llm, system_prompt, user_prompt, modality=modality, model_override=model_override) @@ -222,6 +288,7 @@ def stage2_segmentation(self, video_id: str) -> str: session.commit() elapsed = time.monotonic() - start + _emit_event(video_id, "stage2_segmentation", "complete") logger.info( "Stage 2 (segmentation) completed for video_id=%s in %.1fs — %d topic groups found", video_id, elapsed, len(result.segments), @@ -232,6 +299,7 @@ def stage2_segmentation(self, video_id: str) -> str: raise # Don't retry missing prompt files except Exception as exc: session.rollback() + _emit_event(video_id, "stage2_segmentation", "error", payload={"error": str(exc)}) logger.error("Stage 2 failed for video_id=%s: %s", video_id, exc) raise self.retry(exc=exc) finally: @@ -251,6 +319,7 @@ def stage3_extraction(self, video_id: str) -> str: """ start = time.monotonic() logger.info("Stage 3 (extraction) starting for video_id=%s", video_id) + _emit_event(video_id, "stage3_extraction", "start") session = _get_sync_session() try: @@ -278,6 +347,7 @@ def stage3_extraction(self, video_id: str) -> str: system_prompt = _load_prompt("stage3_extraction.txt") llm = _get_llm_client() model_override, modality = _get_stage_config(3) + hard_limit = get_settings().llm_max_tokens_hard_limit logger.info("Stage 3 using model=%s, modality=%s", model_override or "default", modality) total_moments = 0 @@ -295,8 +365,9 @@ def stage3_extraction(self, video_id: str) -> str: f"\n{segment_text}\n" ) - raw = llm.complete(system_prompt, user_prompt, response_model=ExtractionResult, - modality=modality, model_override=model_override) + max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage3_extraction", hard_limit=hard_limit) + raw = llm.complete(system_prompt, user_prompt, response_model=ExtractionResult, on_complete=_make_llm_callback(video_id, "stage3_extraction"), + modality=modality, model_override=model_override, max_tokens=max_tokens) result = _safe_parse_llm_response(raw, ExtractionResult, llm, system_prompt, user_prompt, modality=modality, model_override=model_override) @@ -329,6 +400,7 @@ def stage3_extraction(self, video_id: str) -> str: session.commit() elapsed = time.monotonic() - start + _emit_event(video_id, "stage3_extraction", "complete") logger.info( "Stage 3 (extraction) completed for video_id=%s in %.1fs — %d moments created", video_id, elapsed, total_moments, @@ -339,6 +411,7 @@ def stage3_extraction(self, video_id: str) -> str: raise except Exception as exc: session.rollback() + _emit_event(video_id, "stage3_extraction", "error", payload={"error": str(exc)}) logger.error("Stage 3 failed for video_id=%s: %s", video_id, exc) raise self.retry(exc=exc) finally: @@ -361,6 +434,7 @@ def stage4_classification(self, video_id: str) -> str: """ start = time.monotonic() logger.info("Stage 4 (classification) starting for video_id=%s", video_id) + _emit_event(video_id, "stage4_classification", "start") session = _get_sync_session() try: @@ -404,9 +478,11 @@ def stage4_classification(self, video_id: str) -> str: llm = _get_llm_client() model_override, modality = _get_stage_config(4) - logger.info("Stage 4 using model=%s, modality=%s", model_override or "default", modality) - raw = llm.complete(system_prompt, user_prompt, response_model=ClassificationResult, - modality=modality, model_override=model_override) + hard_limit = get_settings().llm_max_tokens_hard_limit + max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage4_classification", hard_limit=hard_limit) + logger.info("Stage 4 using model=%s, modality=%s, max_tokens=%d", model_override or "default", modality, max_tokens) + raw = llm.complete(system_prompt, user_prompt, response_model=ClassificationResult, on_complete=_make_llm_callback(video_id, "stage4_classification"), + modality=modality, model_override=model_override, max_tokens=max_tokens) result = _safe_parse_llm_response(raw, ClassificationResult, llm, system_prompt, user_prompt, modality=modality, model_override=model_override) @@ -437,6 +513,7 @@ def stage4_classification(self, video_id: str) -> str: _store_classification_data(video_id, classification_data) elapsed = time.monotonic() - start + _emit_event(video_id, "stage4_classification", "complete") logger.info( "Stage 4 (classification) completed for video_id=%s in %.1fs — %d moments classified", video_id, elapsed, len(classification_data), @@ -447,6 +524,7 @@ def stage4_classification(self, video_id: str) -> str: raise except Exception as exc: session.rollback() + _emit_event(video_id, "stage4_classification", "error", payload={"error": str(exc)}) logger.error("Stage 4 failed for video_id=%s: %s", video_id, exc) raise self.retry(exc=exc) finally: @@ -539,6 +617,7 @@ def stage5_synthesis(self, video_id: str) -> str: """ start = time.monotonic() logger.info("Stage 5 (synthesis) starting for video_id=%s", video_id) + _emit_event(video_id, "stage5_synthesis", "start") settings = get_settings() session = _get_sync_session() @@ -576,6 +655,7 @@ def stage5_synthesis(self, video_id: str) -> str: system_prompt = _load_prompt("stage5_synthesis.txt") llm = _get_llm_client() model_override, modality = _get_stage_config(5) + hard_limit = get_settings().llm_max_tokens_hard_limit logger.info("Stage 5 using model=%s, modality=%s", model_override or "default", modality) pages_created = 0 @@ -600,17 +680,39 @@ def stage5_synthesis(self, video_id: str) -> str: user_prompt = f"\n{moments_text}\n" - raw = llm.complete(system_prompt, user_prompt, response_model=SynthesisResult, - modality=modality, model_override=model_override) + max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage5_synthesis", hard_limit=hard_limit) + raw = llm.complete(system_prompt, user_prompt, response_model=SynthesisResult, on_complete=_make_llm_callback(video_id, "stage5_synthesis"), + modality=modality, model_override=model_override, max_tokens=max_tokens) result = _safe_parse_llm_response(raw, SynthesisResult, llm, system_prompt, user_prompt, modality=modality, model_override=model_override) + # Load prior pages from this video (snapshot taken before pipeline reset) + prior_page_ids = _load_prior_pages(video_id) + # Create/update TechniquePage rows for page_data in result.pages: - # Check if page with this slug already exists - existing = session.execute( - select(TechniquePage).where(TechniquePage.slug == page_data.slug) - ).scalar_one_or_none() + existing = None + + # First: check prior pages from this video by creator + category + if prior_page_ids: + existing = session.execute( + select(TechniquePage).where( + TechniquePage.id.in_(prior_page_ids), + TechniquePage.creator_id == video.creator_id, + TechniquePage.topic_category == (page_data.topic_category or category), + ) + ).scalar_one_or_none() + if existing: + logger.info( + "Stage 5: Matched prior page '%s' (id=%s) by creator+category for video_id=%s", + existing.slug, existing.id, video_id, + ) + + # Fallback: check by slug (handles cross-video dedup) + if existing is None: + existing = session.execute( + select(TechniquePage).where(TechniquePage.slug == page_data.slug) + ).scalar_one_or_none() if existing: # Snapshot existing content before overwriting @@ -690,6 +792,7 @@ def stage5_synthesis(self, video_id: str) -> str: session.commit() elapsed = time.monotonic() - start + _emit_event(video_id, "stage5_synthesis", "complete") logger.info( "Stage 5 (synthesis) completed for video_id=%s in %.1fs — %d pages created/updated", video_id, elapsed, pages_created, @@ -700,6 +803,7 @@ def stage5_synthesis(self, video_id: str) -> str: raise except Exception as exc: session.rollback() + _emit_event(video_id, "stage5_synthesis", "error", payload={"error": str(exc)}) logger.error("Stage 5 failed for video_id=%s: %s", video_id, exc) raise self.retry(exc=exc) finally: @@ -837,6 +941,58 @@ def stage6_embed_and_index(self, video_id: str) -> str: session.close() + + +def _snapshot_prior_pages(video_id: str) -> None: + """Save existing technique_page_ids linked to this video before pipeline resets them. + + When a video is reprocessed, stage 3 deletes and recreates key_moments, + breaking the link to technique pages. This snapshots the page IDs to Redis + so stage 5 can find and update prior pages instead of creating duplicates. + """ + import redis + + session = _get_sync_session() + try: + # Find technique pages linked via this video's key moments + rows = session.execute( + select(KeyMoment.technique_page_id) + .where( + KeyMoment.source_video_id == video_id, + KeyMoment.technique_page_id.isnot(None), + ) + .distinct() + ).scalars().all() + + page_ids = [str(pid) for pid in rows] + + if page_ids: + settings = get_settings() + r = redis.Redis.from_url(settings.redis_url) + key = f"chrysopedia:prior_pages:{video_id}" + r.set(key, json.dumps(page_ids), ex=86400) + logger.info( + "Snapshot %d prior technique pages for video_id=%s: %s", + len(page_ids), video_id, page_ids, + ) + else: + logger.info("No prior technique pages for video_id=%s", video_id) + finally: + session.close() + + +def _load_prior_pages(video_id: str) -> list[str]: + """Load prior technique page IDs from Redis.""" + import redis + + settings = get_settings() + r = redis.Redis.from_url(settings.redis_url) + key = f"chrysopedia:prior_pages:{video_id}" + raw = r.get(key) + if raw is None: + return [] + return json.loads(raw) + # ── Orchestrator ───────────────────────────────────────────────────────────── @celery_app.task @@ -870,6 +1026,9 @@ def run_pipeline(video_id: str) -> str: finally: session.close() + # Snapshot prior technique pages before pipeline resets key_moments + _snapshot_prior_pages(video_id) + # Build the chain based on current status stages = [] if status in (ProcessingStatus.pending, ProcessingStatus.transcribed):