chrysopedia/backend/pipeline/llm_client.py
jlightner 906b6491fe fix: static 96k max_tokens for all pipeline stages — dynamic estimator was truncating thinking model output
The dynamic token estimator calculated max_tokens from input size × stage ratio,
which produced ~9k for stage 5 compose calls. Thinking models consume unpredictable
budget for internal reasoning, leaving 0 visible output tokens.

Changed: hard_limit 32768→96000, estimate_max_tokens now returns hard_limit directly.
2026-04-03 08:18:28 +00:00

357 lines
14 KiB
Python

"""Synchronous LLM client with primary/fallback endpoint logic.
Uses the OpenAI-compatible API (works with Ollama, vLLM, OpenWebUI, etc.).
Celery tasks run synchronously, so this uses ``openai.OpenAI`` (not Async).
Supports two modalities:
- **chat**: Standard JSON mode with ``response_format: {"type": "json_object"}``
- **thinking**: For reasoning models that emit ``<think>...</think>`` blocks
before their answer. Skips ``response_format``, appends JSON instructions to
the system prompt, and strips think tags from the response.
"""
from __future__ import annotations
import logging
import re
from typing import TYPE_CHECKING, TypeVar
if TYPE_CHECKING:
from collections.abc import Callable
import openai
from pydantic import BaseModel
from config import Settings
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=BaseModel)
# ── LLM Response wrapper ─────────────────────────────────────────────────────
class LLMResponse(str):
"""String subclass that carries LLM response metadata.
Backward-compatible with all code that treats the response as a plain
string, but callers that know about it can inspect finish_reason and
the truncated property.
"""
finish_reason: str | None
prompt_tokens: int | None
completion_tokens: int | None
def __new__(
cls,
text: str,
finish_reason: str | None = None,
prompt_tokens: int | None = None,
completion_tokens: int | None = None,
):
obj = super().__new__(cls, text)
obj.finish_reason = finish_reason
obj.prompt_tokens = prompt_tokens
obj.completion_tokens = completion_tokens
return obj
@property
def truncated(self) -> bool:
"""True if the model hit its token limit before finishing."""
return self.finish_reason == "length"
# ── Think-tag stripping ──────────────────────────────────────────────────────
_THINK_PATTERN = re.compile(r"<think>.*?</think>", re.DOTALL)
def strip_think_tags(text: str) -> str:
"""Remove ``<think>...</think>`` blocks from LLM output.
Thinking/reasoning models often prefix their JSON with a reasoning trace
wrapped in ``<think>`` tags. This strips all such blocks (including
multiline and multiple occurrences) and returns the cleaned text.
Handles:
- Single ``<think>...</think>`` block
- Multiple blocks in one response
- Multiline content inside think tags
- Responses with no think tags (passthrough)
- Empty input (passthrough)
"""
if not text:
return text
cleaned = _THINK_PATTERN.sub("", text)
return cleaned.strip()
# ── Token estimation ─────────────────────────────────────────────────────────
# Stage-specific output multipliers: estimated output tokens as a ratio of input tokens.
# Tuned from actual pipeline data (KCL Ep 31 audit, April 2026):
# stage2: actual compl/prompt = 680/39312 = 0.017 → use 0.05 with buffer
# stage3: actual compl/prompt ≈ 1000/7000 = 0.14 → use 0.3 with buffer
# stage4: actual compl/prompt = 740/3736 = 0.20 → use 0.3 with buffer
# stage5: actual compl/prompt ≈ 2500/7000 = 0.36 → use 0.8 with buffer
_STAGE_OUTPUT_RATIOS: dict[str, float] = {
"stage2_segmentation": 0.05, # Compact topic groups — much smaller than input
"stage3_extraction": 0.3, # Key moments with summaries — moderate
"stage4_classification": 0.3, # Tags + categories per moment
"stage5_synthesis": 0.8, # Full prose technique pages — heaviest output
}
# Minimum floor so we never send a trivially small max_tokens
_MIN_MAX_TOKENS = 4096
def estimate_tokens(text: str) -> int:
"""Estimate token count from text using a chars-per-token heuristic.
Uses 3.5 chars/token which is conservative for English + JSON markup.
"""
if not text:
return 0
return max(1, int(len(text) / 3.5))
def estimate_max_tokens(
system_prompt: str,
user_prompt: str,
stage: str | None = None,
hard_limit: int = 32768,
) -> int:
"""Return the hard_limit as max_tokens for all stages.
Previously used dynamic estimation based on input size and stage-specific
multipliers, but thinking models consume unpredictable token budgets for
internal reasoning. A static ceiling avoids truncation errors.
The hard_limit value comes from Settings.llm_max_tokens_hard_limit (96000).
"""
input_tokens = estimate_tokens(system_prompt) + estimate_tokens(user_prompt)
logger.info(
"Token estimate: input≈%d, stage=%s, max_tokens=%d (static hard_limit)",
input_tokens, stage or "default", hard_limit,
)
return hard_limit
class LLMClient:
"""Sync LLM client that tries a primary endpoint and falls back on failure."""
def __init__(self, settings: Settings) -> None:
self.settings = settings
self._primary = openai.OpenAI(
base_url=settings.llm_api_url,
api_key=settings.llm_api_key,
)
self._fallback = openai.OpenAI(
base_url=settings.llm_fallback_url,
api_key=settings.llm_api_key,
)
# ── Core completion ──────────────────────────────────────────────────
def complete(
self,
system_prompt: str,
user_prompt: str,
response_model: type[BaseModel] | None = None,
modality: str = "chat",
model_override: str | None = None,
on_complete: "Callable | None" = None,
max_tokens: int | None = None,
) -> "LLMResponse":
"""Send a chat completion request, falling back on connection/timeout errors.
Parameters
----------
system_prompt:
System message content.
user_prompt:
User message content.
response_model:
If provided and modality is "chat", ``response_format`` is set to
``{"type": "json_object"}``. For "thinking" modality, JSON
instructions are appended to the system prompt instead.
modality:
Either "chat" (default) or "thinking". Thinking modality skips
response_format and strips ``<think>`` tags from output.
model_override:
Model name to use instead of the default. If None, uses the
configured default for the endpoint.
max_tokens:
Override for max_tokens on this call. If None, falls back to
the configured ``llm_max_tokens`` from settings.
Returns
-------
LLMResponse
Raw completion text (str subclass) with finish_reason metadata.
"""
kwargs: dict = {}
effective_system = system_prompt
if modality == "thinking":
# Thinking models often don't support response_format: json_object.
# Instead, append explicit JSON instructions to the system prompt.
if response_model is not None:
json_schema_hint = (
"\n\nYou MUST respond with ONLY valid JSON. "
"No markdown code fences, no explanation, no preamble — "
"just the raw JSON object."
)
effective_system = system_prompt + json_schema_hint
else:
# Chat modality — use standard JSON mode
if response_model is not None:
kwargs["response_format"] = {"type": "json_object"}
messages = [
{"role": "system", "content": effective_system},
{"role": "user", "content": user_prompt},
]
primary_model = model_override or self.settings.llm_model
fallback_model = self.settings.llm_fallback_model
effective_max_tokens = max_tokens if max_tokens is not None else self.settings.llm_max_tokens
effective_temperature = self.settings.llm_temperature
logger.info(
"LLM request: model=%s, modality=%s, response_model=%s, max_tokens=%d, temperature=%.1f",
primary_model,
modality,
response_model.__name__ if response_model else None,
effective_max_tokens,
effective_temperature,
)
# --- Try primary endpoint ---
try:
response = self._primary.chat.completions.create(
model=primary_model,
messages=messages,
max_tokens=effective_max_tokens,
temperature=effective_temperature,
**kwargs,
)
raw = response.choices[0].message.content or ""
usage = getattr(response, "usage", None)
if usage:
logger.info(
"LLM response: prompt_tokens=%s, completion_tokens=%s, total=%s, content_len=%d, finish=%s",
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens,
len(raw), response.choices[0].finish_reason,
)
if modality == "thinking":
raw = strip_think_tags(raw)
finish = response.choices[0].finish_reason if response.choices else None
if on_complete is not None:
try:
on_complete(
model=primary_model,
prompt_tokens=usage.prompt_tokens if usage else None,
completion_tokens=usage.completion_tokens if usage else None,
total_tokens=usage.total_tokens if usage else None,
content=raw,
finish_reason=finish,
)
except Exception as cb_exc:
logger.warning("on_complete callback failed: %s", cb_exc)
return LLMResponse(
raw,
finish_reason=finish,
prompt_tokens=usage.prompt_tokens if usage else None,
completion_tokens=usage.completion_tokens if usage else None,
)
except (openai.APIConnectionError, openai.APITimeoutError) as exc:
logger.warning(
"Primary LLM endpoint failed (%s: %s), trying fallback at %s",
type(exc).__name__,
exc,
self.settings.llm_fallback_url,
)
# --- Try fallback endpoint ---
try:
response = self._fallback.chat.completions.create(
model=fallback_model,
messages=messages,
max_tokens=effective_max_tokens,
temperature=effective_temperature,
**kwargs,
)
raw = response.choices[0].message.content or ""
usage = getattr(response, "usage", None)
if usage:
logger.info(
"LLM response (fallback): prompt_tokens=%s, completion_tokens=%s, total=%s, content_len=%d, finish=%s",
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens,
len(raw), response.choices[0].finish_reason,
)
if modality == "thinking":
raw = strip_think_tags(raw)
finish = response.choices[0].finish_reason if response.choices else None
if on_complete is not None:
try:
on_complete(
model=fallback_model,
prompt_tokens=usage.prompt_tokens if usage else None,
completion_tokens=usage.completion_tokens if usage else None,
total_tokens=usage.total_tokens if usage else None,
content=raw,
finish_reason=finish,
is_fallback=True,
)
except Exception as cb_exc:
logger.warning("on_complete callback failed: %s", cb_exc)
return LLMResponse(
raw,
finish_reason=finish,
prompt_tokens=usage.prompt_tokens if usage else None,
completion_tokens=usage.completion_tokens if usage else None,
)
except (openai.APIConnectionError, openai.APITimeoutError, openai.APIError) as exc:
logger.error(
"Fallback LLM endpoint also failed (%s: %s). Giving up.",
type(exc).__name__,
exc,
)
raise
# ── Response parsing ─────────────────────────────────────────────────
def parse_response(self, text: str, model: type[T]) -> T:
"""Parse raw LLM output as JSON and validate against a Pydantic model.
Parameters
----------
text:
Raw JSON string from the LLM.
model:
Pydantic model class to validate against.
Returns
-------
T
Validated Pydantic model instance.
Raises
------
pydantic.ValidationError
If the JSON doesn't match the schema.
ValueError
If the text is not valid JSON.
"""
try:
return model.model_validate_json(text)
except Exception:
logger.error(
"Failed to parse LLM response as %s. Response text: %.500s",
model.__name__,
text,
)
raise