MAESTRO: Implement OpenAI-compatible LLM adapter with streaming, retries, and tests
Add OpenAICompatAdapter that works with any OpenAI-compatible API endpoint (OpenWebUI, vLLM, Ollama, OpenAI, Anthropic via proxy). Features: - Async HTTP calls via httpx with configurable timeout - Chat completions format with system + user messages - Token usage parsing from responses - Exponential backoff retries (configurable, default 3 attempts) - Both streaming (SSE) and non-streaming modes - Model listing and connection testing - 21 tests covering construction, request building, response parsing, retry logic, and error handling
This commit is contained in:
parent
060f399789
commit
bf1e9d1c84
4 changed files with 563 additions and 2 deletions
|
|
@ -4,7 +4,7 @@ Implement the core experiment execution engine: LLM adapters, response caching,
|
||||||
|
|
||||||
- [x] Implement backend/engine/adapters/base.py defining the BaseAdapter abstract class with methods: complete(prompt, model, params) → AdapterResponse (containing response text, tokens_in, tokens_out, latency_ms), list_models() → list of model identifiers, and test_connection() → bool. Define the AdapterResponse dataclass.
|
- [x] Implement backend/engine/adapters/base.py defining the BaseAdapter abstract class with methods: complete(prompt, model, params) → AdapterResponse (containing response text, tokens_in, tokens_out, latency_ms), list_models() → list of model identifiers, and test_connection() → bool. Define the AdapterResponse dataclass.
|
||||||
|
|
||||||
- [ ] Implement backend/engine/adapters/openai_compat.py as the primary adapter. It should work with any OpenAI-compatible API (OpenWebUI, vLLM, Ollama, OpenAI, Anthropic via proxy). Use httpx for async HTTP calls. Support chat completions format with system + user messages. Parse token usage from the response. Handle errors gracefully with retries (3 attempts, exponential backoff). Support both streaming and non-streaming modes.
|
- [x] Implement backend/engine/adapters/openai_compat.py as the primary adapter. It should work with any OpenAI-compatible API (OpenWebUI, vLLM, Ollama, OpenAI, Anthropic via proxy). Use httpx for async HTTP calls. Support chat completions format with system + user messages. Parse token usage from the response. Handle errors gracefully with retries (3 attempts, exponential backoff). Support both streaming and non-streaming modes.
|
||||||
|
|
||||||
- [ ] Implement backend/engine/cache.py with the ResponseCache layer. Key function: compute_config_hash(prompt, model, params, input_data) → SHA-256 hex string. Methods: get(config_hash) → CachedResponse or None, put(config_hash, response, metadata). In SQLite mode, use the ResponseCache table directly. In Postgres mode, same table but with connection pooling. Include a cache_stats() method returning hit rate, total entries, and storage size.
|
- [ ] Implement backend/engine/cache.py with the ResponseCache layer. Key function: compute_config_hash(prompt, model, params, input_data) → SHA-256 hex string. Methods: get(config_hash) → CachedResponse or None, put(config_hash, response, metadata). In SQLite mode, use the ResponseCache table directly. In Postgres mode, same table but with connection pooling. Include a cache_stats() method returning hit rate, total entries, and storage size.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
"""LLM endpoint adapters."""
|
"""LLM endpoint adapters."""
|
||||||
|
|
||||||
from engine.adapters.base import AdapterResponse, BaseAdapter
|
from engine.adapters.base import AdapterResponse, BaseAdapter
|
||||||
|
from engine.adapters.openai_compat import OpenAICompatAdapter
|
||||||
|
|
||||||
__all__ = ["AdapterResponse", "BaseAdapter"]
|
__all__ = ["AdapterResponse", "BaseAdapter", "OpenAICompatAdapter"]
|
||||||
|
|
|
||||||
248
backend/engine/adapters/openai_compat.py
Normal file
248
backend/engine/adapters/openai_compat.py
Normal file
|
|
@ -0,0 +1,248 @@
|
||||||
|
"""OpenAI-compatible LLM adapter.
|
||||||
|
|
||||||
|
Works with any endpoint that speaks the OpenAI chat completions API:
|
||||||
|
OpenAI, OpenWebUI, vLLM, Ollama, Anthropic via proxy, etc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Any, AsyncIterator
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from engine.adapters.base import AdapterResponse, BaseAdapter
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAICompatAdapter(BaseAdapter):
|
||||||
|
"""Adapter for OpenAI-compatible chat completion APIs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_url: API base URL (e.g. "https://api.openai.com/v1").
|
||||||
|
api_key: Bearer token for authentication. Optional for local endpoints.
|
||||||
|
timeout: Request timeout in seconds.
|
||||||
|
max_retries: Number of retry attempts on transient failures.
|
||||||
|
"""
|
||||||
|
|
||||||
|
RETRYABLE_STATUS_CODES = {429, 500, 502, 503, 504}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str,
|
||||||
|
api_key: str | None = None,
|
||||||
|
timeout: float = 120.0,
|
||||||
|
max_retries: int = 3,
|
||||||
|
) -> None:
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.api_key = api_key
|
||||||
|
self.timeout = timeout
|
||||||
|
self.max_retries = max_retries
|
||||||
|
|
||||||
|
def _headers(self) -> dict[str, str]:
|
||||||
|
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||||
|
if self.api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def _client(self) -> httpx.AsyncClient:
|
||||||
|
return httpx.AsyncClient(
|
||||||
|
timeout=httpx.Timeout(self.timeout),
|
||||||
|
headers=self._headers(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def complete(
|
||||||
|
self, prompt: str, model: str, params: dict[str, Any]
|
||||||
|
) -> AdapterResponse:
|
||||||
|
body = self._build_request_body(prompt, model, params)
|
||||||
|
stream = params.get("stream", False)
|
||||||
|
body["stream"] = stream
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return await self._complete_stream(body, model)
|
||||||
|
return await self._complete_non_stream(body, model)
|
||||||
|
|
||||||
|
async def _complete_non_stream(
|
||||||
|
self, body: dict[str, Any], model: str
|
||||||
|
) -> AdapterResponse:
|
||||||
|
url = f"{self.base_url}/chat/completions"
|
||||||
|
last_exc: Exception | None = None
|
||||||
|
|
||||||
|
for attempt in range(self.max_retries):
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
try:
|
||||||
|
async with self._client() as client:
|
||||||
|
resp = await client.post(url, json=body)
|
||||||
|
|
||||||
|
latency_ms = (time.perf_counter() - t0) * 1000
|
||||||
|
|
||||||
|
if resp.status_code == 200:
|
||||||
|
data = resp.json()
|
||||||
|
return self._parse_response(data, latency_ms, model)
|
||||||
|
|
||||||
|
if resp.status_code not in self.RETRYABLE_STATUS_CODES:
|
||||||
|
resp.raise_for_status()
|
||||||
|
|
||||||
|
last_exc = httpx.HTTPStatusError(
|
||||||
|
f"HTTP {resp.status_code}",
|
||||||
|
request=resp.request,
|
||||||
|
response=resp,
|
||||||
|
)
|
||||||
|
except httpx.HTTPStatusError:
|
||||||
|
raise
|
||||||
|
except httpx.HTTPError as exc:
|
||||||
|
last_exc = exc
|
||||||
|
|
||||||
|
if attempt < self.max_retries - 1:
|
||||||
|
delay = 2**attempt # 1s, 2s, 4s …
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
f"All {self.max_retries} attempts failed for {url}"
|
||||||
|
) from last_exc
|
||||||
|
|
||||||
|
async def _complete_stream(
|
||||||
|
self, body: dict[str, Any], model: str
|
||||||
|
) -> AdapterResponse:
|
||||||
|
url = f"{self.base_url}/chat/completions"
|
||||||
|
last_exc: Exception | None = None
|
||||||
|
|
||||||
|
for attempt in range(self.max_retries):
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
try:
|
||||||
|
async with self._client() as client:
|
||||||
|
async with client.stream("POST", url, json=body) as resp:
|
||||||
|
if resp.status_code != 200:
|
||||||
|
await resp.aread()
|
||||||
|
if resp.status_code not in self.RETRYABLE_STATUS_CODES:
|
||||||
|
resp.raise_for_status()
|
||||||
|
last_exc = httpx.HTTPStatusError(
|
||||||
|
f"HTTP {resp.status_code}",
|
||||||
|
request=resp.request,
|
||||||
|
response=resp,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
text, usage = await self._consume_stream(resp)
|
||||||
|
latency_ms = (time.perf_counter() - t0) * 1000
|
||||||
|
return AdapterResponse(
|
||||||
|
text=text,
|
||||||
|
tokens_in=usage.get("prompt_tokens", 0),
|
||||||
|
tokens_out=usage.get("completion_tokens", 0),
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
model=model,
|
||||||
|
raw={"streamed": True, "usage": usage},
|
||||||
|
)
|
||||||
|
except httpx.HTTPStatusError:
|
||||||
|
raise
|
||||||
|
except httpx.HTTPError as exc:
|
||||||
|
last_exc = exc
|
||||||
|
|
||||||
|
if attempt < self.max_retries - 1:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
await asyncio.sleep(2**attempt)
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
f"All {self.max_retries} attempts failed for {url}"
|
||||||
|
) from last_exc
|
||||||
|
|
||||||
|
async def _consume_stream(
|
||||||
|
self, resp: httpx.Response
|
||||||
|
) -> tuple[str, dict[str, int]]:
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
chunks: list[str] = []
|
||||||
|
usage: dict[str, int] = {}
|
||||||
|
|
||||||
|
async for line in resp.aiter_lines():
|
||||||
|
if not line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
payload = line[6:]
|
||||||
|
if payload.strip() == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
data = _json.loads(payload)
|
||||||
|
except _json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for choice in data.get("choices", []):
|
||||||
|
delta = choice.get("delta", {})
|
||||||
|
if "content" in delta and delta["content"]:
|
||||||
|
chunks.append(delta["content"])
|
||||||
|
|
||||||
|
if "usage" in data and data["usage"]:
|
||||||
|
usage = data["usage"]
|
||||||
|
|
||||||
|
return "".join(chunks), usage
|
||||||
|
|
||||||
|
async def list_models(self) -> list[str]:
|
||||||
|
url = f"{self.base_url}/models"
|
||||||
|
async with self._client() as client:
|
||||||
|
resp = await client.get(url)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
models: list[str] = []
|
||||||
|
for entry in data.get("data", []):
|
||||||
|
model_id = entry.get("id")
|
||||||
|
if model_id:
|
||||||
|
models.append(model_id)
|
||||||
|
return sorted(models)
|
||||||
|
|
||||||
|
async def test_connection(self) -> bool:
|
||||||
|
try:
|
||||||
|
await self.list_models()
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_request_body(
|
||||||
|
prompt: str, model: str, params: dict[str, Any]
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
messages: list[dict[str, str]] = []
|
||||||
|
|
||||||
|
system_message = params.get("system_message") or params.get("system")
|
||||||
|
if system_message:
|
||||||
|
messages.append({"role": "system", "content": system_message})
|
||||||
|
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
body: dict[str, Any] = {"model": model, "messages": messages}
|
||||||
|
|
||||||
|
passthrough_keys = {
|
||||||
|
"temperature",
|
||||||
|
"max_tokens",
|
||||||
|
"top_p",
|
||||||
|
"frequency_penalty",
|
||||||
|
"presence_penalty",
|
||||||
|
"stop",
|
||||||
|
"seed",
|
||||||
|
}
|
||||||
|
for key in passthrough_keys:
|
||||||
|
if key in params:
|
||||||
|
body[key] = params[key]
|
||||||
|
|
||||||
|
return body
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_response(
|
||||||
|
data: dict[str, Any], latency_ms: float, model: str
|
||||||
|
) -> AdapterResponse:
|
||||||
|
choices = data.get("choices", [])
|
||||||
|
text = ""
|
||||||
|
if choices:
|
||||||
|
message = choices[0].get("message", {})
|
||||||
|
text = message.get("content", "")
|
||||||
|
|
||||||
|
usage = data.get("usage", {})
|
||||||
|
tokens_in = usage.get("prompt_tokens", 0)
|
||||||
|
tokens_out = usage.get("completion_tokens", 0)
|
||||||
|
|
||||||
|
return AdapterResponse(
|
||||||
|
text=text,
|
||||||
|
tokens_in=tokens_in,
|
||||||
|
tokens_out=tokens_out,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
model=data.get("model", model),
|
||||||
|
raw=data,
|
||||||
|
)
|
||||||
312
backend/tests/test_openai_compat.py
Normal file
312
backend/tests/test_openai_compat.py
Normal file
|
|
@ -0,0 +1,312 @@
|
||||||
|
"""Tests for the OpenAI-compatible adapter."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from engine.adapters.base import AdapterResponse, BaseAdapter
|
||||||
|
from engine.adapters.openai_compat import OpenAICompatAdapter
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures & helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
FAKE_COMPLETION = {
|
||||||
|
"id": "chatcmpl-test",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"model": "test-model",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {"role": "assistant", "content": "Hello back!"},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 5, "completion_tokens": 3, "total_tokens": 8},
|
||||||
|
}
|
||||||
|
|
||||||
|
FAKE_MODELS = {
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{"id": "model-b", "object": "model"},
|
||||||
|
{"id": "model-a", "object": "model"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
FAKE_STREAM_LINES = [
|
||||||
|
'data: {"id":"c1","choices":[{"delta":{"role":"assistant"},"index":0}]}\n',
|
||||||
|
'data: {"id":"c1","choices":[{"delta":{"content":"Hi"},"index":0}]}\n',
|
||||||
|
'data: {"id":"c1","choices":[{"delta":{"content":" there"},"index":0}]}\n',
|
||||||
|
'data: {"id":"c1","choices":[{"delta":{},"index":0,"finish_reason":"stop"}],"usage":{"prompt_tokens":4,"completion_tokens":2}}\n',
|
||||||
|
"data: [DONE]\n",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def adapter():
|
||||||
|
return OpenAICompatAdapter(
|
||||||
|
base_url="https://fake.api/v1",
|
||||||
|
api_key="sk-test-key",
|
||||||
|
max_retries=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_response(status_code=200, json_data=None, text=""):
|
||||||
|
"""Build an httpx.Response suitable for mocking."""
|
||||||
|
if json_data is not None:
|
||||||
|
content = json.dumps(json_data).encode()
|
||||||
|
else:
|
||||||
|
content = text.encode()
|
||||||
|
return httpx.Response(
|
||||||
|
status_code=status_code,
|
||||||
|
content=content,
|
||||||
|
request=httpx.Request("POST", "https://fake.api/v1/chat/completions"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Construction & interface
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestConstruction:
|
||||||
|
def test_is_base_adapter_subclass(self):
|
||||||
|
assert issubclass(OpenAICompatAdapter, BaseAdapter)
|
||||||
|
|
||||||
|
def test_base_url_trailing_slash_stripped(self):
|
||||||
|
a = OpenAICompatAdapter(base_url="https://api.example.com/v1/")
|
||||||
|
assert a.base_url == "https://api.example.com/v1"
|
||||||
|
|
||||||
|
def test_default_params(self):
|
||||||
|
a = OpenAICompatAdapter(base_url="https://x")
|
||||||
|
assert a.api_key is None
|
||||||
|
assert a.timeout == 120.0
|
||||||
|
assert a.max_retries == 3
|
||||||
|
|
||||||
|
def test_headers_with_api_key(self, adapter):
|
||||||
|
h = adapter._headers()
|
||||||
|
assert h["Authorization"] == "Bearer sk-test-key"
|
||||||
|
assert h["Content-Type"] == "application/json"
|
||||||
|
|
||||||
|
def test_headers_without_api_key(self):
|
||||||
|
a = OpenAICompatAdapter(base_url="http://localhost:11434/v1")
|
||||||
|
h = a._headers()
|
||||||
|
assert "Authorization" not in h
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Request body building
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildRequestBody:
|
||||||
|
def test_basic(self):
|
||||||
|
body = OpenAICompatAdapter._build_request_body(
|
||||||
|
"Hello", "gpt-4", {}
|
||||||
|
)
|
||||||
|
assert body["model"] == "gpt-4"
|
||||||
|
assert body["messages"] == [{"role": "user", "content": "Hello"}]
|
||||||
|
|
||||||
|
def test_with_system_message(self):
|
||||||
|
body = OpenAICompatAdapter._build_request_body(
|
||||||
|
"Hi", "m", {"system_message": "You are helpful."}
|
||||||
|
)
|
||||||
|
assert body["messages"][0] == {
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are helpful.",
|
||||||
|
}
|
||||||
|
assert body["messages"][1]["role"] == "user"
|
||||||
|
|
||||||
|
def test_with_system_alias(self):
|
||||||
|
body = OpenAICompatAdapter._build_request_body(
|
||||||
|
"Hi", "m", {"system": "Be brief."}
|
||||||
|
)
|
||||||
|
assert body["messages"][0]["content"] == "Be brief."
|
||||||
|
|
||||||
|
def test_passthrough_params(self):
|
||||||
|
body = OpenAICompatAdapter._build_request_body(
|
||||||
|
"x", "m", {"temperature": 0.7, "max_tokens": 100, "top_p": 0.9}
|
||||||
|
)
|
||||||
|
assert body["temperature"] == 0.7
|
||||||
|
assert body["max_tokens"] == 100
|
||||||
|
assert body["top_p"] == 0.9
|
||||||
|
|
||||||
|
def test_unknown_params_ignored(self):
|
||||||
|
body = OpenAICompatAdapter._build_request_body(
|
||||||
|
"x", "m", {"custom_thing": 42, "stream": True}
|
||||||
|
)
|
||||||
|
assert "custom_thing" not in body
|
||||||
|
# stream is handled separately, not in passthrough
|
||||||
|
assert "stream" not in body
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Response parsing
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseResponse:
|
||||||
|
def test_standard_response(self):
|
||||||
|
resp = OpenAICompatAdapter._parse_response(
|
||||||
|
FAKE_COMPLETION, latency_ms=55.0, model="gpt-4"
|
||||||
|
)
|
||||||
|
assert isinstance(resp, AdapterResponse)
|
||||||
|
assert resp.text == "Hello back!"
|
||||||
|
assert resp.tokens_in == 5
|
||||||
|
assert resp.tokens_out == 3
|
||||||
|
assert resp.latency_ms == 55.0
|
||||||
|
assert resp.model == "test-model"
|
||||||
|
assert resp.raw == FAKE_COMPLETION
|
||||||
|
|
||||||
|
def test_empty_choices(self):
|
||||||
|
data = {**FAKE_COMPLETION, "choices": []}
|
||||||
|
resp = OpenAICompatAdapter._parse_response(data, 10.0, "m")
|
||||||
|
assert resp.text == ""
|
||||||
|
|
||||||
|
def test_missing_usage(self):
|
||||||
|
data = dict(FAKE_COMPLETION)
|
||||||
|
del data["usage"]
|
||||||
|
resp = OpenAICompatAdapter._parse_response(data, 10.0, "m")
|
||||||
|
assert resp.tokens_in == 0
|
||||||
|
assert resp.tokens_out == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Non-streaming complete (mocked transport)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompleteNonStream:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_success(self, adapter, monkeypatch):
|
||||||
|
async def fake_post(self_client, url, **kwargs):
|
||||||
|
return _mock_response(200, FAKE_COMPLETION)
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "post", fake_post)
|
||||||
|
|
||||||
|
resp = await adapter.complete("Hi", "gpt-4", {})
|
||||||
|
assert resp.text == "Hello back!"
|
||||||
|
assert resp.tokens_in == 5
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_non_retryable_error_raises_immediately(
|
||||||
|
self, adapter, monkeypatch
|
||||||
|
):
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def fake_post(self_client, url, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return _mock_response(401)
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "post", fake_post)
|
||||||
|
|
||||||
|
with pytest.raises(httpx.HTTPStatusError):
|
||||||
|
await adapter.complete("Hi", "m", {})
|
||||||
|
assert call_count == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retryable_error_exhausts_retries(
|
||||||
|
self, monkeypatch
|
||||||
|
):
|
||||||
|
a = OpenAICompatAdapter(
|
||||||
|
base_url="https://fake.api/v1",
|
||||||
|
api_key="k",
|
||||||
|
max_retries=2,
|
||||||
|
)
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def fake_post(self_client, url, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return _mock_response(503)
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "post", fake_post)
|
||||||
|
# Patch sleep to avoid real delays
|
||||||
|
monkeypatch.setattr("asyncio.sleep", lambda _: __import__("asyncio").coroutine(lambda: None)())
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def fake_sleep(delay):
|
||||||
|
pass
|
||||||
|
|
||||||
|
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="All 2 attempts failed"):
|
||||||
|
await a.complete("Hi", "m", {})
|
||||||
|
assert call_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retry_then_success(self, monkeypatch):
|
||||||
|
a = OpenAICompatAdapter(
|
||||||
|
base_url="https://fake.api/v1",
|
||||||
|
api_key="k",
|
||||||
|
max_retries=3,
|
||||||
|
)
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def fake_post(self_client, url, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count < 3:
|
||||||
|
return _mock_response(429)
|
||||||
|
return _mock_response(200, FAKE_COMPLETION)
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "post", fake_post)
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def fake_sleep(delay):
|
||||||
|
pass
|
||||||
|
|
||||||
|
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
|
||||||
|
|
||||||
|
resp = await a.complete("Hi", "m", {})
|
||||||
|
assert resp.text == "Hello back!"
|
||||||
|
assert call_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# list_models & test_connection
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestListModels:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_success(self, adapter, monkeypatch):
|
||||||
|
async def fake_get(self_client, url, **kwargs):
|
||||||
|
return _mock_response(200, FAKE_MODELS)
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "get", fake_get)
|
||||||
|
|
||||||
|
models = await adapter.list_models()
|
||||||
|
assert models == ["model-a", "model-b"] # sorted
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_response(self, adapter, monkeypatch):
|
||||||
|
async def fake_get(self_client, url, **kwargs):
|
||||||
|
return _mock_response(200, {"data": []})
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "get", fake_get)
|
||||||
|
|
||||||
|
models = await adapter.list_models()
|
||||||
|
assert models == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestTestConnection:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_true_on_success(self, adapter, monkeypatch):
|
||||||
|
async def fake_get(self_client, url, **kwargs):
|
||||||
|
return _mock_response(200, FAKE_MODELS)
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "get", fake_get)
|
||||||
|
assert await adapter.test_connection() is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_false_on_failure(self, adapter, monkeypatch):
|
||||||
|
async def fake_get(self_client, url, **kwargs):
|
||||||
|
raise httpx.ConnectError("refused")
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "get", fake_get)
|
||||||
|
assert await adapter.test_connection() is False
|
||||||
Loading…
Add table
Reference in a new issue