feat: Dynamic token estimation for per-stage max_tokens

- Add estimate_tokens() and estimate_max_tokens() to llm_client with
  stage-specific output ratios (0.3x segmentation, 1.2x extraction,
  0.15x classification, 1.5x synthesis)
- Add max_tokens override parameter to LLMClient.complete()
- Wire all 4 pipeline stages to estimate max_tokens from actual prompt
  content with 20% buffer and 2048 floor
- Add LLM_MAX_TOKENS_HARD_LIMIT=32768 config (dynamic estimator ceiling)
- Log token estimates alongside every LLM request

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
jlightner 2026-03-30 05:55:17 -05:00
parent df33d15360
commit c6c15defee
3 changed files with 287 additions and 21 deletions

View file

@ -43,8 +43,9 @@ class Settings(BaseSettings):
llm_stage5_model: str | None = "fyn-llm-agent-think" # synthesis — reasoning
llm_stage5_modality: str = "thinking"
# Max tokens for LLM responses (OpenWebUI defaults to 1000 which truncates pipeline JSON)
llm_max_tokens: int = 65536
# Dynamic token estimation — each stage calculates max_tokens from input size
llm_max_tokens_hard_limit: int = 32768 # Hard ceiling for dynamic estimator
llm_max_tokens: int = 65536 # Fallback when no estimate is provided
# Embedding endpoint
embedding_api_url: str = "http://localhost:11434/v1"

View file

@ -14,7 +14,10 @@ from __future__ import annotations
import logging
import re
from typing import TypeVar
from typing import TYPE_CHECKING, TypeVar
if TYPE_CHECKING:
from collections.abc import Callable
import openai
from pydantic import BaseModel
@ -50,6 +53,77 @@ def strip_think_tags(text: str) -> str:
return cleaned.strip()
# ── Token estimation ─────────────────────────────────────────────────────────
# Stage-specific output multipliers: estimated output tokens as a ratio of input tokens.
# These are empirically tuned based on observed pipeline behavior.
_STAGE_OUTPUT_RATIOS: dict[str, float] = {
"stage2_segmentation": 0.3, # Compact topic groups — much smaller than input
"stage3_extraction": 1.2, # Detailed moments with summaries — can exceed input
"stage4_classification": 0.15, # Index + category + tags per moment — very compact
"stage5_synthesis": 1.5, # Full prose technique pages — heaviest output
}
# Minimum floor so we never send a trivially small max_tokens
_MIN_MAX_TOKENS = 2048
def estimate_tokens(text: str) -> int:
"""Estimate token count from text using a chars-per-token heuristic.
Uses 3.5 chars/token which is conservative for English + JSON markup.
"""
if not text:
return 0
return max(1, int(len(text) / 3.5))
def estimate_max_tokens(
system_prompt: str,
user_prompt: str,
stage: str | None = None,
hard_limit: int = 32768,
) -> int:
"""Estimate the max_tokens parameter for an LLM call.
Calculates expected output size based on input size and stage-specific
multipliers. The result is clamped between _MIN_MAX_TOKENS and hard_limit.
Parameters
----------
system_prompt:
The system prompt text.
user_prompt:
The user prompt text (transcript, moments, etc.).
stage:
Pipeline stage name (e.g. "stage3_extraction"). If None or unknown,
uses a default 1.0x multiplier.
hard_limit:
Absolute ceiling never exceed this value.
Returns
-------
int
Estimated max_tokens value to pass to the LLM API.
"""
input_tokens = estimate_tokens(system_prompt) + estimate_tokens(user_prompt)
ratio = _STAGE_OUTPUT_RATIOS.get(stage or "", 1.0)
estimated_output = int(input_tokens * ratio)
# Add a 20% buffer for JSON overhead and variability
estimated_output = int(estimated_output * 1.2)
# Clamp to [_MIN_MAX_TOKENS, hard_limit]
result = max(_MIN_MAX_TOKENS, min(estimated_output, hard_limit))
logger.info(
"Token estimate: input≈%d, stage=%s, ratio=%.2f, estimated_output=%d, max_tokens=%d (hard_limit=%d)",
input_tokens, stage or "default", ratio, estimated_output, result, hard_limit,
)
return result
class LLMClient:
"""Sync LLM client that tries a primary endpoint and falls back on failure."""
@ -73,6 +147,8 @@ class LLMClient:
response_model: type[BaseModel] | None = None,
modality: str = "chat",
model_override: str | None = None,
on_complete: "Callable | None" = None,
max_tokens: int | None = None,
) -> str:
"""Send a chat completion request, falling back on connection/timeout errors.
@ -92,6 +168,9 @@ class LLMClient:
model_override:
Model name to use instead of the default. If None, uses the
configured default for the endpoint.
max_tokens:
Override for max_tokens on this call. If None, falls back to
the configured ``llm_max_tokens`` from settings.
Returns
-------
@ -123,12 +202,14 @@ class LLMClient:
primary_model = model_override or self.settings.llm_model
fallback_model = self.settings.llm_fallback_model
effective_max_tokens = max_tokens if max_tokens is not None else self.settings.llm_max_tokens
logger.info(
"LLM request: model=%s, modality=%s, response_model=%s",
"LLM request: model=%s, modality=%s, response_model=%s, max_tokens=%d",
primary_model,
modality,
response_model.__name__ if response_model else None,
effective_max_tokens,
)
# --- Try primary endpoint ---
@ -136,7 +217,7 @@ class LLMClient:
response = self._primary.chat.completions.create(
model=primary_model,
messages=messages,
max_tokens=self.settings.llm_max_tokens,
max_tokens=effective_max_tokens,
**kwargs,
)
raw = response.choices[0].message.content or ""
@ -149,6 +230,18 @@ class LLMClient:
)
if modality == "thinking":
raw = strip_think_tags(raw)
if on_complete is not None:
try:
on_complete(
model=primary_model,
prompt_tokens=usage.prompt_tokens if usage else None,
completion_tokens=usage.completion_tokens if usage else None,
total_tokens=usage.total_tokens if usage else None,
content=raw,
finish_reason=response.choices[0].finish_reason if response.choices else None,
)
except Exception as cb_exc:
logger.warning("on_complete callback failed: %s", cb_exc)
return raw
except (openai.APIConnectionError, openai.APITimeoutError) as exc:
@ -164,7 +257,7 @@ class LLMClient:
response = self._fallback.chat.completions.create(
model=fallback_model,
messages=messages,
max_tokens=self.settings.llm_max_tokens,
max_tokens=effective_max_tokens,
**kwargs,
)
raw = response.choices[0].message.content or ""
@ -177,6 +270,19 @@ class LLMClient:
)
if modality == "thinking":
raw = strip_think_tags(raw)
if on_complete is not None:
try:
on_complete(
model=fallback_model,
prompt_tokens=usage.prompt_tokens if usage else None,
completion_tokens=usage.completion_tokens if usage else None,
total_tokens=usage.total_tokens if usage else None,
content=raw,
finish_reason=response.choices[0].finish_reason if response.choices else None,
is_fallback=True,
)
except Exception as cb_exc:
logger.warning("on_complete callback failed: %s", cb_exc)
return raw
except (openai.APIConnectionError, openai.APITimeoutError, openai.APIError) as exc:

View file

@ -26,6 +26,7 @@ from config import get_settings
from models import (
KeyMoment,
KeyMomentContentType,
PipelineEvent,
ProcessingStatus,
SourceVideo,
TechniquePage,
@ -33,7 +34,7 @@ from models import (
TranscriptSegment,
)
from pipeline.embedding_client import EmbeddingClient
from pipeline.llm_client import LLMClient
from pipeline.llm_client import LLMClient, estimate_max_tokens
from pipeline.qdrant_client import QdrantManager
from pipeline.schemas import (
ClassificationResult,
@ -45,6 +46,68 @@ from worker import celery_app
logger = logging.getLogger(__name__)
# ── Pipeline event persistence ───────────────────────────────────────────────
def _emit_event(
video_id: str,
stage: str,
event_type: str,
*,
prompt_tokens: int | None = None,
completion_tokens: int | None = None,
total_tokens: int | None = None,
model: str | None = None,
duration_ms: int | None = None,
payload: dict | None = None,
) -> None:
"""Persist a pipeline event to the DB. Best-effort -- failures logged, not raised."""
try:
session = _get_sync_session()
try:
event = PipelineEvent(
video_id=video_id,
stage=stage,
event_type=event_type,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
model=model,
duration_ms=duration_ms,
payload=payload,
)
session.add(event)
session.commit()
finally:
session.close()
except Exception as exc:
logger.warning("Failed to emit pipeline event: %s", exc)
def _make_llm_callback(video_id: str, stage: str):
"""Create an on_complete callback for LLMClient that emits llm_call events."""
def callback(*, model=None, prompt_tokens=None, completion_tokens=None,
total_tokens=None, content=None, finish_reason=None,
is_fallback=False, **_kwargs):
# Truncate content for storage — keep first 2000 chars for debugging
truncated = content[:2000] if content and len(content) > 2000 else content
_emit_event(
video_id=video_id,
stage=stage,
event_type="llm_call",
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
payload={
"content_preview": truncated,
"content_length": len(content) if content else 0,
"finish_reason": finish_reason,
"is_fallback": is_fallback,
},
)
return callback
# ── Helpers ──────────────────────────────────────────────────────────────────
_engine = None
@ -175,6 +238,7 @@ def stage2_segmentation(self, video_id: str) -> str:
"""
start = time.monotonic()
logger.info("Stage 2 (segmentation) starting for video_id=%s", video_id)
_emit_event(video_id, "stage2_segmentation", "start")
session = _get_sync_session()
try:
@ -207,9 +271,11 @@ def stage2_segmentation(self, video_id: str) -> str:
llm = _get_llm_client()
model_override, modality = _get_stage_config(2)
logger.info("Stage 2 using model=%s, modality=%s", model_override or "default", modality)
raw = llm.complete(system_prompt, user_prompt, response_model=SegmentationResult,
modality=modality, model_override=model_override)
hard_limit = get_settings().llm_max_tokens_hard_limit
max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage2_segmentation", hard_limit=hard_limit)
logger.info("Stage 2 using model=%s, modality=%s, max_tokens=%d", model_override or "default", modality, max_tokens)
raw = llm.complete(system_prompt, user_prompt, response_model=SegmentationResult, on_complete=_make_llm_callback(video_id, "stage2_segmentation"),
modality=modality, model_override=model_override, max_tokens=max_tokens)
result = _safe_parse_llm_response(raw, SegmentationResult, llm, system_prompt, user_prompt,
modality=modality, model_override=model_override)
@ -222,6 +288,7 @@ def stage2_segmentation(self, video_id: str) -> str:
session.commit()
elapsed = time.monotonic() - start
_emit_event(video_id, "stage2_segmentation", "complete")
logger.info(
"Stage 2 (segmentation) completed for video_id=%s in %.1fs — %d topic groups found",
video_id, elapsed, len(result.segments),
@ -232,6 +299,7 @@ def stage2_segmentation(self, video_id: str) -> str:
raise # Don't retry missing prompt files
except Exception as exc:
session.rollback()
_emit_event(video_id, "stage2_segmentation", "error", payload={"error": str(exc)})
logger.error("Stage 2 failed for video_id=%s: %s", video_id, exc)
raise self.retry(exc=exc)
finally:
@ -251,6 +319,7 @@ def stage3_extraction(self, video_id: str) -> str:
"""
start = time.monotonic()
logger.info("Stage 3 (extraction) starting for video_id=%s", video_id)
_emit_event(video_id, "stage3_extraction", "start")
session = _get_sync_session()
try:
@ -278,6 +347,7 @@ def stage3_extraction(self, video_id: str) -> str:
system_prompt = _load_prompt("stage3_extraction.txt")
llm = _get_llm_client()
model_override, modality = _get_stage_config(3)
hard_limit = get_settings().llm_max_tokens_hard_limit
logger.info("Stage 3 using model=%s, modality=%s", model_override or "default", modality)
total_moments = 0
@ -295,8 +365,9 @@ def stage3_extraction(self, video_id: str) -> str:
f"<segment>\n{segment_text}\n</segment>"
)
raw = llm.complete(system_prompt, user_prompt, response_model=ExtractionResult,
modality=modality, model_override=model_override)
max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage3_extraction", hard_limit=hard_limit)
raw = llm.complete(system_prompt, user_prompt, response_model=ExtractionResult, on_complete=_make_llm_callback(video_id, "stage3_extraction"),
modality=modality, model_override=model_override, max_tokens=max_tokens)
result = _safe_parse_llm_response(raw, ExtractionResult, llm, system_prompt, user_prompt,
modality=modality, model_override=model_override)
@ -329,6 +400,7 @@ def stage3_extraction(self, video_id: str) -> str:
session.commit()
elapsed = time.monotonic() - start
_emit_event(video_id, "stage3_extraction", "complete")
logger.info(
"Stage 3 (extraction) completed for video_id=%s in %.1fs — %d moments created",
video_id, elapsed, total_moments,
@ -339,6 +411,7 @@ def stage3_extraction(self, video_id: str) -> str:
raise
except Exception as exc:
session.rollback()
_emit_event(video_id, "stage3_extraction", "error", payload={"error": str(exc)})
logger.error("Stage 3 failed for video_id=%s: %s", video_id, exc)
raise self.retry(exc=exc)
finally:
@ -361,6 +434,7 @@ def stage4_classification(self, video_id: str) -> str:
"""
start = time.monotonic()
logger.info("Stage 4 (classification) starting for video_id=%s", video_id)
_emit_event(video_id, "stage4_classification", "start")
session = _get_sync_session()
try:
@ -404,9 +478,11 @@ def stage4_classification(self, video_id: str) -> str:
llm = _get_llm_client()
model_override, modality = _get_stage_config(4)
logger.info("Stage 4 using model=%s, modality=%s", model_override or "default", modality)
raw = llm.complete(system_prompt, user_prompt, response_model=ClassificationResult,
modality=modality, model_override=model_override)
hard_limit = get_settings().llm_max_tokens_hard_limit
max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage4_classification", hard_limit=hard_limit)
logger.info("Stage 4 using model=%s, modality=%s, max_tokens=%d", model_override or "default", modality, max_tokens)
raw = llm.complete(system_prompt, user_prompt, response_model=ClassificationResult, on_complete=_make_llm_callback(video_id, "stage4_classification"),
modality=modality, model_override=model_override, max_tokens=max_tokens)
result = _safe_parse_llm_response(raw, ClassificationResult, llm, system_prompt, user_prompt,
modality=modality, model_override=model_override)
@ -437,6 +513,7 @@ def stage4_classification(self, video_id: str) -> str:
_store_classification_data(video_id, classification_data)
elapsed = time.monotonic() - start
_emit_event(video_id, "stage4_classification", "complete")
logger.info(
"Stage 4 (classification) completed for video_id=%s in %.1fs — %d moments classified",
video_id, elapsed, len(classification_data),
@ -447,6 +524,7 @@ def stage4_classification(self, video_id: str) -> str:
raise
except Exception as exc:
session.rollback()
_emit_event(video_id, "stage4_classification", "error", payload={"error": str(exc)})
logger.error("Stage 4 failed for video_id=%s: %s", video_id, exc)
raise self.retry(exc=exc)
finally:
@ -539,6 +617,7 @@ def stage5_synthesis(self, video_id: str) -> str:
"""
start = time.monotonic()
logger.info("Stage 5 (synthesis) starting for video_id=%s", video_id)
_emit_event(video_id, "stage5_synthesis", "start")
settings = get_settings()
session = _get_sync_session()
@ -576,6 +655,7 @@ def stage5_synthesis(self, video_id: str) -> str:
system_prompt = _load_prompt("stage5_synthesis.txt")
llm = _get_llm_client()
model_override, modality = _get_stage_config(5)
hard_limit = get_settings().llm_max_tokens_hard_limit
logger.info("Stage 5 using model=%s, modality=%s", model_override or "default", modality)
pages_created = 0
@ -600,14 +680,36 @@ def stage5_synthesis(self, video_id: str) -> str:
user_prompt = f"<moments>\n{moments_text}\n</moments>"
raw = llm.complete(system_prompt, user_prompt, response_model=SynthesisResult,
modality=modality, model_override=model_override)
max_tokens = estimate_max_tokens(system_prompt, user_prompt, stage="stage5_synthesis", hard_limit=hard_limit)
raw = llm.complete(system_prompt, user_prompt, response_model=SynthesisResult, on_complete=_make_llm_callback(video_id, "stage5_synthesis"),
modality=modality, model_override=model_override, max_tokens=max_tokens)
result = _safe_parse_llm_response(raw, SynthesisResult, llm, system_prompt, user_prompt,
modality=modality, model_override=model_override)
# Load prior pages from this video (snapshot taken before pipeline reset)
prior_page_ids = _load_prior_pages(video_id)
# Create/update TechniquePage rows
for page_data in result.pages:
# Check if page with this slug already exists
existing = None
# First: check prior pages from this video by creator + category
if prior_page_ids:
existing = session.execute(
select(TechniquePage).where(
TechniquePage.id.in_(prior_page_ids),
TechniquePage.creator_id == video.creator_id,
TechniquePage.topic_category == (page_data.topic_category or category),
)
).scalar_one_or_none()
if existing:
logger.info(
"Stage 5: Matched prior page '%s' (id=%s) by creator+category for video_id=%s",
existing.slug, existing.id, video_id,
)
# Fallback: check by slug (handles cross-video dedup)
if existing is None:
existing = session.execute(
select(TechniquePage).where(TechniquePage.slug == page_data.slug)
).scalar_one_or_none()
@ -690,6 +792,7 @@ def stage5_synthesis(self, video_id: str) -> str:
session.commit()
elapsed = time.monotonic() - start
_emit_event(video_id, "stage5_synthesis", "complete")
logger.info(
"Stage 5 (synthesis) completed for video_id=%s in %.1fs — %d pages created/updated",
video_id, elapsed, pages_created,
@ -700,6 +803,7 @@ def stage5_synthesis(self, video_id: str) -> str:
raise
except Exception as exc:
session.rollback()
_emit_event(video_id, "stage5_synthesis", "error", payload={"error": str(exc)})
logger.error("Stage 5 failed for video_id=%s: %s", video_id, exc)
raise self.retry(exc=exc)
finally:
@ -837,6 +941,58 @@ def stage6_embed_and_index(self, video_id: str) -> str:
session.close()
def _snapshot_prior_pages(video_id: str) -> None:
"""Save existing technique_page_ids linked to this video before pipeline resets them.
When a video is reprocessed, stage 3 deletes and recreates key_moments,
breaking the link to technique pages. This snapshots the page IDs to Redis
so stage 5 can find and update prior pages instead of creating duplicates.
"""
import redis
session = _get_sync_session()
try:
# Find technique pages linked via this video's key moments
rows = session.execute(
select(KeyMoment.technique_page_id)
.where(
KeyMoment.source_video_id == video_id,
KeyMoment.technique_page_id.isnot(None),
)
.distinct()
).scalars().all()
page_ids = [str(pid) for pid in rows]
if page_ids:
settings = get_settings()
r = redis.Redis.from_url(settings.redis_url)
key = f"chrysopedia:prior_pages:{video_id}"
r.set(key, json.dumps(page_ids), ex=86400)
logger.info(
"Snapshot %d prior technique pages for video_id=%s: %s",
len(page_ids), video_id, page_ids,
)
else:
logger.info("No prior technique pages for video_id=%s", video_id)
finally:
session.close()
def _load_prior_pages(video_id: str) -> list[str]:
"""Load prior technique page IDs from Redis."""
import redis
settings = get_settings()
r = redis.Redis.from_url(settings.redis_url)
key = f"chrysopedia:prior_pages:{video_id}"
raw = r.get(key)
if raw is None:
return []
return json.loads(raw)
# ── Orchestrator ─────────────────────────────────────────────────────────────
@celery_app.task
@ -870,6 +1026,9 @@ def run_pipeline(video_id: str) -> str:
finally:
session.close()
# Snapshot prior technique pages before pipeline resets key_moments
_snapshot_prior_pages(video_id)
# Build the chain based on current status
stages = []
if status in (ProcessingStatus.pending, ProcessingStatus.transcribed):