Add OpenAICompatAdapter that works with any OpenAI-compatible API endpoint (OpenWebUI, vLLM, Ollama, OpenAI, Anthropic via proxy). Features: - Async HTTP calls via httpx with configurable timeout - Chat completions format with system + user messages - Token usage parsing from responses - Exponential backoff retries (configurable, default 3 attempts) - Both streaming (SSE) and non-streaming modes - Model listing and connection testing - 21 tests covering construction, request building, response parsing, retry logic, and error handling
312 lines
9.9 KiB
Python
312 lines
9.9 KiB
Python
"""Tests for the OpenAI-compatible adapter."""
|
|
|
|
import json
|
|
import pytest
|
|
import httpx
|
|
|
|
from engine.adapters.base import AdapterResponse, BaseAdapter
|
|
from engine.adapters.openai_compat import OpenAICompatAdapter
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures & helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
FAKE_COMPLETION = {
|
|
"id": "chatcmpl-test",
|
|
"object": "chat.completion",
|
|
"model": "test-model",
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {"role": "assistant", "content": "Hello back!"},
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
"usage": {"prompt_tokens": 5, "completion_tokens": 3, "total_tokens": 8},
|
|
}
|
|
|
|
FAKE_MODELS = {
|
|
"object": "list",
|
|
"data": [
|
|
{"id": "model-b", "object": "model"},
|
|
{"id": "model-a", "object": "model"},
|
|
],
|
|
}
|
|
|
|
FAKE_STREAM_LINES = [
|
|
'data: {"id":"c1","choices":[{"delta":{"role":"assistant"},"index":0}]}\n',
|
|
'data: {"id":"c1","choices":[{"delta":{"content":"Hi"},"index":0}]}\n',
|
|
'data: {"id":"c1","choices":[{"delta":{"content":" there"},"index":0}]}\n',
|
|
'data: {"id":"c1","choices":[{"delta":{},"index":0,"finish_reason":"stop"}],"usage":{"prompt_tokens":4,"completion_tokens":2}}\n',
|
|
"data: [DONE]\n",
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def adapter():
|
|
return OpenAICompatAdapter(
|
|
base_url="https://fake.api/v1",
|
|
api_key="sk-test-key",
|
|
max_retries=1,
|
|
)
|
|
|
|
|
|
def _mock_response(status_code=200, json_data=None, text=""):
|
|
"""Build an httpx.Response suitable for mocking."""
|
|
if json_data is not None:
|
|
content = json.dumps(json_data).encode()
|
|
else:
|
|
content = text.encode()
|
|
return httpx.Response(
|
|
status_code=status_code,
|
|
content=content,
|
|
request=httpx.Request("POST", "https://fake.api/v1/chat/completions"),
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Construction & interface
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestConstruction:
|
|
def test_is_base_adapter_subclass(self):
|
|
assert issubclass(OpenAICompatAdapter, BaseAdapter)
|
|
|
|
def test_base_url_trailing_slash_stripped(self):
|
|
a = OpenAICompatAdapter(base_url="https://api.example.com/v1/")
|
|
assert a.base_url == "https://api.example.com/v1"
|
|
|
|
def test_default_params(self):
|
|
a = OpenAICompatAdapter(base_url="https://x")
|
|
assert a.api_key is None
|
|
assert a.timeout == 120.0
|
|
assert a.max_retries == 3
|
|
|
|
def test_headers_with_api_key(self, adapter):
|
|
h = adapter._headers()
|
|
assert h["Authorization"] == "Bearer sk-test-key"
|
|
assert h["Content-Type"] == "application/json"
|
|
|
|
def test_headers_without_api_key(self):
|
|
a = OpenAICompatAdapter(base_url="http://localhost:11434/v1")
|
|
h = a._headers()
|
|
assert "Authorization" not in h
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Request body building
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestBuildRequestBody:
|
|
def test_basic(self):
|
|
body = OpenAICompatAdapter._build_request_body(
|
|
"Hello", "gpt-4", {}
|
|
)
|
|
assert body["model"] == "gpt-4"
|
|
assert body["messages"] == [{"role": "user", "content": "Hello"}]
|
|
|
|
def test_with_system_message(self):
|
|
body = OpenAICompatAdapter._build_request_body(
|
|
"Hi", "m", {"system_message": "You are helpful."}
|
|
)
|
|
assert body["messages"][0] == {
|
|
"role": "system",
|
|
"content": "You are helpful.",
|
|
}
|
|
assert body["messages"][1]["role"] == "user"
|
|
|
|
def test_with_system_alias(self):
|
|
body = OpenAICompatAdapter._build_request_body(
|
|
"Hi", "m", {"system": "Be brief."}
|
|
)
|
|
assert body["messages"][0]["content"] == "Be brief."
|
|
|
|
def test_passthrough_params(self):
|
|
body = OpenAICompatAdapter._build_request_body(
|
|
"x", "m", {"temperature": 0.7, "max_tokens": 100, "top_p": 0.9}
|
|
)
|
|
assert body["temperature"] == 0.7
|
|
assert body["max_tokens"] == 100
|
|
assert body["top_p"] == 0.9
|
|
|
|
def test_unknown_params_ignored(self):
|
|
body = OpenAICompatAdapter._build_request_body(
|
|
"x", "m", {"custom_thing": 42, "stream": True}
|
|
)
|
|
assert "custom_thing" not in body
|
|
# stream is handled separately, not in passthrough
|
|
assert "stream" not in body
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Response parsing
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestParseResponse:
|
|
def test_standard_response(self):
|
|
resp = OpenAICompatAdapter._parse_response(
|
|
FAKE_COMPLETION, latency_ms=55.0, model="gpt-4"
|
|
)
|
|
assert isinstance(resp, AdapterResponse)
|
|
assert resp.text == "Hello back!"
|
|
assert resp.tokens_in == 5
|
|
assert resp.tokens_out == 3
|
|
assert resp.latency_ms == 55.0
|
|
assert resp.model == "test-model"
|
|
assert resp.raw == FAKE_COMPLETION
|
|
|
|
def test_empty_choices(self):
|
|
data = {**FAKE_COMPLETION, "choices": []}
|
|
resp = OpenAICompatAdapter._parse_response(data, 10.0, "m")
|
|
assert resp.text == ""
|
|
|
|
def test_missing_usage(self):
|
|
data = dict(FAKE_COMPLETION)
|
|
del data["usage"]
|
|
resp = OpenAICompatAdapter._parse_response(data, 10.0, "m")
|
|
assert resp.tokens_in == 0
|
|
assert resp.tokens_out == 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Non-streaming complete (mocked transport)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCompleteNonStream:
|
|
@pytest.mark.asyncio
|
|
async def test_success(self, adapter, monkeypatch):
|
|
async def fake_post(self_client, url, **kwargs):
|
|
return _mock_response(200, FAKE_COMPLETION)
|
|
|
|
monkeypatch.setattr(httpx.AsyncClient, "post", fake_post)
|
|
|
|
resp = await adapter.complete("Hi", "gpt-4", {})
|
|
assert resp.text == "Hello back!"
|
|
assert resp.tokens_in == 5
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_non_retryable_error_raises_immediately(
|
|
self, adapter, monkeypatch
|
|
):
|
|
call_count = 0
|
|
|
|
async def fake_post(self_client, url, **kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
return _mock_response(401)
|
|
|
|
monkeypatch.setattr(httpx.AsyncClient, "post", fake_post)
|
|
|
|
with pytest.raises(httpx.HTTPStatusError):
|
|
await adapter.complete("Hi", "m", {})
|
|
assert call_count == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retryable_error_exhausts_retries(
|
|
self, monkeypatch
|
|
):
|
|
a = OpenAICompatAdapter(
|
|
base_url="https://fake.api/v1",
|
|
api_key="k",
|
|
max_retries=2,
|
|
)
|
|
call_count = 0
|
|
|
|
async def fake_post(self_client, url, **kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
return _mock_response(503)
|
|
|
|
monkeypatch.setattr(httpx.AsyncClient, "post", fake_post)
|
|
# Patch sleep to avoid real delays
|
|
monkeypatch.setattr("asyncio.sleep", lambda _: __import__("asyncio").coroutine(lambda: None)())
|
|
|
|
import asyncio
|
|
|
|
async def fake_sleep(delay):
|
|
pass
|
|
|
|
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
|
|
|
|
with pytest.raises(RuntimeError, match="All 2 attempts failed"):
|
|
await a.complete("Hi", "m", {})
|
|
assert call_count == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retry_then_success(self, monkeypatch):
|
|
a = OpenAICompatAdapter(
|
|
base_url="https://fake.api/v1",
|
|
api_key="k",
|
|
max_retries=3,
|
|
)
|
|
call_count = 0
|
|
|
|
async def fake_post(self_client, url, **kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count < 3:
|
|
return _mock_response(429)
|
|
return _mock_response(200, FAKE_COMPLETION)
|
|
|
|
monkeypatch.setattr(httpx.AsyncClient, "post", fake_post)
|
|
|
|
import asyncio
|
|
|
|
async def fake_sleep(delay):
|
|
pass
|
|
|
|
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
|
|
|
|
resp = await a.complete("Hi", "m", {})
|
|
assert resp.text == "Hello back!"
|
|
assert call_count == 3
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# list_models & test_connection
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestListModels:
|
|
@pytest.mark.asyncio
|
|
async def test_success(self, adapter, monkeypatch):
|
|
async def fake_get(self_client, url, **kwargs):
|
|
return _mock_response(200, FAKE_MODELS)
|
|
|
|
monkeypatch.setattr(httpx.AsyncClient, "get", fake_get)
|
|
|
|
models = await adapter.list_models()
|
|
assert models == ["model-a", "model-b"] # sorted
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_response(self, adapter, monkeypatch):
|
|
async def fake_get(self_client, url, **kwargs):
|
|
return _mock_response(200, {"data": []})
|
|
|
|
monkeypatch.setattr(httpx.AsyncClient, "get", fake_get)
|
|
|
|
models = await adapter.list_models()
|
|
assert models == []
|
|
|
|
|
|
class TestTestConnection:
|
|
@pytest.mark.asyncio
|
|
async def test_returns_true_on_success(self, adapter, monkeypatch):
|
|
async def fake_get(self_client, url, **kwargs):
|
|
return _mock_response(200, FAKE_MODELS)
|
|
|
|
monkeypatch.setattr(httpx.AsyncClient, "get", fake_get)
|
|
assert await adapter.test_connection() is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_false_on_failure(self, adapter, monkeypatch):
|
|
async def fake_get(self_client, url, **kwargs):
|
|
raise httpx.ConnectError("refused")
|
|
|
|
monkeypatch.setattr(httpx.AsyncClient, "get", fake_get)
|
|
assert await adapter.test_connection() is False
|