Created tests/test_engine_core.py with 52 tests covering webhook dispatch engine (sync+async delivery, retries, dispatch), format scorer structure/length edge cases, cache hash determinism with nested/special chars, adapter mock call tracking, grid sweep combo verification, scorer integration with known inputs, and EventBus. Engine coverage improved from 83% to 90%, webhooks.py from 27% to 99%.
674 lines
25 KiB
Python
674 lines
25 KiB
Python
"""Comprehensive tests for the core engine module.
|
||
|
||
Targets gaps not covered by individual module tests:
|
||
- engine/webhooks.py (dispatch, delivery logging, retries)
|
||
- engine/scorers/format.py (structure validation, length edge cases)
|
||
- engine/cache.py (additional edge cases)
|
||
- engine/adapters (adapter mock call patterns)
|
||
- engine/sweep.py (grid combo count verification)
|
||
- Cross-module integration tests
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import math
|
||
import uuid
|
||
from typing import Any
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import httpx
|
||
import pytest
|
||
from sqlalchemy import create_engine
|
||
from sqlalchemy.orm import Session
|
||
|
||
from engine.adapters.base import AdapterResponse, BaseAdapter
|
||
from engine.cache import ResponseCacheLayer, compute_config_hash
|
||
from engine.runner import EventBus, render_prompt, run_single
|
||
from engine.scorers.format import FormatScorer
|
||
from engine.scorers.keyword import KeywordScorer
|
||
from engine.sweep import (
|
||
generate_grid_configs,
|
||
generate_random_configs,
|
||
)
|
||
from engine.webhooks import (
|
||
_log_delivery,
|
||
deliver_webhook,
|
||
deliver_webhook_async,
|
||
dispatch_webhooks,
|
||
dispatch_webhooks_async,
|
||
get_active_webhooks,
|
||
)
|
||
from models import (
|
||
Base,
|
||
Experiment,
|
||
ExperimentStatus,
|
||
Project,
|
||
ResponseCache,
|
||
Run,
|
||
RunStatus,
|
||
Score,
|
||
StageResult,
|
||
User,
|
||
WebhookConfig,
|
||
WebhookDelivery,
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# DB helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _engine():
|
||
engine = create_engine("sqlite:///:memory:")
|
||
Base.metadata.create_all(engine)
|
||
return engine
|
||
|
||
|
||
def _session(engine):
|
||
return Session(engine)
|
||
|
||
|
||
def _make_webhook(db: Session, event_type: str = "run.completed", url: str = "https://example.com/hook", is_active: bool = True, headers: dict | None = None) -> WebhookConfig:
|
||
wh = WebhookConfig(event_type=event_type, url=url, is_active=is_active, headers=headers)
|
||
db.add(wh)
|
||
db.commit()
|
||
db.refresh(wh)
|
||
return wh
|
||
|
||
|
||
class MockAdapter(BaseAdapter):
|
||
def __init__(self, response_text: str = "mock response", tokens_in: int = 10, tokens_out: int = 5):
|
||
self.response_text = response_text
|
||
self.tokens_in = tokens_in
|
||
self.tokens_out = tokens_out
|
||
self.calls: list[tuple[str, str, dict]] = []
|
||
|
||
async def complete(self, prompt: str, model: str, params: dict[str, Any]) -> AdapterResponse:
|
||
self.calls.append((prompt, model, params))
|
||
return AdapterResponse(
|
||
text=self.response_text,
|
||
tokens_in=self.tokens_in,
|
||
tokens_out=self.tokens_out,
|
||
latency_ms=42.0,
|
||
model=model,
|
||
)
|
||
|
||
async def list_models(self) -> list[str]:
|
||
return ["mock-model"]
|
||
|
||
async def test_connection(self) -> bool:
|
||
return True
|
||
|
||
|
||
# ===========================================================================
|
||
# Webhook engine tests — engine/webhooks.py
|
||
# ===========================================================================
|
||
|
||
|
||
class TestGetActiveWebhooks:
|
||
def test_returns_matching_active(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
_make_webhook(db, event_type="run.completed")
|
||
_make_webhook(db, event_type="run.failed")
|
||
_make_webhook(db, event_type="run.completed", is_active=False)
|
||
|
||
results = get_active_webhooks(db, "run.completed")
|
||
assert len(results) == 1
|
||
assert results[0].event_type == "run.completed"
|
||
assert results[0].is_active is True
|
||
|
||
def test_returns_empty_when_no_match(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
_make_webhook(db, event_type="run.failed")
|
||
results = get_active_webhooks(db, "run.completed")
|
||
assert results == []
|
||
|
||
def test_returns_multiple_active(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
_make_webhook(db, event_type="sweep.completed", url="https://a.com")
|
||
_make_webhook(db, event_type="sweep.completed", url="https://b.com")
|
||
results = get_active_webhooks(db, "sweep.completed")
|
||
assert len(results) == 2
|
||
|
||
|
||
class TestLogDelivery:
|
||
def test_creates_delivery_record(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
wh = _make_webhook(db)
|
||
delivery = _log_delivery(
|
||
db,
|
||
webhook_id=wh.id,
|
||
event_type="run.completed",
|
||
payload={"key": "value"},
|
||
status_code=200,
|
||
success=True,
|
||
attempts=1,
|
||
)
|
||
assert delivery.id is not None
|
||
assert delivery.webhook_id == wh.id
|
||
assert delivery.success is True
|
||
assert delivery.attempts == 1
|
||
assert delivery.status_code == 200
|
||
|
||
def test_logs_failure_with_error(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
wh = _make_webhook(db)
|
||
delivery = _log_delivery(
|
||
db,
|
||
webhook_id=wh.id,
|
||
event_type="run.failed",
|
||
payload={},
|
||
status_code=500,
|
||
success=False,
|
||
attempts=3,
|
||
error_message="Server error",
|
||
)
|
||
assert delivery.success is False
|
||
assert delivery.error_message == "Server error"
|
||
assert delivery.attempts == 3
|
||
|
||
|
||
class TestDeliverWebhook:
|
||
def test_successful_delivery(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
wh = _make_webhook(db, url="https://example.com/ok")
|
||
mock_response = httpx.Response(200, request=httpx.Request("POST", "https://example.com/ok"))
|
||
|
||
with patch("engine.webhooks.httpx.Client") as MockClient:
|
||
client_instance = MagicMock()
|
||
client_instance.post.return_value = mock_response
|
||
client_instance.__enter__ = MagicMock(return_value=client_instance)
|
||
client_instance.__exit__ = MagicMock(return_value=False)
|
||
MockClient.return_value = client_instance
|
||
|
||
result = deliver_webhook(db, wh, "run.completed", {"data": 1})
|
||
|
||
assert result is True
|
||
deliveries = db.query(WebhookDelivery).all()
|
||
assert len(deliveries) == 1
|
||
assert deliveries[0].success is True
|
||
|
||
def test_failed_delivery_after_retries(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
wh = _make_webhook(db, url="https://example.com/fail")
|
||
mock_response = httpx.Response(
|
||
500,
|
||
request=httpx.Request("POST", "https://example.com/fail"),
|
||
text="Internal error",
|
||
)
|
||
|
||
with patch("engine.webhooks.httpx.Client") as MockClient, \
|
||
patch("engine.webhooks.time.sleep"):
|
||
client_instance = MagicMock()
|
||
client_instance.post.return_value = mock_response
|
||
client_instance.__enter__ = MagicMock(return_value=client_instance)
|
||
client_instance.__exit__ = MagicMock(return_value=False)
|
||
MockClient.return_value = client_instance
|
||
|
||
result = deliver_webhook(db, wh, "run.completed", {"data": 1})
|
||
|
||
assert result is False
|
||
deliveries = db.query(WebhookDelivery).all()
|
||
assert len(deliveries) == 1
|
||
assert deliveries[0].success is False
|
||
assert deliveries[0].attempts == 3
|
||
|
||
def test_network_error_retries(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
wh = _make_webhook(db, url="https://example.com/timeout")
|
||
|
||
with patch("engine.webhooks.httpx.Client") as MockClient, \
|
||
patch("engine.webhooks.time.sleep"):
|
||
client_instance = MagicMock()
|
||
client_instance.post.side_effect = httpx.ConnectError("Connection refused")
|
||
client_instance.__enter__ = MagicMock(return_value=client_instance)
|
||
client_instance.__exit__ = MagicMock(return_value=False)
|
||
MockClient.return_value = client_instance
|
||
|
||
result = deliver_webhook(db, wh, "run.completed", {})
|
||
|
||
assert result is False
|
||
|
||
def test_custom_headers_included(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
wh = _make_webhook(
|
||
db,
|
||
url="https://example.com/hook",
|
||
headers={"X-Custom": "test-value"},
|
||
)
|
||
mock_response = httpx.Response(200, request=httpx.Request("POST", wh.url))
|
||
|
||
with patch("engine.webhooks.httpx.Client") as MockClient:
|
||
client_instance = MagicMock()
|
||
client_instance.post.return_value = mock_response
|
||
client_instance.__enter__ = MagicMock(return_value=client_instance)
|
||
client_instance.__exit__ = MagicMock(return_value=False)
|
||
MockClient.return_value = client_instance
|
||
|
||
deliver_webhook(db, wh, "run.completed", {})
|
||
|
||
call_args = client_instance.post.call_args
|
||
headers_sent = call_args[1].get("headers") or call_args.kwargs.get("headers")
|
||
assert headers_sent["X-Custom"] == "test-value"
|
||
|
||
|
||
class TestDeliverWebhookAsync:
|
||
@pytest.mark.asyncio
|
||
async def test_successful_async_delivery(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
wh = _make_webhook(db, url="https://example.com/ok")
|
||
mock_response = httpx.Response(200, request=httpx.Request("POST", wh.url))
|
||
|
||
with patch("engine.webhooks.httpx.AsyncClient") as MockAsyncClient:
|
||
client_instance = AsyncMock()
|
||
client_instance.post.return_value = mock_response
|
||
client_instance.__aenter__ = AsyncMock(return_value=client_instance)
|
||
client_instance.__aexit__ = AsyncMock(return_value=False)
|
||
MockAsyncClient.return_value = client_instance
|
||
|
||
result = await deliver_webhook_async(db, wh, "run.completed", {"test": True})
|
||
|
||
assert result is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_failed_async_delivery(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
wh = _make_webhook(db, url="https://example.com/fail")
|
||
mock_response = httpx.Response(
|
||
503,
|
||
request=httpx.Request("POST", wh.url),
|
||
text="Service unavailable",
|
||
)
|
||
|
||
with patch("engine.webhooks.httpx.AsyncClient") as MockAsyncClient, \
|
||
patch("asyncio.sleep", new_callable=AsyncMock):
|
||
client_instance = AsyncMock()
|
||
client_instance.post.return_value = mock_response
|
||
client_instance.__aenter__ = AsyncMock(return_value=client_instance)
|
||
client_instance.__aexit__ = AsyncMock(return_value=False)
|
||
MockAsyncClient.return_value = client_instance
|
||
|
||
result = await deliver_webhook_async(db, wh, "run.completed", {})
|
||
|
||
assert result is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_async_network_error(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
wh = _make_webhook(db, url="https://example.com/err")
|
||
|
||
with patch("engine.webhooks.httpx.AsyncClient") as MockAsyncClient, \
|
||
patch("asyncio.sleep", new_callable=AsyncMock):
|
||
client_instance = AsyncMock()
|
||
client_instance.post.side_effect = httpx.ConnectError("refused")
|
||
client_instance.__aenter__ = AsyncMock(return_value=client_instance)
|
||
client_instance.__aexit__ = AsyncMock(return_value=False)
|
||
MockAsyncClient.return_value = client_instance
|
||
|
||
result = await deliver_webhook_async(db, wh, "run.failed", {})
|
||
|
||
assert result is False
|
||
|
||
|
||
class TestDispatchWebhooks:
|
||
def test_dispatches_to_matching_webhooks(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
_make_webhook(db, event_type="run.completed", url="https://a.com")
|
||
_make_webhook(db, event_type="run.completed", url="https://b.com")
|
||
mock_response = httpx.Response(200, request=httpx.Request("POST", "https://a.com"))
|
||
|
||
with patch("engine.webhooks.httpx.Client") as MockClient:
|
||
client_instance = MagicMock()
|
||
client_instance.post.return_value = mock_response
|
||
client_instance.__enter__ = MagicMock(return_value=client_instance)
|
||
client_instance.__exit__ = MagicMock(return_value=False)
|
||
MockClient.return_value = client_instance
|
||
|
||
count = dispatch_webhooks(db, "run.completed", {"data": 1})
|
||
|
||
assert count == 2
|
||
|
||
def test_returns_zero_when_no_webhooks(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
count = dispatch_webhooks(db, "nonexistent.event", {})
|
||
assert count == 0
|
||
|
||
|
||
class TestDispatchWebhooksAsync:
|
||
@pytest.mark.asyncio
|
||
async def test_dispatches_async(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
_make_webhook(db, event_type="sweep.completed", url="https://hook.com")
|
||
mock_response = httpx.Response(200, request=httpx.Request("POST", "https://hook.com"))
|
||
|
||
with patch("engine.webhooks.httpx.AsyncClient") as MockAsyncClient:
|
||
client_instance = AsyncMock()
|
||
client_instance.post.return_value = mock_response
|
||
client_instance.__aenter__ = AsyncMock(return_value=client_instance)
|
||
client_instance.__aexit__ = AsyncMock(return_value=False)
|
||
MockAsyncClient.return_value = client_instance
|
||
|
||
count = await dispatch_webhooks_async(db, "sweep.completed", {"info": "done"})
|
||
|
||
assert count == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_returns_zero_when_none_match(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
count = await dispatch_webhooks_async(db, "no.match", {})
|
||
assert count == 0
|
||
|
||
|
||
# ===========================================================================
|
||
# Format scorer — structure / length edge cases
|
||
# ===========================================================================
|
||
|
||
|
||
class TestFormatScorerStructure:
|
||
def test_valid_against_schema(self):
|
||
schema = {
|
||
"type": "object",
|
||
"required": ["name", "age"],
|
||
"properties": {
|
||
"name": {"type": "string"},
|
||
"age": {"type": "integer"},
|
||
},
|
||
}
|
||
scorer = FormatScorer(format_type="structure", json_schema=schema)
|
||
assert scorer.score(None, '{"name": "Alice", "age": 30}', {}) == 1.0
|
||
|
||
def test_invalid_against_schema(self):
|
||
schema = {
|
||
"type": "object",
|
||
"required": ["name", "age"],
|
||
}
|
||
scorer = FormatScorer(format_type="structure", json_schema=schema)
|
||
assert scorer.score(None, '{"name": "Alice"}', {}) < 1.0
|
||
|
||
def test_not_json(self):
|
||
schema = {"type": "object"}
|
||
scorer = FormatScorer(format_type="structure", json_schema=schema)
|
||
assert scorer.score(None, "not json at all", {}) == 0.0
|
||
|
||
def test_no_schema_returns_zero(self):
|
||
scorer = FormatScorer(format_type="structure", json_schema=None)
|
||
assert scorer.score(None, '{"anything": true}', {}) == 0.0
|
||
|
||
def test_basic_schema_check_fallback(self):
|
||
"""Test the _basic_schema_check path when jsonschema is unavailable."""
|
||
schema = {
|
||
"type": "object",
|
||
"required": ["a", "b", "c"],
|
||
}
|
||
scorer = FormatScorer(format_type="structure", json_schema=schema)
|
||
# Simulate missing jsonschema library
|
||
with patch.dict("sys.modules", {"jsonschema": None}):
|
||
with patch("builtins.__import__", side_effect=lambda name, *a, **k: (_ for _ in ()).throw(ImportError) if name == "jsonschema" else __builtins__.__import__(name, *a, **k)):
|
||
result = scorer._basic_schema_check({"a": 1, "b": 2}, schema)
|
||
# 2 of 3 required fields → 2/3
|
||
assert result == pytest.approx(2 / 3)
|
||
|
||
def test_basic_schema_check_type_mismatch(self):
|
||
scorer = FormatScorer(format_type="structure")
|
||
result = scorer._basic_schema_check("not a dict", {"type": "object"})
|
||
assert result == 0.0
|
||
|
||
def test_basic_schema_check_array_type(self):
|
||
scorer = FormatScorer(format_type="structure")
|
||
result = scorer._basic_schema_check([1, 2], {"type": "array"})
|
||
assert result == 1.0
|
||
|
||
def test_basic_schema_check_no_type(self):
|
||
scorer = FormatScorer(format_type="structure")
|
||
result = scorer._basic_schema_check("anything", {})
|
||
assert result == 1.0
|
||
|
||
|
||
class TestFormatScorerLengthEdgeCases:
|
||
def test_over_max_score_decays(self):
|
||
scorer = FormatScorer(format_type="length", max_tokens=10)
|
||
# 20 tokens = 2x max, should return 0.0
|
||
output = " ".join(["word"] * 20)
|
||
assert scorer.score(None, output, {}) == 0.0
|
||
|
||
def test_at_max_boundary(self):
|
||
scorer = FormatScorer(format_type="length", max_tokens=5)
|
||
output = " ".join(["word"] * 5)
|
||
assert scorer.score(None, output, {}) == 1.0
|
||
|
||
def test_under_min_proportional(self):
|
||
scorer = FormatScorer(format_type="length", min_tokens=10)
|
||
output = " ".join(["word"] * 5)
|
||
score = scorer.score(None, output, {})
|
||
assert score == pytest.approx(0.5)
|
||
|
||
def test_zero_min_tokens(self):
|
||
scorer = FormatScorer(format_type="length", min_tokens=0, max_tokens=100)
|
||
assert scorer.score(None, "", {}) == 1.0
|
||
|
||
|
||
# ===========================================================================
|
||
# Cache — additional edge cases
|
||
# ===========================================================================
|
||
|
||
|
||
class TestCacheEdgeCases:
|
||
def test_hash_with_nested_params(self):
|
||
h1 = compute_config_hash("p", "m", {"nested": {"a": 1, "b": [2, 3]}})
|
||
h2 = compute_config_hash("p", "m", {"nested": {"b": [2, 3], "a": 1}})
|
||
assert h1 == h2
|
||
|
||
def test_hash_with_special_characters(self):
|
||
h = compute_config_hash("Hello 世界! 🌍", "model-ñ", {"key": "value with \"quotes\""})
|
||
assert len(h) == 64
|
||
|
||
def test_hash_with_numeric_input_data(self):
|
||
h1 = compute_config_hash("p", "m", {}, input_data=42)
|
||
h2 = compute_config_hash("p", "m", {}, input_data=42)
|
||
assert h1 == h2
|
||
|
||
def test_cache_put_with_none_fields(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
cache = ResponseCacheLayer()
|
||
cache.put(db, "a" * 64, response="resp", model="m")
|
||
result = cache.get(db, "a" * 64)
|
||
assert result.tokens_in is None
|
||
assert result.tokens_out is None
|
||
assert result.latency_ms is None
|
||
|
||
def test_independent_hit_tracking(self):
|
||
"""Two cache instances track hits independently."""
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
cache1 = ResponseCacheLayer()
|
||
cache2 = ResponseCacheLayer()
|
||
h = "b" * 64
|
||
cache1.put(db, h, response="r", model="m")
|
||
|
||
cache1.get(db, h) # hit on cache1
|
||
cache2.get(db, h) # hit on cache2
|
||
|
||
stats1 = cache1.cache_stats(db)
|
||
stats2 = cache2.cache_stats(db)
|
||
assert stats1.hit_rate == 1.0
|
||
assert stats2.hit_rate == 1.0
|
||
|
||
|
||
# ===========================================================================
|
||
# Adapter mock call tracking
|
||
# ===========================================================================
|
||
|
||
|
||
class TestAdapterMockCallTracking:
|
||
@pytest.mark.asyncio
|
||
async def test_adapter_records_all_calls(self):
|
||
adapter = MockAdapter()
|
||
await adapter.complete("prompt1", "model-a", {"temp": 0.5})
|
||
await adapter.complete("prompt2", "model-b", {"temp": 0.9})
|
||
|
||
assert len(adapter.calls) == 2
|
||
assert adapter.calls[0] == ("prompt1", "model-a", {"temp": 0.5})
|
||
assert adapter.calls[1] == ("prompt2", "model-b", {"temp": 0.9})
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_adapter_returns_correct_model(self):
|
||
adapter = MockAdapter(response_text="hello")
|
||
resp = await adapter.complete("test", "specific-model", {})
|
||
assert resp.model == "specific-model"
|
||
assert resp.text == "hello"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_adapter_token_counts(self):
|
||
adapter = MockAdapter(tokens_in=100, tokens_out=50)
|
||
resp = await adapter.complete("test", "m", {})
|
||
assert resp.tokens_in == 100
|
||
assert resp.tokens_out == 50
|
||
|
||
|
||
# ===========================================================================
|
||
# Grid sweep — combo count verification
|
||
# ===========================================================================
|
||
|
||
|
||
class TestGridComboVerification:
|
||
def test_single_param_3_values(self):
|
||
configs = generate_grid_configs(
|
||
{"prompt": "test", "params": {}},
|
||
{"params.temperature": [0.1, 0.5, 0.9]},
|
||
)
|
||
assert len(configs) == 3
|
||
|
||
def test_cartesian_product_2x3(self):
|
||
configs = generate_grid_configs(
|
||
{"prompt": "test", "model": "m", "params": {}},
|
||
{
|
||
"model": ["gpt-4", "claude"],
|
||
"params.temperature": [0.1, 0.5, 0.9],
|
||
},
|
||
)
|
||
assert len(configs) == 6 # 2 × 3
|
||
|
||
def test_cartesian_product_2x2x3(self):
|
||
configs = generate_grid_configs(
|
||
{"prompt": "test", "model": "m", "params": {}},
|
||
{
|
||
"model": ["a", "b"],
|
||
"params.temperature": [0.1, 0.9],
|
||
"params.top_p": [0.5, 0.7, 1.0],
|
||
},
|
||
)
|
||
assert len(configs) == 12 # 2 × 2 × 3
|
||
|
||
def test_all_combos_are_unique(self):
|
||
configs = generate_grid_configs(
|
||
{"prompt": "test", "params": {}},
|
||
{
|
||
"params.temperature": [0.1, 0.5, 0.9],
|
||
"params.top_p": [0.5, 1.0],
|
||
},
|
||
)
|
||
serialized = [json.dumps(c, sort_keys=True) for c in configs]
|
||
assert len(set(serialized)) == len(serialized)
|
||
|
||
def test_random_respects_ranges_for_all_samples(self):
|
||
configs = generate_random_configs(
|
||
{"prompt": "test", "params": {}},
|
||
{
|
||
"params.temperature": {"min": 0.0, "max": 1.0},
|
||
"params.top_p": {"min": 0.5, "max": 1.0},
|
||
},
|
||
n_trials=100,
|
||
)
|
||
for c in configs:
|
||
assert 0.0 <= c["params"]["temperature"] <= 1.0
|
||
assert 0.5 <= c["params"]["top_p"] <= 1.0
|
||
|
||
|
||
# ===========================================================================
|
||
# Scorer integration with known inputs
|
||
# ===========================================================================
|
||
|
||
|
||
class TestScorerKnownInputs:
|
||
def test_keyword_all_present(self):
|
||
scorer = KeywordScorer(required_present=["python", "fast", "simple"])
|
||
score = scorer.score(None, "Python is fast and simple to learn", {})
|
||
assert score == 1.0
|
||
|
||
def test_keyword_partial_match(self):
|
||
scorer = KeywordScorer(required_present=["python", "fast", "simple"])
|
||
score = scorer.score(None, "Python is a great language", {})
|
||
assert score == pytest.approx(1 / 3)
|
||
|
||
def test_keyword_absent_violation(self):
|
||
scorer = KeywordScorer(required_absent=["error", "bug"])
|
||
score = scorer.score(None, "There was an error in the code", {})
|
||
assert score < 1.0
|
||
|
||
def test_keyword_absent_clean(self):
|
||
scorer = KeywordScorer(required_absent=["error", "bug"])
|
||
score = scorer.score(None, "The code works perfectly", {})
|
||
assert score == 1.0
|
||
|
||
def test_format_json_valid(self):
|
||
scorer = FormatScorer(format_type="json")
|
||
assert scorer.score(None, '{"key": "value", "number": 42}', {}) == 1.0
|
||
|
||
def test_format_json_invalid(self):
|
||
scorer = FormatScorer(format_type="json")
|
||
assert scorer.score(None, "not json", {}) == 0.0
|
||
|
||
def test_format_markdown_headers_only(self):
|
||
scorer = FormatScorer(format_type="markdown")
|
||
assert scorer.score(None, "# Title\nSome text", {}) == 0.5
|
||
|
||
def test_format_markdown_full(self):
|
||
scorer = FormatScorer(format_type="markdown")
|
||
text = "# Title\n\n- Item 1\n- Item 2\n"
|
||
assert scorer.score(None, text, {}) == 1.0
|
||
|
||
|
||
# ===========================================================================
|
||
# EventBus tests
|
||
# ===========================================================================
|
||
|
||
|
||
class TestEventBus:
|
||
def test_publish_and_listen(self):
|
||
bus = EventBus()
|
||
events = []
|
||
bus.add_listener(lambda e: events.append(e))
|
||
bus.publish({"type": "test", "data": 42})
|
||
assert len(events) == 1
|
||
assert events[0]["data"] == 42
|
||
|
||
def test_multiple_listeners(self):
|
||
bus = EventBus()
|
||
events1, events2 = [], []
|
||
bus.add_listener(lambda e: events1.append(e))
|
||
bus.add_listener(lambda e: events2.append(e))
|
||
bus.publish({"type": "test"})
|
||
assert len(events1) == 1
|
||
assert len(events2) == 1
|
||
|
||
def test_no_listeners_no_error(self):
|
||
bus = EventBus()
|
||
bus.publish({"type": "test"}) # Should not raise
|