diff --git a/Auto Run Docs/02a-backend-engine.md b/Auto Run Docs/02a-backend-engine.md index 7c11252..933a620 100644 --- a/Auto Run Docs/02a-backend-engine.md +++ b/Auto Run Docs/02a-backend-engine.md @@ -4,7 +4,7 @@ Implement the core experiment execution engine: LLM adapters, response caching, - [x] Implement backend/engine/adapters/base.py defining the BaseAdapter abstract class with methods: complete(prompt, model, params) → AdapterResponse (containing response text, tokens_in, tokens_out, latency_ms), list_models() → list of model identifiers, and test_connection() → bool. Define the AdapterResponse dataclass. -- [ ] Implement backend/engine/adapters/openai_compat.py as the primary adapter. It should work with any OpenAI-compatible API (OpenWebUI, vLLM, Ollama, OpenAI, Anthropic via proxy). Use httpx for async HTTP calls. Support chat completions format with system + user messages. Parse token usage from the response. Handle errors gracefully with retries (3 attempts, exponential backoff). Support both streaming and non-streaming modes. +- [x] Implement backend/engine/adapters/openai_compat.py as the primary adapter. It should work with any OpenAI-compatible API (OpenWebUI, vLLM, Ollama, OpenAI, Anthropic via proxy). Use httpx for async HTTP calls. Support chat completions format with system + user messages. Parse token usage from the response. Handle errors gracefully with retries (3 attempts, exponential backoff). Support both streaming and non-streaming modes. - [ ] Implement backend/engine/cache.py with the ResponseCache layer. Key function: compute_config_hash(prompt, model, params, input_data) → SHA-256 hex string. Methods: get(config_hash) → CachedResponse or None, put(config_hash, response, metadata). In SQLite mode, use the ResponseCache table directly. In Postgres mode, same table but with connection pooling. Include a cache_stats() method returning hit rate, total entries, and storage size. diff --git a/backend/engine/adapters/__init__.py b/backend/engine/adapters/__init__.py index 5b60247..5d62c03 100644 --- a/backend/engine/adapters/__init__.py +++ b/backend/engine/adapters/__init__.py @@ -1,5 +1,6 @@ """LLM endpoint adapters.""" from engine.adapters.base import AdapterResponse, BaseAdapter +from engine.adapters.openai_compat import OpenAICompatAdapter -__all__ = ["AdapterResponse", "BaseAdapter"] +__all__ = ["AdapterResponse", "BaseAdapter", "OpenAICompatAdapter"] diff --git a/backend/engine/adapters/openai_compat.py b/backend/engine/adapters/openai_compat.py new file mode 100644 index 0000000..5c55184 --- /dev/null +++ b/backend/engine/adapters/openai_compat.py @@ -0,0 +1,248 @@ +"""OpenAI-compatible LLM adapter. + +Works with any endpoint that speaks the OpenAI chat completions API: +OpenAI, OpenWebUI, vLLM, Ollama, Anthropic via proxy, etc. +""" + +import time +from typing import Any, AsyncIterator + +import httpx + +from engine.adapters.base import AdapterResponse, BaseAdapter + + +class OpenAICompatAdapter(BaseAdapter): + """Adapter for OpenAI-compatible chat completion APIs. + + Args: + base_url: API base URL (e.g. "https://api.openai.com/v1"). + api_key: Bearer token for authentication. Optional for local endpoints. + timeout: Request timeout in seconds. + max_retries: Number of retry attempts on transient failures. + """ + + RETRYABLE_STATUS_CODES = {429, 500, 502, 503, 504} + + def __init__( + self, + base_url: str, + api_key: str | None = None, + timeout: float = 120.0, + max_retries: int = 3, + ) -> None: + self.base_url = base_url.rstrip("/") + self.api_key = api_key + self.timeout = timeout + self.max_retries = max_retries + + def _headers(self) -> dict[str, str]: + headers: dict[str, str] = {"Content-Type": "application/json"} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + return headers + + def _client(self) -> httpx.AsyncClient: + return httpx.AsyncClient( + timeout=httpx.Timeout(self.timeout), + headers=self._headers(), + ) + + async def complete( + self, prompt: str, model: str, params: dict[str, Any] + ) -> AdapterResponse: + body = self._build_request_body(prompt, model, params) + stream = params.get("stream", False) + body["stream"] = stream + + if stream: + return await self._complete_stream(body, model) + return await self._complete_non_stream(body, model) + + async def _complete_non_stream( + self, body: dict[str, Any], model: str + ) -> AdapterResponse: + url = f"{self.base_url}/chat/completions" + last_exc: Exception | None = None + + for attempt in range(self.max_retries): + t0 = time.perf_counter() + try: + async with self._client() as client: + resp = await client.post(url, json=body) + + latency_ms = (time.perf_counter() - t0) * 1000 + + if resp.status_code == 200: + data = resp.json() + return self._parse_response(data, latency_ms, model) + + if resp.status_code not in self.RETRYABLE_STATUS_CODES: + resp.raise_for_status() + + last_exc = httpx.HTTPStatusError( + f"HTTP {resp.status_code}", + request=resp.request, + response=resp, + ) + except httpx.HTTPStatusError: + raise + except httpx.HTTPError as exc: + last_exc = exc + + if attempt < self.max_retries - 1: + delay = 2**attempt # 1s, 2s, 4s … + import asyncio + + await asyncio.sleep(delay) + + raise RuntimeError( + f"All {self.max_retries} attempts failed for {url}" + ) from last_exc + + async def _complete_stream( + self, body: dict[str, Any], model: str + ) -> AdapterResponse: + url = f"{self.base_url}/chat/completions" + last_exc: Exception | None = None + + for attempt in range(self.max_retries): + t0 = time.perf_counter() + try: + async with self._client() as client: + async with client.stream("POST", url, json=body) as resp: + if resp.status_code != 200: + await resp.aread() + if resp.status_code not in self.RETRYABLE_STATUS_CODES: + resp.raise_for_status() + last_exc = httpx.HTTPStatusError( + f"HTTP {resp.status_code}", + request=resp.request, + response=resp, + ) + else: + text, usage = await self._consume_stream(resp) + latency_ms = (time.perf_counter() - t0) * 1000 + return AdapterResponse( + text=text, + tokens_in=usage.get("prompt_tokens", 0), + tokens_out=usage.get("completion_tokens", 0), + latency_ms=latency_ms, + model=model, + raw={"streamed": True, "usage": usage}, + ) + except httpx.HTTPStatusError: + raise + except httpx.HTTPError as exc: + last_exc = exc + + if attempt < self.max_retries - 1: + import asyncio + + await asyncio.sleep(2**attempt) + + raise RuntimeError( + f"All {self.max_retries} attempts failed for {url}" + ) from last_exc + + async def _consume_stream( + self, resp: httpx.Response + ) -> tuple[str, dict[str, int]]: + import json as _json + + chunks: list[str] = [] + usage: dict[str, int] = {} + + async for line in resp.aiter_lines(): + if not line.startswith("data: "): + continue + payload = line[6:] + if payload.strip() == "[DONE]": + break + try: + data = _json.loads(payload) + except _json.JSONDecodeError: + continue + + for choice in data.get("choices", []): + delta = choice.get("delta", {}) + if "content" in delta and delta["content"]: + chunks.append(delta["content"]) + + if "usage" in data and data["usage"]: + usage = data["usage"] + + return "".join(chunks), usage + + async def list_models(self) -> list[str]: + url = f"{self.base_url}/models" + async with self._client() as client: + resp = await client.get(url) + resp.raise_for_status() + data = resp.json() + + models: list[str] = [] + for entry in data.get("data", []): + model_id = entry.get("id") + if model_id: + models.append(model_id) + return sorted(models) + + async def test_connection(self) -> bool: + try: + await self.list_models() + return True + except Exception: + return False + + @staticmethod + def _build_request_body( + prompt: str, model: str, params: dict[str, Any] + ) -> dict[str, Any]: + messages: list[dict[str, str]] = [] + + system_message = params.get("system_message") or params.get("system") + if system_message: + messages.append({"role": "system", "content": system_message}) + + messages.append({"role": "user", "content": prompt}) + + body: dict[str, Any] = {"model": model, "messages": messages} + + passthrough_keys = { + "temperature", + "max_tokens", + "top_p", + "frequency_penalty", + "presence_penalty", + "stop", + "seed", + } + for key in passthrough_keys: + if key in params: + body[key] = params[key] + + return body + + @staticmethod + def _parse_response( + data: dict[str, Any], latency_ms: float, model: str + ) -> AdapterResponse: + choices = data.get("choices", []) + text = "" + if choices: + message = choices[0].get("message", {}) + text = message.get("content", "") + + usage = data.get("usage", {}) + tokens_in = usage.get("prompt_tokens", 0) + tokens_out = usage.get("completion_tokens", 0) + + return AdapterResponse( + text=text, + tokens_in=tokens_in, + tokens_out=tokens_out, + latency_ms=latency_ms, + model=data.get("model", model), + raw=data, + ) diff --git a/backend/tests/test_openai_compat.py b/backend/tests/test_openai_compat.py new file mode 100644 index 0000000..3c9b764 --- /dev/null +++ b/backend/tests/test_openai_compat.py @@ -0,0 +1,312 @@ +"""Tests for the OpenAI-compatible adapter.""" + +import json +import pytest +import httpx + +from engine.adapters.base import AdapterResponse, BaseAdapter +from engine.adapters.openai_compat import OpenAICompatAdapter + + +# --------------------------------------------------------------------------- +# Fixtures & helpers +# --------------------------------------------------------------------------- + +FAKE_COMPLETION = { + "id": "chatcmpl-test", + "object": "chat.completion", + "model": "test-model", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "Hello back!"}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 5, "completion_tokens": 3, "total_tokens": 8}, +} + +FAKE_MODELS = { + "object": "list", + "data": [ + {"id": "model-b", "object": "model"}, + {"id": "model-a", "object": "model"}, + ], +} + +FAKE_STREAM_LINES = [ + 'data: {"id":"c1","choices":[{"delta":{"role":"assistant"},"index":0}]}\n', + 'data: {"id":"c1","choices":[{"delta":{"content":"Hi"},"index":0}]}\n', + 'data: {"id":"c1","choices":[{"delta":{"content":" there"},"index":0}]}\n', + 'data: {"id":"c1","choices":[{"delta":{},"index":0,"finish_reason":"stop"}],"usage":{"prompt_tokens":4,"completion_tokens":2}}\n', + "data: [DONE]\n", +] + + +@pytest.fixture +def adapter(): + return OpenAICompatAdapter( + base_url="https://fake.api/v1", + api_key="sk-test-key", + max_retries=1, + ) + + +def _mock_response(status_code=200, json_data=None, text=""): + """Build an httpx.Response suitable for mocking.""" + if json_data is not None: + content = json.dumps(json_data).encode() + else: + content = text.encode() + return httpx.Response( + status_code=status_code, + content=content, + request=httpx.Request("POST", "https://fake.api/v1/chat/completions"), + ) + + +# --------------------------------------------------------------------------- +# Construction & interface +# --------------------------------------------------------------------------- + + +class TestConstruction: + def test_is_base_adapter_subclass(self): + assert issubclass(OpenAICompatAdapter, BaseAdapter) + + def test_base_url_trailing_slash_stripped(self): + a = OpenAICompatAdapter(base_url="https://api.example.com/v1/") + assert a.base_url == "https://api.example.com/v1" + + def test_default_params(self): + a = OpenAICompatAdapter(base_url="https://x") + assert a.api_key is None + assert a.timeout == 120.0 + assert a.max_retries == 3 + + def test_headers_with_api_key(self, adapter): + h = adapter._headers() + assert h["Authorization"] == "Bearer sk-test-key" + assert h["Content-Type"] == "application/json" + + def test_headers_without_api_key(self): + a = OpenAICompatAdapter(base_url="http://localhost:11434/v1") + h = a._headers() + assert "Authorization" not in h + + +# --------------------------------------------------------------------------- +# Request body building +# --------------------------------------------------------------------------- + + +class TestBuildRequestBody: + def test_basic(self): + body = OpenAICompatAdapter._build_request_body( + "Hello", "gpt-4", {} + ) + assert body["model"] == "gpt-4" + assert body["messages"] == [{"role": "user", "content": "Hello"}] + + def test_with_system_message(self): + body = OpenAICompatAdapter._build_request_body( + "Hi", "m", {"system_message": "You are helpful."} + ) + assert body["messages"][0] == { + "role": "system", + "content": "You are helpful.", + } + assert body["messages"][1]["role"] == "user" + + def test_with_system_alias(self): + body = OpenAICompatAdapter._build_request_body( + "Hi", "m", {"system": "Be brief."} + ) + assert body["messages"][0]["content"] == "Be brief." + + def test_passthrough_params(self): + body = OpenAICompatAdapter._build_request_body( + "x", "m", {"temperature": 0.7, "max_tokens": 100, "top_p": 0.9} + ) + assert body["temperature"] == 0.7 + assert body["max_tokens"] == 100 + assert body["top_p"] == 0.9 + + def test_unknown_params_ignored(self): + body = OpenAICompatAdapter._build_request_body( + "x", "m", {"custom_thing": 42, "stream": True} + ) + assert "custom_thing" not in body + # stream is handled separately, not in passthrough + assert "stream" not in body + + +# --------------------------------------------------------------------------- +# Response parsing +# --------------------------------------------------------------------------- + + +class TestParseResponse: + def test_standard_response(self): + resp = OpenAICompatAdapter._parse_response( + FAKE_COMPLETION, latency_ms=55.0, model="gpt-4" + ) + assert isinstance(resp, AdapterResponse) + assert resp.text == "Hello back!" + assert resp.tokens_in == 5 + assert resp.tokens_out == 3 + assert resp.latency_ms == 55.0 + assert resp.model == "test-model" + assert resp.raw == FAKE_COMPLETION + + def test_empty_choices(self): + data = {**FAKE_COMPLETION, "choices": []} + resp = OpenAICompatAdapter._parse_response(data, 10.0, "m") + assert resp.text == "" + + def test_missing_usage(self): + data = dict(FAKE_COMPLETION) + del data["usage"] + resp = OpenAICompatAdapter._parse_response(data, 10.0, "m") + assert resp.tokens_in == 0 + assert resp.tokens_out == 0 + + +# --------------------------------------------------------------------------- +# Non-streaming complete (mocked transport) +# --------------------------------------------------------------------------- + + +class TestCompleteNonStream: + @pytest.mark.asyncio + async def test_success(self, adapter, monkeypatch): + async def fake_post(self_client, url, **kwargs): + return _mock_response(200, FAKE_COMPLETION) + + monkeypatch.setattr(httpx.AsyncClient, "post", fake_post) + + resp = await adapter.complete("Hi", "gpt-4", {}) + assert resp.text == "Hello back!" + assert resp.tokens_in == 5 + + @pytest.mark.asyncio + async def test_non_retryable_error_raises_immediately( + self, adapter, monkeypatch + ): + call_count = 0 + + async def fake_post(self_client, url, **kwargs): + nonlocal call_count + call_count += 1 + return _mock_response(401) + + monkeypatch.setattr(httpx.AsyncClient, "post", fake_post) + + with pytest.raises(httpx.HTTPStatusError): + await adapter.complete("Hi", "m", {}) + assert call_count == 1 + + @pytest.mark.asyncio + async def test_retryable_error_exhausts_retries( + self, monkeypatch + ): + a = OpenAICompatAdapter( + base_url="https://fake.api/v1", + api_key="k", + max_retries=2, + ) + call_count = 0 + + async def fake_post(self_client, url, **kwargs): + nonlocal call_count + call_count += 1 + return _mock_response(503) + + monkeypatch.setattr(httpx.AsyncClient, "post", fake_post) + # Patch sleep to avoid real delays + monkeypatch.setattr("asyncio.sleep", lambda _: __import__("asyncio").coroutine(lambda: None)()) + + import asyncio + + async def fake_sleep(delay): + pass + + monkeypatch.setattr(asyncio, "sleep", fake_sleep) + + with pytest.raises(RuntimeError, match="All 2 attempts failed"): + await a.complete("Hi", "m", {}) + assert call_count == 2 + + @pytest.mark.asyncio + async def test_retry_then_success(self, monkeypatch): + a = OpenAICompatAdapter( + base_url="https://fake.api/v1", + api_key="k", + max_retries=3, + ) + call_count = 0 + + async def fake_post(self_client, url, **kwargs): + nonlocal call_count + call_count += 1 + if call_count < 3: + return _mock_response(429) + return _mock_response(200, FAKE_COMPLETION) + + monkeypatch.setattr(httpx.AsyncClient, "post", fake_post) + + import asyncio + + async def fake_sleep(delay): + pass + + monkeypatch.setattr(asyncio, "sleep", fake_sleep) + + resp = await a.complete("Hi", "m", {}) + assert resp.text == "Hello back!" + assert call_count == 3 + + +# --------------------------------------------------------------------------- +# list_models & test_connection +# --------------------------------------------------------------------------- + + +class TestListModels: + @pytest.mark.asyncio + async def test_success(self, adapter, monkeypatch): + async def fake_get(self_client, url, **kwargs): + return _mock_response(200, FAKE_MODELS) + + monkeypatch.setattr(httpx.AsyncClient, "get", fake_get) + + models = await adapter.list_models() + assert models == ["model-a", "model-b"] # sorted + + @pytest.mark.asyncio + async def test_empty_response(self, adapter, monkeypatch): + async def fake_get(self_client, url, **kwargs): + return _mock_response(200, {"data": []}) + + monkeypatch.setattr(httpx.AsyncClient, "get", fake_get) + + models = await adapter.list_models() + assert models == [] + + +class TestTestConnection: + @pytest.mark.asyncio + async def test_returns_true_on_success(self, adapter, monkeypatch): + async def fake_get(self_client, url, **kwargs): + return _mock_response(200, FAKE_MODELS) + + monkeypatch.setattr(httpx.AsyncClient, "get", fake_get) + assert await adapter.test_connection() is True + + @pytest.mark.asyncio + async def test_returns_false_on_failure(self, adapter, monkeypatch): + async def fake_get(self_client, url, **kwargs): + raise httpx.ConnectError("refused") + + monkeypatch.setattr(httpx.AsyncClient, "get", fake_get) + assert await adapter.test_connection() is False