Add OpenAICompatAdapter that works with any OpenAI-compatible API endpoint (OpenWebUI, vLLM, Ollama, OpenAI, Anthropic via proxy). Features: - Async HTTP calls via httpx with configurable timeout - Chat completions format with system + user messages - Token usage parsing from responses - Exponential backoff retries (configurable, default 3 attempts) - Both streaming (SSE) and non-streaming modes - Model listing and connection testing - 21 tests covering construction, request building, response parsing, retry logic, and error handling
248 lines
8 KiB
Python
248 lines
8 KiB
Python
"""OpenAI-compatible LLM adapter.
|
|
|
|
Works with any endpoint that speaks the OpenAI chat completions API:
|
|
OpenAI, OpenWebUI, vLLM, Ollama, Anthropic via proxy, etc.
|
|
"""
|
|
|
|
import time
|
|
from typing import Any, AsyncIterator
|
|
|
|
import httpx
|
|
|
|
from engine.adapters.base import AdapterResponse, BaseAdapter
|
|
|
|
|
|
class OpenAICompatAdapter(BaseAdapter):
|
|
"""Adapter for OpenAI-compatible chat completion APIs.
|
|
|
|
Args:
|
|
base_url: API base URL (e.g. "https://api.openai.com/v1").
|
|
api_key: Bearer token for authentication. Optional for local endpoints.
|
|
timeout: Request timeout in seconds.
|
|
max_retries: Number of retry attempts on transient failures.
|
|
"""
|
|
|
|
RETRYABLE_STATUS_CODES = {429, 500, 502, 503, 504}
|
|
|
|
def __init__(
|
|
self,
|
|
base_url: str,
|
|
api_key: str | None = None,
|
|
timeout: float = 120.0,
|
|
max_retries: int = 3,
|
|
) -> None:
|
|
self.base_url = base_url.rstrip("/")
|
|
self.api_key = api_key
|
|
self.timeout = timeout
|
|
self.max_retries = max_retries
|
|
|
|
def _headers(self) -> dict[str, str]:
|
|
headers: dict[str, str] = {"Content-Type": "application/json"}
|
|
if self.api_key:
|
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
return headers
|
|
|
|
def _client(self) -> httpx.AsyncClient:
|
|
return httpx.AsyncClient(
|
|
timeout=httpx.Timeout(self.timeout),
|
|
headers=self._headers(),
|
|
)
|
|
|
|
async def complete(
|
|
self, prompt: str, model: str, params: dict[str, Any]
|
|
) -> AdapterResponse:
|
|
body = self._build_request_body(prompt, model, params)
|
|
stream = params.get("stream", False)
|
|
body["stream"] = stream
|
|
|
|
if stream:
|
|
return await self._complete_stream(body, model)
|
|
return await self._complete_non_stream(body, model)
|
|
|
|
async def _complete_non_stream(
|
|
self, body: dict[str, Any], model: str
|
|
) -> AdapterResponse:
|
|
url = f"{self.base_url}/chat/completions"
|
|
last_exc: Exception | None = None
|
|
|
|
for attempt in range(self.max_retries):
|
|
t0 = time.perf_counter()
|
|
try:
|
|
async with self._client() as client:
|
|
resp = await client.post(url, json=body)
|
|
|
|
latency_ms = (time.perf_counter() - t0) * 1000
|
|
|
|
if resp.status_code == 200:
|
|
data = resp.json()
|
|
return self._parse_response(data, latency_ms, model)
|
|
|
|
if resp.status_code not in self.RETRYABLE_STATUS_CODES:
|
|
resp.raise_for_status()
|
|
|
|
last_exc = httpx.HTTPStatusError(
|
|
f"HTTP {resp.status_code}",
|
|
request=resp.request,
|
|
response=resp,
|
|
)
|
|
except httpx.HTTPStatusError:
|
|
raise
|
|
except httpx.HTTPError as exc:
|
|
last_exc = exc
|
|
|
|
if attempt < self.max_retries - 1:
|
|
delay = 2**attempt # 1s, 2s, 4s …
|
|
import asyncio
|
|
|
|
await asyncio.sleep(delay)
|
|
|
|
raise RuntimeError(
|
|
f"All {self.max_retries} attempts failed for {url}"
|
|
) from last_exc
|
|
|
|
async def _complete_stream(
|
|
self, body: dict[str, Any], model: str
|
|
) -> AdapterResponse:
|
|
url = f"{self.base_url}/chat/completions"
|
|
last_exc: Exception | None = None
|
|
|
|
for attempt in range(self.max_retries):
|
|
t0 = time.perf_counter()
|
|
try:
|
|
async with self._client() as client:
|
|
async with client.stream("POST", url, json=body) as resp:
|
|
if resp.status_code != 200:
|
|
await resp.aread()
|
|
if resp.status_code not in self.RETRYABLE_STATUS_CODES:
|
|
resp.raise_for_status()
|
|
last_exc = httpx.HTTPStatusError(
|
|
f"HTTP {resp.status_code}",
|
|
request=resp.request,
|
|
response=resp,
|
|
)
|
|
else:
|
|
text, usage = await self._consume_stream(resp)
|
|
latency_ms = (time.perf_counter() - t0) * 1000
|
|
return AdapterResponse(
|
|
text=text,
|
|
tokens_in=usage.get("prompt_tokens", 0),
|
|
tokens_out=usage.get("completion_tokens", 0),
|
|
latency_ms=latency_ms,
|
|
model=model,
|
|
raw={"streamed": True, "usage": usage},
|
|
)
|
|
except httpx.HTTPStatusError:
|
|
raise
|
|
except httpx.HTTPError as exc:
|
|
last_exc = exc
|
|
|
|
if attempt < self.max_retries - 1:
|
|
import asyncio
|
|
|
|
await asyncio.sleep(2**attempt)
|
|
|
|
raise RuntimeError(
|
|
f"All {self.max_retries} attempts failed for {url}"
|
|
) from last_exc
|
|
|
|
async def _consume_stream(
|
|
self, resp: httpx.Response
|
|
) -> tuple[str, dict[str, int]]:
|
|
import json as _json
|
|
|
|
chunks: list[str] = []
|
|
usage: dict[str, int] = {}
|
|
|
|
async for line in resp.aiter_lines():
|
|
if not line.startswith("data: "):
|
|
continue
|
|
payload = line[6:]
|
|
if payload.strip() == "[DONE]":
|
|
break
|
|
try:
|
|
data = _json.loads(payload)
|
|
except _json.JSONDecodeError:
|
|
continue
|
|
|
|
for choice in data.get("choices", []):
|
|
delta = choice.get("delta", {})
|
|
if "content" in delta and delta["content"]:
|
|
chunks.append(delta["content"])
|
|
|
|
if "usage" in data and data["usage"]:
|
|
usage = data["usage"]
|
|
|
|
return "".join(chunks), usage
|
|
|
|
async def list_models(self) -> list[str]:
|
|
url = f"{self.base_url}/models"
|
|
async with self._client() as client:
|
|
resp = await client.get(url)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
|
|
models: list[str] = []
|
|
for entry in data.get("data", []):
|
|
model_id = entry.get("id")
|
|
if model_id:
|
|
models.append(model_id)
|
|
return sorted(models)
|
|
|
|
async def test_connection(self) -> bool:
|
|
try:
|
|
await self.list_models()
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
@staticmethod
|
|
def _build_request_body(
|
|
prompt: str, model: str, params: dict[str, Any]
|
|
) -> dict[str, Any]:
|
|
messages: list[dict[str, str]] = []
|
|
|
|
system_message = params.get("system_message") or params.get("system")
|
|
if system_message:
|
|
messages.append({"role": "system", "content": system_message})
|
|
|
|
messages.append({"role": "user", "content": prompt})
|
|
|
|
body: dict[str, Any] = {"model": model, "messages": messages}
|
|
|
|
passthrough_keys = {
|
|
"temperature",
|
|
"max_tokens",
|
|
"top_p",
|
|
"frequency_penalty",
|
|
"presence_penalty",
|
|
"stop",
|
|
"seed",
|
|
}
|
|
for key in passthrough_keys:
|
|
if key in params:
|
|
body[key] = params[key]
|
|
|
|
return body
|
|
|
|
@staticmethod
|
|
def _parse_response(
|
|
data: dict[str, Any], latency_ms: float, model: str
|
|
) -> AdapterResponse:
|
|
choices = data.get("choices", [])
|
|
text = ""
|
|
if choices:
|
|
message = choices[0].get("message", {})
|
|
text = message.get("content", "")
|
|
|
|
usage = data.get("usage", {})
|
|
tokens_in = usage.get("prompt_tokens", 0)
|
|
tokens_out = usage.get("completion_tokens", 0)
|
|
|
|
return AdapterResponse(
|
|
text=text,
|
|
tokens_in=tokens_in,
|
|
tokens_out=tokens_out,
|
|
latency_ms=latency_ms,
|
|
model=data.get("model", model),
|
|
raw=data,
|
|
)
|