Adds backend/engine/sweep.py with three sweep strategies: - GridSweep: exhaustive enumeration of all parameter combinations - RandomSweep: N random samples from parameter ranges (list, min/max, step) - GuidedSweep: top-K exploitation + random exploration from previous results Features: bounded parallelism via asyncio.Semaphore, token budget enforcement, Redis-based pause/resume/stop control flags, sweep-level event publishing. 36 tests in test_sweep.py covering config generation, helpers, and full sweep execution.
636 lines
21 KiB
Python
636 lines
21 KiB
Python
"""Tests for sweep orchestration engine."""
|
||
|
||
import asyncio
|
||
import uuid
|
||
from typing import Any
|
||
from unittest.mock import MagicMock
|
||
|
||
import pytest
|
||
from sqlalchemy import create_engine
|
||
from sqlalchemy.orm import Session
|
||
|
||
from engine.adapters.base import AdapterResponse, BaseAdapter
|
||
from engine.cache import ResponseCacheLayer
|
||
from engine.runner import EventBus
|
||
from engine.sweep import (
|
||
SweepConfig,
|
||
SweepResult,
|
||
_parse_sweep_config,
|
||
_sample_param,
|
||
_set_config_value,
|
||
generate_grid_configs,
|
||
generate_guided_configs,
|
||
generate_random_configs,
|
||
run_sweep,
|
||
)
|
||
from models import Base, Experiment, ExperimentStatus, Project, Run, RunStatus, Score, User
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Fixtures / helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _engine():
|
||
engine = create_engine("sqlite:///:memory:")
|
||
Base.metadata.create_all(engine)
|
||
return engine
|
||
|
||
|
||
def _session(engine):
|
||
return Session(engine)
|
||
|
||
|
||
def _make_experiment(
|
||
db: Session,
|
||
parameter_space: dict | None = None,
|
||
pipeline_stages: dict | None = None,
|
||
sample_data: dict | None = None,
|
||
) -> Experiment:
|
||
"""Create a full User → Project → Experiment chain for testing."""
|
||
user = User(username="testuser", password_hash="fakehash")
|
||
db.add(user)
|
||
db.flush()
|
||
|
||
project = Project(name="Test Project", owner_id=user.id)
|
||
db.add(project)
|
||
db.flush()
|
||
|
||
experiment = Experiment(
|
||
project_id=project.id,
|
||
name="Test Experiment",
|
||
parameter_space=parameter_space,
|
||
pipeline_stages=pipeline_stages,
|
||
sample_data=sample_data,
|
||
status=ExperimentStatus.draft,
|
||
)
|
||
db.add(experiment)
|
||
db.commit()
|
||
db.refresh(experiment)
|
||
return experiment
|
||
|
||
|
||
class MockAdapter(BaseAdapter):
|
||
"""Test adapter that returns a fixed response."""
|
||
|
||
def __init__(self, response_text: str = "mock response", tokens_in: int = 10, tokens_out: int = 5):
|
||
self.response_text = response_text
|
||
self.tokens_in = tokens_in
|
||
self.tokens_out = tokens_out
|
||
self.calls: list[tuple[str, str, dict]] = []
|
||
|
||
async def complete(self, prompt: str, model: str, params: dict[str, Any]) -> AdapterResponse:
|
||
self.calls.append((prompt, model, params))
|
||
return AdapterResponse(
|
||
text=self.response_text,
|
||
tokens_in=self.tokens_in,
|
||
tokens_out=self.tokens_out,
|
||
latency_ms=42.0,
|
||
model=model,
|
||
)
|
||
|
||
async def list_models(self) -> list[str]:
|
||
return ["mock-model"]
|
||
|
||
async def test_connection(self) -> bool:
|
||
return True
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Grid config generation
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestGenerateGridConfigs:
|
||
def test_single_param(self):
|
||
base = {"prompt": "test", "model": "m", "params": {}}
|
||
configs = generate_grid_configs(base, {"params.temperature": [0.1, 0.5, 0.9]})
|
||
assert len(configs) == 3
|
||
temps = [c["params"]["temperature"] for c in configs]
|
||
assert sorted(temps) == [0.1, 0.5, 0.9]
|
||
|
||
def test_two_params(self):
|
||
base = {"prompt": "test", "model": "m", "params": {}}
|
||
configs = generate_grid_configs(
|
||
base,
|
||
{
|
||
"model": ["gpt-4", "claude"],
|
||
"params.temperature": [0.1, 0.9],
|
||
},
|
||
)
|
||
# 2 models × 2 temperatures = 4 combos
|
||
assert len(configs) == 4
|
||
|
||
def test_three_params(self):
|
||
base = {"prompt": "test", "model": "m", "params": {}}
|
||
configs = generate_grid_configs(
|
||
base,
|
||
{
|
||
"model": ["a", "b"],
|
||
"params.temperature": [0.1, 0.9],
|
||
"params.max_tokens": [100, 200],
|
||
},
|
||
)
|
||
assert len(configs) == 8 # 2 × 2 × 2
|
||
|
||
def test_empty_param_space(self):
|
||
base = {"prompt": "test"}
|
||
configs = generate_grid_configs(base, {})
|
||
assert len(configs) == 1
|
||
assert configs[0]["prompt"] == "test"
|
||
|
||
def test_base_config_not_mutated(self):
|
||
base = {"prompt": "test", "params": {"temperature": 0.5}}
|
||
generate_grid_configs(base, {"params.temperature": [0.1, 0.9]})
|
||
assert base["params"]["temperature"] == 0.5
|
||
|
||
def test_model_override(self):
|
||
base = {"prompt": "test", "model": "original"}
|
||
configs = generate_grid_configs(base, {"model": ["a", "b"]})
|
||
models = {c["model"] for c in configs}
|
||
assert models == {"a", "b"}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Random config generation
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestGenerateRandomConfigs:
|
||
def test_correct_count(self):
|
||
base = {"prompt": "test", "params": {}}
|
||
configs = generate_random_configs(
|
||
base, {"params.temperature": [0.1, 0.5, 0.9]}, n_trials=10
|
||
)
|
||
assert len(configs) == 10
|
||
|
||
def test_list_spec(self):
|
||
base = {"prompt": "test", "params": {}}
|
||
configs = generate_random_configs(
|
||
base, {"params.temperature": [0.1, 0.5, 0.9]}, n_trials=50
|
||
)
|
||
temps = {c["params"]["temperature"] for c in configs}
|
||
assert temps.issubset({0.1, 0.5, 0.9})
|
||
|
||
def test_range_spec(self):
|
||
base = {"prompt": "test", "params": {}}
|
||
configs = generate_random_configs(
|
||
base, {"params.temperature": {"min": 0.0, "max": 1.0}}, n_trials=20
|
||
)
|
||
for c in configs:
|
||
assert 0.0 <= c["params"]["temperature"] <= 1.0
|
||
|
||
def test_step_spec(self):
|
||
base = {"prompt": "test", "params": {}}
|
||
configs = generate_random_configs(
|
||
base, {"params.temperature": {"min": 0.0, "max": 1.0, "step": 0.5}}, n_trials=50
|
||
)
|
||
valid_values = {0.0, 0.5, 1.0}
|
||
for c in configs:
|
||
assert c["params"]["temperature"] in valid_values
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Guided config generation
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestGenerateGuidedConfigs:
|
||
def test_with_previous_results(self):
|
||
base = {"prompt": "test", "model": "m", "params": {"temperature": 0.5}}
|
||
previous = [
|
||
{"config": {"prompt": "test", "model": "m", "params": {"temperature": 0.7}}, "score": 0.9},
|
||
{"config": {"prompt": "test", "model": "m", "params": {"temperature": 0.3}}, "score": 0.8},
|
||
]
|
||
configs = generate_guided_configs(
|
||
base,
|
||
{"params.temperature": [0.1, 0.3, 0.5, 0.7, 0.9]},
|
||
previous,
|
||
n_trials=10,
|
||
top_k=2,
|
||
explore_ratio=0.3,
|
||
)
|
||
assert len(configs) == 10
|
||
|
||
def test_no_previous_results_falls_back_to_random(self):
|
||
base = {"prompt": "test", "params": {}}
|
||
configs = generate_guided_configs(
|
||
base,
|
||
{"params.temperature": [0.1, 0.5, 0.9]},
|
||
[],
|
||
n_trials=5,
|
||
)
|
||
assert len(configs) == 5
|
||
|
||
def test_respects_n_trials_limit(self):
|
||
base = {"prompt": "test", "params": {}}
|
||
previous = [{"config": base, "score": 0.9}] * 20
|
||
configs = generate_guided_configs(
|
||
base,
|
||
{"params.temperature": [0.1, 0.5, 0.9]},
|
||
previous,
|
||
n_trials=7,
|
||
)
|
||
assert len(configs) == 7
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Helper functions
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestSampleParam:
|
||
def test_list(self):
|
||
for _ in range(20):
|
||
v = _sample_param([1, 2, 3])
|
||
assert v in [1, 2, 3]
|
||
|
||
def test_range(self):
|
||
for _ in range(20):
|
||
v = _sample_param({"min": 0.0, "max": 1.0})
|
||
assert 0.0 <= v <= 1.0
|
||
|
||
def test_step(self):
|
||
for _ in range(20):
|
||
v = _sample_param({"min": 0.0, "max": 1.0, "step": 0.25})
|
||
assert v in [0.0, 0.25, 0.5, 0.75, 1.0]
|
||
|
||
def test_scalar(self):
|
||
assert _sample_param(42) == 42
|
||
assert _sample_param("gpt-4") == "gpt-4"
|
||
|
||
|
||
class TestSetConfigValue:
|
||
def test_simple_key(self):
|
||
cfg = {}
|
||
_set_config_value(cfg, "model", "gpt-4")
|
||
assert cfg["model"] == "gpt-4"
|
||
|
||
def test_dotted_key(self):
|
||
cfg = {"params": {}}
|
||
_set_config_value(cfg, "params.temperature", 0.7)
|
||
assert cfg["params"]["temperature"] == 0.7
|
||
|
||
def test_creates_intermediate_dicts(self):
|
||
cfg = {}
|
||
_set_config_value(cfg, "params.temperature", 0.7)
|
||
assert cfg["params"]["temperature"] == 0.7
|
||
|
||
def test_deeply_nested(self):
|
||
cfg = {}
|
||
_set_config_value(cfg, "a.b.c", "deep")
|
||
assert cfg["a"]["b"]["c"] == "deep"
|
||
|
||
|
||
class TestParseSweepConfig:
|
||
def test_grid(self):
|
||
sc = _parse_sweep_config({"type": "grid", "params": {"model": ["a", "b"]}})
|
||
assert sc.sweep_type == "grid"
|
||
assert sc.params == {"model": ["a", "b"]}
|
||
|
||
def test_random(self):
|
||
sc = _parse_sweep_config({"type": "random", "params": {}, "n_trials": 50})
|
||
assert sc.sweep_type == "random"
|
||
assert sc.n_trials == 50
|
||
|
||
def test_guided(self):
|
||
sc = _parse_sweep_config({
|
||
"type": "guided",
|
||
"params": {},
|
||
"n_trials": 20,
|
||
"top_k": 3,
|
||
"explore_ratio": 0.4,
|
||
})
|
||
assert sc.sweep_type == "guided"
|
||
assert sc.top_k == 3
|
||
assert sc.explore_ratio == 0.4
|
||
|
||
def test_defaults(self):
|
||
sc = _parse_sweep_config({})
|
||
assert sc.sweep_type == "grid"
|
||
assert sc.n_trials == 100
|
||
assert sc.top_k == 5
|
||
assert sc.explore_ratio == 0.3
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Full sweep execution
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestRunSweep:
|
||
def test_grid_sweep_creates_and_runs_all(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
experiment = _make_experiment(
|
||
db,
|
||
parameter_space={
|
||
"type": "grid",
|
||
"params": {
|
||
"model": ["model-a", "model-b"],
|
||
"params.temperature": [0.1, 0.9],
|
||
},
|
||
},
|
||
pipeline_stages=[
|
||
{"prompt_template": "Hello", "model": "default", "params": {}},
|
||
],
|
||
)
|
||
adapter = MockAdapter()
|
||
cache = ResponseCacheLayer()
|
||
|
||
result = asyncio.get_event_loop().run_until_complete(
|
||
run_sweep(db, experiment, adapter, cache, max_concurrent=2)
|
||
)
|
||
|
||
assert result.total_runs == 4 # 2 models × 2 temps
|
||
assert result.completed_runs == 4
|
||
assert result.failed_runs == 0
|
||
assert result.stopped_reason == "completed"
|
||
|
||
# Verify experiment status updated
|
||
db.refresh(experiment)
|
||
assert experiment.status == ExperimentStatus.completed
|
||
|
||
# Verify Run records exist
|
||
runs = db.query(Run).filter_by(experiment_id=experiment.id).all()
|
||
assert len(runs) == 4
|
||
assert all(r.status == RunStatus.completed for r in runs)
|
||
|
||
def test_random_sweep(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
experiment = _make_experiment(
|
||
db,
|
||
parameter_space={
|
||
"type": "random",
|
||
"params": {"params.temperature": [0.1, 0.5, 0.9]},
|
||
"n_trials": 3,
|
||
},
|
||
pipeline_stages=[
|
||
{"prompt_template": "Hello", "model": "m", "params": {}},
|
||
],
|
||
)
|
||
adapter = MockAdapter()
|
||
cache = ResponseCacheLayer()
|
||
|
||
result = asyncio.get_event_loop().run_until_complete(
|
||
run_sweep(db, experiment, adapter, cache)
|
||
)
|
||
|
||
assert result.total_runs == 3
|
||
assert result.completed_runs == 3
|
||
|
||
def test_sweep_events_published(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
experiment = _make_experiment(
|
||
db,
|
||
parameter_space={
|
||
"type": "grid",
|
||
"params": {"model": ["a"]},
|
||
},
|
||
pipeline_stages=[
|
||
{"prompt_template": "Hello", "model": "m", "params": {}},
|
||
],
|
||
)
|
||
adapter = MockAdapter()
|
||
cache = ResponseCacheLayer()
|
||
events: list[dict] = []
|
||
bus = EventBus()
|
||
bus.add_listener(lambda e: events.append(e))
|
||
|
||
asyncio.get_event_loop().run_until_complete(
|
||
run_sweep(db, experiment, adapter, cache, event_bus=bus)
|
||
)
|
||
|
||
event_types = [e["type"] for e in events]
|
||
assert "sweep.started" in event_types
|
||
assert "sweep.completed" in event_types
|
||
assert "run.started" in event_types
|
||
assert "run.completed" in event_types
|
||
|
||
def test_sweep_stop_flag(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
experiment = _make_experiment(
|
||
db,
|
||
parameter_space={
|
||
"type": "grid",
|
||
"params": {"model": ["a", "b", "c", "d"]},
|
||
},
|
||
pipeline_stages=[
|
||
{"prompt_template": "Hello", "model": "m", "params": {}},
|
||
],
|
||
)
|
||
# Mock Redis that always returns stop flag
|
||
mock_redis = MagicMock()
|
||
mock_redis.get = MagicMock(side_effect=lambda key: "1" if "stop" in key else None)
|
||
adapter = MockAdapter()
|
||
cache = ResponseCacheLayer()
|
||
|
||
result = asyncio.get_event_loop().run_until_complete(
|
||
run_sweep(
|
||
db, experiment, adapter, cache,
|
||
redis_client=mock_redis, max_concurrent=1,
|
||
)
|
||
)
|
||
|
||
assert result.stopped_reason == "stopped"
|
||
# Not all runs should have completed
|
||
assert result.completed_runs < result.total_runs
|
||
|
||
def test_sweep_pause_flag(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
experiment = _make_experiment(
|
||
db,
|
||
parameter_space={
|
||
"type": "grid",
|
||
"params": {"model": ["a", "b", "c"]},
|
||
},
|
||
pipeline_stages=[
|
||
{"prompt_template": "Hello", "model": "m", "params": {}},
|
||
],
|
||
)
|
||
mock_redis = MagicMock()
|
||
mock_redis.get = MagicMock(side_effect=lambda key: "1" if "pause" in key else None)
|
||
adapter = MockAdapter()
|
||
cache = ResponseCacheLayer()
|
||
|
||
result = asyncio.get_event_loop().run_until_complete(
|
||
run_sweep(
|
||
db, experiment, adapter, cache,
|
||
redis_client=mock_redis, max_concurrent=1,
|
||
)
|
||
)
|
||
|
||
assert result.stopped_reason == "paused"
|
||
db.refresh(experiment)
|
||
assert experiment.status == ExperimentStatus.paused
|
||
|
||
def test_sweep_token_budget(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
experiment = _make_experiment(
|
||
db,
|
||
parameter_space={
|
||
"type": "grid",
|
||
"params": {"model": ["a", "b", "c", "d", "e"]},
|
||
},
|
||
pipeline_stages=[
|
||
{"prompt_template": "Hello", "model": "m", "params": {}},
|
||
],
|
||
)
|
||
# Each run uses 10 tokens_in + 5 tokens_out = 15
|
||
# Budget of 20 should stop after first batch
|
||
adapter = MockAdapter(tokens_in=10, tokens_out=5)
|
||
cache = ResponseCacheLayer()
|
||
|
||
result = asyncio.get_event_loop().run_until_complete(
|
||
run_sweep(
|
||
db, experiment, adapter, cache,
|
||
max_concurrent=1, max_tokens=20,
|
||
)
|
||
)
|
||
|
||
assert result.stopped_reason == "token_budget"
|
||
assert result.completed_runs < result.total_runs
|
||
|
||
def test_sweep_with_scorers(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
experiment = _make_experiment(
|
||
db,
|
||
parameter_space={
|
||
"type": "grid",
|
||
"params": {"model": ["a"]},
|
||
},
|
||
pipeline_stages=[
|
||
{"prompt_template": "Hello", "model": "m", "params": {}},
|
||
],
|
||
)
|
||
adapter = MockAdapter()
|
||
cache = ResponseCacheLayer()
|
||
|
||
class SimpleScorer:
|
||
name = "quality"
|
||
def score(self, input_data, output, context):
|
||
return 0.85
|
||
|
||
result = asyncio.get_event_loop().run_until_complete(
|
||
run_sweep(
|
||
db, experiment, adapter, cache,
|
||
scorers=[SimpleScorer()],
|
||
)
|
||
)
|
||
|
||
assert result.completed_runs == 1
|
||
scores = db.query(Score).all()
|
||
assert len(scores) == 1
|
||
assert scores[0].value == pytest.approx(0.85)
|
||
|
||
def test_sweep_failed_run_doesnt_stop_sweep(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
experiment = _make_experiment(
|
||
db,
|
||
parameter_space={
|
||
"type": "grid",
|
||
"params": {"params.fail": [False, True]},
|
||
},
|
||
pipeline_stages=[
|
||
{"prompt_template": "Hello", "model": "m"},
|
||
],
|
||
)
|
||
cache = ResponseCacheLayer()
|
||
call_count = 0
|
||
|
||
class SelectiveAdapter(BaseAdapter):
|
||
async def complete(self, prompt, model, params):
|
||
nonlocal call_count
|
||
call_count += 1
|
||
if params.get("fail"):
|
||
raise RuntimeError("Intentional failure")
|
||
return AdapterResponse(
|
||
text="ok", tokens_in=5, tokens_out=3,
|
||
latency_ms=10.0, model=model,
|
||
)
|
||
|
||
async def list_models(self):
|
||
return []
|
||
|
||
async def test_connection(self):
|
||
return True
|
||
|
||
result = asyncio.get_event_loop().run_until_complete(
|
||
run_sweep(db, experiment, SelectiveAdapter(), cache, max_concurrent=1)
|
||
)
|
||
|
||
assert result.total_runs == 2
|
||
assert result.completed_runs == 1
|
||
assert result.failed_runs == 1
|
||
assert result.stopped_reason == "completed"
|
||
|
||
def test_guided_sweep(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
experiment = _make_experiment(
|
||
db,
|
||
parameter_space={
|
||
"type": "guided",
|
||
"params": {"params.temperature": [0.1, 0.3, 0.5, 0.7, 0.9]},
|
||
"n_trials": 5,
|
||
"top_k": 2,
|
||
"explore_ratio": 0.4,
|
||
},
|
||
pipeline_stages=[
|
||
{"prompt_template": "Hello", "model": "m", "params": {}},
|
||
],
|
||
)
|
||
adapter = MockAdapter()
|
||
cache = ResponseCacheLayer()
|
||
|
||
result = asyncio.get_event_loop().run_until_complete(
|
||
run_sweep(db, experiment, adapter, cache)
|
||
)
|
||
|
||
assert result.total_runs == 5
|
||
assert result.completed_runs == 5
|
||
|
||
def test_unknown_sweep_type_raises(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
experiment = _make_experiment(
|
||
db,
|
||
parameter_space={"type": "bayesian", "params": {}},
|
||
)
|
||
adapter = MockAdapter()
|
||
cache = ResponseCacheLayer()
|
||
|
||
with pytest.raises(ValueError, match="Unknown sweep type"):
|
||
asyncio.get_event_loop().run_until_complete(
|
||
run_sweep(db, experiment, adapter, cache)
|
||
)
|
||
|
||
def test_sweep_tokens_tracked(self):
|
||
engine = _engine()
|
||
with _session(engine) as db:
|
||
experiment = _make_experiment(
|
||
db,
|
||
parameter_space={
|
||
"type": "grid",
|
||
"params": {"model": ["a", "b"]},
|
||
},
|
||
pipeline_stages=[
|
||
{"prompt_template": "Hello", "model": "m", "params": {}},
|
||
],
|
||
)
|
||
adapter = MockAdapter(tokens_in=100, tokens_out=50)
|
||
cache = ResponseCacheLayer()
|
||
|
||
result = asyncio.get_event_loop().run_until_complete(
|
||
run_sweep(db, experiment, adapter, cache)
|
||
)
|
||
|
||
assert result.total_tokens_in == 200 # 100 × 2
|
||
assert result.total_tokens_out == 100 # 50 × 2
|