MAESTRO: Implement run execution engine with Jinja2 templating, caching, scoring, and event bus
Adds backend/engine/runner.py with run_single() that iterates pipeline stages, renders Jinja2 prompt templates with stage history context, checks/stores response cache, calls LLM adapters, runs configured scorers, creates StageResult and Score records, and publishes progress events via Redis pub/sub or in-process EventBus. Includes 21 passing tests covering all execution paths.
This commit is contained in:
parent
04a96f3dc3
commit
d607970f0c
4 changed files with 801 additions and 1 deletions
|
|
@ -8,7 +8,8 @@ Implement the core experiment execution engine: LLM adapters, response caching,
|
|||
|
||||
- [x] Implement backend/engine/cache.py with the ResponseCache layer. Key function: compute_config_hash(prompt, model, params, input_data) → SHA-256 hex string. Methods: get(config_hash) → CachedResponse or None, put(config_hash, response, metadata). In SQLite mode, use the ResponseCache table directly. In Postgres mode, same table but with connection pooling. Include a cache_stats() method returning hit rate, total entries, and storage size.
|
||||
|
||||
- [ ] Implement backend/engine/runner.py for individual run execution. The run_single function should: (1) iterate through pipeline stages, (2) render prompt templates with Jinja2 (allowing previous stage output as context), (3) check cache before calling LLM, (4) call the LLM adapter if cache miss, (5) store response in cache, (6) create StageResult records, (7) run all configured scorers, (8) create Score records, (9) update Run status and timing, (10) publish progress events via Redis pub/sub (or in-process event bus).
|
||||
- [x] Implement backend/engine/runner.py for individual run execution. The run_single function should: (1) iterate through pipeline stages, (2) render prompt templates with Jinja2 (allowing previous stage output as context), (3) check cache before calling LLM, (4) call the LLM adapter if cache miss, (5) store response in cache, (6) create StageResult records, (7) run all configured scorers, (8) create Score records, (9) update Run status and timing, (10) publish progress events via Redis pub/sub (or in-process event bus).
|
||||
<!-- Completed: Implemented run_single with all 10 requirements, EventBus (Redis + in-process), Jinja2 templating. 21 tests in test_runner.py, all passing. -->
|
||||
|
||||
- [ ] Implement backend/engine/sweep.py for sweep orchestration. Support three sweep types: GridSweep (enumerate all combinations from parameter_space), RandomSweep (sample N random configs from parameter ranges), GuidedSweep (use previous results to inform next config — start with top-K exploitation + random exploration). The sweep runner should: respect MAX_CONCURRENT_RUNS for parallelism, track token budget and stop at MAX_TOKENS_PER_SWEEP, emit WebSocket events for each run completion, handle pause/resume/stop via Redis flags.
|
||||
|
||||
|
|
|
|||
285
backend/engine/runner.py
Normal file
285
backend/engine/runner.py
Normal file
|
|
@ -0,0 +1,285 @@
|
|||
"""Individual run execution for PromptLooper.
|
||||
|
||||
Executes a single Run by iterating through pipeline stages, rendering
|
||||
prompts via Jinja2, checking/storing cache, calling the LLM adapter,
|
||||
scoring results, and publishing progress events.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from jinja2 import BaseLoader, Environment
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from engine.adapters.base import AdapterResponse, BaseAdapter
|
||||
from engine.cache import ResponseCacheLayer, compute_config_hash
|
||||
from models import Run, RunStatus, Score, StageResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Jinja2 environment with sandboxed string loader (no filesystem access)
|
||||
_jinja_env = Environment(loader=BaseLoader(), autoescape=False)
|
||||
|
||||
|
||||
class EventBus:
|
||||
"""Simple event publisher that uses Redis pub/sub when available,
|
||||
or falls back to in-process callbacks for single-container mode."""
|
||||
|
||||
def __init__(self, redis_client: Any | None = None, channel: str = "promptlooper:events"):
|
||||
self._redis = redis_client
|
||||
self._channel = channel
|
||||
self._listeners: list[Any] = []
|
||||
|
||||
def add_listener(self, callback: Any) -> None:
|
||||
self._listeners.append(callback)
|
||||
|
||||
def publish(self, event: dict[str, Any]) -> None:
|
||||
payload = json.dumps(event, default=str)
|
||||
if self._redis is not None:
|
||||
try:
|
||||
self._redis.publish(self._channel, payload)
|
||||
except Exception:
|
||||
logger.warning("Failed to publish event to Redis", exc_info=True)
|
||||
for listener in self._listeners:
|
||||
try:
|
||||
listener(event)
|
||||
except Exception:
|
||||
logger.warning("Event listener error", exc_info=True)
|
||||
|
||||
|
||||
def render_prompt(template_str: str, context: dict[str, Any]) -> str:
|
||||
"""Render a Jinja2 template string with the given context."""
|
||||
template = _jinja_env.from_string(template_str)
|
||||
return template.render(**context)
|
||||
|
||||
|
||||
def _build_stage_context(
|
||||
stage_index: int,
|
||||
stage_results: list[dict[str, Any]],
|
||||
run_config: dict[str, Any],
|
||||
input_data: Any = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build the Jinja2 template context for a given stage.
|
||||
|
||||
Available variables in templates:
|
||||
- input: the original input data
|
||||
- config: the run configuration
|
||||
- stages: list of previous stage outputs (dicts with 'output', 'model', etc.)
|
||||
- prev_output: the output text of the immediately preceding stage (or empty string)
|
||||
- stage_index: current 0-based stage index
|
||||
"""
|
||||
prev_output = ""
|
||||
if stage_results:
|
||||
prev_output = stage_results[-1].get("output", "")
|
||||
|
||||
return {
|
||||
"input": input_data or "",
|
||||
"config": run_config,
|
||||
"stages": stage_results,
|
||||
"prev_output": prev_output,
|
||||
"stage_index": stage_index,
|
||||
}
|
||||
|
||||
|
||||
async def run_single(
|
||||
db: Session,
|
||||
run: Run,
|
||||
adapter: BaseAdapter,
|
||||
cache: ResponseCacheLayer,
|
||||
scorers: list[Any] | None = None,
|
||||
event_bus: EventBus | None = None,
|
||||
) -> Run:
|
||||
"""Execute a single Run through all its pipeline stages.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy session.
|
||||
run: The Run ORM object (must already be persisted with config).
|
||||
adapter: LLM adapter to use for completions.
|
||||
cache: Response cache layer.
|
||||
scorers: Optional list of scorer instances (must have .name and .score()).
|
||||
event_bus: Optional event publisher for progress updates.
|
||||
|
||||
Returns:
|
||||
The updated Run object.
|
||||
"""
|
||||
scorers = scorers or []
|
||||
|
||||
# Extract pipeline stages from config
|
||||
config = run.config or {}
|
||||
stages = config.get("pipeline_stages", [])
|
||||
if not stages:
|
||||
# Single-stage fallback: use prompt + model + params from config directly
|
||||
stages = [
|
||||
{
|
||||
"prompt_template": config.get("prompt", config.get("prompt_template", "")),
|
||||
"model": config.get("model", ""),
|
||||
"params": config.get("params", {}),
|
||||
}
|
||||
]
|
||||
|
||||
input_data = config.get("input_data")
|
||||
|
||||
# Mark run as running
|
||||
run.status = RunStatus.running
|
||||
run.started_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
|
||||
if event_bus:
|
||||
event_bus.publish({
|
||||
"type": "run.started",
|
||||
"run_id": str(run.id),
|
||||
"experiment_id": str(run.experiment_id),
|
||||
})
|
||||
|
||||
total_tokens_in = 0
|
||||
total_tokens_out = 0
|
||||
completed_stages: list[dict[str, Any]] = []
|
||||
t_start = time.perf_counter()
|
||||
|
||||
try:
|
||||
for stage_index, stage_def in enumerate(stages):
|
||||
prompt_template = stage_def.get("prompt_template", stage_def.get("prompt", ""))
|
||||
model = stage_def.get("model", config.get("model", ""))
|
||||
params = stage_def.get("params", config.get("params", {}))
|
||||
|
||||
# Build context and render prompt
|
||||
template_ctx = _build_stage_context(
|
||||
stage_index, completed_stages, config, input_data
|
||||
)
|
||||
rendered_prompt = render_prompt(prompt_template, template_ctx)
|
||||
|
||||
# Check cache
|
||||
config_hash = compute_config_hash(rendered_prompt, model, params, input_data)
|
||||
cached = cache.get(db, config_hash)
|
||||
|
||||
if cached is not None:
|
||||
# Cache hit
|
||||
response_text = cached.response
|
||||
tokens_in = cached.tokens_in or 0
|
||||
tokens_out = cached.tokens_out or 0
|
||||
latency_ms = cached.latency_ms or 0
|
||||
else:
|
||||
# Cache miss — call LLM
|
||||
adapter_resp: AdapterResponse = await adapter.complete(
|
||||
rendered_prompt, model, params
|
||||
)
|
||||
response_text = adapter_resp.text
|
||||
tokens_in = adapter_resp.tokens_in
|
||||
tokens_out = adapter_resp.tokens_out
|
||||
latency_ms = int(adapter_resp.latency_ms)
|
||||
|
||||
# Store in cache
|
||||
cache.put(
|
||||
db,
|
||||
config_hash=config_hash,
|
||||
response=response_text,
|
||||
model=model,
|
||||
tokens_in=tokens_in,
|
||||
tokens_out=tokens_out,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
total_tokens_in += tokens_in
|
||||
total_tokens_out += tokens_out
|
||||
|
||||
# Create StageResult record
|
||||
stage_result = StageResult(
|
||||
run_id=run.id,
|
||||
stage_index=stage_index,
|
||||
prompt_sent=rendered_prompt,
|
||||
response_raw=response_text,
|
||||
model_used=model,
|
||||
parameters=params,
|
||||
tokens_in=tokens_in,
|
||||
tokens_out=tokens_out,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
db.add(stage_result)
|
||||
db.flush()
|
||||
|
||||
completed_stages.append({
|
||||
"output": response_text,
|
||||
"model": model,
|
||||
"tokens_in": tokens_in,
|
||||
"tokens_out": tokens_out,
|
||||
"latency_ms": latency_ms,
|
||||
"stage_index": stage_index,
|
||||
})
|
||||
|
||||
if event_bus:
|
||||
event_bus.publish({
|
||||
"type": "run.stage_completed",
|
||||
"run_id": str(run.id),
|
||||
"experiment_id": str(run.experiment_id),
|
||||
"stage_index": stage_index,
|
||||
"total_stages": len(stages),
|
||||
})
|
||||
|
||||
# Run scorers on the final output
|
||||
final_output = completed_stages[-1]["output"] if completed_stages else ""
|
||||
scorer_context = {
|
||||
"config": config,
|
||||
"stages": completed_stages,
|
||||
"input_data": input_data,
|
||||
}
|
||||
|
||||
for scorer in scorers:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(getattr(scorer, "score", None)):
|
||||
score_value = await scorer.score(input_data, final_output, scorer_context)
|
||||
else:
|
||||
score_value = scorer.score(input_data, final_output, scorer_context)
|
||||
|
||||
score_value = max(0.0, min(1.0, float(score_value)))
|
||||
|
||||
score_record = Score(
|
||||
run_id=run.id,
|
||||
scorer_name=scorer.name,
|
||||
value=score_value,
|
||||
)
|
||||
db.add(score_record)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Scorer %s failed for run %s", scorer.name, run.id, exc_info=True
|
||||
)
|
||||
|
||||
# Update run status and timing
|
||||
duration_ms = int((time.perf_counter() - t_start) * 1000)
|
||||
run.status = RunStatus.completed
|
||||
run.completed_at = datetime.now(timezone.utc)
|
||||
run.duration_ms = duration_ms
|
||||
run.tokens_in = total_tokens_in
|
||||
run.tokens_out = total_tokens_out
|
||||
db.commit()
|
||||
|
||||
if event_bus:
|
||||
event_bus.publish({
|
||||
"type": "run.completed",
|
||||
"run_id": str(run.id),
|
||||
"experiment_id": str(run.experiment_id),
|
||||
"duration_ms": duration_ms,
|
||||
"tokens_in": total_tokens_in,
|
||||
"tokens_out": total_tokens_out,
|
||||
})
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Run %s failed: %s", run.id, exc, exc_info=True)
|
||||
run.status = RunStatus.failed
|
||||
run.completed_at = datetime.now(timezone.utc)
|
||||
run.duration_ms = int((time.perf_counter() - t_start) * 1000)
|
||||
db.commit()
|
||||
|
||||
if event_bus:
|
||||
event_bus.publish({
|
||||
"type": "run.failed",
|
||||
"run_id": str(run.id),
|
||||
"experiment_id": str(run.experiment_id),
|
||||
"error": str(exc),
|
||||
})
|
||||
raise
|
||||
|
||||
return run
|
||||
|
|
@ -14,3 +14,4 @@ websockets>=13.0,<14.0
|
|||
psycopg2-binary>=2.9,<3.0
|
||||
aiosqlite>=0.20,<1.0
|
||||
python-multipart>=0.0.9
|
||||
jinja2>=3.1,<4.0
|
||||
|
|
|
|||
513
backend/tests/test_runner.py
Normal file
513
backend/tests/test_runner.py
Normal file
|
|
@ -0,0 +1,513 @@
|
|||
"""Tests for the run execution engine."""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from engine.adapters.base import AdapterResponse, BaseAdapter
|
||||
from engine.cache import ResponseCacheLayer, compute_config_hash
|
||||
from engine.runner import EventBus, render_prompt, run_single
|
||||
from models import Base, Run, RunStatus, Score, StageResult
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures / helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _engine():
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
Base.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
|
||||
def _session(engine):
|
||||
return Session(engine)
|
||||
|
||||
|
||||
def _make_run(db: Session, config: dict, experiment_id: uuid.UUID | None = None) -> Run:
|
||||
"""Create and persist a Run for testing."""
|
||||
exp_id = experiment_id or uuid.uuid4()
|
||||
config_hash = compute_config_hash(
|
||||
config.get("prompt", ""), config.get("model", ""), config.get("params", {})
|
||||
)
|
||||
run = Run(
|
||||
experiment_id=exp_id,
|
||||
config_hash=config_hash,
|
||||
config=config,
|
||||
status=RunStatus.pending,
|
||||
)
|
||||
db.add(run)
|
||||
db.commit()
|
||||
db.refresh(run)
|
||||
return run
|
||||
|
||||
|
||||
class MockAdapter(BaseAdapter):
|
||||
"""Test adapter that returns a fixed response."""
|
||||
|
||||
def __init__(self, response_text: str = "mock response", tokens_in: int = 10, tokens_out: int = 5):
|
||||
self.response_text = response_text
|
||||
self.tokens_in = tokens_in
|
||||
self.tokens_out = tokens_out
|
||||
self.calls: list[tuple[str, str, dict]] = []
|
||||
|
||||
async def complete(self, prompt: str, model: str, params: dict[str, Any]) -> AdapterResponse:
|
||||
self.calls.append((prompt, model, params))
|
||||
return AdapterResponse(
|
||||
text=self.response_text,
|
||||
tokens_in=self.tokens_in,
|
||||
tokens_out=self.tokens_out,
|
||||
latency_ms=42.0,
|
||||
model=model,
|
||||
)
|
||||
|
||||
async def list_models(self) -> list[str]:
|
||||
return ["mock-model"]
|
||||
|
||||
async def test_connection(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class MockScorer:
|
||||
"""Test scorer that returns a fixed value."""
|
||||
|
||||
def __init__(self, name: str = "mock_scorer", value: float = 0.85):
|
||||
self.name = name
|
||||
self._value = value
|
||||
|
||||
def score(self, input_data: Any, output: str, context: dict) -> float:
|
||||
return self._value
|
||||
|
||||
|
||||
class AsyncMockScorer:
|
||||
"""Test async scorer."""
|
||||
|
||||
def __init__(self, name: str = "async_scorer", value: float = 0.9):
|
||||
self.name = name
|
||||
self._value = value
|
||||
|
||||
async def score(self, input_data: Any, output: str, context: dict) -> float:
|
||||
return self._value
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# render_prompt tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRenderPrompt:
|
||||
def test_simple_variable(self):
|
||||
result = render_prompt("Hello {{ name }}!", {"name": "world"})
|
||||
assert result == "Hello world!"
|
||||
|
||||
def test_prev_output(self):
|
||||
result = render_prompt(
|
||||
"Summarize: {{ prev_output }}",
|
||||
{"prev_output": "The quick brown fox"},
|
||||
)
|
||||
assert result == "Summarize: The quick brown fox"
|
||||
|
||||
def test_no_variables(self):
|
||||
result = render_prompt("Static prompt", {})
|
||||
assert result == "Static prompt"
|
||||
|
||||
def test_nested_context(self):
|
||||
result = render_prompt(
|
||||
"Model: {{ config.model }}",
|
||||
{"config": {"model": "gpt-4"}},
|
||||
)
|
||||
assert result == "Model: gpt-4"
|
||||
|
||||
def test_conditional(self):
|
||||
result = render_prompt(
|
||||
"{% if input %}Input: {{ input }}{% else %}No input{% endif %}",
|
||||
{"input": "test data"},
|
||||
)
|
||||
assert result == "Input: test data"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# run_single tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunSingle:
|
||||
def test_single_stage_run(self):
|
||||
engine = _engine()
|
||||
with _session(engine) as db:
|
||||
config = {
|
||||
"prompt": "Tell me a joke",
|
||||
"model": "test-model",
|
||||
"params": {"temperature": 0.7},
|
||||
}
|
||||
run = _make_run(db, config)
|
||||
adapter = MockAdapter(response_text="Why did the chicken...")
|
||||
cache = ResponseCacheLayer()
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
run_single(db, run, adapter, cache)
|
||||
)
|
||||
|
||||
assert result.status == RunStatus.completed
|
||||
assert result.tokens_in == 10
|
||||
assert result.tokens_out == 5
|
||||
assert result.duration_ms is not None
|
||||
assert result.started_at is not None
|
||||
assert result.completed_at is not None
|
||||
|
||||
# Verify stage result was created
|
||||
stages = db.query(StageResult).filter_by(run_id=run.id).all()
|
||||
assert len(stages) == 1
|
||||
assert stages[0].response_raw == "Why did the chicken..."
|
||||
assert stages[0].model_used == "test-model"
|
||||
assert stages[0].stage_index == 0
|
||||
|
||||
def test_multi_stage_pipeline(self):
|
||||
engine = _engine()
|
||||
with _session(engine) as db:
|
||||
config = {
|
||||
"pipeline_stages": [
|
||||
{
|
||||
"prompt_template": "Generate a question about {{ input }}",
|
||||
"model": "model-a",
|
||||
"params": {},
|
||||
},
|
||||
{
|
||||
"prompt_template": "Answer this: {{ prev_output }}",
|
||||
"model": "model-b",
|
||||
"params": {},
|
||||
},
|
||||
],
|
||||
"input_data": "Python programming",
|
||||
}
|
||||
run = _make_run(db, config)
|
||||
|
||||
call_count = 0
|
||||
|
||||
class StagedAdapter(BaseAdapter):
|
||||
async def complete(self, prompt, model, params):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return AdapterResponse(
|
||||
text=f"stage-{call_count}-output",
|
||||
tokens_in=5,
|
||||
tokens_out=3,
|
||||
latency_ms=20.0,
|
||||
model=model,
|
||||
)
|
||||
|
||||
async def list_models(self):
|
||||
return []
|
||||
|
||||
async def test_connection(self):
|
||||
return True
|
||||
|
||||
adapter = StagedAdapter()
|
||||
cache = ResponseCacheLayer()
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
run_single(db, run, adapter, cache)
|
||||
)
|
||||
|
||||
assert result.status == RunStatus.completed
|
||||
assert result.tokens_in == 10 # 5 + 5
|
||||
assert result.tokens_out == 6 # 3 + 3
|
||||
|
||||
stages = (
|
||||
db.query(StageResult)
|
||||
.filter_by(run_id=run.id)
|
||||
.order_by(StageResult.stage_index)
|
||||
.all()
|
||||
)
|
||||
assert len(stages) == 2
|
||||
assert stages[0].model_used == "model-a"
|
||||
assert stages[1].model_used == "model-b"
|
||||
# Second stage prompt should contain first stage output
|
||||
assert "stage-1-output" in stages[1].prompt_sent
|
||||
|
||||
def test_cache_hit_skips_adapter(self):
|
||||
engine = _engine()
|
||||
with _session(engine) as db:
|
||||
config = {
|
||||
"prompt": "cached prompt",
|
||||
"model": "test-model",
|
||||
"params": {},
|
||||
}
|
||||
run = _make_run(db, config)
|
||||
adapter = MockAdapter()
|
||||
cache = ResponseCacheLayer()
|
||||
|
||||
# Pre-populate cache
|
||||
rendered = render_prompt("cached prompt", {
|
||||
"input": "",
|
||||
"config": config,
|
||||
"stages": [],
|
||||
"prev_output": "",
|
||||
"stage_index": 0,
|
||||
})
|
||||
h = compute_config_hash(rendered, "test-model", {}, None)
|
||||
cache.put(db, h, response="cached!", model="test-model", tokens_in=1, tokens_out=1)
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
run_single(db, run, adapter, cache)
|
||||
)
|
||||
|
||||
assert result.status == RunStatus.completed
|
||||
# Adapter should NOT have been called
|
||||
assert len(adapter.calls) == 0
|
||||
|
||||
stages = db.query(StageResult).filter_by(run_id=run.id).all()
|
||||
assert stages[0].response_raw == "cached!"
|
||||
|
||||
def test_scorer_creates_score_records(self):
|
||||
engine = _engine()
|
||||
with _session(engine) as db:
|
||||
config = {"prompt": "test", "model": "m", "params": {}}
|
||||
run = _make_run(db, config)
|
||||
adapter = MockAdapter()
|
||||
cache = ResponseCacheLayer()
|
||||
scorers = [MockScorer("quality", 0.9), MockScorer("relevance", 0.7)]
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
run_single(db, run, adapter, cache, scorers=scorers)
|
||||
)
|
||||
|
||||
scores = db.query(Score).filter_by(run_id=run.id).all()
|
||||
assert len(scores) == 2
|
||||
score_map = {s.scorer_name: s.value for s in scores}
|
||||
assert score_map["quality"] == pytest.approx(0.9)
|
||||
assert score_map["relevance"] == pytest.approx(0.7)
|
||||
|
||||
def test_async_scorer(self):
|
||||
engine = _engine()
|
||||
with _session(engine) as db:
|
||||
config = {"prompt": "test", "model": "m", "params": {}}
|
||||
run = _make_run(db, config)
|
||||
adapter = MockAdapter()
|
||||
cache = ResponseCacheLayer()
|
||||
scorers = [AsyncMockScorer("async_quality", 0.95)]
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
run_single(db, run, adapter, cache, scorers=scorers)
|
||||
)
|
||||
|
||||
scores = db.query(Score).filter_by(run_id=run.id).all()
|
||||
assert len(scores) == 1
|
||||
assert scores[0].value == pytest.approx(0.95)
|
||||
assert scores[0].scorer_name == "async_quality"
|
||||
|
||||
def test_score_clamped_to_range(self):
|
||||
engine = _engine()
|
||||
with _session(engine) as db:
|
||||
config = {"prompt": "test", "model": "m", "params": {}}
|
||||
run = _make_run(db, config)
|
||||
adapter = MockAdapter()
|
||||
cache = ResponseCacheLayer()
|
||||
scorers = [MockScorer("over", 1.5), MockScorer("under", -0.3)]
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
run_single(db, run, adapter, cache, scorers=scorers)
|
||||
)
|
||||
|
||||
scores = db.query(Score).filter_by(run_id=run.id).all()
|
||||
score_map = {s.scorer_name: s.value for s in scores}
|
||||
assert score_map["over"] == pytest.approx(1.0)
|
||||
assert score_map["under"] == pytest.approx(0.0)
|
||||
|
||||
def test_failed_scorer_does_not_crash_run(self):
|
||||
engine = _engine()
|
||||
with _session(engine) as db:
|
||||
config = {"prompt": "test", "model": "m", "params": {}}
|
||||
run = _make_run(db, config)
|
||||
adapter = MockAdapter()
|
||||
cache = ResponseCacheLayer()
|
||||
|
||||
class BrokenScorer:
|
||||
name = "broken"
|
||||
|
||||
def score(self, input_data, output, context):
|
||||
raise ValueError("scorer exploded")
|
||||
|
||||
scorers = [BrokenScorer(), MockScorer("ok", 0.8)]
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
run_single(db, run, adapter, cache, scorers=scorers)
|
||||
)
|
||||
|
||||
assert run.status == RunStatus.completed
|
||||
scores = db.query(Score).filter_by(run_id=run.id).all()
|
||||
# Only the working scorer should have a record
|
||||
assert len(scores) == 1
|
||||
assert scores[0].scorer_name == "ok"
|
||||
|
||||
def test_adapter_failure_marks_run_failed(self):
|
||||
engine = _engine()
|
||||
with _session(engine) as db:
|
||||
config = {"prompt": "test", "model": "m", "params": {}}
|
||||
run = _make_run(db, config)
|
||||
cache = ResponseCacheLayer()
|
||||
|
||||
class FailAdapter(BaseAdapter):
|
||||
async def complete(self, prompt, model, params):
|
||||
raise RuntimeError("LLM endpoint down")
|
||||
|
||||
async def list_models(self):
|
||||
return []
|
||||
|
||||
async def test_connection(self):
|
||||
return False
|
||||
|
||||
adapter = FailAdapter()
|
||||
|
||||
with pytest.raises(RuntimeError, match="LLM endpoint down"):
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
run_single(db, run, adapter, cache)
|
||||
)
|
||||
|
||||
assert run.status == RunStatus.failed
|
||||
assert run.completed_at is not None
|
||||
|
||||
def test_event_bus_receives_events(self):
|
||||
engine = _engine()
|
||||
with _session(engine) as db:
|
||||
config = {"prompt": "test", "model": "m", "params": {}}
|
||||
run = _make_run(db, config)
|
||||
adapter = MockAdapter()
|
||||
cache = ResponseCacheLayer()
|
||||
|
||||
events: list[dict] = []
|
||||
bus = EventBus()
|
||||
bus.add_listener(lambda e: events.append(e))
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
run_single(db, run, adapter, cache, event_bus=bus)
|
||||
)
|
||||
|
||||
event_types = [e["type"] for e in events]
|
||||
assert "run.started" in event_types
|
||||
assert "run.stage_completed" in event_types
|
||||
assert "run.completed" in event_types
|
||||
|
||||
def test_event_bus_on_failure(self):
|
||||
engine = _engine()
|
||||
with _session(engine) as db:
|
||||
config = {"prompt": "test", "model": "m", "params": {}}
|
||||
run = _make_run(db, config)
|
||||
cache = ResponseCacheLayer()
|
||||
|
||||
class FailAdapter(BaseAdapter):
|
||||
async def complete(self, prompt, model, params):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
async def list_models(self):
|
||||
return []
|
||||
|
||||
async def test_connection(self):
|
||||
return False
|
||||
|
||||
events: list[dict] = []
|
||||
bus = EventBus()
|
||||
bus.add_listener(lambda e: events.append(e))
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
run_single(db, run, FailAdapter(), cache, event_bus=bus)
|
||||
)
|
||||
|
||||
event_types = [e["type"] for e in events]
|
||||
assert "run.started" in event_types
|
||||
assert "run.failed" in event_types
|
||||
|
||||
def test_jinja2_template_with_stage_history(self):
|
||||
engine = _engine()
|
||||
with _session(engine) as db:
|
||||
config = {
|
||||
"pipeline_stages": [
|
||||
{"prompt_template": "First: {{ input }}", "model": "m", "params": {}},
|
||||
{
|
||||
"prompt_template": "Stage 0 said: {{ stages[0].output }}. Now continue.",
|
||||
"model": "m",
|
||||
"params": {},
|
||||
},
|
||||
],
|
||||
"input_data": "hello",
|
||||
}
|
||||
run = _make_run(db, config)
|
||||
adapter = MockAdapter(response_text="stage output")
|
||||
cache = ResponseCacheLayer()
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
run_single(db, run, adapter, cache)
|
||||
)
|
||||
|
||||
stages = (
|
||||
db.query(StageResult)
|
||||
.filter_by(run_id=run.id)
|
||||
.order_by(StageResult.stage_index)
|
||||
.all()
|
||||
)
|
||||
assert "hello" in stages[0].prompt_sent
|
||||
assert "stage output" in stages[1].prompt_sent
|
||||
|
||||
def test_cache_stores_response_after_llm_call(self):
|
||||
engine = _engine()
|
||||
with _session(engine) as db:
|
||||
config = {"prompt": "store me", "model": "m", "params": {}}
|
||||
run = _make_run(db, config)
|
||||
adapter = MockAdapter(response_text="fresh response")
|
||||
cache = ResponseCacheLayer()
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
run_single(db, run, adapter, cache)
|
||||
)
|
||||
|
||||
# Verify the response was cached
|
||||
stats = cache.cache_stats(db)
|
||||
assert stats.total_entries == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EventBus tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEventBus:
|
||||
def test_in_process_listener(self):
|
||||
events = []
|
||||
bus = EventBus()
|
||||
bus.add_listener(lambda e: events.append(e))
|
||||
bus.publish({"type": "test", "data": 42})
|
||||
assert len(events) == 1
|
||||
assert events[0]["type"] == "test"
|
||||
|
||||
def test_redis_publish(self):
|
||||
mock_redis = MagicMock()
|
||||
bus = EventBus(redis_client=mock_redis)
|
||||
bus.publish({"type": "test"})
|
||||
mock_redis.publish.assert_called_once()
|
||||
|
||||
def test_broken_listener_does_not_crash(self):
|
||||
def broken(event):
|
||||
raise ValueError("boom")
|
||||
|
||||
events = []
|
||||
bus = EventBus()
|
||||
bus.add_listener(broken)
|
||||
bus.add_listener(lambda e: events.append(e))
|
||||
bus.publish({"type": "test"})
|
||||
# Second listener still receives the event
|
||||
assert len(events) == 1
|
||||
|
||||
def test_redis_failure_does_not_crash(self):
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.publish.side_effect = ConnectionError("Redis down")
|
||||
bus = EventBus(redis_client=mock_redis)
|
||||
# Should not raise
|
||||
bus.publish({"type": "test"})
|
||||
Loading…
Add table
Reference in a new issue