"""Tests for the run execution engine.""" import asyncio import uuid from datetime import datetime, timezone from typing import Any from unittest.mock import AsyncMock, MagicMock import pytest from sqlalchemy import create_engine from sqlalchemy.orm import Session from engine.adapters.base import AdapterResponse, BaseAdapter from engine.cache import ResponseCacheLayer, compute_config_hash from engine.runner import EventBus, render_prompt, run_single from models import Base, Run, RunStatus, Score, StageResult # --------------------------------------------------------------------------- # Fixtures / helpers # --------------------------------------------------------------------------- def _engine(): engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) return engine def _session(engine): return Session(engine) def _make_run(db: Session, config: dict, experiment_id: uuid.UUID | None = None) -> Run: """Create and persist a Run for testing.""" exp_id = experiment_id or uuid.uuid4() config_hash = compute_config_hash( config.get("prompt", ""), config.get("model", ""), config.get("params", {}) ) run = Run( experiment_id=exp_id, config_hash=config_hash, config=config, status=RunStatus.pending, ) db.add(run) db.commit() db.refresh(run) return run class MockAdapter(BaseAdapter): """Test adapter that returns a fixed response.""" def __init__(self, response_text: str = "mock response", tokens_in: int = 10, tokens_out: int = 5): self.response_text = response_text self.tokens_in = tokens_in self.tokens_out = tokens_out self.calls: list[tuple[str, str, dict]] = [] async def complete(self, prompt: str, model: str, params: dict[str, Any]) -> AdapterResponse: self.calls.append((prompt, model, params)) return AdapterResponse( text=self.response_text, tokens_in=self.tokens_in, tokens_out=self.tokens_out, latency_ms=42.0, model=model, ) async def list_models(self) -> list[str]: return ["mock-model"] async def test_connection(self) -> bool: return True class MockScorer: """Test scorer that returns a fixed value.""" def __init__(self, name: str = "mock_scorer", value: float = 0.85): self.name = name self._value = value def score(self, input_data: Any, output: str, context: dict) -> float: return self._value class AsyncMockScorer: """Test async scorer.""" def __init__(self, name: str = "async_scorer", value: float = 0.9): self.name = name self._value = value async def score(self, input_data: Any, output: str, context: dict) -> float: return self._value # --------------------------------------------------------------------------- # render_prompt tests # --------------------------------------------------------------------------- class TestRenderPrompt: def test_simple_variable(self): result = render_prompt("Hello {{ name }}!", {"name": "world"}) assert result == "Hello world!" def test_prev_output(self): result = render_prompt( "Summarize: {{ prev_output }}", {"prev_output": "The quick brown fox"}, ) assert result == "Summarize: The quick brown fox" def test_no_variables(self): result = render_prompt("Static prompt", {}) assert result == "Static prompt" def test_nested_context(self): result = render_prompt( "Model: {{ config.model }}", {"config": {"model": "gpt-4"}}, ) assert result == "Model: gpt-4" def test_conditional(self): result = render_prompt( "{% if input %}Input: {{ input }}{% else %}No input{% endif %}", {"input": "test data"}, ) assert result == "Input: test data" # --------------------------------------------------------------------------- # run_single tests # --------------------------------------------------------------------------- class TestRunSingle: def test_single_stage_run(self): engine = _engine() with _session(engine) as db: config = { "prompt": "Tell me a joke", "model": "test-model", "params": {"temperature": 0.7}, } run = _make_run(db, config) adapter = MockAdapter(response_text="Why did the chicken...") cache = ResponseCacheLayer() result = asyncio.get_event_loop().run_until_complete( run_single(db, run, adapter, cache) ) assert result.status == RunStatus.completed assert result.tokens_in == 10 assert result.tokens_out == 5 assert result.duration_ms is not None assert result.started_at is not None assert result.completed_at is not None # Verify stage result was created stages = db.query(StageResult).filter_by(run_id=run.id).all() assert len(stages) == 1 assert stages[0].response_raw == "Why did the chicken..." assert stages[0].model_used == "test-model" assert stages[0].stage_index == 0 def test_multi_stage_pipeline(self): engine = _engine() with _session(engine) as db: config = { "pipeline_stages": [ { "prompt_template": "Generate a question about {{ input }}", "model": "model-a", "params": {}, }, { "prompt_template": "Answer this: {{ prev_output }}", "model": "model-b", "params": {}, }, ], "input_data": "Python programming", } run = _make_run(db, config) call_count = 0 class StagedAdapter(BaseAdapter): async def complete(self, prompt, model, params): nonlocal call_count call_count += 1 return AdapterResponse( text=f"stage-{call_count}-output", tokens_in=5, tokens_out=3, latency_ms=20.0, model=model, ) async def list_models(self): return [] async def test_connection(self): return True adapter = StagedAdapter() cache = ResponseCacheLayer() result = asyncio.get_event_loop().run_until_complete( run_single(db, run, adapter, cache) ) assert result.status == RunStatus.completed assert result.tokens_in == 10 # 5 + 5 assert result.tokens_out == 6 # 3 + 3 stages = ( db.query(StageResult) .filter_by(run_id=run.id) .order_by(StageResult.stage_index) .all() ) assert len(stages) == 2 assert stages[0].model_used == "model-a" assert stages[1].model_used == "model-b" # Second stage prompt should contain first stage output assert "stage-1-output" in stages[1].prompt_sent def test_cache_hit_skips_adapter(self): engine = _engine() with _session(engine) as db: config = { "prompt": "cached prompt", "model": "test-model", "params": {}, } run = _make_run(db, config) adapter = MockAdapter() cache = ResponseCacheLayer() # Pre-populate cache rendered = render_prompt("cached prompt", { "input": "", "config": config, "stages": [], "prev_output": "", "stage_index": 0, }) h = compute_config_hash(rendered, "test-model", {}, None) cache.put(db, h, response="cached!", model="test-model", tokens_in=1, tokens_out=1) result = asyncio.get_event_loop().run_until_complete( run_single(db, run, adapter, cache) ) assert result.status == RunStatus.completed # Adapter should NOT have been called assert len(adapter.calls) == 0 stages = db.query(StageResult).filter_by(run_id=run.id).all() assert stages[0].response_raw == "cached!" def test_scorer_creates_score_records(self): engine = _engine() with _session(engine) as db: config = {"prompt": "test", "model": "m", "params": {}} run = _make_run(db, config) adapter = MockAdapter() cache = ResponseCacheLayer() scorers = [MockScorer("quality", 0.9), MockScorer("relevance", 0.7)] asyncio.get_event_loop().run_until_complete( run_single(db, run, adapter, cache, scorers=scorers) ) scores = db.query(Score).filter_by(run_id=run.id).all() assert len(scores) == 2 score_map = {s.scorer_name: s.value for s in scores} assert score_map["quality"] == pytest.approx(0.9) assert score_map["relevance"] == pytest.approx(0.7) def test_async_scorer(self): engine = _engine() with _session(engine) as db: config = {"prompt": "test", "model": "m", "params": {}} run = _make_run(db, config) adapter = MockAdapter() cache = ResponseCacheLayer() scorers = [AsyncMockScorer("async_quality", 0.95)] asyncio.get_event_loop().run_until_complete( run_single(db, run, adapter, cache, scorers=scorers) ) scores = db.query(Score).filter_by(run_id=run.id).all() assert len(scores) == 1 assert scores[0].value == pytest.approx(0.95) assert scores[0].scorer_name == "async_quality" def test_score_clamped_to_range(self): engine = _engine() with _session(engine) as db: config = {"prompt": "test", "model": "m", "params": {}} run = _make_run(db, config) adapter = MockAdapter() cache = ResponseCacheLayer() scorers = [MockScorer("over", 1.5), MockScorer("under", -0.3)] asyncio.get_event_loop().run_until_complete( run_single(db, run, adapter, cache, scorers=scorers) ) scores = db.query(Score).filter_by(run_id=run.id).all() score_map = {s.scorer_name: s.value for s in scores} assert score_map["over"] == pytest.approx(1.0) assert score_map["under"] == pytest.approx(0.0) def test_failed_scorer_does_not_crash_run(self): engine = _engine() with _session(engine) as db: config = {"prompt": "test", "model": "m", "params": {}} run = _make_run(db, config) adapter = MockAdapter() cache = ResponseCacheLayer() class BrokenScorer: name = "broken" def score(self, input_data, output, context): raise ValueError("scorer exploded") scorers = [BrokenScorer(), MockScorer("ok", 0.8)] asyncio.get_event_loop().run_until_complete( run_single(db, run, adapter, cache, scorers=scorers) ) assert run.status == RunStatus.completed scores = db.query(Score).filter_by(run_id=run.id).all() # Only the working scorer should have a record assert len(scores) == 1 assert scores[0].scorer_name == "ok" def test_adapter_failure_marks_run_failed(self): engine = _engine() with _session(engine) as db: config = {"prompt": "test", "model": "m", "params": {}} run = _make_run(db, config) cache = ResponseCacheLayer() class FailAdapter(BaseAdapter): async def complete(self, prompt, model, params): raise RuntimeError("LLM endpoint down") async def list_models(self): return [] async def test_connection(self): return False adapter = FailAdapter() with pytest.raises(RuntimeError, match="LLM endpoint down"): asyncio.get_event_loop().run_until_complete( run_single(db, run, adapter, cache) ) assert run.status == RunStatus.failed assert run.completed_at is not None def test_event_bus_receives_events(self): engine = _engine() with _session(engine) as db: config = {"prompt": "test", "model": "m", "params": {}} run = _make_run(db, config) adapter = MockAdapter() cache = ResponseCacheLayer() events: list[dict] = [] bus = EventBus() bus.add_listener(lambda e: events.append(e)) asyncio.get_event_loop().run_until_complete( run_single(db, run, adapter, cache, event_bus=bus) ) event_types = [e["type"] for e in events] assert "run.started" in event_types assert "run.stage_completed" in event_types assert "run.completed" in event_types def test_event_bus_on_failure(self): engine = _engine() with _session(engine) as db: config = {"prompt": "test", "model": "m", "params": {}} run = _make_run(db, config) cache = ResponseCacheLayer() class FailAdapter(BaseAdapter): async def complete(self, prompt, model, params): raise RuntimeError("boom") async def list_models(self): return [] async def test_connection(self): return False events: list[dict] = [] bus = EventBus() bus.add_listener(lambda e: events.append(e)) with pytest.raises(RuntimeError): asyncio.get_event_loop().run_until_complete( run_single(db, run, FailAdapter(), cache, event_bus=bus) ) event_types = [e["type"] for e in events] assert "run.started" in event_types assert "run.failed" in event_types def test_jinja2_template_with_stage_history(self): engine = _engine() with _session(engine) as db: config = { "pipeline_stages": [ {"prompt_template": "First: {{ input }}", "model": "m", "params": {}}, { "prompt_template": "Stage 0 said: {{ stages[0].output }}. Now continue.", "model": "m", "params": {}, }, ], "input_data": "hello", } run = _make_run(db, config) adapter = MockAdapter(response_text="stage output") cache = ResponseCacheLayer() asyncio.get_event_loop().run_until_complete( run_single(db, run, adapter, cache) ) stages = ( db.query(StageResult) .filter_by(run_id=run.id) .order_by(StageResult.stage_index) .all() ) assert "hello" in stages[0].prompt_sent assert "stage output" in stages[1].prompt_sent def test_cache_stores_response_after_llm_call(self): engine = _engine() with _session(engine) as db: config = {"prompt": "store me", "model": "m", "params": {}} run = _make_run(db, config) adapter = MockAdapter(response_text="fresh response") cache = ResponseCacheLayer() asyncio.get_event_loop().run_until_complete( run_single(db, run, adapter, cache) ) # Verify the response was cached stats = cache.cache_stats(db) assert stats.total_entries == 1 # --------------------------------------------------------------------------- # EventBus tests # --------------------------------------------------------------------------- class TestEventBus: def test_in_process_listener(self): events = [] bus = EventBus() bus.add_listener(lambda e: events.append(e)) bus.publish({"type": "test", "data": 42}) assert len(events) == 1 assert events[0]["type"] == "test" def test_redis_publish(self): mock_redis = MagicMock() bus = EventBus(redis_client=mock_redis) bus.publish({"type": "test"}) mock_redis.publish.assert_called_once() def test_broken_listener_does_not_crash(self): def broken(event): raise ValueError("boom") events = [] bus = EventBus() bus.add_listener(broken) bus.add_listener(lambda e: events.append(e)) bus.publish({"type": "test"}) # Second listener still receives the event assert len(events) == 1 def test_redis_failure_does_not_crash(self): mock_redis = MagicMock() mock_redis.publish.side_effect = ConnectionError("Redis down") bus = EventBus(redis_client=mock_redis) # Should not raise bus.publish({"type": "test"})