"""Synchronous LLM client with primary/fallback endpoint logic. Uses the OpenAI-compatible API (works with Ollama, vLLM, OpenWebUI, etc.). Celery tasks run synchronously, so this uses ``openai.OpenAI`` (not Async). Supports two modalities: - **chat**: Standard JSON mode with ``response_format: {"type": "json_object"}`` - **thinking**: For reasoning models that emit ``...`` blocks before their answer. Skips ``response_format``, appends JSON instructions to the system prompt, and strips think tags from the response. """ from __future__ import annotations import logging import re from typing import TYPE_CHECKING, TypeVar if TYPE_CHECKING: from collections.abc import Callable import openai from pydantic import BaseModel from config import Settings logger = logging.getLogger(__name__) T = TypeVar("T", bound=BaseModel) # ── Think-tag stripping ────────────────────────────────────────────────────── _THINK_PATTERN = re.compile(r".*?", re.DOTALL) def strip_think_tags(text: str) -> str: """Remove ``...`` blocks from LLM output. Thinking/reasoning models often prefix their JSON with a reasoning trace wrapped in ```` tags. This strips all such blocks (including multiline and multiple occurrences) and returns the cleaned text. Handles: - Single ``...`` block - Multiple blocks in one response - Multiline content inside think tags - Responses with no think tags (passthrough) - Empty input (passthrough) """ if not text: return text cleaned = _THINK_PATTERN.sub("", text) return cleaned.strip() # ── Token estimation ───────────────────────────────────────────────────────── # Stage-specific output multipliers: estimated output tokens as a ratio of input tokens. # These are empirically tuned based on observed pipeline behavior. _STAGE_OUTPUT_RATIOS: dict[str, float] = { "stage2_segmentation": 0.3, # Compact topic groups — much smaller than input "stage3_extraction": 1.2, # Detailed moments with summaries — can exceed input "stage4_classification": 0.15, # Index + category + tags per moment — very compact "stage5_synthesis": 1.5, # Full prose technique pages — heaviest output } # Minimum floor so we never send a trivially small max_tokens _MIN_MAX_TOKENS = 2048 def estimate_tokens(text: str) -> int: """Estimate token count from text using a chars-per-token heuristic. Uses 3.5 chars/token which is conservative for English + JSON markup. """ if not text: return 0 return max(1, int(len(text) / 3.5)) def estimate_max_tokens( system_prompt: str, user_prompt: str, stage: str | None = None, hard_limit: int = 32768, ) -> int: """Estimate the max_tokens parameter for an LLM call. Calculates expected output size based on input size and stage-specific multipliers. The result is clamped between _MIN_MAX_TOKENS and hard_limit. Parameters ---------- system_prompt: The system prompt text. user_prompt: The user prompt text (transcript, moments, etc.). stage: Pipeline stage name (e.g. "stage3_extraction"). If None or unknown, uses a default 1.0x multiplier. hard_limit: Absolute ceiling — never exceed this value. Returns ------- int Estimated max_tokens value to pass to the LLM API. """ input_tokens = estimate_tokens(system_prompt) + estimate_tokens(user_prompt) ratio = _STAGE_OUTPUT_RATIOS.get(stage or "", 1.0) estimated_output = int(input_tokens * ratio) # Add a 20% buffer for JSON overhead and variability estimated_output = int(estimated_output * 1.2) # Clamp to [_MIN_MAX_TOKENS, hard_limit] result = max(_MIN_MAX_TOKENS, min(estimated_output, hard_limit)) logger.info( "Token estimate: input≈%d, stage=%s, ratio=%.2f, estimated_output=%d, max_tokens=%d (hard_limit=%d)", input_tokens, stage or "default", ratio, estimated_output, result, hard_limit, ) return result class LLMClient: """Sync LLM client that tries a primary endpoint and falls back on failure.""" def __init__(self, settings: Settings) -> None: self.settings = settings self._primary = openai.OpenAI( base_url=settings.llm_api_url, api_key=settings.llm_api_key, ) self._fallback = openai.OpenAI( base_url=settings.llm_fallback_url, api_key=settings.llm_api_key, ) # ── Core completion ────────────────────────────────────────────────── def complete( self, system_prompt: str, user_prompt: str, response_model: type[BaseModel] | None = None, modality: str = "chat", model_override: str | None = None, on_complete: "Callable | None" = None, max_tokens: int | None = None, ) -> str: """Send a chat completion request, falling back on connection/timeout errors. Parameters ---------- system_prompt: System message content. user_prompt: User message content. response_model: If provided and modality is "chat", ``response_format`` is set to ``{"type": "json_object"}``. For "thinking" modality, JSON instructions are appended to the system prompt instead. modality: Either "chat" (default) or "thinking". Thinking modality skips response_format and strips ```` tags from output. model_override: Model name to use instead of the default. If None, uses the configured default for the endpoint. max_tokens: Override for max_tokens on this call. If None, falls back to the configured ``llm_max_tokens`` from settings. Returns ------- str Raw completion text from the model (think tags stripped if thinking). """ kwargs: dict = {} effective_system = system_prompt if modality == "thinking": # Thinking models often don't support response_format: json_object. # Instead, append explicit JSON instructions to the system prompt. if response_model is not None: json_schema_hint = ( "\n\nYou MUST respond with ONLY valid JSON. " "No markdown code fences, no explanation, no preamble — " "just the raw JSON object." ) effective_system = system_prompt + json_schema_hint else: # Chat modality — use standard JSON mode if response_model is not None: kwargs["response_format"] = {"type": "json_object"} messages = [ {"role": "system", "content": effective_system}, {"role": "user", "content": user_prompt}, ] primary_model = model_override or self.settings.llm_model fallback_model = self.settings.llm_fallback_model effective_max_tokens = max_tokens if max_tokens is not None else self.settings.llm_max_tokens logger.info( "LLM request: model=%s, modality=%s, response_model=%s, max_tokens=%d", primary_model, modality, response_model.__name__ if response_model else None, effective_max_tokens, ) # --- Try primary endpoint --- try: response = self._primary.chat.completions.create( model=primary_model, messages=messages, max_tokens=effective_max_tokens, **kwargs, ) raw = response.choices[0].message.content or "" usage = getattr(response, "usage", None) if usage: logger.info( "LLM response: prompt_tokens=%s, completion_tokens=%s, total=%s, content_len=%d, finish=%s", usage.prompt_tokens, usage.completion_tokens, usage.total_tokens, len(raw), response.choices[0].finish_reason, ) if modality == "thinking": raw = strip_think_tags(raw) if on_complete is not None: try: on_complete( model=primary_model, prompt_tokens=usage.prompt_tokens if usage else None, completion_tokens=usage.completion_tokens if usage else None, total_tokens=usage.total_tokens if usage else None, content=raw, finish_reason=response.choices[0].finish_reason if response.choices else None, ) except Exception as cb_exc: logger.warning("on_complete callback failed: %s", cb_exc) return raw except (openai.APIConnectionError, openai.APITimeoutError) as exc: logger.warning( "Primary LLM endpoint failed (%s: %s), trying fallback at %s", type(exc).__name__, exc, self.settings.llm_fallback_url, ) # --- Try fallback endpoint --- try: response = self._fallback.chat.completions.create( model=fallback_model, messages=messages, max_tokens=effective_max_tokens, **kwargs, ) raw = response.choices[0].message.content or "" usage = getattr(response, "usage", None) if usage: logger.info( "LLM response (fallback): prompt_tokens=%s, completion_tokens=%s, total=%s, content_len=%d, finish=%s", usage.prompt_tokens, usage.completion_tokens, usage.total_tokens, len(raw), response.choices[0].finish_reason, ) if modality == "thinking": raw = strip_think_tags(raw) if on_complete is not None: try: on_complete( model=fallback_model, prompt_tokens=usage.prompt_tokens if usage else None, completion_tokens=usage.completion_tokens if usage else None, total_tokens=usage.total_tokens if usage else None, content=raw, finish_reason=response.choices[0].finish_reason if response.choices else None, is_fallback=True, ) except Exception as cb_exc: logger.warning("on_complete callback failed: %s", cb_exc) return raw except (openai.APIConnectionError, openai.APITimeoutError, openai.APIError) as exc: logger.error( "Fallback LLM endpoint also failed (%s: %s). Giving up.", type(exc).__name__, exc, ) raise # ── Response parsing ───────────────────────────────────────────────── def parse_response(self, text: str, model: type[T]) -> T: """Parse raw LLM output as JSON and validate against a Pydantic model. Parameters ---------- text: Raw JSON string from the LLM. model: Pydantic model class to validate against. Returns ------- T Validated Pydantic model instance. Raises ------ pydantic.ValidationError If the JSON doesn't match the schema. ValueError If the text is not valid JSON. """ try: return model.model_validate_json(text) except Exception: logger.error( "Failed to parse LLM response as %s. Response text: %.500s", model.__name__, text, ) raise