Adds backend/engine/runner.py with run_single() that iterates pipeline stages, renders Jinja2 prompt templates with stage history context, checks/stores response cache, calls LLM adapters, runs configured scorers, creates StageResult and Score records, and publishes progress events via Redis pub/sub or in-process EventBus. Includes 21 passing tests covering all execution paths.
513 lines
17 KiB
Python
513 lines
17 KiB
Python
"""Tests for the run execution engine."""
|
|
|
|
import asyncio
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import Session
|
|
|
|
from engine.adapters.base import AdapterResponse, BaseAdapter
|
|
from engine.cache import ResponseCacheLayer, compute_config_hash
|
|
from engine.runner import EventBus, render_prompt, run_single
|
|
from models import Base, Run, RunStatus, Score, StageResult
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures / helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _engine():
|
|
engine = create_engine("sqlite:///:memory:")
|
|
Base.metadata.create_all(engine)
|
|
return engine
|
|
|
|
|
|
def _session(engine):
|
|
return Session(engine)
|
|
|
|
|
|
def _make_run(db: Session, config: dict, experiment_id: uuid.UUID | None = None) -> Run:
|
|
"""Create and persist a Run for testing."""
|
|
exp_id = experiment_id or uuid.uuid4()
|
|
config_hash = compute_config_hash(
|
|
config.get("prompt", ""), config.get("model", ""), config.get("params", {})
|
|
)
|
|
run = Run(
|
|
experiment_id=exp_id,
|
|
config_hash=config_hash,
|
|
config=config,
|
|
status=RunStatus.pending,
|
|
)
|
|
db.add(run)
|
|
db.commit()
|
|
db.refresh(run)
|
|
return run
|
|
|
|
|
|
class MockAdapter(BaseAdapter):
|
|
"""Test adapter that returns a fixed response."""
|
|
|
|
def __init__(self, response_text: str = "mock response", tokens_in: int = 10, tokens_out: int = 5):
|
|
self.response_text = response_text
|
|
self.tokens_in = tokens_in
|
|
self.tokens_out = tokens_out
|
|
self.calls: list[tuple[str, str, dict]] = []
|
|
|
|
async def complete(self, prompt: str, model: str, params: dict[str, Any]) -> AdapterResponse:
|
|
self.calls.append((prompt, model, params))
|
|
return AdapterResponse(
|
|
text=self.response_text,
|
|
tokens_in=self.tokens_in,
|
|
tokens_out=self.tokens_out,
|
|
latency_ms=42.0,
|
|
model=model,
|
|
)
|
|
|
|
async def list_models(self) -> list[str]:
|
|
return ["mock-model"]
|
|
|
|
async def test_connection(self) -> bool:
|
|
return True
|
|
|
|
|
|
class MockScorer:
|
|
"""Test scorer that returns a fixed value."""
|
|
|
|
def __init__(self, name: str = "mock_scorer", value: float = 0.85):
|
|
self.name = name
|
|
self._value = value
|
|
|
|
def score(self, input_data: Any, output: str, context: dict) -> float:
|
|
return self._value
|
|
|
|
|
|
class AsyncMockScorer:
|
|
"""Test async scorer."""
|
|
|
|
def __init__(self, name: str = "async_scorer", value: float = 0.9):
|
|
self.name = name
|
|
self._value = value
|
|
|
|
async def score(self, input_data: Any, output: str, context: dict) -> float:
|
|
return self._value
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# render_prompt tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestRenderPrompt:
|
|
def test_simple_variable(self):
|
|
result = render_prompt("Hello {{ name }}!", {"name": "world"})
|
|
assert result == "Hello world!"
|
|
|
|
def test_prev_output(self):
|
|
result = render_prompt(
|
|
"Summarize: {{ prev_output }}",
|
|
{"prev_output": "The quick brown fox"},
|
|
)
|
|
assert result == "Summarize: The quick brown fox"
|
|
|
|
def test_no_variables(self):
|
|
result = render_prompt("Static prompt", {})
|
|
assert result == "Static prompt"
|
|
|
|
def test_nested_context(self):
|
|
result = render_prompt(
|
|
"Model: {{ config.model }}",
|
|
{"config": {"model": "gpt-4"}},
|
|
)
|
|
assert result == "Model: gpt-4"
|
|
|
|
def test_conditional(self):
|
|
result = render_prompt(
|
|
"{% if input %}Input: {{ input }}{% else %}No input{% endif %}",
|
|
{"input": "test data"},
|
|
)
|
|
assert result == "Input: test data"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# run_single tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestRunSingle:
|
|
def test_single_stage_run(self):
|
|
engine = _engine()
|
|
with _session(engine) as db:
|
|
config = {
|
|
"prompt": "Tell me a joke",
|
|
"model": "test-model",
|
|
"params": {"temperature": 0.7},
|
|
}
|
|
run = _make_run(db, config)
|
|
adapter = MockAdapter(response_text="Why did the chicken...")
|
|
cache = ResponseCacheLayer()
|
|
|
|
result = asyncio.get_event_loop().run_until_complete(
|
|
run_single(db, run, adapter, cache)
|
|
)
|
|
|
|
assert result.status == RunStatus.completed
|
|
assert result.tokens_in == 10
|
|
assert result.tokens_out == 5
|
|
assert result.duration_ms is not None
|
|
assert result.started_at is not None
|
|
assert result.completed_at is not None
|
|
|
|
# Verify stage result was created
|
|
stages = db.query(StageResult).filter_by(run_id=run.id).all()
|
|
assert len(stages) == 1
|
|
assert stages[0].response_raw == "Why did the chicken..."
|
|
assert stages[0].model_used == "test-model"
|
|
assert stages[0].stage_index == 0
|
|
|
|
def test_multi_stage_pipeline(self):
|
|
engine = _engine()
|
|
with _session(engine) as db:
|
|
config = {
|
|
"pipeline_stages": [
|
|
{
|
|
"prompt_template": "Generate a question about {{ input }}",
|
|
"model": "model-a",
|
|
"params": {},
|
|
},
|
|
{
|
|
"prompt_template": "Answer this: {{ prev_output }}",
|
|
"model": "model-b",
|
|
"params": {},
|
|
},
|
|
],
|
|
"input_data": "Python programming",
|
|
}
|
|
run = _make_run(db, config)
|
|
|
|
call_count = 0
|
|
|
|
class StagedAdapter(BaseAdapter):
|
|
async def complete(self, prompt, model, params):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
return AdapterResponse(
|
|
text=f"stage-{call_count}-output",
|
|
tokens_in=5,
|
|
tokens_out=3,
|
|
latency_ms=20.0,
|
|
model=model,
|
|
)
|
|
|
|
async def list_models(self):
|
|
return []
|
|
|
|
async def test_connection(self):
|
|
return True
|
|
|
|
adapter = StagedAdapter()
|
|
cache = ResponseCacheLayer()
|
|
|
|
result = asyncio.get_event_loop().run_until_complete(
|
|
run_single(db, run, adapter, cache)
|
|
)
|
|
|
|
assert result.status == RunStatus.completed
|
|
assert result.tokens_in == 10 # 5 + 5
|
|
assert result.tokens_out == 6 # 3 + 3
|
|
|
|
stages = (
|
|
db.query(StageResult)
|
|
.filter_by(run_id=run.id)
|
|
.order_by(StageResult.stage_index)
|
|
.all()
|
|
)
|
|
assert len(stages) == 2
|
|
assert stages[0].model_used == "model-a"
|
|
assert stages[1].model_used == "model-b"
|
|
# Second stage prompt should contain first stage output
|
|
assert "stage-1-output" in stages[1].prompt_sent
|
|
|
|
def test_cache_hit_skips_adapter(self):
|
|
engine = _engine()
|
|
with _session(engine) as db:
|
|
config = {
|
|
"prompt": "cached prompt",
|
|
"model": "test-model",
|
|
"params": {},
|
|
}
|
|
run = _make_run(db, config)
|
|
adapter = MockAdapter()
|
|
cache = ResponseCacheLayer()
|
|
|
|
# Pre-populate cache
|
|
rendered = render_prompt("cached prompt", {
|
|
"input": "",
|
|
"config": config,
|
|
"stages": [],
|
|
"prev_output": "",
|
|
"stage_index": 0,
|
|
})
|
|
h = compute_config_hash(rendered, "test-model", {}, None)
|
|
cache.put(db, h, response="cached!", model="test-model", tokens_in=1, tokens_out=1)
|
|
|
|
result = asyncio.get_event_loop().run_until_complete(
|
|
run_single(db, run, adapter, cache)
|
|
)
|
|
|
|
assert result.status == RunStatus.completed
|
|
# Adapter should NOT have been called
|
|
assert len(adapter.calls) == 0
|
|
|
|
stages = db.query(StageResult).filter_by(run_id=run.id).all()
|
|
assert stages[0].response_raw == "cached!"
|
|
|
|
def test_scorer_creates_score_records(self):
|
|
engine = _engine()
|
|
with _session(engine) as db:
|
|
config = {"prompt": "test", "model": "m", "params": {}}
|
|
run = _make_run(db, config)
|
|
adapter = MockAdapter()
|
|
cache = ResponseCacheLayer()
|
|
scorers = [MockScorer("quality", 0.9), MockScorer("relevance", 0.7)]
|
|
|
|
asyncio.get_event_loop().run_until_complete(
|
|
run_single(db, run, adapter, cache, scorers=scorers)
|
|
)
|
|
|
|
scores = db.query(Score).filter_by(run_id=run.id).all()
|
|
assert len(scores) == 2
|
|
score_map = {s.scorer_name: s.value for s in scores}
|
|
assert score_map["quality"] == pytest.approx(0.9)
|
|
assert score_map["relevance"] == pytest.approx(0.7)
|
|
|
|
def test_async_scorer(self):
|
|
engine = _engine()
|
|
with _session(engine) as db:
|
|
config = {"prompt": "test", "model": "m", "params": {}}
|
|
run = _make_run(db, config)
|
|
adapter = MockAdapter()
|
|
cache = ResponseCacheLayer()
|
|
scorers = [AsyncMockScorer("async_quality", 0.95)]
|
|
|
|
asyncio.get_event_loop().run_until_complete(
|
|
run_single(db, run, adapter, cache, scorers=scorers)
|
|
)
|
|
|
|
scores = db.query(Score).filter_by(run_id=run.id).all()
|
|
assert len(scores) == 1
|
|
assert scores[0].value == pytest.approx(0.95)
|
|
assert scores[0].scorer_name == "async_quality"
|
|
|
|
def test_score_clamped_to_range(self):
|
|
engine = _engine()
|
|
with _session(engine) as db:
|
|
config = {"prompt": "test", "model": "m", "params": {}}
|
|
run = _make_run(db, config)
|
|
adapter = MockAdapter()
|
|
cache = ResponseCacheLayer()
|
|
scorers = [MockScorer("over", 1.5), MockScorer("under", -0.3)]
|
|
|
|
asyncio.get_event_loop().run_until_complete(
|
|
run_single(db, run, adapter, cache, scorers=scorers)
|
|
)
|
|
|
|
scores = db.query(Score).filter_by(run_id=run.id).all()
|
|
score_map = {s.scorer_name: s.value for s in scores}
|
|
assert score_map["over"] == pytest.approx(1.0)
|
|
assert score_map["under"] == pytest.approx(0.0)
|
|
|
|
def test_failed_scorer_does_not_crash_run(self):
|
|
engine = _engine()
|
|
with _session(engine) as db:
|
|
config = {"prompt": "test", "model": "m", "params": {}}
|
|
run = _make_run(db, config)
|
|
adapter = MockAdapter()
|
|
cache = ResponseCacheLayer()
|
|
|
|
class BrokenScorer:
|
|
name = "broken"
|
|
|
|
def score(self, input_data, output, context):
|
|
raise ValueError("scorer exploded")
|
|
|
|
scorers = [BrokenScorer(), MockScorer("ok", 0.8)]
|
|
|
|
asyncio.get_event_loop().run_until_complete(
|
|
run_single(db, run, adapter, cache, scorers=scorers)
|
|
)
|
|
|
|
assert run.status == RunStatus.completed
|
|
scores = db.query(Score).filter_by(run_id=run.id).all()
|
|
# Only the working scorer should have a record
|
|
assert len(scores) == 1
|
|
assert scores[0].scorer_name == "ok"
|
|
|
|
def test_adapter_failure_marks_run_failed(self):
|
|
engine = _engine()
|
|
with _session(engine) as db:
|
|
config = {"prompt": "test", "model": "m", "params": {}}
|
|
run = _make_run(db, config)
|
|
cache = ResponseCacheLayer()
|
|
|
|
class FailAdapter(BaseAdapter):
|
|
async def complete(self, prompt, model, params):
|
|
raise RuntimeError("LLM endpoint down")
|
|
|
|
async def list_models(self):
|
|
return []
|
|
|
|
async def test_connection(self):
|
|
return False
|
|
|
|
adapter = FailAdapter()
|
|
|
|
with pytest.raises(RuntimeError, match="LLM endpoint down"):
|
|
asyncio.get_event_loop().run_until_complete(
|
|
run_single(db, run, adapter, cache)
|
|
)
|
|
|
|
assert run.status == RunStatus.failed
|
|
assert run.completed_at is not None
|
|
|
|
def test_event_bus_receives_events(self):
|
|
engine = _engine()
|
|
with _session(engine) as db:
|
|
config = {"prompt": "test", "model": "m", "params": {}}
|
|
run = _make_run(db, config)
|
|
adapter = MockAdapter()
|
|
cache = ResponseCacheLayer()
|
|
|
|
events: list[dict] = []
|
|
bus = EventBus()
|
|
bus.add_listener(lambda e: events.append(e))
|
|
|
|
asyncio.get_event_loop().run_until_complete(
|
|
run_single(db, run, adapter, cache, event_bus=bus)
|
|
)
|
|
|
|
event_types = [e["type"] for e in events]
|
|
assert "run.started" in event_types
|
|
assert "run.stage_completed" in event_types
|
|
assert "run.completed" in event_types
|
|
|
|
def test_event_bus_on_failure(self):
|
|
engine = _engine()
|
|
with _session(engine) as db:
|
|
config = {"prompt": "test", "model": "m", "params": {}}
|
|
run = _make_run(db, config)
|
|
cache = ResponseCacheLayer()
|
|
|
|
class FailAdapter(BaseAdapter):
|
|
async def complete(self, prompt, model, params):
|
|
raise RuntimeError("boom")
|
|
|
|
async def list_models(self):
|
|
return []
|
|
|
|
async def test_connection(self):
|
|
return False
|
|
|
|
events: list[dict] = []
|
|
bus = EventBus()
|
|
bus.add_listener(lambda e: events.append(e))
|
|
|
|
with pytest.raises(RuntimeError):
|
|
asyncio.get_event_loop().run_until_complete(
|
|
run_single(db, run, FailAdapter(), cache, event_bus=bus)
|
|
)
|
|
|
|
event_types = [e["type"] for e in events]
|
|
assert "run.started" in event_types
|
|
assert "run.failed" in event_types
|
|
|
|
def test_jinja2_template_with_stage_history(self):
|
|
engine = _engine()
|
|
with _session(engine) as db:
|
|
config = {
|
|
"pipeline_stages": [
|
|
{"prompt_template": "First: {{ input }}", "model": "m", "params": {}},
|
|
{
|
|
"prompt_template": "Stage 0 said: {{ stages[0].output }}. Now continue.",
|
|
"model": "m",
|
|
"params": {},
|
|
},
|
|
],
|
|
"input_data": "hello",
|
|
}
|
|
run = _make_run(db, config)
|
|
adapter = MockAdapter(response_text="stage output")
|
|
cache = ResponseCacheLayer()
|
|
|
|
asyncio.get_event_loop().run_until_complete(
|
|
run_single(db, run, adapter, cache)
|
|
)
|
|
|
|
stages = (
|
|
db.query(StageResult)
|
|
.filter_by(run_id=run.id)
|
|
.order_by(StageResult.stage_index)
|
|
.all()
|
|
)
|
|
assert "hello" in stages[0].prompt_sent
|
|
assert "stage output" in stages[1].prompt_sent
|
|
|
|
def test_cache_stores_response_after_llm_call(self):
|
|
engine = _engine()
|
|
with _session(engine) as db:
|
|
config = {"prompt": "store me", "model": "m", "params": {}}
|
|
run = _make_run(db, config)
|
|
adapter = MockAdapter(response_text="fresh response")
|
|
cache = ResponseCacheLayer()
|
|
|
|
asyncio.get_event_loop().run_until_complete(
|
|
run_single(db, run, adapter, cache)
|
|
)
|
|
|
|
# Verify the response was cached
|
|
stats = cache.cache_stats(db)
|
|
assert stats.total_entries == 1
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# EventBus tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestEventBus:
|
|
def test_in_process_listener(self):
|
|
events = []
|
|
bus = EventBus()
|
|
bus.add_listener(lambda e: events.append(e))
|
|
bus.publish({"type": "test", "data": 42})
|
|
assert len(events) == 1
|
|
assert events[0]["type"] == "test"
|
|
|
|
def test_redis_publish(self):
|
|
mock_redis = MagicMock()
|
|
bus = EventBus(redis_client=mock_redis)
|
|
bus.publish({"type": "test"})
|
|
mock_redis.publish.assert_called_once()
|
|
|
|
def test_broken_listener_does_not_crash(self):
|
|
def broken(event):
|
|
raise ValueError("boom")
|
|
|
|
events = []
|
|
bus = EventBus()
|
|
bus.add_listener(broken)
|
|
bus.add_listener(lambda e: events.append(e))
|
|
bus.publish({"type": "test"})
|
|
# Second listener still receives the event
|
|
assert len(events) == 1
|
|
|
|
def test_redis_failure_does_not_crash(self):
|
|
mock_redis = MagicMock()
|
|
mock_redis.publish.side_effect = ConnectionError("Redis down")
|
|
bus = EventBus(redis_client=mock_redis)
|
|
# Should not raise
|
|
bus.publish({"type": "test"})
|