"""OpenAI-compatible LLM adapter. Works with any endpoint that speaks the OpenAI chat completions API: OpenAI, OpenWebUI, vLLM, Ollama, Anthropic via proxy, etc. """ import time from typing import Any, AsyncIterator import httpx from engine.adapters.base import AdapterResponse, BaseAdapter class OpenAICompatAdapter(BaseAdapter): """Adapter for OpenAI-compatible chat completion APIs. Args: base_url: API base URL (e.g. "https://api.openai.com/v1"). api_key: Bearer token for authentication. Optional for local endpoints. timeout: Request timeout in seconds. max_retries: Number of retry attempts on transient failures. """ RETRYABLE_STATUS_CODES = {429, 500, 502, 503, 504} def __init__( self, base_url: str, api_key: str | None = None, timeout: float = 120.0, max_retries: int = 3, ) -> None: self.base_url = base_url.rstrip("/") self.api_key = api_key self.timeout = timeout self.max_retries = max_retries def _headers(self) -> dict[str, str]: headers: dict[str, str] = {"Content-Type": "application/json"} if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" return headers def _client(self) -> httpx.AsyncClient: return httpx.AsyncClient( timeout=httpx.Timeout(self.timeout), headers=self._headers(), ) async def complete( self, prompt: str, model: str, params: dict[str, Any] ) -> AdapterResponse: body = self._build_request_body(prompt, model, params) stream = params.get("stream", False) body["stream"] = stream if stream: return await self._complete_stream(body, model) return await self._complete_non_stream(body, model) async def _complete_non_stream( self, body: dict[str, Any], model: str ) -> AdapterResponse: url = f"{self.base_url}/chat/completions" last_exc: Exception | None = None for attempt in range(self.max_retries): t0 = time.perf_counter() try: async with self._client() as client: resp = await client.post(url, json=body) latency_ms = (time.perf_counter() - t0) * 1000 if resp.status_code == 200: data = resp.json() return self._parse_response(data, latency_ms, model) if resp.status_code not in self.RETRYABLE_STATUS_CODES: resp.raise_for_status() last_exc = httpx.HTTPStatusError( f"HTTP {resp.status_code}", request=resp.request, response=resp, ) except httpx.HTTPStatusError: raise except httpx.HTTPError as exc: last_exc = exc if attempt < self.max_retries - 1: delay = 2**attempt # 1s, 2s, 4s … import asyncio await asyncio.sleep(delay) raise RuntimeError( f"All {self.max_retries} attempts failed for {url}" ) from last_exc async def _complete_stream( self, body: dict[str, Any], model: str ) -> AdapterResponse: url = f"{self.base_url}/chat/completions" last_exc: Exception | None = None for attempt in range(self.max_retries): t0 = time.perf_counter() try: async with self._client() as client: async with client.stream("POST", url, json=body) as resp: if resp.status_code != 200: await resp.aread() if resp.status_code not in self.RETRYABLE_STATUS_CODES: resp.raise_for_status() last_exc = httpx.HTTPStatusError( f"HTTP {resp.status_code}", request=resp.request, response=resp, ) else: text, usage = await self._consume_stream(resp) latency_ms = (time.perf_counter() - t0) * 1000 return AdapterResponse( text=text, tokens_in=usage.get("prompt_tokens", 0), tokens_out=usage.get("completion_tokens", 0), latency_ms=latency_ms, model=model, raw={"streamed": True, "usage": usage}, ) except httpx.HTTPStatusError: raise except httpx.HTTPError as exc: last_exc = exc if attempt < self.max_retries - 1: import asyncio await asyncio.sleep(2**attempt) raise RuntimeError( f"All {self.max_retries} attempts failed for {url}" ) from last_exc async def _consume_stream( self, resp: httpx.Response ) -> tuple[str, dict[str, int]]: import json as _json chunks: list[str] = [] usage: dict[str, int] = {} async for line in resp.aiter_lines(): if not line.startswith("data: "): continue payload = line[6:] if payload.strip() == "[DONE]": break try: data = _json.loads(payload) except _json.JSONDecodeError: continue for choice in data.get("choices", []): delta = choice.get("delta", {}) if "content" in delta and delta["content"]: chunks.append(delta["content"]) if "usage" in data and data["usage"]: usage = data["usage"] return "".join(chunks), usage async def list_models(self) -> list[str]: url = f"{self.base_url}/models" async with self._client() as client: resp = await client.get(url) resp.raise_for_status() data = resp.json() models: list[str] = [] for entry in data.get("data", []): model_id = entry.get("id") if model_id: models.append(model_id) return sorted(models) async def test_connection(self) -> bool: try: await self.list_models() return True except Exception: return False @staticmethod def _build_request_body( prompt: str, model: str, params: dict[str, Any] ) -> dict[str, Any]: messages: list[dict[str, str]] = [] system_message = params.get("system_message") or params.get("system") if system_message: messages.append({"role": "system", "content": system_message}) messages.append({"role": "user", "content": prompt}) body: dict[str, Any] = {"model": model, "messages": messages} passthrough_keys = { "temperature", "max_tokens", "top_p", "frequency_penalty", "presence_penalty", "stop", "seed", } for key in passthrough_keys: if key in params: body[key] = params[key] return body @staticmethod def _parse_response( data: dict[str, Any], latency_ms: float, model: str ) -> AdapterResponse: choices = data.get("choices", []) text = "" if choices: message = choices[0].get("message", {}) text = message.get("content", "") usage = data.get("usage", {}) tokens_in = usage.get("prompt_tokens", 0) tokens_out = usage.get("completion_tokens", 0) return AdapterResponse( text=text, tokens_in=tokens_in, tokens_out=tokens_out, latency_ms=latency_ms, model=data.get("model", model), raw=data, )