feat: Per-stage LLM model routing with thinking modality and think-tag stripping
- Added 8 per-stage config fields: llm_stage{2-5}_model and llm_stage{2-5}_modality
- LLMClient.complete() accepts modality ('chat'/'thinking') and model_override
- Thinking modality: appends JSON instructions to system prompt, strips <think> tags
- strip_think_tags() handles multiline, multiple blocks, and edge cases
- Pipeline stages 2-5 read per-stage config and pass to LLM client
- Updated .env.example with per-stage model/modality documentation
- All 59 tests pass including new think-tag stripping test
This commit is contained in:
parent
9fdef3b720
commit
4aa4b08a7f
5 changed files with 187 additions and 20 deletions
13
.env.example
13
.env.example
|
|
@ -16,6 +16,19 @@ LLM_MODEL=FYN-QWEN35
|
||||||
LLM_FALLBACK_URL=https://chat.forgetyour.name/api/v1
|
LLM_FALLBACK_URL=https://chat.forgetyour.name/api/v1
|
||||||
LLM_FALLBACK_MODEL=fyn-qwen35-chat
|
LLM_FALLBACK_MODEL=fyn-qwen35-chat
|
||||||
|
|
||||||
|
# Per-stage LLM model overrides (optional — defaults to LLM_MODEL)
|
||||||
|
# Modality: "chat" = standard JSON mode, "thinking" = reasoning model (strips <think> tags)
|
||||||
|
# Stages 2 (segmentation) and 4 (classification) are mechanical — use fast chat model
|
||||||
|
# Stages 3 (extraction) and 5 (synthesis) need reasoning — use thinking model
|
||||||
|
#LLM_STAGE2_MODEL=fyn-qwen35-chat
|
||||||
|
#LLM_STAGE2_MODALITY=chat
|
||||||
|
#LLM_STAGE3_MODEL=fyn-qwen35-thinking
|
||||||
|
#LLM_STAGE3_MODALITY=thinking
|
||||||
|
#LLM_STAGE4_MODEL=fyn-qwen35-chat
|
||||||
|
#LLM_STAGE4_MODALITY=chat
|
||||||
|
#LLM_STAGE5_MODEL=fyn-qwen35-thinking
|
||||||
|
#LLM_STAGE5_MODALITY=thinking
|
||||||
|
|
||||||
# Embedding endpoint (Ollama container in the compose stack)
|
# Embedding endpoint (Ollama container in the compose stack)
|
||||||
EMBEDDING_API_URL=http://chrysopedia-ollama:11434/v1
|
EMBEDDING_API_URL=http://chrysopedia-ollama:11434/v1
|
||||||
EMBEDDING_MODEL=nomic-embed-text
|
EMBEDDING_MODEL=nomic-embed-text
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,16 @@ class Settings(BaseSettings):
|
||||||
llm_fallback_url: str = "http://localhost:11434/v1"
|
llm_fallback_url: str = "http://localhost:11434/v1"
|
||||||
llm_fallback_model: str = "qwen2.5:14b-q8_0"
|
llm_fallback_model: str = "qwen2.5:14b-q8_0"
|
||||||
|
|
||||||
|
# Per-stage model overrides (optional — falls back to llm_model / "chat")
|
||||||
|
llm_stage2_model: str | None = None # segmentation — fast chat model recommended
|
||||||
|
llm_stage2_modality: str = "chat" # "chat" or "thinking"
|
||||||
|
llm_stage3_model: str | None = None # extraction — thinking model recommended
|
||||||
|
llm_stage3_modality: str = "chat"
|
||||||
|
llm_stage4_model: str | None = None # classification — fast chat model recommended
|
||||||
|
llm_stage4_modality: str = "chat"
|
||||||
|
llm_stage5_model: str | None = None # synthesis — thinking model recommended
|
||||||
|
llm_stage5_modality: str = "chat"
|
||||||
|
|
||||||
# Embedding endpoint
|
# Embedding endpoint
|
||||||
embedding_api_url: str = "http://localhost:11434/v1"
|
embedding_api_url: str = "http://localhost:11434/v1"
|
||||||
embedding_model: str = "nomic-embed-text"
|
embedding_model: str = "nomic-embed-text"
|
||||||
|
|
|
||||||
|
|
@ -2,11 +2,18 @@
|
||||||
|
|
||||||
Uses the OpenAI-compatible API (works with Ollama, vLLM, OpenWebUI, etc.).
|
Uses the OpenAI-compatible API (works with Ollama, vLLM, OpenWebUI, etc.).
|
||||||
Celery tasks run synchronously, so this uses ``openai.OpenAI`` (not Async).
|
Celery tasks run synchronously, so this uses ``openai.OpenAI`` (not Async).
|
||||||
|
|
||||||
|
Supports two modalities:
|
||||||
|
- **chat**: Standard JSON mode with ``response_format: {"type": "json_object"}``
|
||||||
|
- **thinking**: For reasoning models that emit ``<think>...</think>`` blocks
|
||||||
|
before their answer. Skips ``response_format``, appends JSON instructions to
|
||||||
|
the system prompt, and strips think tags from the response.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
|
@ -18,6 +25,30 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
# ── Think-tag stripping ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_THINK_PATTERN = re.compile(r"<think>.*?</think>", re.DOTALL)
|
||||||
|
|
||||||
|
|
||||||
|
def strip_think_tags(text: str) -> str:
|
||||||
|
"""Remove ``<think>...</think>`` blocks from LLM output.
|
||||||
|
|
||||||
|
Thinking/reasoning models often prefix their JSON with a reasoning trace
|
||||||
|
wrapped in ``<think>`` tags. This strips all such blocks (including
|
||||||
|
multiline and multiple occurrences) and returns the cleaned text.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- Single ``<think>...</think>`` block
|
||||||
|
- Multiple blocks in one response
|
||||||
|
- Multiline content inside think tags
|
||||||
|
- Responses with no think tags (passthrough)
|
||||||
|
- Empty input (passthrough)
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return text
|
||||||
|
cleaned = _THINK_PATTERN.sub("", text)
|
||||||
|
return cleaned.strip()
|
||||||
|
|
||||||
|
|
||||||
class LLMClient:
|
class LLMClient:
|
||||||
"""Sync LLM client that tries a primary endpoint and falls back on failure."""
|
"""Sync LLM client that tries a primary endpoint and falls back on failure."""
|
||||||
|
|
@ -40,6 +71,8 @@ class LLMClient:
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
user_prompt: str,
|
user_prompt: str,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
|
modality: str = "chat",
|
||||||
|
model_override: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Send a chat completion request, falling back on connection/timeout errors.
|
"""Send a chat completion request, falling back on connection/timeout errors.
|
||||||
|
|
||||||
|
|
@ -50,31 +83,65 @@ class LLMClient:
|
||||||
user_prompt:
|
user_prompt:
|
||||||
User message content.
|
User message content.
|
||||||
response_model:
|
response_model:
|
||||||
If provided, ``response_format`` is set to ``{"type": "json_object"}``
|
If provided and modality is "chat", ``response_format`` is set to
|
||||||
so the LLM returns parseable JSON.
|
``{"type": "json_object"}``. For "thinking" modality, JSON
|
||||||
|
instructions are appended to the system prompt instead.
|
||||||
|
modality:
|
||||||
|
Either "chat" (default) or "thinking". Thinking modality skips
|
||||||
|
response_format and strips ``<think>`` tags from output.
|
||||||
|
model_override:
|
||||||
|
Model name to use instead of the default. If None, uses the
|
||||||
|
configured default for the endpoint.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
str
|
str
|
||||||
Raw completion text from the model.
|
Raw completion text from the model (think tags stripped if thinking).
|
||||||
"""
|
"""
|
||||||
kwargs: dict = {}
|
kwargs: dict = {}
|
||||||
if response_model is not None:
|
effective_system = system_prompt
|
||||||
kwargs["response_format"] = {"type": "json_object"}
|
|
||||||
|
if modality == "thinking":
|
||||||
|
# Thinking models often don't support response_format: json_object.
|
||||||
|
# Instead, append explicit JSON instructions to the system prompt.
|
||||||
|
if response_model is not None:
|
||||||
|
json_schema_hint = (
|
||||||
|
"\n\nYou MUST respond with ONLY valid JSON. "
|
||||||
|
"No markdown code fences, no explanation, no preamble — "
|
||||||
|
"just the raw JSON object."
|
||||||
|
)
|
||||||
|
effective_system = system_prompt + json_schema_hint
|
||||||
|
else:
|
||||||
|
# Chat modality — use standard JSON mode
|
||||||
|
if response_model is not None:
|
||||||
|
kwargs["response_format"] = {"type": "json_object"}
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": system_prompt},
|
{"role": "system", "content": effective_system},
|
||||||
{"role": "user", "content": user_prompt},
|
{"role": "user", "content": user_prompt},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
primary_model = model_override or self.settings.llm_model
|
||||||
|
fallback_model = self.settings.llm_fallback_model
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"LLM request: model=%s, modality=%s, response_model=%s",
|
||||||
|
primary_model,
|
||||||
|
modality,
|
||||||
|
response_model.__name__ if response_model else None,
|
||||||
|
)
|
||||||
|
|
||||||
# --- Try primary endpoint ---
|
# --- Try primary endpoint ---
|
||||||
try:
|
try:
|
||||||
response = self._primary.chat.completions.create(
|
response = self._primary.chat.completions.create(
|
||||||
model=self.settings.llm_model,
|
model=primary_model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return response.choices[0].message.content or ""
|
raw = response.choices[0].message.content or ""
|
||||||
|
if modality == "thinking":
|
||||||
|
raw = strip_think_tags(raw)
|
||||||
|
return raw
|
||||||
|
|
||||||
except (openai.APIConnectionError, openai.APITimeoutError) as exc:
|
except (openai.APIConnectionError, openai.APITimeoutError) as exc:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
@ -87,11 +154,14 @@ class LLMClient:
|
||||||
# --- Try fallback endpoint ---
|
# --- Try fallback endpoint ---
|
||||||
try:
|
try:
|
||||||
response = self._fallback.chat.completions.create(
|
response = self._fallback.chat.completions.create(
|
||||||
model=self.settings.llm_fallback_model,
|
model=fallback_model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return response.choices[0].message.content or ""
|
raw = response.choices[0].message.content or ""
|
||||||
|
if modality == "thinking":
|
||||||
|
raw = strip_think_tags(raw)
|
||||||
|
return raw
|
||||||
|
|
||||||
except (openai.APIConnectionError, openai.APITimeoutError, openai.APIError) as exc:
|
except (openai.APIConnectionError, openai.APITimeoutError, openai.APIError) as exc:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|
|
||||||
|
|
@ -87,6 +87,19 @@ def _get_llm_client() -> LLMClient:
|
||||||
return LLMClient(get_settings())
|
return LLMClient(get_settings())
|
||||||
|
|
||||||
|
|
||||||
|
def _get_stage_config(stage_num: int) -> tuple[str | None, str]:
|
||||||
|
"""Return (model_override, modality) for a pipeline stage.
|
||||||
|
|
||||||
|
Reads stage-specific config from Settings. If the stage-specific model
|
||||||
|
is None/empty, returns None (LLMClient will use its default). If the
|
||||||
|
stage-specific modality is unset, defaults to "chat".
|
||||||
|
"""
|
||||||
|
settings = get_settings()
|
||||||
|
model = getattr(settings, f"llm_stage{stage_num}_model", None) or None
|
||||||
|
modality = getattr(settings, f"llm_stage{stage_num}_modality", None) or "chat"
|
||||||
|
return model, modality
|
||||||
|
|
||||||
|
|
||||||
def _load_canonical_tags() -> dict:
|
def _load_canonical_tags() -> dict:
|
||||||
"""Load canonical tag taxonomy from config/canonical_tags.yaml."""
|
"""Load canonical tag taxonomy from config/canonical_tags.yaml."""
|
||||||
# Walk up from backend/ to find config/
|
# Walk up from backend/ to find config/
|
||||||
|
|
@ -114,7 +127,15 @@ def _format_taxonomy_for_prompt(tags_data: dict) -> str:
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
def _safe_parse_llm_response(raw: str, model_cls, llm: LLMClient, system_prompt: str, user_prompt: str):
|
def _safe_parse_llm_response(
|
||||||
|
raw: str,
|
||||||
|
model_cls,
|
||||||
|
llm: LLMClient,
|
||||||
|
system_prompt: str,
|
||||||
|
user_prompt: str,
|
||||||
|
modality: str = "chat",
|
||||||
|
model_override: str | None = None,
|
||||||
|
):
|
||||||
"""Parse LLM response with one retry on failure.
|
"""Parse LLM response with one retry on failure.
|
||||||
|
|
||||||
On malformed response: log the raw text, retry once with a JSON nudge,
|
On malformed response: log the raw text, retry once with a JSON nudge,
|
||||||
|
|
@ -132,7 +153,10 @@ def _safe_parse_llm_response(raw: str, model_cls, llm: LLMClient, system_prompt:
|
||||||
)
|
)
|
||||||
# Retry with explicit JSON instruction
|
# Retry with explicit JSON instruction
|
||||||
nudge_prompt = user_prompt + "\n\nIMPORTANT: Output ONLY valid JSON. No markdown, no explanation."
|
nudge_prompt = user_prompt + "\n\nIMPORTANT: Output ONLY valid JSON. No markdown, no explanation."
|
||||||
retry_raw = llm.complete(system_prompt, nudge_prompt, response_model=model_cls)
|
retry_raw = llm.complete(
|
||||||
|
system_prompt, nudge_prompt, response_model=model_cls,
|
||||||
|
modality=modality, model_override=model_override,
|
||||||
|
)
|
||||||
return llm.parse_response(retry_raw, model_cls)
|
return llm.parse_response(retry_raw, model_cls)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -180,8 +204,12 @@ def stage2_segmentation(self, video_id: str) -> str:
|
||||||
user_prompt = f"<transcript>\n{transcript_text}\n</transcript>"
|
user_prompt = f"<transcript>\n{transcript_text}\n</transcript>"
|
||||||
|
|
||||||
llm = _get_llm_client()
|
llm = _get_llm_client()
|
||||||
raw = llm.complete(system_prompt, user_prompt, response_model=SegmentationResult)
|
model_override, modality = _get_stage_config(2)
|
||||||
result = _safe_parse_llm_response(raw, SegmentationResult, llm, system_prompt, user_prompt)
|
logger.info("Stage 2 using model=%s, modality=%s", model_override or "default", modality)
|
||||||
|
raw = llm.complete(system_prompt, user_prompt, response_model=SegmentationResult,
|
||||||
|
modality=modality, model_override=model_override)
|
||||||
|
result = _safe_parse_llm_response(raw, SegmentationResult, llm, system_prompt, user_prompt,
|
||||||
|
modality=modality, model_override=model_override)
|
||||||
|
|
||||||
# Update topic_label on each segment row
|
# Update topic_label on each segment row
|
||||||
seg_by_index = {s.segment_index: s for s in segments}
|
seg_by_index = {s.segment_index: s for s in segments}
|
||||||
|
|
@ -247,6 +275,8 @@ def stage3_extraction(self, video_id: str) -> str:
|
||||||
|
|
||||||
system_prompt = _load_prompt("stage3_extraction.txt")
|
system_prompt = _load_prompt("stage3_extraction.txt")
|
||||||
llm = _get_llm_client()
|
llm = _get_llm_client()
|
||||||
|
model_override, modality = _get_stage_config(3)
|
||||||
|
logger.info("Stage 3 using model=%s, modality=%s", model_override or "default", modality)
|
||||||
total_moments = 0
|
total_moments = 0
|
||||||
|
|
||||||
for topic_label, group_segs in groups.items():
|
for topic_label, group_segs in groups.items():
|
||||||
|
|
@ -263,8 +293,10 @@ def stage3_extraction(self, video_id: str) -> str:
|
||||||
f"<segment>\n{segment_text}\n</segment>"
|
f"<segment>\n{segment_text}\n</segment>"
|
||||||
)
|
)
|
||||||
|
|
||||||
raw = llm.complete(system_prompt, user_prompt, response_model=ExtractionResult)
|
raw = llm.complete(system_prompt, user_prompt, response_model=ExtractionResult,
|
||||||
result = _safe_parse_llm_response(raw, ExtractionResult, llm, system_prompt, user_prompt)
|
modality=modality, model_override=model_override)
|
||||||
|
result = _safe_parse_llm_response(raw, ExtractionResult, llm, system_prompt, user_prompt,
|
||||||
|
modality=modality, model_override=model_override)
|
||||||
|
|
||||||
# Create KeyMoment rows
|
# Create KeyMoment rows
|
||||||
for moment in result.moments:
|
for moment in result.moments:
|
||||||
|
|
@ -369,8 +401,12 @@ def stage4_classification(self, video_id: str) -> str:
|
||||||
)
|
)
|
||||||
|
|
||||||
llm = _get_llm_client()
|
llm = _get_llm_client()
|
||||||
raw = llm.complete(system_prompt, user_prompt, response_model=ClassificationResult)
|
model_override, modality = _get_stage_config(4)
|
||||||
result = _safe_parse_llm_response(raw, ClassificationResult, llm, system_prompt, user_prompt)
|
logger.info("Stage 4 using model=%s, modality=%s", model_override or "default", modality)
|
||||||
|
raw = llm.complete(system_prompt, user_prompt, response_model=ClassificationResult,
|
||||||
|
modality=modality, model_override=model_override)
|
||||||
|
result = _safe_parse_llm_response(raw, ClassificationResult, llm, system_prompt, user_prompt,
|
||||||
|
modality=modality, model_override=model_override)
|
||||||
|
|
||||||
# Apply content_type overrides and prepare classification data for stage 5
|
# Apply content_type overrides and prepare classification data for stage 5
|
||||||
classification_data = []
|
classification_data = []
|
||||||
|
|
@ -490,6 +526,8 @@ def stage5_synthesis(self, video_id: str) -> str:
|
||||||
|
|
||||||
system_prompt = _load_prompt("stage5_synthesis.txt")
|
system_prompt = _load_prompt("stage5_synthesis.txt")
|
||||||
llm = _get_llm_client()
|
llm = _get_llm_client()
|
||||||
|
model_override, modality = _get_stage_config(5)
|
||||||
|
logger.info("Stage 5 using model=%s, modality=%s", model_override or "default", modality)
|
||||||
pages_created = 0
|
pages_created = 0
|
||||||
|
|
||||||
for category, moment_group in groups.items():
|
for category, moment_group in groups.items():
|
||||||
|
|
@ -513,8 +551,10 @@ def stage5_synthesis(self, video_id: str) -> str:
|
||||||
|
|
||||||
user_prompt = f"<moments>\n{moments_text}\n</moments>"
|
user_prompt = f"<moments>\n{moments_text}\n</moments>"
|
||||||
|
|
||||||
raw = llm.complete(system_prompt, user_prompt, response_model=SynthesisResult)
|
raw = llm.complete(system_prompt, user_prompt, response_model=SynthesisResult,
|
||||||
result = _safe_parse_llm_response(raw, SynthesisResult, llm, system_prompt, user_prompt)
|
modality=modality, model_override=model_override)
|
||||||
|
result = _safe_parse_llm_response(raw, SynthesisResult, llm, system_prompt, user_prompt,
|
||||||
|
modality=modality, model_override=model_override)
|
||||||
|
|
||||||
# Create/update TechniquePage rows
|
# Create/update TechniquePage rows
|
||||||
for page_data in result.pages:
|
for page_data in result.pages:
|
||||||
|
|
|
||||||
|
|
@ -737,3 +737,37 @@ def test_llm_fallback_on_primary_failure():
|
||||||
assert result == '{"result": "ok"}'
|
assert result == '{"result": "ok"}'
|
||||||
primary_client.chat.completions.create.assert_called_once()
|
primary_client.chat.completions.create.assert_called_once()
|
||||||
fallback_client.chat.completions.create.assert_called_once()
|
fallback_client.chat.completions.create.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Think-tag stripping ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_strip_think_tags():
|
||||||
|
"""strip_think_tags should handle all edge cases correctly."""
|
||||||
|
from pipeline.llm_client import strip_think_tags
|
||||||
|
|
||||||
|
# Single block with JSON after
|
||||||
|
assert strip_think_tags('<think>reasoning here</think>{"a": 1}') == '{"a": 1}'
|
||||||
|
|
||||||
|
# Multiline think block
|
||||||
|
assert strip_think_tags(
|
||||||
|
'<think>\nI need to analyze this.\nLet me think step by step.\n</think>\n{"result": "ok"}'
|
||||||
|
) == '{"result": "ok"}'
|
||||||
|
|
||||||
|
# Multiple think blocks
|
||||||
|
result = strip_think_tags('<think>first</think>hello<think>second</think> world')
|
||||||
|
assert result == "hello world"
|
||||||
|
|
||||||
|
# No think tags — passthrough
|
||||||
|
assert strip_think_tags('{"clean": true}') == '{"clean": true}'
|
||||||
|
|
||||||
|
# Empty string
|
||||||
|
assert strip_think_tags("") == ""
|
||||||
|
|
||||||
|
# Think block with special characters
|
||||||
|
assert strip_think_tags(
|
||||||
|
'<think>analyzing "complex" <data> & stuff</think>{"done": true}'
|
||||||
|
) == '{"done": true}'
|
||||||
|
|
||||||
|
# Only a think block, no actual content
|
||||||
|
assert strip_think_tags("<think>just thinking</think>") == ""
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue