diff --git a/Auto Run Docs/02a-backend-engine.md b/Auto Run Docs/02a-backend-engine.md index 788df79..eff79bc 100644 --- a/Auto Run Docs/02a-backend-engine.md +++ b/Auto Run Docs/02a-backend-engine.md @@ -50,4 +50,5 @@ Implement the core experiment execution engine: LLM adapters, response caching, - [x] Implement backend/routers/webhooks.py — CRUD for webhook configs. When events occur (in runner.py and sweep.py), dispatch webhook calls asynchronously via Celery. Include retry logic (3 attempts) and log delivery status. -- [ ] Write tests for the core engine: test cache hash determinism, test adapter mock calls, test scorer implementations with known inputs, test sweep configuration generation (grid should produce correct number of combos, random should respect ranges). Aim for >80% coverage on engine/ directory. +- [x] Write tests for the core engine: test cache hash determinism, test adapter mock calls, test scorer implementations with known inputs, test sweep configuration generation (grid should produce correct number of combos, random should respect ranges). Aim for >80% coverage on engine/ directory. + diff --git a/backend/tests/test_engine_core.py b/backend/tests/test_engine_core.py new file mode 100644 index 0000000..352b33a --- /dev/null +++ b/backend/tests/test_engine_core.py @@ -0,0 +1,674 @@ +"""Comprehensive tests for the core engine module. + +Targets gaps not covered by individual module tests: +- engine/webhooks.py (dispatch, delivery logging, retries) +- engine/scorers/format.py (structure validation, length edge cases) +- engine/cache.py (additional edge cases) +- engine/adapters (adapter mock call patterns) +- engine/sweep.py (grid combo count verification) +- Cross-module integration tests +""" + +import asyncio +import json +import math +import uuid +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from engine.adapters.base import AdapterResponse, BaseAdapter +from engine.cache import ResponseCacheLayer, compute_config_hash +from engine.runner import EventBus, render_prompt, run_single +from engine.scorers.format import FormatScorer +from engine.scorers.keyword import KeywordScorer +from engine.sweep import ( + generate_grid_configs, + generate_random_configs, +) +from engine.webhooks import ( + _log_delivery, + deliver_webhook, + deliver_webhook_async, + dispatch_webhooks, + dispatch_webhooks_async, + get_active_webhooks, +) +from models import ( + Base, + Experiment, + ExperimentStatus, + Project, + ResponseCache, + Run, + RunStatus, + Score, + StageResult, + User, + WebhookConfig, + WebhookDelivery, +) + + +# --------------------------------------------------------------------------- +# DB helpers +# --------------------------------------------------------------------------- + + +def _engine(): + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + return engine + + +def _session(engine): + return Session(engine) + + +def _make_webhook(db: Session, event_type: str = "run.completed", url: str = "https://example.com/hook", is_active: bool = True, headers: dict | None = None) -> WebhookConfig: + wh = WebhookConfig(event_type=event_type, url=url, is_active=is_active, headers=headers) + db.add(wh) + db.commit() + db.refresh(wh) + return wh + + +class MockAdapter(BaseAdapter): + def __init__(self, response_text: str = "mock response", tokens_in: int = 10, tokens_out: int = 5): + self.response_text = response_text + self.tokens_in = tokens_in + self.tokens_out = tokens_out + self.calls: list[tuple[str, str, dict]] = [] + + async def complete(self, prompt: str, model: str, params: dict[str, Any]) -> AdapterResponse: + self.calls.append((prompt, model, params)) + return AdapterResponse( + text=self.response_text, + tokens_in=self.tokens_in, + tokens_out=self.tokens_out, + latency_ms=42.0, + model=model, + ) + + async def list_models(self) -> list[str]: + return ["mock-model"] + + async def test_connection(self) -> bool: + return True + + +# =========================================================================== +# Webhook engine tests — engine/webhooks.py +# =========================================================================== + + +class TestGetActiveWebhooks: + def test_returns_matching_active(self): + engine = _engine() + with _session(engine) as db: + _make_webhook(db, event_type="run.completed") + _make_webhook(db, event_type="run.failed") + _make_webhook(db, event_type="run.completed", is_active=False) + + results = get_active_webhooks(db, "run.completed") + assert len(results) == 1 + assert results[0].event_type == "run.completed" + assert results[0].is_active is True + + def test_returns_empty_when_no_match(self): + engine = _engine() + with _session(engine) as db: + _make_webhook(db, event_type="run.failed") + results = get_active_webhooks(db, "run.completed") + assert results == [] + + def test_returns_multiple_active(self): + engine = _engine() + with _session(engine) as db: + _make_webhook(db, event_type="sweep.completed", url="https://a.com") + _make_webhook(db, event_type="sweep.completed", url="https://b.com") + results = get_active_webhooks(db, "sweep.completed") + assert len(results) == 2 + + +class TestLogDelivery: + def test_creates_delivery_record(self): + engine = _engine() + with _session(engine) as db: + wh = _make_webhook(db) + delivery = _log_delivery( + db, + webhook_id=wh.id, + event_type="run.completed", + payload={"key": "value"}, + status_code=200, + success=True, + attempts=1, + ) + assert delivery.id is not None + assert delivery.webhook_id == wh.id + assert delivery.success is True + assert delivery.attempts == 1 + assert delivery.status_code == 200 + + def test_logs_failure_with_error(self): + engine = _engine() + with _session(engine) as db: + wh = _make_webhook(db) + delivery = _log_delivery( + db, + webhook_id=wh.id, + event_type="run.failed", + payload={}, + status_code=500, + success=False, + attempts=3, + error_message="Server error", + ) + assert delivery.success is False + assert delivery.error_message == "Server error" + assert delivery.attempts == 3 + + +class TestDeliverWebhook: + def test_successful_delivery(self): + engine = _engine() + with _session(engine) as db: + wh = _make_webhook(db, url="https://example.com/ok") + mock_response = httpx.Response(200, request=httpx.Request("POST", "https://example.com/ok")) + + with patch("engine.webhooks.httpx.Client") as MockClient: + client_instance = MagicMock() + client_instance.post.return_value = mock_response + client_instance.__enter__ = MagicMock(return_value=client_instance) + client_instance.__exit__ = MagicMock(return_value=False) + MockClient.return_value = client_instance + + result = deliver_webhook(db, wh, "run.completed", {"data": 1}) + + assert result is True + deliveries = db.query(WebhookDelivery).all() + assert len(deliveries) == 1 + assert deliveries[0].success is True + + def test_failed_delivery_after_retries(self): + engine = _engine() + with _session(engine) as db: + wh = _make_webhook(db, url="https://example.com/fail") + mock_response = httpx.Response( + 500, + request=httpx.Request("POST", "https://example.com/fail"), + text="Internal error", + ) + + with patch("engine.webhooks.httpx.Client") as MockClient, \ + patch("engine.webhooks.time.sleep"): + client_instance = MagicMock() + client_instance.post.return_value = mock_response + client_instance.__enter__ = MagicMock(return_value=client_instance) + client_instance.__exit__ = MagicMock(return_value=False) + MockClient.return_value = client_instance + + result = deliver_webhook(db, wh, "run.completed", {"data": 1}) + + assert result is False + deliveries = db.query(WebhookDelivery).all() + assert len(deliveries) == 1 + assert deliveries[0].success is False + assert deliveries[0].attempts == 3 + + def test_network_error_retries(self): + engine = _engine() + with _session(engine) as db: + wh = _make_webhook(db, url="https://example.com/timeout") + + with patch("engine.webhooks.httpx.Client") as MockClient, \ + patch("engine.webhooks.time.sleep"): + client_instance = MagicMock() + client_instance.post.side_effect = httpx.ConnectError("Connection refused") + client_instance.__enter__ = MagicMock(return_value=client_instance) + client_instance.__exit__ = MagicMock(return_value=False) + MockClient.return_value = client_instance + + result = deliver_webhook(db, wh, "run.completed", {}) + + assert result is False + + def test_custom_headers_included(self): + engine = _engine() + with _session(engine) as db: + wh = _make_webhook( + db, + url="https://example.com/hook", + headers={"X-Custom": "test-value"}, + ) + mock_response = httpx.Response(200, request=httpx.Request("POST", wh.url)) + + with patch("engine.webhooks.httpx.Client") as MockClient: + client_instance = MagicMock() + client_instance.post.return_value = mock_response + client_instance.__enter__ = MagicMock(return_value=client_instance) + client_instance.__exit__ = MagicMock(return_value=False) + MockClient.return_value = client_instance + + deliver_webhook(db, wh, "run.completed", {}) + + call_args = client_instance.post.call_args + headers_sent = call_args[1].get("headers") or call_args.kwargs.get("headers") + assert headers_sent["X-Custom"] == "test-value" + + +class TestDeliverWebhookAsync: + @pytest.mark.asyncio + async def test_successful_async_delivery(self): + engine = _engine() + with _session(engine) as db: + wh = _make_webhook(db, url="https://example.com/ok") + mock_response = httpx.Response(200, request=httpx.Request("POST", wh.url)) + + with patch("engine.webhooks.httpx.AsyncClient") as MockAsyncClient: + client_instance = AsyncMock() + client_instance.post.return_value = mock_response + client_instance.__aenter__ = AsyncMock(return_value=client_instance) + client_instance.__aexit__ = AsyncMock(return_value=False) + MockAsyncClient.return_value = client_instance + + result = await deliver_webhook_async(db, wh, "run.completed", {"test": True}) + + assert result is True + + @pytest.mark.asyncio + async def test_failed_async_delivery(self): + engine = _engine() + with _session(engine) as db: + wh = _make_webhook(db, url="https://example.com/fail") + mock_response = httpx.Response( + 503, + request=httpx.Request("POST", wh.url), + text="Service unavailable", + ) + + with patch("engine.webhooks.httpx.AsyncClient") as MockAsyncClient, \ + patch("asyncio.sleep", new_callable=AsyncMock): + client_instance = AsyncMock() + client_instance.post.return_value = mock_response + client_instance.__aenter__ = AsyncMock(return_value=client_instance) + client_instance.__aexit__ = AsyncMock(return_value=False) + MockAsyncClient.return_value = client_instance + + result = await deliver_webhook_async(db, wh, "run.completed", {}) + + assert result is False + + @pytest.mark.asyncio + async def test_async_network_error(self): + engine = _engine() + with _session(engine) as db: + wh = _make_webhook(db, url="https://example.com/err") + + with patch("engine.webhooks.httpx.AsyncClient") as MockAsyncClient, \ + patch("asyncio.sleep", new_callable=AsyncMock): + client_instance = AsyncMock() + client_instance.post.side_effect = httpx.ConnectError("refused") + client_instance.__aenter__ = AsyncMock(return_value=client_instance) + client_instance.__aexit__ = AsyncMock(return_value=False) + MockAsyncClient.return_value = client_instance + + result = await deliver_webhook_async(db, wh, "run.failed", {}) + + assert result is False + + +class TestDispatchWebhooks: + def test_dispatches_to_matching_webhooks(self): + engine = _engine() + with _session(engine) as db: + _make_webhook(db, event_type="run.completed", url="https://a.com") + _make_webhook(db, event_type="run.completed", url="https://b.com") + mock_response = httpx.Response(200, request=httpx.Request("POST", "https://a.com")) + + with patch("engine.webhooks.httpx.Client") as MockClient: + client_instance = MagicMock() + client_instance.post.return_value = mock_response + client_instance.__enter__ = MagicMock(return_value=client_instance) + client_instance.__exit__ = MagicMock(return_value=False) + MockClient.return_value = client_instance + + count = dispatch_webhooks(db, "run.completed", {"data": 1}) + + assert count == 2 + + def test_returns_zero_when_no_webhooks(self): + engine = _engine() + with _session(engine) as db: + count = dispatch_webhooks(db, "nonexistent.event", {}) + assert count == 0 + + +class TestDispatchWebhooksAsync: + @pytest.mark.asyncio + async def test_dispatches_async(self): + engine = _engine() + with _session(engine) as db: + _make_webhook(db, event_type="sweep.completed", url="https://hook.com") + mock_response = httpx.Response(200, request=httpx.Request("POST", "https://hook.com")) + + with patch("engine.webhooks.httpx.AsyncClient") as MockAsyncClient: + client_instance = AsyncMock() + client_instance.post.return_value = mock_response + client_instance.__aenter__ = AsyncMock(return_value=client_instance) + client_instance.__aexit__ = AsyncMock(return_value=False) + MockAsyncClient.return_value = client_instance + + count = await dispatch_webhooks_async(db, "sweep.completed", {"info": "done"}) + + assert count == 1 + + @pytest.mark.asyncio + async def test_returns_zero_when_none_match(self): + engine = _engine() + with _session(engine) as db: + count = await dispatch_webhooks_async(db, "no.match", {}) + assert count == 0 + + +# =========================================================================== +# Format scorer — structure / length edge cases +# =========================================================================== + + +class TestFormatScorerStructure: + def test_valid_against_schema(self): + schema = { + "type": "object", + "required": ["name", "age"], + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + } + scorer = FormatScorer(format_type="structure", json_schema=schema) + assert scorer.score(None, '{"name": "Alice", "age": 30}', {}) == 1.0 + + def test_invalid_against_schema(self): + schema = { + "type": "object", + "required": ["name", "age"], + } + scorer = FormatScorer(format_type="structure", json_schema=schema) + assert scorer.score(None, '{"name": "Alice"}', {}) < 1.0 + + def test_not_json(self): + schema = {"type": "object"} + scorer = FormatScorer(format_type="structure", json_schema=schema) + assert scorer.score(None, "not json at all", {}) == 0.0 + + def test_no_schema_returns_zero(self): + scorer = FormatScorer(format_type="structure", json_schema=None) + assert scorer.score(None, '{"anything": true}', {}) == 0.0 + + def test_basic_schema_check_fallback(self): + """Test the _basic_schema_check path when jsonschema is unavailable.""" + schema = { + "type": "object", + "required": ["a", "b", "c"], + } + scorer = FormatScorer(format_type="structure", json_schema=schema) + # Simulate missing jsonschema library + with patch.dict("sys.modules", {"jsonschema": None}): + with patch("builtins.__import__", side_effect=lambda name, *a, **k: (_ for _ in ()).throw(ImportError) if name == "jsonschema" else __builtins__.__import__(name, *a, **k)): + result = scorer._basic_schema_check({"a": 1, "b": 2}, schema) + # 2 of 3 required fields → 2/3 + assert result == pytest.approx(2 / 3) + + def test_basic_schema_check_type_mismatch(self): + scorer = FormatScorer(format_type="structure") + result = scorer._basic_schema_check("not a dict", {"type": "object"}) + assert result == 0.0 + + def test_basic_schema_check_array_type(self): + scorer = FormatScorer(format_type="structure") + result = scorer._basic_schema_check([1, 2], {"type": "array"}) + assert result == 1.0 + + def test_basic_schema_check_no_type(self): + scorer = FormatScorer(format_type="structure") + result = scorer._basic_schema_check("anything", {}) + assert result == 1.0 + + +class TestFormatScorerLengthEdgeCases: + def test_over_max_score_decays(self): + scorer = FormatScorer(format_type="length", max_tokens=10) + # 20 tokens = 2x max, should return 0.0 + output = " ".join(["word"] * 20) + assert scorer.score(None, output, {}) == 0.0 + + def test_at_max_boundary(self): + scorer = FormatScorer(format_type="length", max_tokens=5) + output = " ".join(["word"] * 5) + assert scorer.score(None, output, {}) == 1.0 + + def test_under_min_proportional(self): + scorer = FormatScorer(format_type="length", min_tokens=10) + output = " ".join(["word"] * 5) + score = scorer.score(None, output, {}) + assert score == pytest.approx(0.5) + + def test_zero_min_tokens(self): + scorer = FormatScorer(format_type="length", min_tokens=0, max_tokens=100) + assert scorer.score(None, "", {}) == 1.0 + + +# =========================================================================== +# Cache — additional edge cases +# =========================================================================== + + +class TestCacheEdgeCases: + def test_hash_with_nested_params(self): + h1 = compute_config_hash("p", "m", {"nested": {"a": 1, "b": [2, 3]}}) + h2 = compute_config_hash("p", "m", {"nested": {"b": [2, 3], "a": 1}}) + assert h1 == h2 + + def test_hash_with_special_characters(self): + h = compute_config_hash("Hello 世界! 🌍", "model-ñ", {"key": "value with \"quotes\""}) + assert len(h) == 64 + + def test_hash_with_numeric_input_data(self): + h1 = compute_config_hash("p", "m", {}, input_data=42) + h2 = compute_config_hash("p", "m", {}, input_data=42) + assert h1 == h2 + + def test_cache_put_with_none_fields(self): + engine = _engine() + with _session(engine) as db: + cache = ResponseCacheLayer() + cache.put(db, "a" * 64, response="resp", model="m") + result = cache.get(db, "a" * 64) + assert result.tokens_in is None + assert result.tokens_out is None + assert result.latency_ms is None + + def test_independent_hit_tracking(self): + """Two cache instances track hits independently.""" + engine = _engine() + with _session(engine) as db: + cache1 = ResponseCacheLayer() + cache2 = ResponseCacheLayer() + h = "b" * 64 + cache1.put(db, h, response="r", model="m") + + cache1.get(db, h) # hit on cache1 + cache2.get(db, h) # hit on cache2 + + stats1 = cache1.cache_stats(db) + stats2 = cache2.cache_stats(db) + assert stats1.hit_rate == 1.0 + assert stats2.hit_rate == 1.0 + + +# =========================================================================== +# Adapter mock call tracking +# =========================================================================== + + +class TestAdapterMockCallTracking: + @pytest.mark.asyncio + async def test_adapter_records_all_calls(self): + adapter = MockAdapter() + await adapter.complete("prompt1", "model-a", {"temp": 0.5}) + await adapter.complete("prompt2", "model-b", {"temp": 0.9}) + + assert len(adapter.calls) == 2 + assert adapter.calls[0] == ("prompt1", "model-a", {"temp": 0.5}) + assert adapter.calls[1] == ("prompt2", "model-b", {"temp": 0.9}) + + @pytest.mark.asyncio + async def test_adapter_returns_correct_model(self): + adapter = MockAdapter(response_text="hello") + resp = await adapter.complete("test", "specific-model", {}) + assert resp.model == "specific-model" + assert resp.text == "hello" + + @pytest.mark.asyncio + async def test_adapter_token_counts(self): + adapter = MockAdapter(tokens_in=100, tokens_out=50) + resp = await adapter.complete("test", "m", {}) + assert resp.tokens_in == 100 + assert resp.tokens_out == 50 + + +# =========================================================================== +# Grid sweep — combo count verification +# =========================================================================== + + +class TestGridComboVerification: + def test_single_param_3_values(self): + configs = generate_grid_configs( + {"prompt": "test", "params": {}}, + {"params.temperature": [0.1, 0.5, 0.9]}, + ) + assert len(configs) == 3 + + def test_cartesian_product_2x3(self): + configs = generate_grid_configs( + {"prompt": "test", "model": "m", "params": {}}, + { + "model": ["gpt-4", "claude"], + "params.temperature": [0.1, 0.5, 0.9], + }, + ) + assert len(configs) == 6 # 2 × 3 + + def test_cartesian_product_2x2x3(self): + configs = generate_grid_configs( + {"prompt": "test", "model": "m", "params": {}}, + { + "model": ["a", "b"], + "params.temperature": [0.1, 0.9], + "params.top_p": [0.5, 0.7, 1.0], + }, + ) + assert len(configs) == 12 # 2 × 2 × 3 + + def test_all_combos_are_unique(self): + configs = generate_grid_configs( + {"prompt": "test", "params": {}}, + { + "params.temperature": [0.1, 0.5, 0.9], + "params.top_p": [0.5, 1.0], + }, + ) + serialized = [json.dumps(c, sort_keys=True) for c in configs] + assert len(set(serialized)) == len(serialized) + + def test_random_respects_ranges_for_all_samples(self): + configs = generate_random_configs( + {"prompt": "test", "params": {}}, + { + "params.temperature": {"min": 0.0, "max": 1.0}, + "params.top_p": {"min": 0.5, "max": 1.0}, + }, + n_trials=100, + ) + for c in configs: + assert 0.0 <= c["params"]["temperature"] <= 1.0 + assert 0.5 <= c["params"]["top_p"] <= 1.0 + + +# =========================================================================== +# Scorer integration with known inputs +# =========================================================================== + + +class TestScorerKnownInputs: + def test_keyword_all_present(self): + scorer = KeywordScorer(required_present=["python", "fast", "simple"]) + score = scorer.score(None, "Python is fast and simple to learn", {}) + assert score == 1.0 + + def test_keyword_partial_match(self): + scorer = KeywordScorer(required_present=["python", "fast", "simple"]) + score = scorer.score(None, "Python is a great language", {}) + assert score == pytest.approx(1 / 3) + + def test_keyword_absent_violation(self): + scorer = KeywordScorer(required_absent=["error", "bug"]) + score = scorer.score(None, "There was an error in the code", {}) + assert score < 1.0 + + def test_keyword_absent_clean(self): + scorer = KeywordScorer(required_absent=["error", "bug"]) + score = scorer.score(None, "The code works perfectly", {}) + assert score == 1.0 + + def test_format_json_valid(self): + scorer = FormatScorer(format_type="json") + assert scorer.score(None, '{"key": "value", "number": 42}', {}) == 1.0 + + def test_format_json_invalid(self): + scorer = FormatScorer(format_type="json") + assert scorer.score(None, "not json", {}) == 0.0 + + def test_format_markdown_headers_only(self): + scorer = FormatScorer(format_type="markdown") + assert scorer.score(None, "# Title\nSome text", {}) == 0.5 + + def test_format_markdown_full(self): + scorer = FormatScorer(format_type="markdown") + text = "# Title\n\n- Item 1\n- Item 2\n" + assert scorer.score(None, text, {}) == 1.0 + + +# =========================================================================== +# EventBus tests +# =========================================================================== + + +class TestEventBus: + def test_publish_and_listen(self): + bus = EventBus() + events = [] + bus.add_listener(lambda e: events.append(e)) + bus.publish({"type": "test", "data": 42}) + assert len(events) == 1 + assert events[0]["data"] == 42 + + def test_multiple_listeners(self): + bus = EventBus() + events1, events2 = [], [] + bus.add_listener(lambda e: events1.append(e)) + bus.add_listener(lambda e: events2.append(e)) + bus.publish({"type": "test"}) + assert len(events1) == 1 + assert len(events2) == 1 + + def test_no_listeners_no_error(self): + bus = EventBus() + bus.publish({"type": "test"}) # Should not raise