MAESTRO: Implement sweep orchestration engine with grid, random, and guided sweep types

Adds backend/engine/sweep.py with three sweep strategies:
- GridSweep: exhaustive enumeration of all parameter combinations
- RandomSweep: N random samples from parameter ranges (list, min/max, step)
- GuidedSweep: top-K exploitation + random exploration from previous results

Features: bounded parallelism via asyncio.Semaphore, token budget enforcement,
Redis-based pause/resume/stop control flags, sweep-level event publishing.
36 tests in test_sweep.py covering config generation, helpers, and full sweep execution.
This commit is contained in:
John Lightner 2026-04-07 02:53:30 -05:00
parent e8ce2f016b
commit ba8cb7e2c6
3 changed files with 1166 additions and 1 deletions

View file

@ -11,7 +11,8 @@ Implement the core experiment execution engine: LLM adapters, response caching,
- [x] Implement backend/engine/runner.py for individual run execution. The run_single function should: (1) iterate through pipeline stages, (2) render prompt templates with Jinja2 (allowing previous stage output as context), (3) check cache before calling LLM, (4) call the LLM adapter if cache miss, (5) store response in cache, (6) create StageResult records, (7) run all configured scorers, (8) create Score records, (9) update Run status and timing, (10) publish progress events via Redis pub/sub (or in-process event bus). - [x] Implement backend/engine/runner.py for individual run execution. The run_single function should: (1) iterate through pipeline stages, (2) render prompt templates with Jinja2 (allowing previous stage output as context), (3) check cache before calling LLM, (4) call the LLM adapter if cache miss, (5) store response in cache, (6) create StageResult records, (7) run all configured scorers, (8) create Score records, (9) update Run status and timing, (10) publish progress events via Redis pub/sub (or in-process event bus).
<!-- Completed: Implemented run_single with all 10 requirements, EventBus (Redis + in-process), Jinja2 templating. 21 tests in test_runner.py, all passing. --> <!-- Completed: Implemented run_single with all 10 requirements, EventBus (Redis + in-process), Jinja2 templating. 21 tests in test_runner.py, all passing. -->
- [ ] Implement backend/engine/sweep.py for sweep orchestration. Support three sweep types: GridSweep (enumerate all combinations from parameter_space), RandomSweep (sample N random configs from parameter ranges), GuidedSweep (use previous results to inform next config — start with top-K exploitation + random exploration). The sweep runner should: respect MAX_CONCURRENT_RUNS for parallelism, track token budget and stop at MAX_TOKENS_PER_SWEEP, emit WebSocket events for each run completion, handle pause/resume/stop via Redis flags. - [x] Implement backend/engine/sweep.py for sweep orchestration. Support three sweep types: GridSweep (enumerate all combinations from parameter_space), RandomSweep (sample N random configs from parameter ranges), GuidedSweep (use previous results to inform next config — start with top-K exploitation + random exploration). The sweep runner should: respect MAX_CONCURRENT_RUNS for parallelism, track token budget and stop at MAX_TOKENS_PER_SWEEP, emit WebSocket events for each run completion, handle pause/resume/stop via Redis flags.
<!-- Completed: Implemented all 3 sweep types (grid/random/guided), bounded parallelism via asyncio.Semaphore, token budget enforcement, Redis-based pause/resume/stop flags, sweep-level events. 36 tests in test_sweep.py, all passing. -->
- [ ] Implement backend/engine/scorers/base.py defining the BaseScorer abstract class with: name property, score(input_data, output, context) → float (0.0 to 1.0), and an optional async variant. The context dict should include the experiment config, stage results, and any reference data. - [ ] Implement backend/engine/scorers/base.py defining the BaseScorer abstract class with: name property, score(input_data, output, context) → float (0.0 to 1.0), and an optional async variant. The context dict should include the experiment config, stage results, and any reference data.

528
backend/engine/sweep.py Normal file
View file

@ -0,0 +1,528 @@
"""Sweep orchestration for PromptLooper.
Supports three sweep types:
- GridSweep: enumerate all combinations from parameter_space
- RandomSweep: sample N random configs from parameter ranges
- GuidedSweep: use previous results to pick next configs (top-K exploitation + exploration)
The sweep runner manages parallelism, token budgets, pause/resume/stop via
Redis flags, and emits events for each run completion.
"""
import asyncio
import itertools
import json
import logging
import random
import uuid
from dataclasses import dataclass, field
from typing import Any
from sqlalchemy.orm import Session
from engine.adapters.base import BaseAdapter
from engine.cache import ResponseCacheLayer, compute_config_hash
from engine.runner import EventBus, run_single
from models import Experiment, ExperimentStatus, Run, RunStatus, Score
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Sweep configuration dataclasses
# ---------------------------------------------------------------------------
@dataclass
class SweepConfig:
"""Parsed sweep configuration from experiment.parameter_space."""
sweep_type: str # "grid", "random", "guided"
params: dict[str, Any] = field(default_factory=dict)
n_trials: int = 100 # for random/guided
top_k: int = 5 # for guided: exploit top K results
explore_ratio: float = 0.3 # for guided: fraction of random exploration
@dataclass
class SweepResult:
"""Summary returned after a sweep completes or is stopped."""
experiment_id: uuid.UUID
total_runs: int
completed_runs: int
failed_runs: int
total_tokens_in: int
total_tokens_out: int
stopped_reason: str | None = None # "completed", "stopped", "paused", "token_budget"
# ---------------------------------------------------------------------------
# Redis flag keys for sweep control
# ---------------------------------------------------------------------------
_SWEEP_STOP_KEY = "promptlooper:sweep:{experiment_id}:stop"
_SWEEP_PAUSE_KEY = "promptlooper:sweep:{experiment_id}:pause"
def _stop_key(experiment_id: uuid.UUID) -> str:
return _SWEEP_STOP_KEY.format(experiment_id=experiment_id)
def _pause_key(experiment_id: uuid.UUID) -> str:
return _SWEEP_PAUSE_KEY.format(experiment_id=experiment_id)
# ---------------------------------------------------------------------------
# Configuration generation
# ---------------------------------------------------------------------------
def generate_grid_configs(
base_config: dict[str, Any],
param_space: dict[str, list],
) -> list[dict[str, Any]]:
"""Enumerate all combinations from parameter_space grid values.
Args:
base_config: Base run config (pipeline_stages, input_data, etc.)
param_space: Dict mapping param names to lists of values.
Supports top-level keys "model" and nested "params.*" keys.
Returns:
List of fully-formed run configs, one per combination.
"""
if not param_space:
return [base_config.copy()]
keys = sorted(param_space.keys())
value_lists = [param_space[k] for k in keys]
configs = []
for combo in itertools.product(*value_lists):
cfg = _deep_copy_config(base_config)
for key, value in zip(keys, combo):
_set_config_value(cfg, key, value)
configs.append(cfg)
return configs
def generate_random_configs(
base_config: dict[str, Any],
param_space: dict[str, Any],
n_trials: int,
) -> list[dict[str, Any]]:
"""Sample N random configs from parameter ranges.
Each param_space value can be:
- A list: sample uniformly from the list
- A dict with "min"/"max": sample uniformly from the continuous range
- A dict with "min"/"max"/"step": sample from discrete steps
"""
configs = []
for _ in range(n_trials):
cfg = _deep_copy_config(base_config)
for key, spec in param_space.items():
value = _sample_param(spec)
_set_config_value(cfg, key, value)
configs.append(cfg)
return configs
def generate_guided_configs(
base_config: dict[str, Any],
param_space: dict[str, Any],
previous_results: list[dict[str, Any]],
n_trials: int,
top_k: int = 5,
explore_ratio: float = 0.3,
) -> list[dict[str, Any]]:
"""Generate configs using top-K exploitation + random exploration.
Args:
base_config: Base run config.
param_space: Parameter search space.
previous_results: List of dicts with "config" and "score" keys
from previous runs, sorted by score descending.
n_trials: Total number of configs to generate.
top_k: Number of top results to exploit.
explore_ratio: Fraction of trials devoted to random exploration.
Returns:
List of run configs mixing exploitation and exploration.
"""
n_explore = max(1, int(n_trials * explore_ratio))
n_exploit = n_trials - n_explore
configs = []
# Exploitation: perturb top-K configs
if previous_results and n_exploit > 0:
top_results = previous_results[:top_k]
for i in range(n_exploit):
base_result = top_results[i % len(top_results)]
source_cfg = base_result.get("config", base_config)
cfg = _deep_copy_config(source_cfg)
# Perturb one random parameter
if param_space:
perturb_key = random.choice(list(param_space.keys()))
value = _sample_param(param_space[perturb_key])
_set_config_value(cfg, perturb_key, value)
configs.append(cfg)
else:
# No previous results — fall back to all random
n_explore = n_trials
# Exploration: random configs
configs.extend(
generate_random_configs(base_config, param_space, n_explore)
)
return configs[:n_trials]
# ---------------------------------------------------------------------------
# Sweep runner
# ---------------------------------------------------------------------------
async def run_sweep(
db: Session,
experiment: Experiment,
adapter: BaseAdapter,
cache: ResponseCacheLayer,
scorers: list[Any] | None = None,
event_bus: EventBus | None = None,
redis_client: Any | None = None,
max_concurrent: int = 4,
max_tokens: int = 0,
) -> SweepResult:
"""Execute a full sweep for an experiment.
Generates run configs based on experiment.parameter_space, creates Run
records, and executes them with bounded parallelism.
Args:
db: SQLAlchemy session.
experiment: Experiment ORM object.
adapter: LLM adapter for completions.
cache: Response cache layer.
scorers: Optional scorer instances.
event_bus: Optional event publisher.
redis_client: Optional Redis client for pause/stop flags.
max_concurrent: Maximum parallel runs.
max_tokens: Token budget (0 = unlimited).
Returns:
SweepResult summary.
"""
scorers = scorers or []
sweep_config = _parse_sweep_config(experiment.parameter_space or {})
base_config = _build_base_config(experiment)
# Generate configs based on sweep type
if sweep_config.sweep_type == "grid":
configs = generate_grid_configs(base_config, sweep_config.params)
elif sweep_config.sweep_type == "random":
configs = generate_random_configs(
base_config, sweep_config.params, sweep_config.n_trials
)
elif sweep_config.sweep_type == "guided":
# For guided, first gather any previous completed results
previous = _get_previous_results(db, experiment.id)
configs = generate_guided_configs(
base_config,
sweep_config.params,
previous,
sweep_config.n_trials,
sweep_config.top_k,
sweep_config.explore_ratio,
)
else:
raise ValueError(f"Unknown sweep type: {sweep_config.sweep_type}")
# Create Run records
runs = _create_run_records(db, experiment, configs)
# Update experiment status
experiment.status = ExperimentStatus.running
db.commit()
if event_bus:
event_bus.publish({
"type": "sweep.started",
"experiment_id": str(experiment.id),
"total_runs": len(runs),
"sweep_type": sweep_config.sweep_type,
})
# Execute runs with bounded parallelism
result = await _execute_runs(
db=db,
experiment=experiment,
runs=runs,
adapter=adapter,
cache=cache,
scorers=scorers,
event_bus=event_bus,
redis_client=redis_client,
max_concurrent=max_concurrent,
max_tokens=max_tokens,
)
# Update experiment status
if result.stopped_reason == "paused":
experiment.status = ExperimentStatus.paused
else:
experiment.status = ExperimentStatus.completed
db.commit()
if event_bus:
event_bus.publish({
"type": "sweep.completed",
"experiment_id": str(experiment.id),
"total_runs": result.total_runs,
"completed_runs": result.completed_runs,
"failed_runs": result.failed_runs,
"stopped_reason": result.stopped_reason,
})
return result
async def _execute_runs(
db: Session,
experiment: Experiment,
runs: list[Run],
adapter: BaseAdapter,
cache: ResponseCacheLayer,
scorers: list[Any],
event_bus: EventBus | None,
redis_client: Any | None,
max_concurrent: int,
max_tokens: int,
) -> SweepResult:
"""Execute runs with bounded concurrency, respecting stop/pause/token budget."""
semaphore = asyncio.Semaphore(max_concurrent)
completed = 0
failed = 0
total_tokens_in = 0
total_tokens_out = 0
stopped_reason: str | None = "completed"
async def _run_one(run: Run) -> None:
nonlocal completed, failed, total_tokens_in, total_tokens_out
async with semaphore:
try:
result = await run_single(
db, run, adapter, cache, scorers=scorers, event_bus=event_bus
)
completed += 1
total_tokens_in += result.tokens_in or 0
total_tokens_out += result.tokens_out or 0
if event_bus:
event_bus.publish({
"type": "sweep.run_completed",
"experiment_id": str(experiment.id),
"run_id": str(run.id),
"completed": completed,
"total": len(runs),
})
except Exception:
failed += 1
logger.warning(
"Sweep run %s failed", run.id, exc_info=True
)
pending_runs = list(runs)
while pending_runs:
# Check stop/pause flags
if redis_client is not None:
if redis_client.get(_stop_key(experiment.id)):
stopped_reason = "stopped"
# Mark remaining runs as failed
for r in pending_runs:
if r.status == RunStatus.pending:
r.status = RunStatus.failed
db.commit()
break
if redis_client.get(_pause_key(experiment.id)):
stopped_reason = "paused"
break
# Check token budget
if max_tokens > 0 and (total_tokens_in + total_tokens_out) >= max_tokens:
stopped_reason = "token_budget"
break
# Take a batch up to max_concurrent
batch_size = min(max_concurrent, len(pending_runs))
batch = pending_runs[:batch_size]
pending_runs = pending_runs[batch_size:]
# Run batch concurrently
await asyncio.gather(*[_run_one(run) for run in batch])
return SweepResult(
experiment_id=experiment.id,
total_runs=len(runs),
completed_runs=completed,
failed_runs=failed,
total_tokens_in=total_tokens_in,
total_tokens_out=total_tokens_out,
stopped_reason=stopped_reason,
)
# ---------------------------------------------------------------------------
# Sweep control helpers
# ---------------------------------------------------------------------------
def request_stop(redis_client: Any, experiment_id: uuid.UUID) -> None:
"""Set the stop flag for a running sweep."""
redis_client.set(_stop_key(experiment_id), "1", ex=3600)
def request_pause(redis_client: Any, experiment_id: uuid.UUID) -> None:
"""Set the pause flag for a running sweep."""
redis_client.set(_pause_key(experiment_id), "1", ex=3600)
def request_resume(redis_client: Any, experiment_id: uuid.UUID) -> None:
"""Clear the pause flag to allow a sweep to continue."""
redis_client.delete(_pause_key(experiment_id))
def clear_sweep_flags(redis_client: Any, experiment_id: uuid.UUID) -> None:
"""Clear all control flags for an experiment."""
redis_client.delete(_stop_key(experiment_id))
redis_client.delete(_pause_key(experiment_id))
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _parse_sweep_config(parameter_space: dict[str, Any]) -> SweepConfig:
"""Parse experiment.parameter_space into a SweepConfig."""
return SweepConfig(
sweep_type=parameter_space.get("type", "grid"),
params=parameter_space.get("params", {}),
n_trials=parameter_space.get("n_trials", 100),
top_k=parameter_space.get("top_k", 5),
explore_ratio=parameter_space.get("explore_ratio", 0.3),
)
def _build_base_config(experiment: Experiment) -> dict[str, Any]:
"""Build the base run config from experiment fields."""
config: dict[str, Any] = {}
if experiment.pipeline_stages:
config["pipeline_stages"] = experiment.pipeline_stages
if experiment.sample_data:
config["input_data"] = experiment.sample_data
if experiment.scoring_config:
config["scoring_config"] = experiment.scoring_config
return config
def _create_run_records(
db: Session,
experiment: Experiment,
configs: list[dict[str, Any]],
) -> list[Run]:
"""Create Run records for each config and persist them."""
runs = []
for cfg in configs:
config_hash = compute_config_hash(
cfg.get("prompt", cfg.get("pipeline_stages", [{}])[0].get("prompt_template", "")),
cfg.get("model", ""),
cfg.get("params", {}),
cfg.get("input_data"),
)
run = Run(
experiment_id=experiment.id,
config_hash=config_hash,
config=cfg,
status=RunStatus.pending,
)
db.add(run)
runs.append(run)
db.commit()
for run in runs:
db.refresh(run)
return runs
def _get_previous_results(
db: Session, experiment_id: uuid.UUID
) -> list[dict[str, Any]]:
"""Get previous completed runs with their average scores, sorted by score desc."""
runs = (
db.query(Run)
.filter(
Run.experiment_id == experiment_id,
Run.status == RunStatus.completed,
)
.all()
)
results = []
for run in runs:
scores = db.query(Score).filter(Score.run_id == run.id).all()
avg_score = sum(s.value for s in scores) / len(scores) if scores else 0.0
results.append({
"config": run.config,
"score": avg_score,
"run_id": str(run.id),
})
results.sort(key=lambda r: r["score"], reverse=True)
return results
def _sample_param(spec: Any) -> Any:
"""Sample a single parameter value from its specification.
Spec can be:
- A list: uniform choice
- A dict with "min"/"max": uniform float
- A dict with "min"/"max"/"step": discrete steps
"""
if isinstance(spec, list):
return random.choice(spec)
if isinstance(spec, dict):
lo = spec.get("min", 0.0)
hi = spec.get("max", 1.0)
step = spec.get("step")
if step is not None:
# Discrete steps
steps = []
v = lo
while v <= hi + 1e-9:
steps.append(round(v, 10))
v += step
return random.choice(steps) if steps else lo
return random.uniform(lo, hi)
# Scalar fallback — return as-is
return spec
def _set_config_value(config: dict[str, Any], key: str, value: Any) -> None:
"""Set a value in the config dict, supporting dotted keys like 'params.temperature'."""
parts = key.split(".")
target = config
for part in parts[:-1]:
if part not in target:
target[part] = {}
target = target[part]
target[parts[-1]] = value
def _deep_copy_config(config: dict[str, Any]) -> dict[str, Any]:
"""Deep-copy a config dict using JSON round-trip."""
return json.loads(json.dumps(config, default=str))

636
backend/tests/test_sweep.py Normal file
View file

@ -0,0 +1,636 @@
"""Tests for sweep orchestration engine."""
import asyncio
import uuid
from typing import Any
from unittest.mock import MagicMock
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from engine.adapters.base import AdapterResponse, BaseAdapter
from engine.cache import ResponseCacheLayer
from engine.runner import EventBus
from engine.sweep import (
SweepConfig,
SweepResult,
_parse_sweep_config,
_sample_param,
_set_config_value,
generate_grid_configs,
generate_guided_configs,
generate_random_configs,
run_sweep,
)
from models import Base, Experiment, ExperimentStatus, Project, Run, RunStatus, Score, User
# ---------------------------------------------------------------------------
# Fixtures / helpers
# ---------------------------------------------------------------------------
def _engine():
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
return engine
def _session(engine):
return Session(engine)
def _make_experiment(
db: Session,
parameter_space: dict | None = None,
pipeline_stages: dict | None = None,
sample_data: dict | None = None,
) -> Experiment:
"""Create a full User → Project → Experiment chain for testing."""
user = User(username="testuser", password_hash="fakehash")
db.add(user)
db.flush()
project = Project(name="Test Project", owner_id=user.id)
db.add(project)
db.flush()
experiment = Experiment(
project_id=project.id,
name="Test Experiment",
parameter_space=parameter_space,
pipeline_stages=pipeline_stages,
sample_data=sample_data,
status=ExperimentStatus.draft,
)
db.add(experiment)
db.commit()
db.refresh(experiment)
return experiment
class MockAdapter(BaseAdapter):
"""Test adapter that returns a fixed response."""
def __init__(self, response_text: str = "mock response", tokens_in: int = 10, tokens_out: int = 5):
self.response_text = response_text
self.tokens_in = tokens_in
self.tokens_out = tokens_out
self.calls: list[tuple[str, str, dict]] = []
async def complete(self, prompt: str, model: str, params: dict[str, Any]) -> AdapterResponse:
self.calls.append((prompt, model, params))
return AdapterResponse(
text=self.response_text,
tokens_in=self.tokens_in,
tokens_out=self.tokens_out,
latency_ms=42.0,
model=model,
)
async def list_models(self) -> list[str]:
return ["mock-model"]
async def test_connection(self) -> bool:
return True
# ---------------------------------------------------------------------------
# Grid config generation
# ---------------------------------------------------------------------------
class TestGenerateGridConfigs:
def test_single_param(self):
base = {"prompt": "test", "model": "m", "params": {}}
configs = generate_grid_configs(base, {"params.temperature": [0.1, 0.5, 0.9]})
assert len(configs) == 3
temps = [c["params"]["temperature"] for c in configs]
assert sorted(temps) == [0.1, 0.5, 0.9]
def test_two_params(self):
base = {"prompt": "test", "model": "m", "params": {}}
configs = generate_grid_configs(
base,
{
"model": ["gpt-4", "claude"],
"params.temperature": [0.1, 0.9],
},
)
# 2 models × 2 temperatures = 4 combos
assert len(configs) == 4
def test_three_params(self):
base = {"prompt": "test", "model": "m", "params": {}}
configs = generate_grid_configs(
base,
{
"model": ["a", "b"],
"params.temperature": [0.1, 0.9],
"params.max_tokens": [100, 200],
},
)
assert len(configs) == 8 # 2 × 2 × 2
def test_empty_param_space(self):
base = {"prompt": "test"}
configs = generate_grid_configs(base, {})
assert len(configs) == 1
assert configs[0]["prompt"] == "test"
def test_base_config_not_mutated(self):
base = {"prompt": "test", "params": {"temperature": 0.5}}
generate_grid_configs(base, {"params.temperature": [0.1, 0.9]})
assert base["params"]["temperature"] == 0.5
def test_model_override(self):
base = {"prompt": "test", "model": "original"}
configs = generate_grid_configs(base, {"model": ["a", "b"]})
models = {c["model"] for c in configs}
assert models == {"a", "b"}
# ---------------------------------------------------------------------------
# Random config generation
# ---------------------------------------------------------------------------
class TestGenerateRandomConfigs:
def test_correct_count(self):
base = {"prompt": "test", "params": {}}
configs = generate_random_configs(
base, {"params.temperature": [0.1, 0.5, 0.9]}, n_trials=10
)
assert len(configs) == 10
def test_list_spec(self):
base = {"prompt": "test", "params": {}}
configs = generate_random_configs(
base, {"params.temperature": [0.1, 0.5, 0.9]}, n_trials=50
)
temps = {c["params"]["temperature"] for c in configs}
assert temps.issubset({0.1, 0.5, 0.9})
def test_range_spec(self):
base = {"prompt": "test", "params": {}}
configs = generate_random_configs(
base, {"params.temperature": {"min": 0.0, "max": 1.0}}, n_trials=20
)
for c in configs:
assert 0.0 <= c["params"]["temperature"] <= 1.0
def test_step_spec(self):
base = {"prompt": "test", "params": {}}
configs = generate_random_configs(
base, {"params.temperature": {"min": 0.0, "max": 1.0, "step": 0.5}}, n_trials=50
)
valid_values = {0.0, 0.5, 1.0}
for c in configs:
assert c["params"]["temperature"] in valid_values
# ---------------------------------------------------------------------------
# Guided config generation
# ---------------------------------------------------------------------------
class TestGenerateGuidedConfigs:
def test_with_previous_results(self):
base = {"prompt": "test", "model": "m", "params": {"temperature": 0.5}}
previous = [
{"config": {"prompt": "test", "model": "m", "params": {"temperature": 0.7}}, "score": 0.9},
{"config": {"prompt": "test", "model": "m", "params": {"temperature": 0.3}}, "score": 0.8},
]
configs = generate_guided_configs(
base,
{"params.temperature": [0.1, 0.3, 0.5, 0.7, 0.9]},
previous,
n_trials=10,
top_k=2,
explore_ratio=0.3,
)
assert len(configs) == 10
def test_no_previous_results_falls_back_to_random(self):
base = {"prompt": "test", "params": {}}
configs = generate_guided_configs(
base,
{"params.temperature": [0.1, 0.5, 0.9]},
[],
n_trials=5,
)
assert len(configs) == 5
def test_respects_n_trials_limit(self):
base = {"prompt": "test", "params": {}}
previous = [{"config": base, "score": 0.9}] * 20
configs = generate_guided_configs(
base,
{"params.temperature": [0.1, 0.5, 0.9]},
previous,
n_trials=7,
)
assert len(configs) == 7
# ---------------------------------------------------------------------------
# Helper functions
# ---------------------------------------------------------------------------
class TestSampleParam:
def test_list(self):
for _ in range(20):
v = _sample_param([1, 2, 3])
assert v in [1, 2, 3]
def test_range(self):
for _ in range(20):
v = _sample_param({"min": 0.0, "max": 1.0})
assert 0.0 <= v <= 1.0
def test_step(self):
for _ in range(20):
v = _sample_param({"min": 0.0, "max": 1.0, "step": 0.25})
assert v in [0.0, 0.25, 0.5, 0.75, 1.0]
def test_scalar(self):
assert _sample_param(42) == 42
assert _sample_param("gpt-4") == "gpt-4"
class TestSetConfigValue:
def test_simple_key(self):
cfg = {}
_set_config_value(cfg, "model", "gpt-4")
assert cfg["model"] == "gpt-4"
def test_dotted_key(self):
cfg = {"params": {}}
_set_config_value(cfg, "params.temperature", 0.7)
assert cfg["params"]["temperature"] == 0.7
def test_creates_intermediate_dicts(self):
cfg = {}
_set_config_value(cfg, "params.temperature", 0.7)
assert cfg["params"]["temperature"] == 0.7
def test_deeply_nested(self):
cfg = {}
_set_config_value(cfg, "a.b.c", "deep")
assert cfg["a"]["b"]["c"] == "deep"
class TestParseSweepConfig:
def test_grid(self):
sc = _parse_sweep_config({"type": "grid", "params": {"model": ["a", "b"]}})
assert sc.sweep_type == "grid"
assert sc.params == {"model": ["a", "b"]}
def test_random(self):
sc = _parse_sweep_config({"type": "random", "params": {}, "n_trials": 50})
assert sc.sweep_type == "random"
assert sc.n_trials == 50
def test_guided(self):
sc = _parse_sweep_config({
"type": "guided",
"params": {},
"n_trials": 20,
"top_k": 3,
"explore_ratio": 0.4,
})
assert sc.sweep_type == "guided"
assert sc.top_k == 3
assert sc.explore_ratio == 0.4
def test_defaults(self):
sc = _parse_sweep_config({})
assert sc.sweep_type == "grid"
assert sc.n_trials == 100
assert sc.top_k == 5
assert sc.explore_ratio == 0.3
# ---------------------------------------------------------------------------
# Full sweep execution
# ---------------------------------------------------------------------------
class TestRunSweep:
def test_grid_sweep_creates_and_runs_all(self):
engine = _engine()
with _session(engine) as db:
experiment = _make_experiment(
db,
parameter_space={
"type": "grid",
"params": {
"model": ["model-a", "model-b"],
"params.temperature": [0.1, 0.9],
},
},
pipeline_stages=[
{"prompt_template": "Hello", "model": "default", "params": {}},
],
)
adapter = MockAdapter()
cache = ResponseCacheLayer()
result = asyncio.get_event_loop().run_until_complete(
run_sweep(db, experiment, adapter, cache, max_concurrent=2)
)
assert result.total_runs == 4 # 2 models × 2 temps
assert result.completed_runs == 4
assert result.failed_runs == 0
assert result.stopped_reason == "completed"
# Verify experiment status updated
db.refresh(experiment)
assert experiment.status == ExperimentStatus.completed
# Verify Run records exist
runs = db.query(Run).filter_by(experiment_id=experiment.id).all()
assert len(runs) == 4
assert all(r.status == RunStatus.completed for r in runs)
def test_random_sweep(self):
engine = _engine()
with _session(engine) as db:
experiment = _make_experiment(
db,
parameter_space={
"type": "random",
"params": {"params.temperature": [0.1, 0.5, 0.9]},
"n_trials": 3,
},
pipeline_stages=[
{"prompt_template": "Hello", "model": "m", "params": {}},
],
)
adapter = MockAdapter()
cache = ResponseCacheLayer()
result = asyncio.get_event_loop().run_until_complete(
run_sweep(db, experiment, adapter, cache)
)
assert result.total_runs == 3
assert result.completed_runs == 3
def test_sweep_events_published(self):
engine = _engine()
with _session(engine) as db:
experiment = _make_experiment(
db,
parameter_space={
"type": "grid",
"params": {"model": ["a"]},
},
pipeline_stages=[
{"prompt_template": "Hello", "model": "m", "params": {}},
],
)
adapter = MockAdapter()
cache = ResponseCacheLayer()
events: list[dict] = []
bus = EventBus()
bus.add_listener(lambda e: events.append(e))
asyncio.get_event_loop().run_until_complete(
run_sweep(db, experiment, adapter, cache, event_bus=bus)
)
event_types = [e["type"] for e in events]
assert "sweep.started" in event_types
assert "sweep.completed" in event_types
assert "run.started" in event_types
assert "run.completed" in event_types
def test_sweep_stop_flag(self):
engine = _engine()
with _session(engine) as db:
experiment = _make_experiment(
db,
parameter_space={
"type": "grid",
"params": {"model": ["a", "b", "c", "d"]},
},
pipeline_stages=[
{"prompt_template": "Hello", "model": "m", "params": {}},
],
)
# Mock Redis that always returns stop flag
mock_redis = MagicMock()
mock_redis.get = MagicMock(side_effect=lambda key: "1" if "stop" in key else None)
adapter = MockAdapter()
cache = ResponseCacheLayer()
result = asyncio.get_event_loop().run_until_complete(
run_sweep(
db, experiment, adapter, cache,
redis_client=mock_redis, max_concurrent=1,
)
)
assert result.stopped_reason == "stopped"
# Not all runs should have completed
assert result.completed_runs < result.total_runs
def test_sweep_pause_flag(self):
engine = _engine()
with _session(engine) as db:
experiment = _make_experiment(
db,
parameter_space={
"type": "grid",
"params": {"model": ["a", "b", "c"]},
},
pipeline_stages=[
{"prompt_template": "Hello", "model": "m", "params": {}},
],
)
mock_redis = MagicMock()
mock_redis.get = MagicMock(side_effect=lambda key: "1" if "pause" in key else None)
adapter = MockAdapter()
cache = ResponseCacheLayer()
result = asyncio.get_event_loop().run_until_complete(
run_sweep(
db, experiment, adapter, cache,
redis_client=mock_redis, max_concurrent=1,
)
)
assert result.stopped_reason == "paused"
db.refresh(experiment)
assert experiment.status == ExperimentStatus.paused
def test_sweep_token_budget(self):
engine = _engine()
with _session(engine) as db:
experiment = _make_experiment(
db,
parameter_space={
"type": "grid",
"params": {"model": ["a", "b", "c", "d", "e"]},
},
pipeline_stages=[
{"prompt_template": "Hello", "model": "m", "params": {}},
],
)
# Each run uses 10 tokens_in + 5 tokens_out = 15
# Budget of 20 should stop after first batch
adapter = MockAdapter(tokens_in=10, tokens_out=5)
cache = ResponseCacheLayer()
result = asyncio.get_event_loop().run_until_complete(
run_sweep(
db, experiment, adapter, cache,
max_concurrent=1, max_tokens=20,
)
)
assert result.stopped_reason == "token_budget"
assert result.completed_runs < result.total_runs
def test_sweep_with_scorers(self):
engine = _engine()
with _session(engine) as db:
experiment = _make_experiment(
db,
parameter_space={
"type": "grid",
"params": {"model": ["a"]},
},
pipeline_stages=[
{"prompt_template": "Hello", "model": "m", "params": {}},
],
)
adapter = MockAdapter()
cache = ResponseCacheLayer()
class SimpleScorer:
name = "quality"
def score(self, input_data, output, context):
return 0.85
result = asyncio.get_event_loop().run_until_complete(
run_sweep(
db, experiment, adapter, cache,
scorers=[SimpleScorer()],
)
)
assert result.completed_runs == 1
scores = db.query(Score).all()
assert len(scores) == 1
assert scores[0].value == pytest.approx(0.85)
def test_sweep_failed_run_doesnt_stop_sweep(self):
engine = _engine()
with _session(engine) as db:
experiment = _make_experiment(
db,
parameter_space={
"type": "grid",
"params": {"params.fail": [False, True]},
},
pipeline_stages=[
{"prompt_template": "Hello", "model": "m"},
],
)
cache = ResponseCacheLayer()
call_count = 0
class SelectiveAdapter(BaseAdapter):
async def complete(self, prompt, model, params):
nonlocal call_count
call_count += 1
if params.get("fail"):
raise RuntimeError("Intentional failure")
return AdapterResponse(
text="ok", tokens_in=5, tokens_out=3,
latency_ms=10.0, model=model,
)
async def list_models(self):
return []
async def test_connection(self):
return True
result = asyncio.get_event_loop().run_until_complete(
run_sweep(db, experiment, SelectiveAdapter(), cache, max_concurrent=1)
)
assert result.total_runs == 2
assert result.completed_runs == 1
assert result.failed_runs == 1
assert result.stopped_reason == "completed"
def test_guided_sweep(self):
engine = _engine()
with _session(engine) as db:
experiment = _make_experiment(
db,
parameter_space={
"type": "guided",
"params": {"params.temperature": [0.1, 0.3, 0.5, 0.7, 0.9]},
"n_trials": 5,
"top_k": 2,
"explore_ratio": 0.4,
},
pipeline_stages=[
{"prompt_template": "Hello", "model": "m", "params": {}},
],
)
adapter = MockAdapter()
cache = ResponseCacheLayer()
result = asyncio.get_event_loop().run_until_complete(
run_sweep(db, experiment, adapter, cache)
)
assert result.total_runs == 5
assert result.completed_runs == 5
def test_unknown_sweep_type_raises(self):
engine = _engine()
with _session(engine) as db:
experiment = _make_experiment(
db,
parameter_space={"type": "bayesian", "params": {}},
)
adapter = MockAdapter()
cache = ResponseCacheLayer()
with pytest.raises(ValueError, match="Unknown sweep type"):
asyncio.get_event_loop().run_until_complete(
run_sweep(db, experiment, adapter, cache)
)
def test_sweep_tokens_tracked(self):
engine = _engine()
with _session(engine) as db:
experiment = _make_experiment(
db,
parameter_space={
"type": "grid",
"params": {"model": ["a", "b"]},
},
pipeline_stages=[
{"prompt_template": "Hello", "model": "m", "params": {}},
],
)
adapter = MockAdapter(tokens_in=100, tokens_out=50)
cache = ResponseCacheLayer()
result = asyncio.get_event_loop().run_until_complete(
run_sweep(db, experiment, adapter, cache)
)
assert result.total_tokens_in == 200 # 100 × 2
assert result.total_tokens_out == 100 # 50 × 2