promptlooper/backend/engine/adapters/openai_compat.py
John Lightner bf1e9d1c84 MAESTRO: Implement OpenAI-compatible LLM adapter with streaming, retries, and tests
Add OpenAICompatAdapter that works with any OpenAI-compatible API endpoint
(OpenWebUI, vLLM, Ollama, OpenAI, Anthropic via proxy). Features:
- Async HTTP calls via httpx with configurable timeout
- Chat completions format with system + user messages
- Token usage parsing from responses
- Exponential backoff retries (configurable, default 3 attempts)
- Both streaming (SSE) and non-streaming modes
- Model listing and connection testing
- 21 tests covering construction, request building, response parsing,
  retry logic, and error handling
2026-04-07 02:35:52 -05:00

248 lines
8 KiB
Python

"""OpenAI-compatible LLM adapter.
Works with any endpoint that speaks the OpenAI chat completions API:
OpenAI, OpenWebUI, vLLM, Ollama, Anthropic via proxy, etc.
"""
import time
from typing import Any, AsyncIterator
import httpx
from engine.adapters.base import AdapterResponse, BaseAdapter
class OpenAICompatAdapter(BaseAdapter):
"""Adapter for OpenAI-compatible chat completion APIs.
Args:
base_url: API base URL (e.g. "https://api.openai.com/v1").
api_key: Bearer token for authentication. Optional for local endpoints.
timeout: Request timeout in seconds.
max_retries: Number of retry attempts on transient failures.
"""
RETRYABLE_STATUS_CODES = {429, 500, 502, 503, 504}
def __init__(
self,
base_url: str,
api_key: str | None = None,
timeout: float = 120.0,
max_retries: int = 3,
) -> None:
self.base_url = base_url.rstrip("/")
self.api_key = api_key
self.timeout = timeout
self.max_retries = max_retries
def _headers(self) -> dict[str, str]:
headers: dict[str, str] = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
return headers
def _client(self) -> httpx.AsyncClient:
return httpx.AsyncClient(
timeout=httpx.Timeout(self.timeout),
headers=self._headers(),
)
async def complete(
self, prompt: str, model: str, params: dict[str, Any]
) -> AdapterResponse:
body = self._build_request_body(prompt, model, params)
stream = params.get("stream", False)
body["stream"] = stream
if stream:
return await self._complete_stream(body, model)
return await self._complete_non_stream(body, model)
async def _complete_non_stream(
self, body: dict[str, Any], model: str
) -> AdapterResponse:
url = f"{self.base_url}/chat/completions"
last_exc: Exception | None = None
for attempt in range(self.max_retries):
t0 = time.perf_counter()
try:
async with self._client() as client:
resp = await client.post(url, json=body)
latency_ms = (time.perf_counter() - t0) * 1000
if resp.status_code == 200:
data = resp.json()
return self._parse_response(data, latency_ms, model)
if resp.status_code not in self.RETRYABLE_STATUS_CODES:
resp.raise_for_status()
last_exc = httpx.HTTPStatusError(
f"HTTP {resp.status_code}",
request=resp.request,
response=resp,
)
except httpx.HTTPStatusError:
raise
except httpx.HTTPError as exc:
last_exc = exc
if attempt < self.max_retries - 1:
delay = 2**attempt # 1s, 2s, 4s …
import asyncio
await asyncio.sleep(delay)
raise RuntimeError(
f"All {self.max_retries} attempts failed for {url}"
) from last_exc
async def _complete_stream(
self, body: dict[str, Any], model: str
) -> AdapterResponse:
url = f"{self.base_url}/chat/completions"
last_exc: Exception | None = None
for attempt in range(self.max_retries):
t0 = time.perf_counter()
try:
async with self._client() as client:
async with client.stream("POST", url, json=body) as resp:
if resp.status_code != 200:
await resp.aread()
if resp.status_code not in self.RETRYABLE_STATUS_CODES:
resp.raise_for_status()
last_exc = httpx.HTTPStatusError(
f"HTTP {resp.status_code}",
request=resp.request,
response=resp,
)
else:
text, usage = await self._consume_stream(resp)
latency_ms = (time.perf_counter() - t0) * 1000
return AdapterResponse(
text=text,
tokens_in=usage.get("prompt_tokens", 0),
tokens_out=usage.get("completion_tokens", 0),
latency_ms=latency_ms,
model=model,
raw={"streamed": True, "usage": usage},
)
except httpx.HTTPStatusError:
raise
except httpx.HTTPError as exc:
last_exc = exc
if attempt < self.max_retries - 1:
import asyncio
await asyncio.sleep(2**attempt)
raise RuntimeError(
f"All {self.max_retries} attempts failed for {url}"
) from last_exc
async def _consume_stream(
self, resp: httpx.Response
) -> tuple[str, dict[str, int]]:
import json as _json
chunks: list[str] = []
usage: dict[str, int] = {}
async for line in resp.aiter_lines():
if not line.startswith("data: "):
continue
payload = line[6:]
if payload.strip() == "[DONE]":
break
try:
data = _json.loads(payload)
except _json.JSONDecodeError:
continue
for choice in data.get("choices", []):
delta = choice.get("delta", {})
if "content" in delta and delta["content"]:
chunks.append(delta["content"])
if "usage" in data and data["usage"]:
usage = data["usage"]
return "".join(chunks), usage
async def list_models(self) -> list[str]:
url = f"{self.base_url}/models"
async with self._client() as client:
resp = await client.get(url)
resp.raise_for_status()
data = resp.json()
models: list[str] = []
for entry in data.get("data", []):
model_id = entry.get("id")
if model_id:
models.append(model_id)
return sorted(models)
async def test_connection(self) -> bool:
try:
await self.list_models()
return True
except Exception:
return False
@staticmethod
def _build_request_body(
prompt: str, model: str, params: dict[str, Any]
) -> dict[str, Any]:
messages: list[dict[str, str]] = []
system_message = params.get("system_message") or params.get("system")
if system_message:
messages.append({"role": "system", "content": system_message})
messages.append({"role": "user", "content": prompt})
body: dict[str, Any] = {"model": model, "messages": messages}
passthrough_keys = {
"temperature",
"max_tokens",
"top_p",
"frequency_penalty",
"presence_penalty",
"stop",
"seed",
}
for key in passthrough_keys:
if key in params:
body[key] = params[key]
return body
@staticmethod
def _parse_response(
data: dict[str, Any], latency_ms: float, model: str
) -> AdapterResponse:
choices = data.get("choices", [])
text = ""
if choices:
message = choices[0].get("message", {})
text = message.get("content", "")
usage = data.get("usage", {})
tokens_in = usage.get("prompt_tokens", 0)
tokens_out = usage.get("completion_tokens", 0)
return AdapterResponse(
text=text,
tokens_in=tokens_in,
tokens_out=tokens_out,
latency_ms=latency_ms,
model=data.get("model", model),
raw=data,
)