chrysopedia/backend/pipeline/llm_client.py
jlightner 12cc86aef9 chore: Extended Settings with 12 LLM/embedding/Qdrant config fields, cr…
- "backend/config.py"
- "backend/worker.py"
- "backend/pipeline/schemas.py"
- "backend/pipeline/llm_client.py"
- "backend/requirements.txt"
- "backend/pipeline/__init__.py"
- "backend/pipeline/stages.py"

GSD-Task: S03/T01
2026-03-29 22:30:31 +00:00

136 lines
4.2 KiB
Python

"""Synchronous LLM client with primary/fallback endpoint logic.
Uses the OpenAI-compatible API (works with Ollama, vLLM, OpenWebUI, etc.).
Celery tasks run synchronously, so this uses ``openai.OpenAI`` (not Async).
"""
from __future__ import annotations
import logging
from typing import TypeVar
import openai
from pydantic import BaseModel
from config import Settings
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=BaseModel)
class LLMClient:
"""Sync LLM client that tries a primary endpoint and falls back on failure."""
def __init__(self, settings: Settings) -> None:
self.settings = settings
self._primary = openai.OpenAI(
base_url=settings.llm_api_url,
api_key=settings.llm_api_key,
)
self._fallback = openai.OpenAI(
base_url=settings.llm_fallback_url,
api_key=settings.llm_api_key,
)
# ── Core completion ──────────────────────────────────────────────────
def complete(
self,
system_prompt: str,
user_prompt: str,
response_model: type[BaseModel] | None = None,
) -> str:
"""Send a chat completion request, falling back on connection/timeout errors.
Parameters
----------
system_prompt:
System message content.
user_prompt:
User message content.
response_model:
If provided, ``response_format`` is set to ``{"type": "json_object"}``
so the LLM returns parseable JSON.
Returns
-------
str
Raw completion text from the model.
"""
kwargs: dict = {}
if response_model is not None:
kwargs["response_format"] = {"type": "json_object"}
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
# --- Try primary endpoint ---
try:
response = self._primary.chat.completions.create(
model=self.settings.llm_model,
messages=messages,
**kwargs,
)
return response.choices[0].message.content or ""
except (openai.APIConnectionError, openai.APITimeoutError) as exc:
logger.warning(
"Primary LLM endpoint failed (%s: %s), trying fallback at %s",
type(exc).__name__,
exc,
self.settings.llm_fallback_url,
)
# --- Try fallback endpoint ---
try:
response = self._fallback.chat.completions.create(
model=self.settings.llm_fallback_model,
messages=messages,
**kwargs,
)
return response.choices[0].message.content or ""
except (openai.APIConnectionError, openai.APITimeoutError, openai.APIError) as exc:
logger.error(
"Fallback LLM endpoint also failed (%s: %s). Giving up.",
type(exc).__name__,
exc,
)
raise
# ── Response parsing ─────────────────────────────────────────────────
def parse_response(self, text: str, model: type[T]) -> T:
"""Parse raw LLM output as JSON and validate against a Pydantic model.
Parameters
----------
text:
Raw JSON string from the LLM.
model:
Pydantic model class to validate against.
Returns
-------
T
Validated Pydantic model instance.
Raises
------
pydantic.ValidationError
If the JSON doesn't match the schema.
ValueError
If the text is not valid JSON.
"""
try:
return model.model_validate_json(text)
except Exception:
logger.error(
"Failed to parse LLM response as %s. Response text: %.500s",
model.__name__,
text,
)
raise