- "backend/config.py" - "backend/worker.py" - "backend/pipeline/schemas.py" - "backend/pipeline/llm_client.py" - "backend/requirements.txt" - "backend/pipeline/__init__.py" - "backend/pipeline/stages.py" GSD-Task: S03/T01
136 lines
4.2 KiB
Python
136 lines
4.2 KiB
Python
"""Synchronous LLM client with primary/fallback endpoint logic.
|
|
|
|
Uses the OpenAI-compatible API (works with Ollama, vLLM, OpenWebUI, etc.).
|
|
Celery tasks run synchronously, so this uses ``openai.OpenAI`` (not Async).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import TypeVar
|
|
|
|
import openai
|
|
from pydantic import BaseModel
|
|
|
|
from config import Settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
T = TypeVar("T", bound=BaseModel)
|
|
|
|
|
|
class LLMClient:
|
|
"""Sync LLM client that tries a primary endpoint and falls back on failure."""
|
|
|
|
def __init__(self, settings: Settings) -> None:
|
|
self.settings = settings
|
|
self._primary = openai.OpenAI(
|
|
base_url=settings.llm_api_url,
|
|
api_key=settings.llm_api_key,
|
|
)
|
|
self._fallback = openai.OpenAI(
|
|
base_url=settings.llm_fallback_url,
|
|
api_key=settings.llm_api_key,
|
|
)
|
|
|
|
# ── Core completion ──────────────────────────────────────────────────
|
|
|
|
def complete(
|
|
self,
|
|
system_prompt: str,
|
|
user_prompt: str,
|
|
response_model: type[BaseModel] | None = None,
|
|
) -> str:
|
|
"""Send a chat completion request, falling back on connection/timeout errors.
|
|
|
|
Parameters
|
|
----------
|
|
system_prompt:
|
|
System message content.
|
|
user_prompt:
|
|
User message content.
|
|
response_model:
|
|
If provided, ``response_format`` is set to ``{"type": "json_object"}``
|
|
so the LLM returns parseable JSON.
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
Raw completion text from the model.
|
|
"""
|
|
kwargs: dict = {}
|
|
if response_model is not None:
|
|
kwargs["response_format"] = {"type": "json_object"}
|
|
|
|
messages = [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_prompt},
|
|
]
|
|
|
|
# --- Try primary endpoint ---
|
|
try:
|
|
response = self._primary.chat.completions.create(
|
|
model=self.settings.llm_model,
|
|
messages=messages,
|
|
**kwargs,
|
|
)
|
|
return response.choices[0].message.content or ""
|
|
|
|
except (openai.APIConnectionError, openai.APITimeoutError) as exc:
|
|
logger.warning(
|
|
"Primary LLM endpoint failed (%s: %s), trying fallback at %s",
|
|
type(exc).__name__,
|
|
exc,
|
|
self.settings.llm_fallback_url,
|
|
)
|
|
|
|
# --- Try fallback endpoint ---
|
|
try:
|
|
response = self._fallback.chat.completions.create(
|
|
model=self.settings.llm_fallback_model,
|
|
messages=messages,
|
|
**kwargs,
|
|
)
|
|
return response.choices[0].message.content or ""
|
|
|
|
except (openai.APIConnectionError, openai.APITimeoutError, openai.APIError) as exc:
|
|
logger.error(
|
|
"Fallback LLM endpoint also failed (%s: %s). Giving up.",
|
|
type(exc).__name__,
|
|
exc,
|
|
)
|
|
raise
|
|
|
|
# ── Response parsing ─────────────────────────────────────────────────
|
|
|
|
def parse_response(self, text: str, model: type[T]) -> T:
|
|
"""Parse raw LLM output as JSON and validate against a Pydantic model.
|
|
|
|
Parameters
|
|
----------
|
|
text:
|
|
Raw JSON string from the LLM.
|
|
model:
|
|
Pydantic model class to validate against.
|
|
|
|
Returns
|
|
-------
|
|
T
|
|
Validated Pydantic model instance.
|
|
|
|
Raises
|
|
------
|
|
pydantic.ValidationError
|
|
If the JSON doesn't match the schema.
|
|
ValueError
|
|
If the text is not valid JSON.
|
|
"""
|
|
try:
|
|
return model.model_validate_json(text)
|
|
except Exception:
|
|
logger.error(
|
|
"Failed to parse LLM response as %s. Response text: %.500s",
|
|
model.__name__,
|
|
text,
|
|
)
|
|
raise
|