"""Tests for the OpenAI-compatible adapter.""" import json import pytest import httpx from engine.adapters.base import AdapterResponse, BaseAdapter from engine.adapters.openai_compat import OpenAICompatAdapter # --------------------------------------------------------------------------- # Fixtures & helpers # --------------------------------------------------------------------------- FAKE_COMPLETION = { "id": "chatcmpl-test", "object": "chat.completion", "model": "test-model", "choices": [ { "index": 0, "message": {"role": "assistant", "content": "Hello back!"}, "finish_reason": "stop", } ], "usage": {"prompt_tokens": 5, "completion_tokens": 3, "total_tokens": 8}, } FAKE_MODELS = { "object": "list", "data": [ {"id": "model-b", "object": "model"}, {"id": "model-a", "object": "model"}, ], } FAKE_STREAM_LINES = [ 'data: {"id":"c1","choices":[{"delta":{"role":"assistant"},"index":0}]}\n', 'data: {"id":"c1","choices":[{"delta":{"content":"Hi"},"index":0}]}\n', 'data: {"id":"c1","choices":[{"delta":{"content":" there"},"index":0}]}\n', 'data: {"id":"c1","choices":[{"delta":{},"index":0,"finish_reason":"stop"}],"usage":{"prompt_tokens":4,"completion_tokens":2}}\n', "data: [DONE]\n", ] @pytest.fixture def adapter(): return OpenAICompatAdapter( base_url="https://fake.api/v1", api_key="sk-test-key", max_retries=1, ) def _mock_response(status_code=200, json_data=None, text=""): """Build an httpx.Response suitable for mocking.""" if json_data is not None: content = json.dumps(json_data).encode() else: content = text.encode() return httpx.Response( status_code=status_code, content=content, request=httpx.Request("POST", "https://fake.api/v1/chat/completions"), ) # --------------------------------------------------------------------------- # Construction & interface # --------------------------------------------------------------------------- class TestConstruction: def test_is_base_adapter_subclass(self): assert issubclass(OpenAICompatAdapter, BaseAdapter) def test_base_url_trailing_slash_stripped(self): a = OpenAICompatAdapter(base_url="https://api.example.com/v1/") assert a.base_url == "https://api.example.com/v1" def test_default_params(self): a = OpenAICompatAdapter(base_url="https://x") assert a.api_key is None assert a.timeout == 120.0 assert a.max_retries == 3 def test_headers_with_api_key(self, adapter): h = adapter._headers() assert h["Authorization"] == "Bearer sk-test-key" assert h["Content-Type"] == "application/json" def test_headers_without_api_key(self): a = OpenAICompatAdapter(base_url="http://localhost:11434/v1") h = a._headers() assert "Authorization" not in h # --------------------------------------------------------------------------- # Request body building # --------------------------------------------------------------------------- class TestBuildRequestBody: def test_basic(self): body = OpenAICompatAdapter._build_request_body( "Hello", "gpt-4", {} ) assert body["model"] == "gpt-4" assert body["messages"] == [{"role": "user", "content": "Hello"}] def test_with_system_message(self): body = OpenAICompatAdapter._build_request_body( "Hi", "m", {"system_message": "You are helpful."} ) assert body["messages"][0] == { "role": "system", "content": "You are helpful.", } assert body["messages"][1]["role"] == "user" def test_with_system_alias(self): body = OpenAICompatAdapter._build_request_body( "Hi", "m", {"system": "Be brief."} ) assert body["messages"][0]["content"] == "Be brief." def test_passthrough_params(self): body = OpenAICompatAdapter._build_request_body( "x", "m", {"temperature": 0.7, "max_tokens": 100, "top_p": 0.9} ) assert body["temperature"] == 0.7 assert body["max_tokens"] == 100 assert body["top_p"] == 0.9 def test_unknown_params_ignored(self): body = OpenAICompatAdapter._build_request_body( "x", "m", {"custom_thing": 42, "stream": True} ) assert "custom_thing" not in body # stream is handled separately, not in passthrough assert "stream" not in body # --------------------------------------------------------------------------- # Response parsing # --------------------------------------------------------------------------- class TestParseResponse: def test_standard_response(self): resp = OpenAICompatAdapter._parse_response( FAKE_COMPLETION, latency_ms=55.0, model="gpt-4" ) assert isinstance(resp, AdapterResponse) assert resp.text == "Hello back!" assert resp.tokens_in == 5 assert resp.tokens_out == 3 assert resp.latency_ms == 55.0 assert resp.model == "test-model" assert resp.raw == FAKE_COMPLETION def test_empty_choices(self): data = {**FAKE_COMPLETION, "choices": []} resp = OpenAICompatAdapter._parse_response(data, 10.0, "m") assert resp.text == "" def test_missing_usage(self): data = dict(FAKE_COMPLETION) del data["usage"] resp = OpenAICompatAdapter._parse_response(data, 10.0, "m") assert resp.tokens_in == 0 assert resp.tokens_out == 0 # --------------------------------------------------------------------------- # Non-streaming complete (mocked transport) # --------------------------------------------------------------------------- class TestCompleteNonStream: @pytest.mark.asyncio async def test_success(self, adapter, monkeypatch): async def fake_post(self_client, url, **kwargs): return _mock_response(200, FAKE_COMPLETION) monkeypatch.setattr(httpx.AsyncClient, "post", fake_post) resp = await adapter.complete("Hi", "gpt-4", {}) assert resp.text == "Hello back!" assert resp.tokens_in == 5 @pytest.mark.asyncio async def test_non_retryable_error_raises_immediately( self, adapter, monkeypatch ): call_count = 0 async def fake_post(self_client, url, **kwargs): nonlocal call_count call_count += 1 return _mock_response(401) monkeypatch.setattr(httpx.AsyncClient, "post", fake_post) with pytest.raises(httpx.HTTPStatusError): await adapter.complete("Hi", "m", {}) assert call_count == 1 @pytest.mark.asyncio async def test_retryable_error_exhausts_retries( self, monkeypatch ): a = OpenAICompatAdapter( base_url="https://fake.api/v1", api_key="k", max_retries=2, ) call_count = 0 async def fake_post(self_client, url, **kwargs): nonlocal call_count call_count += 1 return _mock_response(503) monkeypatch.setattr(httpx.AsyncClient, "post", fake_post) # Patch sleep to avoid real delays monkeypatch.setattr("asyncio.sleep", lambda _: __import__("asyncio").coroutine(lambda: None)()) import asyncio async def fake_sleep(delay): pass monkeypatch.setattr(asyncio, "sleep", fake_sleep) with pytest.raises(RuntimeError, match="All 2 attempts failed"): await a.complete("Hi", "m", {}) assert call_count == 2 @pytest.mark.asyncio async def test_retry_then_success(self, monkeypatch): a = OpenAICompatAdapter( base_url="https://fake.api/v1", api_key="k", max_retries=3, ) call_count = 0 async def fake_post(self_client, url, **kwargs): nonlocal call_count call_count += 1 if call_count < 3: return _mock_response(429) return _mock_response(200, FAKE_COMPLETION) monkeypatch.setattr(httpx.AsyncClient, "post", fake_post) import asyncio async def fake_sleep(delay): pass monkeypatch.setattr(asyncio, "sleep", fake_sleep) resp = await a.complete("Hi", "m", {}) assert resp.text == "Hello back!" assert call_count == 3 # --------------------------------------------------------------------------- # list_models & test_connection # --------------------------------------------------------------------------- class TestListModels: @pytest.mark.asyncio async def test_success(self, adapter, monkeypatch): async def fake_get(self_client, url, **kwargs): return _mock_response(200, FAKE_MODELS) monkeypatch.setattr(httpx.AsyncClient, "get", fake_get) models = await adapter.list_models() assert models == ["model-a", "model-b"] # sorted @pytest.mark.asyncio async def test_empty_response(self, adapter, monkeypatch): async def fake_get(self_client, url, **kwargs): return _mock_response(200, {"data": []}) monkeypatch.setattr(httpx.AsyncClient, "get", fake_get) models = await adapter.list_models() assert models == [] class TestTestConnection: @pytest.mark.asyncio async def test_returns_true_on_success(self, adapter, monkeypatch): async def fake_get(self_client, url, **kwargs): return _mock_response(200, FAKE_MODELS) monkeypatch.setattr(httpx.AsyncClient, "get", fake_get) assert await adapter.test_connection() is True @pytest.mark.asyncio async def test_returns_false_on_failure(self, adapter, monkeypatch): async def fake_get(self_client, url, **kwargs): raise httpx.ConnectError("refused") monkeypatch.setattr(httpx.AsyncClient, "get", fake_get) assert await adapter.test_connection() is False