feat: Per-stage LLM model routing with thinking modality and think-tag stripping

- Added 8 per-stage config fields: llm_stage{2-5}_model and llm_stage{2-5}_modality
- LLMClient.complete() accepts modality ('chat'/'thinking') and model_override
- Thinking modality: appends JSON instructions to system prompt, strips <think> tags
- strip_think_tags() handles multiline, multiple blocks, and edge cases
- Pipeline stages 2-5 read per-stage config and pass to LLM client
- Updated .env.example with per-stage model/modality documentation
- All 59 tests pass including new think-tag stripping test
This commit is contained in:
jlightner 2026-03-30 02:12:14 +00:00
parent 9fdef3b720
commit 4aa4b08a7f
5 changed files with 187 additions and 20 deletions

View file

@ -16,6 +16,19 @@ LLM_MODEL=FYN-QWEN35
LLM_FALLBACK_URL=https://chat.forgetyour.name/api/v1
LLM_FALLBACK_MODEL=fyn-qwen35-chat
# Per-stage LLM model overrides (optional — defaults to LLM_MODEL)
# Modality: "chat" = standard JSON mode, "thinking" = reasoning model (strips <think> tags)
# Stages 2 (segmentation) and 4 (classification) are mechanical — use fast chat model
# Stages 3 (extraction) and 5 (synthesis) need reasoning — use thinking model
#LLM_STAGE2_MODEL=fyn-qwen35-chat
#LLM_STAGE2_MODALITY=chat
#LLM_STAGE3_MODEL=fyn-qwen35-thinking
#LLM_STAGE3_MODALITY=thinking
#LLM_STAGE4_MODEL=fyn-qwen35-chat
#LLM_STAGE4_MODALITY=chat
#LLM_STAGE5_MODEL=fyn-qwen35-thinking
#LLM_STAGE5_MODALITY=thinking
# Embedding endpoint (Ollama container in the compose stack)
EMBEDDING_API_URL=http://chrysopedia-ollama:11434/v1
EMBEDDING_MODEL=nomic-embed-text

View file

@ -33,6 +33,16 @@ class Settings(BaseSettings):
llm_fallback_url: str = "http://localhost:11434/v1"
llm_fallback_model: str = "qwen2.5:14b-q8_0"
# Per-stage model overrides (optional — falls back to llm_model / "chat")
llm_stage2_model: str | None = None # segmentation — fast chat model recommended
llm_stage2_modality: str = "chat" # "chat" or "thinking"
llm_stage3_model: str | None = None # extraction — thinking model recommended
llm_stage3_modality: str = "chat"
llm_stage4_model: str | None = None # classification — fast chat model recommended
llm_stage4_modality: str = "chat"
llm_stage5_model: str | None = None # synthesis — thinking model recommended
llm_stage5_modality: str = "chat"
# Embedding endpoint
embedding_api_url: str = "http://localhost:11434/v1"
embedding_model: str = "nomic-embed-text"

View file

@ -2,11 +2,18 @@
Uses the OpenAI-compatible API (works with Ollama, vLLM, OpenWebUI, etc.).
Celery tasks run synchronously, so this uses ``openai.OpenAI`` (not Async).
Supports two modalities:
- **chat**: Standard JSON mode with ``response_format: {"type": "json_object"}``
- **thinking**: For reasoning models that emit ``<think>...</think>`` blocks
before their answer. Skips ``response_format``, appends JSON instructions to
the system prompt, and strips think tags from the response.
"""
from __future__ import annotations
import logging
import re
from typing import TypeVar
import openai
@ -18,6 +25,30 @@ logger = logging.getLogger(__name__)
T = TypeVar("T", bound=BaseModel)
# ── Think-tag stripping ──────────────────────────────────────────────────────
_THINK_PATTERN = re.compile(r"<think>.*?</think>", re.DOTALL)
def strip_think_tags(text: str) -> str:
"""Remove ``<think>...</think>`` blocks from LLM output.
Thinking/reasoning models often prefix their JSON with a reasoning trace
wrapped in ``<think>`` tags. This strips all such blocks (including
multiline and multiple occurrences) and returns the cleaned text.
Handles:
- Single ``<think>...</think>`` block
- Multiple blocks in one response
- Multiline content inside think tags
- Responses with no think tags (passthrough)
- Empty input (passthrough)
"""
if not text:
return text
cleaned = _THINK_PATTERN.sub("", text)
return cleaned.strip()
class LLMClient:
"""Sync LLM client that tries a primary endpoint and falls back on failure."""
@ -40,6 +71,8 @@ class LLMClient:
system_prompt: str,
user_prompt: str,
response_model: type[BaseModel] | None = None,
modality: str = "chat",
model_override: str | None = None,
) -> str:
"""Send a chat completion request, falling back on connection/timeout errors.
@ -50,31 +83,65 @@ class LLMClient:
user_prompt:
User message content.
response_model:
If provided, ``response_format`` is set to ``{"type": "json_object"}``
so the LLM returns parseable JSON.
If provided and modality is "chat", ``response_format`` is set to
``{"type": "json_object"}``. For "thinking" modality, JSON
instructions are appended to the system prompt instead.
modality:
Either "chat" (default) or "thinking". Thinking modality skips
response_format and strips ``<think>`` tags from output.
model_override:
Model name to use instead of the default. If None, uses the
configured default for the endpoint.
Returns
-------
str
Raw completion text from the model.
Raw completion text from the model (think tags stripped if thinking).
"""
kwargs: dict = {}
effective_system = system_prompt
if modality == "thinking":
# Thinking models often don't support response_format: json_object.
# Instead, append explicit JSON instructions to the system prompt.
if response_model is not None:
json_schema_hint = (
"\n\nYou MUST respond with ONLY valid JSON. "
"No markdown code fences, no explanation, no preamble — "
"just the raw JSON object."
)
effective_system = system_prompt + json_schema_hint
else:
# Chat modality — use standard JSON mode
if response_model is not None:
kwargs["response_format"] = {"type": "json_object"}
messages = [
{"role": "system", "content": system_prompt},
{"role": "system", "content": effective_system},
{"role": "user", "content": user_prompt},
]
primary_model = model_override or self.settings.llm_model
fallback_model = self.settings.llm_fallback_model
logger.info(
"LLM request: model=%s, modality=%s, response_model=%s",
primary_model,
modality,
response_model.__name__ if response_model else None,
)
# --- Try primary endpoint ---
try:
response = self._primary.chat.completions.create(
model=self.settings.llm_model,
model=primary_model,
messages=messages,
**kwargs,
)
return response.choices[0].message.content or ""
raw = response.choices[0].message.content or ""
if modality == "thinking":
raw = strip_think_tags(raw)
return raw
except (openai.APIConnectionError, openai.APITimeoutError) as exc:
logger.warning(
@ -87,11 +154,14 @@ class LLMClient:
# --- Try fallback endpoint ---
try:
response = self._fallback.chat.completions.create(
model=self.settings.llm_fallback_model,
model=fallback_model,
messages=messages,
**kwargs,
)
return response.choices[0].message.content or ""
raw = response.choices[0].message.content or ""
if modality == "thinking":
raw = strip_think_tags(raw)
return raw
except (openai.APIConnectionError, openai.APITimeoutError, openai.APIError) as exc:
logger.error(

View file

@ -87,6 +87,19 @@ def _get_llm_client() -> LLMClient:
return LLMClient(get_settings())
def _get_stage_config(stage_num: int) -> tuple[str | None, str]:
"""Return (model_override, modality) for a pipeline stage.
Reads stage-specific config from Settings. If the stage-specific model
is None/empty, returns None (LLMClient will use its default). If the
stage-specific modality is unset, defaults to "chat".
"""
settings = get_settings()
model = getattr(settings, f"llm_stage{stage_num}_model", None) or None
modality = getattr(settings, f"llm_stage{stage_num}_modality", None) or "chat"
return model, modality
def _load_canonical_tags() -> dict:
"""Load canonical tag taxonomy from config/canonical_tags.yaml."""
# Walk up from backend/ to find config/
@ -114,7 +127,15 @@ def _format_taxonomy_for_prompt(tags_data: dict) -> str:
return "\n".join(lines)
def _safe_parse_llm_response(raw: str, model_cls, llm: LLMClient, system_prompt: str, user_prompt: str):
def _safe_parse_llm_response(
raw: str,
model_cls,
llm: LLMClient,
system_prompt: str,
user_prompt: str,
modality: str = "chat",
model_override: str | None = None,
):
"""Parse LLM response with one retry on failure.
On malformed response: log the raw text, retry once with a JSON nudge,
@ -132,7 +153,10 @@ def _safe_parse_llm_response(raw: str, model_cls, llm: LLMClient, system_prompt:
)
# Retry with explicit JSON instruction
nudge_prompt = user_prompt + "\n\nIMPORTANT: Output ONLY valid JSON. No markdown, no explanation."
retry_raw = llm.complete(system_prompt, nudge_prompt, response_model=model_cls)
retry_raw = llm.complete(
system_prompt, nudge_prompt, response_model=model_cls,
modality=modality, model_override=model_override,
)
return llm.parse_response(retry_raw, model_cls)
@ -180,8 +204,12 @@ def stage2_segmentation(self, video_id: str) -> str:
user_prompt = f"<transcript>\n{transcript_text}\n</transcript>"
llm = _get_llm_client()
raw = llm.complete(system_prompt, user_prompt, response_model=SegmentationResult)
result = _safe_parse_llm_response(raw, SegmentationResult, llm, system_prompt, user_prompt)
model_override, modality = _get_stage_config(2)
logger.info("Stage 2 using model=%s, modality=%s", model_override or "default", modality)
raw = llm.complete(system_prompt, user_prompt, response_model=SegmentationResult,
modality=modality, model_override=model_override)
result = _safe_parse_llm_response(raw, SegmentationResult, llm, system_prompt, user_prompt,
modality=modality, model_override=model_override)
# Update topic_label on each segment row
seg_by_index = {s.segment_index: s for s in segments}
@ -247,6 +275,8 @@ def stage3_extraction(self, video_id: str) -> str:
system_prompt = _load_prompt("stage3_extraction.txt")
llm = _get_llm_client()
model_override, modality = _get_stage_config(3)
logger.info("Stage 3 using model=%s, modality=%s", model_override or "default", modality)
total_moments = 0
for topic_label, group_segs in groups.items():
@ -263,8 +293,10 @@ def stage3_extraction(self, video_id: str) -> str:
f"<segment>\n{segment_text}\n</segment>"
)
raw = llm.complete(system_prompt, user_prompt, response_model=ExtractionResult)
result = _safe_parse_llm_response(raw, ExtractionResult, llm, system_prompt, user_prompt)
raw = llm.complete(system_prompt, user_prompt, response_model=ExtractionResult,
modality=modality, model_override=model_override)
result = _safe_parse_llm_response(raw, ExtractionResult, llm, system_prompt, user_prompt,
modality=modality, model_override=model_override)
# Create KeyMoment rows
for moment in result.moments:
@ -369,8 +401,12 @@ def stage4_classification(self, video_id: str) -> str:
)
llm = _get_llm_client()
raw = llm.complete(system_prompt, user_prompt, response_model=ClassificationResult)
result = _safe_parse_llm_response(raw, ClassificationResult, llm, system_prompt, user_prompt)
model_override, modality = _get_stage_config(4)
logger.info("Stage 4 using model=%s, modality=%s", model_override or "default", modality)
raw = llm.complete(system_prompt, user_prompt, response_model=ClassificationResult,
modality=modality, model_override=model_override)
result = _safe_parse_llm_response(raw, ClassificationResult, llm, system_prompt, user_prompt,
modality=modality, model_override=model_override)
# Apply content_type overrides and prepare classification data for stage 5
classification_data = []
@ -490,6 +526,8 @@ def stage5_synthesis(self, video_id: str) -> str:
system_prompt = _load_prompt("stage5_synthesis.txt")
llm = _get_llm_client()
model_override, modality = _get_stage_config(5)
logger.info("Stage 5 using model=%s, modality=%s", model_override or "default", modality)
pages_created = 0
for category, moment_group in groups.items():
@ -513,8 +551,10 @@ def stage5_synthesis(self, video_id: str) -> str:
user_prompt = f"<moments>\n{moments_text}\n</moments>"
raw = llm.complete(system_prompt, user_prompt, response_model=SynthesisResult)
result = _safe_parse_llm_response(raw, SynthesisResult, llm, system_prompt, user_prompt)
raw = llm.complete(system_prompt, user_prompt, response_model=SynthesisResult,
modality=modality, model_override=model_override)
result = _safe_parse_llm_response(raw, SynthesisResult, llm, system_prompt, user_prompt,
modality=modality, model_override=model_override)
# Create/update TechniquePage rows
for page_data in result.pages:

View file

@ -737,3 +737,37 @@ def test_llm_fallback_on_primary_failure():
assert result == '{"result": "ok"}'
primary_client.chat.completions.create.assert_called_once()
fallback_client.chat.completions.create.assert_called_once()
# ── Think-tag stripping ─────────────────────────────────────────────────────
def test_strip_think_tags():
"""strip_think_tags should handle all edge cases correctly."""
from pipeline.llm_client import strip_think_tags
# Single block with JSON after
assert strip_think_tags('<think>reasoning here</think>{"a": 1}') == '{"a": 1}'
# Multiline think block
assert strip_think_tags(
'<think>\nI need to analyze this.\nLet me think step by step.\n</think>\n{"result": "ok"}'
) == '{"result": "ok"}'
# Multiple think blocks
result = strip_think_tags('<think>first</think>hello<think>second</think> world')
assert result == "hello world"
# No think tags — passthrough
assert strip_think_tags('{"clean": true}') == '{"clean": true}'
# Empty string
assert strip_think_tags("") == ""
# Think block with special characters
assert strip_think_tags(
'<think>analyzing "complex" <data> & stuff</think>{"done": true}'
) == '{"done": true}'
# Only a think block, no actual content
assert strip_think_tags("<think>just thinking</think>") == ""