diff --git a/.env.example b/.env.example index 0ddbc51..ac7a180 100644 --- a/.env.example +++ b/.env.example @@ -16,6 +16,19 @@ LLM_MODEL=FYN-QWEN35 LLM_FALLBACK_URL=https://chat.forgetyour.name/api/v1 LLM_FALLBACK_MODEL=fyn-qwen35-chat +# Per-stage LLM model overrides (optional — defaults to LLM_MODEL) +# Modality: "chat" = standard JSON mode, "thinking" = reasoning model (strips tags) +# Stages 2 (segmentation) and 4 (classification) are mechanical — use fast chat model +# Stages 3 (extraction) and 5 (synthesis) need reasoning — use thinking model +#LLM_STAGE2_MODEL=fyn-qwen35-chat +#LLM_STAGE2_MODALITY=chat +#LLM_STAGE3_MODEL=fyn-qwen35-thinking +#LLM_STAGE3_MODALITY=thinking +#LLM_STAGE4_MODEL=fyn-qwen35-chat +#LLM_STAGE4_MODALITY=chat +#LLM_STAGE5_MODEL=fyn-qwen35-thinking +#LLM_STAGE5_MODALITY=thinking + # Embedding endpoint (Ollama container in the compose stack) EMBEDDING_API_URL=http://chrysopedia-ollama:11434/v1 EMBEDDING_MODEL=nomic-embed-text diff --git a/backend/config.py b/backend/config.py index 878e983..8a2b9a9 100644 --- a/backend/config.py +++ b/backend/config.py @@ -33,6 +33,16 @@ class Settings(BaseSettings): llm_fallback_url: str = "http://localhost:11434/v1" llm_fallback_model: str = "qwen2.5:14b-q8_0" + # Per-stage model overrides (optional — falls back to llm_model / "chat") + llm_stage2_model: str | None = None # segmentation — fast chat model recommended + llm_stage2_modality: str = "chat" # "chat" or "thinking" + llm_stage3_model: str | None = None # extraction — thinking model recommended + llm_stage3_modality: str = "chat" + llm_stage4_model: str | None = None # classification — fast chat model recommended + llm_stage4_modality: str = "chat" + llm_stage5_model: str | None = None # synthesis — thinking model recommended + llm_stage5_modality: str = "chat" + # Embedding endpoint embedding_api_url: str = "http://localhost:11434/v1" embedding_model: str = "nomic-embed-text" diff --git a/backend/pipeline/llm_client.py b/backend/pipeline/llm_client.py index 1ee095c..3c6671a 100644 --- a/backend/pipeline/llm_client.py +++ b/backend/pipeline/llm_client.py @@ -2,11 +2,18 @@ Uses the OpenAI-compatible API (works with Ollama, vLLM, OpenWebUI, etc.). Celery tasks run synchronously, so this uses ``openai.OpenAI`` (not Async). + +Supports two modalities: +- **chat**: Standard JSON mode with ``response_format: {"type": "json_object"}`` +- **thinking**: For reasoning models that emit ``...`` blocks + before their answer. Skips ``response_format``, appends JSON instructions to + the system prompt, and strips think tags from the response. """ from __future__ import annotations import logging +import re from typing import TypeVar import openai @@ -18,6 +25,30 @@ logger = logging.getLogger(__name__) T = TypeVar("T", bound=BaseModel) +# ── Think-tag stripping ────────────────────────────────────────────────────── + +_THINK_PATTERN = re.compile(r".*?", re.DOTALL) + + +def strip_think_tags(text: str) -> str: + """Remove ``...`` blocks from LLM output. + + Thinking/reasoning models often prefix their JSON with a reasoning trace + wrapped in ```` tags. This strips all such blocks (including + multiline and multiple occurrences) and returns the cleaned text. + + Handles: + - Single ``...`` block + - Multiple blocks in one response + - Multiline content inside think tags + - Responses with no think tags (passthrough) + - Empty input (passthrough) + """ + if not text: + return text + cleaned = _THINK_PATTERN.sub("", text) + return cleaned.strip() + class LLMClient: """Sync LLM client that tries a primary endpoint and falls back on failure.""" @@ -40,6 +71,8 @@ class LLMClient: system_prompt: str, user_prompt: str, response_model: type[BaseModel] | None = None, + modality: str = "chat", + model_override: str | None = None, ) -> str: """Send a chat completion request, falling back on connection/timeout errors. @@ -50,31 +83,65 @@ class LLMClient: user_prompt: User message content. response_model: - If provided, ``response_format`` is set to ``{"type": "json_object"}`` - so the LLM returns parseable JSON. + If provided and modality is "chat", ``response_format`` is set to + ``{"type": "json_object"}``. For "thinking" modality, JSON + instructions are appended to the system prompt instead. + modality: + Either "chat" (default) or "thinking". Thinking modality skips + response_format and strips ```` tags from output. + model_override: + Model name to use instead of the default. If None, uses the + configured default for the endpoint. Returns ------- str - Raw completion text from the model. + Raw completion text from the model (think tags stripped if thinking). """ kwargs: dict = {} - if response_model is not None: - kwargs["response_format"] = {"type": "json_object"} + effective_system = system_prompt + + if modality == "thinking": + # Thinking models often don't support response_format: json_object. + # Instead, append explicit JSON instructions to the system prompt. + if response_model is not None: + json_schema_hint = ( + "\n\nYou MUST respond with ONLY valid JSON. " + "No markdown code fences, no explanation, no preamble — " + "just the raw JSON object." + ) + effective_system = system_prompt + json_schema_hint + else: + # Chat modality — use standard JSON mode + if response_model is not None: + kwargs["response_format"] = {"type": "json_object"} messages = [ - {"role": "system", "content": system_prompt}, + {"role": "system", "content": effective_system}, {"role": "user", "content": user_prompt}, ] + primary_model = model_override or self.settings.llm_model + fallback_model = self.settings.llm_fallback_model + + logger.info( + "LLM request: model=%s, modality=%s, response_model=%s", + primary_model, + modality, + response_model.__name__ if response_model else None, + ) + # --- Try primary endpoint --- try: response = self._primary.chat.completions.create( - model=self.settings.llm_model, + model=primary_model, messages=messages, **kwargs, ) - return response.choices[0].message.content or "" + raw = response.choices[0].message.content or "" + if modality == "thinking": + raw = strip_think_tags(raw) + return raw except (openai.APIConnectionError, openai.APITimeoutError) as exc: logger.warning( @@ -87,11 +154,14 @@ class LLMClient: # --- Try fallback endpoint --- try: response = self._fallback.chat.completions.create( - model=self.settings.llm_fallback_model, + model=fallback_model, messages=messages, **kwargs, ) - return response.choices[0].message.content or "" + raw = response.choices[0].message.content or "" + if modality == "thinking": + raw = strip_think_tags(raw) + return raw except (openai.APIConnectionError, openai.APITimeoutError, openai.APIError) as exc: logger.error( diff --git a/backend/pipeline/stages.py b/backend/pipeline/stages.py index a2893b0..fd979a6 100644 --- a/backend/pipeline/stages.py +++ b/backend/pipeline/stages.py @@ -87,6 +87,19 @@ def _get_llm_client() -> LLMClient: return LLMClient(get_settings()) +def _get_stage_config(stage_num: int) -> tuple[str | None, str]: + """Return (model_override, modality) for a pipeline stage. + + Reads stage-specific config from Settings. If the stage-specific model + is None/empty, returns None (LLMClient will use its default). If the + stage-specific modality is unset, defaults to "chat". + """ + settings = get_settings() + model = getattr(settings, f"llm_stage{stage_num}_model", None) or None + modality = getattr(settings, f"llm_stage{stage_num}_modality", None) or "chat" + return model, modality + + def _load_canonical_tags() -> dict: """Load canonical tag taxonomy from config/canonical_tags.yaml.""" # Walk up from backend/ to find config/ @@ -114,7 +127,15 @@ def _format_taxonomy_for_prompt(tags_data: dict) -> str: return "\n".join(lines) -def _safe_parse_llm_response(raw: str, model_cls, llm: LLMClient, system_prompt: str, user_prompt: str): +def _safe_parse_llm_response( + raw: str, + model_cls, + llm: LLMClient, + system_prompt: str, + user_prompt: str, + modality: str = "chat", + model_override: str | None = None, +): """Parse LLM response with one retry on failure. On malformed response: log the raw text, retry once with a JSON nudge, @@ -132,7 +153,10 @@ def _safe_parse_llm_response(raw: str, model_cls, llm: LLMClient, system_prompt: ) # Retry with explicit JSON instruction nudge_prompt = user_prompt + "\n\nIMPORTANT: Output ONLY valid JSON. No markdown, no explanation." - retry_raw = llm.complete(system_prompt, nudge_prompt, response_model=model_cls) + retry_raw = llm.complete( + system_prompt, nudge_prompt, response_model=model_cls, + modality=modality, model_override=model_override, + ) return llm.parse_response(retry_raw, model_cls) @@ -180,8 +204,12 @@ def stage2_segmentation(self, video_id: str) -> str: user_prompt = f"\n{transcript_text}\n" llm = _get_llm_client() - raw = llm.complete(system_prompt, user_prompt, response_model=SegmentationResult) - result = _safe_parse_llm_response(raw, SegmentationResult, llm, system_prompt, user_prompt) + model_override, modality = _get_stage_config(2) + logger.info("Stage 2 using model=%s, modality=%s", model_override or "default", modality) + raw = llm.complete(system_prompt, user_prompt, response_model=SegmentationResult, + modality=modality, model_override=model_override) + result = _safe_parse_llm_response(raw, SegmentationResult, llm, system_prompt, user_prompt, + modality=modality, model_override=model_override) # Update topic_label on each segment row seg_by_index = {s.segment_index: s for s in segments} @@ -247,6 +275,8 @@ def stage3_extraction(self, video_id: str) -> str: system_prompt = _load_prompt("stage3_extraction.txt") llm = _get_llm_client() + model_override, modality = _get_stage_config(3) + logger.info("Stage 3 using model=%s, modality=%s", model_override or "default", modality) total_moments = 0 for topic_label, group_segs in groups.items(): @@ -263,8 +293,10 @@ def stage3_extraction(self, video_id: str) -> str: f"\n{segment_text}\n" ) - raw = llm.complete(system_prompt, user_prompt, response_model=ExtractionResult) - result = _safe_parse_llm_response(raw, ExtractionResult, llm, system_prompt, user_prompt) + raw = llm.complete(system_prompt, user_prompt, response_model=ExtractionResult, + modality=modality, model_override=model_override) + result = _safe_parse_llm_response(raw, ExtractionResult, llm, system_prompt, user_prompt, + modality=modality, model_override=model_override) # Create KeyMoment rows for moment in result.moments: @@ -369,8 +401,12 @@ def stage4_classification(self, video_id: str) -> str: ) llm = _get_llm_client() - raw = llm.complete(system_prompt, user_prompt, response_model=ClassificationResult) - result = _safe_parse_llm_response(raw, ClassificationResult, llm, system_prompt, user_prompt) + model_override, modality = _get_stage_config(4) + logger.info("Stage 4 using model=%s, modality=%s", model_override or "default", modality) + raw = llm.complete(system_prompt, user_prompt, response_model=ClassificationResult, + modality=modality, model_override=model_override) + result = _safe_parse_llm_response(raw, ClassificationResult, llm, system_prompt, user_prompt, + modality=modality, model_override=model_override) # Apply content_type overrides and prepare classification data for stage 5 classification_data = [] @@ -490,6 +526,8 @@ def stage5_synthesis(self, video_id: str) -> str: system_prompt = _load_prompt("stage5_synthesis.txt") llm = _get_llm_client() + model_override, modality = _get_stage_config(5) + logger.info("Stage 5 using model=%s, modality=%s", model_override or "default", modality) pages_created = 0 for category, moment_group in groups.items(): @@ -513,8 +551,10 @@ def stage5_synthesis(self, video_id: str) -> str: user_prompt = f"\n{moments_text}\n" - raw = llm.complete(system_prompt, user_prompt, response_model=SynthesisResult) - result = _safe_parse_llm_response(raw, SynthesisResult, llm, system_prompt, user_prompt) + raw = llm.complete(system_prompt, user_prompt, response_model=SynthesisResult, + modality=modality, model_override=model_override) + result = _safe_parse_llm_response(raw, SynthesisResult, llm, system_prompt, user_prompt, + modality=modality, model_override=model_override) # Create/update TechniquePage rows for page_data in result.pages: diff --git a/backend/tests/test_pipeline.py b/backend/tests/test_pipeline.py index b0d01a3..9e5d00a 100644 --- a/backend/tests/test_pipeline.py +++ b/backend/tests/test_pipeline.py @@ -737,3 +737,37 @@ def test_llm_fallback_on_primary_failure(): assert result == '{"result": "ok"}' primary_client.chat.completions.create.assert_called_once() fallback_client.chat.completions.create.assert_called_once() + + +# ── Think-tag stripping ───────────────────────────────────────────────────── + + +def test_strip_think_tags(): + """strip_think_tags should handle all edge cases correctly.""" + from pipeline.llm_client import strip_think_tags + + # Single block with JSON after + assert strip_think_tags('reasoning here{"a": 1}') == '{"a": 1}' + + # Multiline think block + assert strip_think_tags( + '\nI need to analyze this.\nLet me think step by step.\n\n{"result": "ok"}' + ) == '{"result": "ok"}' + + # Multiple think blocks + result = strip_think_tags('firsthellosecond world') + assert result == "hello world" + + # No think tags — passthrough + assert strip_think_tags('{"clean": true}') == '{"clean": true}' + + # Empty string + assert strip_think_tags("") == "" + + # Think block with special characters + assert strip_think_tags( + 'analyzing "complex" & stuff{"done": true}' + ) == '{"done": true}' + + # Only a think block, no actual content + assert strip_think_tags("just thinking") == ""