From ba8cb7e2c669c3ec4ee645e3b4676e33ffa73343 Mon Sep 17 00:00:00 2001 From: John Lightner Date: Tue, 7 Apr 2026 02:53:30 -0500 Subject: [PATCH] MAESTRO: Implement sweep orchestration engine with grid, random, and guided sweep types Adds backend/engine/sweep.py with three sweep strategies: - GridSweep: exhaustive enumeration of all parameter combinations - RandomSweep: N random samples from parameter ranges (list, min/max, step) - GuidedSweep: top-K exploitation + random exploration from previous results Features: bounded parallelism via asyncio.Semaphore, token budget enforcement, Redis-based pause/resume/stop control flags, sweep-level event publishing. 36 tests in test_sweep.py covering config generation, helpers, and full sweep execution. --- Auto Run Docs/02a-backend-engine.md | 3 +- backend/engine/sweep.py | 528 +++++++++++++++++++++++ backend/tests/test_sweep.py | 636 ++++++++++++++++++++++++++++ 3 files changed, 1166 insertions(+), 1 deletion(-) create mode 100644 backend/engine/sweep.py create mode 100644 backend/tests/test_sweep.py diff --git a/Auto Run Docs/02a-backend-engine.md b/Auto Run Docs/02a-backend-engine.md index 539da74..df5697d 100644 --- a/Auto Run Docs/02a-backend-engine.md +++ b/Auto Run Docs/02a-backend-engine.md @@ -11,7 +11,8 @@ Implement the core experiment execution engine: LLM adapters, response caching, - [x] Implement backend/engine/runner.py for individual run execution. The run_single function should: (1) iterate through pipeline stages, (2) render prompt templates with Jinja2 (allowing previous stage output as context), (3) check cache before calling LLM, (4) call the LLM adapter if cache miss, (5) store response in cache, (6) create StageResult records, (7) run all configured scorers, (8) create Score records, (9) update Run status and timing, (10) publish progress events via Redis pub/sub (or in-process event bus). -- [ ] Implement backend/engine/sweep.py for sweep orchestration. Support three sweep types: GridSweep (enumerate all combinations from parameter_space), RandomSweep (sample N random configs from parameter ranges), GuidedSweep (use previous results to inform next config — start with top-K exploitation + random exploration). The sweep runner should: respect MAX_CONCURRENT_RUNS for parallelism, track token budget and stop at MAX_TOKENS_PER_SWEEP, emit WebSocket events for each run completion, handle pause/resume/stop via Redis flags. +- [x] Implement backend/engine/sweep.py for sweep orchestration. Support three sweep types: GridSweep (enumerate all combinations from parameter_space), RandomSweep (sample N random configs from parameter ranges), GuidedSweep (use previous results to inform next config — start with top-K exploitation + random exploration). The sweep runner should: respect MAX_CONCURRENT_RUNS for parallelism, track token budget and stop at MAX_TOKENS_PER_SWEEP, emit WebSocket events for each run completion, handle pause/resume/stop via Redis flags. + - [ ] Implement backend/engine/scorers/base.py defining the BaseScorer abstract class with: name property, score(input_data, output, context) → float (0.0 to 1.0), and an optional async variant. The context dict should include the experiment config, stage results, and any reference data. diff --git a/backend/engine/sweep.py b/backend/engine/sweep.py new file mode 100644 index 0000000..b6d6786 --- /dev/null +++ b/backend/engine/sweep.py @@ -0,0 +1,528 @@ +"""Sweep orchestration for PromptLooper. + +Supports three sweep types: + - GridSweep: enumerate all combinations from parameter_space + - RandomSweep: sample N random configs from parameter ranges + - GuidedSweep: use previous results to pick next configs (top-K exploitation + exploration) + +The sweep runner manages parallelism, token budgets, pause/resume/stop via +Redis flags, and emits events for each run completion. +""" + +import asyncio +import itertools +import json +import logging +import random +import uuid +from dataclasses import dataclass, field +from typing import Any + +from sqlalchemy.orm import Session + +from engine.adapters.base import BaseAdapter +from engine.cache import ResponseCacheLayer, compute_config_hash +from engine.runner import EventBus, run_single +from models import Experiment, ExperimentStatus, Run, RunStatus, Score + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Sweep configuration dataclasses +# --------------------------------------------------------------------------- + + +@dataclass +class SweepConfig: + """Parsed sweep configuration from experiment.parameter_space.""" + + sweep_type: str # "grid", "random", "guided" + params: dict[str, Any] = field(default_factory=dict) + n_trials: int = 100 # for random/guided + top_k: int = 5 # for guided: exploit top K results + explore_ratio: float = 0.3 # for guided: fraction of random exploration + + +@dataclass +class SweepResult: + """Summary returned after a sweep completes or is stopped.""" + + experiment_id: uuid.UUID + total_runs: int + completed_runs: int + failed_runs: int + total_tokens_in: int + total_tokens_out: int + stopped_reason: str | None = None # "completed", "stopped", "paused", "token_budget" + + +# --------------------------------------------------------------------------- +# Redis flag keys for sweep control +# --------------------------------------------------------------------------- + +_SWEEP_STOP_KEY = "promptlooper:sweep:{experiment_id}:stop" +_SWEEP_PAUSE_KEY = "promptlooper:sweep:{experiment_id}:pause" + + +def _stop_key(experiment_id: uuid.UUID) -> str: + return _SWEEP_STOP_KEY.format(experiment_id=experiment_id) + + +def _pause_key(experiment_id: uuid.UUID) -> str: + return _SWEEP_PAUSE_KEY.format(experiment_id=experiment_id) + + +# --------------------------------------------------------------------------- +# Configuration generation +# --------------------------------------------------------------------------- + + +def generate_grid_configs( + base_config: dict[str, Any], + param_space: dict[str, list], +) -> list[dict[str, Any]]: + """Enumerate all combinations from parameter_space grid values. + + Args: + base_config: Base run config (pipeline_stages, input_data, etc.) + param_space: Dict mapping param names to lists of values. + Supports top-level keys "model" and nested "params.*" keys. + + Returns: + List of fully-formed run configs, one per combination. + """ + if not param_space: + return [base_config.copy()] + + keys = sorted(param_space.keys()) + value_lists = [param_space[k] for k in keys] + configs = [] + + for combo in itertools.product(*value_lists): + cfg = _deep_copy_config(base_config) + for key, value in zip(keys, combo): + _set_config_value(cfg, key, value) + configs.append(cfg) + + return configs + + +def generate_random_configs( + base_config: dict[str, Any], + param_space: dict[str, Any], + n_trials: int, +) -> list[dict[str, Any]]: + """Sample N random configs from parameter ranges. + + Each param_space value can be: + - A list: sample uniformly from the list + - A dict with "min"/"max": sample uniformly from the continuous range + - A dict with "min"/"max"/"step": sample from discrete steps + """ + configs = [] + for _ in range(n_trials): + cfg = _deep_copy_config(base_config) + for key, spec in param_space.items(): + value = _sample_param(spec) + _set_config_value(cfg, key, value) + configs.append(cfg) + return configs + + +def generate_guided_configs( + base_config: dict[str, Any], + param_space: dict[str, Any], + previous_results: list[dict[str, Any]], + n_trials: int, + top_k: int = 5, + explore_ratio: float = 0.3, +) -> list[dict[str, Any]]: + """Generate configs using top-K exploitation + random exploration. + + Args: + base_config: Base run config. + param_space: Parameter search space. + previous_results: List of dicts with "config" and "score" keys + from previous runs, sorted by score descending. + n_trials: Total number of configs to generate. + top_k: Number of top results to exploit. + explore_ratio: Fraction of trials devoted to random exploration. + + Returns: + List of run configs mixing exploitation and exploration. + """ + n_explore = max(1, int(n_trials * explore_ratio)) + n_exploit = n_trials - n_explore + + configs = [] + + # Exploitation: perturb top-K configs + if previous_results and n_exploit > 0: + top_results = previous_results[:top_k] + for i in range(n_exploit): + base_result = top_results[i % len(top_results)] + source_cfg = base_result.get("config", base_config) + cfg = _deep_copy_config(source_cfg) + # Perturb one random parameter + if param_space: + perturb_key = random.choice(list(param_space.keys())) + value = _sample_param(param_space[perturb_key]) + _set_config_value(cfg, perturb_key, value) + configs.append(cfg) + else: + # No previous results — fall back to all random + n_explore = n_trials + + # Exploration: random configs + configs.extend( + generate_random_configs(base_config, param_space, n_explore) + ) + + return configs[:n_trials] + + +# --------------------------------------------------------------------------- +# Sweep runner +# --------------------------------------------------------------------------- + + +async def run_sweep( + db: Session, + experiment: Experiment, + adapter: BaseAdapter, + cache: ResponseCacheLayer, + scorers: list[Any] | None = None, + event_bus: EventBus | None = None, + redis_client: Any | None = None, + max_concurrent: int = 4, + max_tokens: int = 0, +) -> SweepResult: + """Execute a full sweep for an experiment. + + Generates run configs based on experiment.parameter_space, creates Run + records, and executes them with bounded parallelism. + + Args: + db: SQLAlchemy session. + experiment: Experiment ORM object. + adapter: LLM adapter for completions. + cache: Response cache layer. + scorers: Optional scorer instances. + event_bus: Optional event publisher. + redis_client: Optional Redis client for pause/stop flags. + max_concurrent: Maximum parallel runs. + max_tokens: Token budget (0 = unlimited). + + Returns: + SweepResult summary. + """ + scorers = scorers or [] + sweep_config = _parse_sweep_config(experiment.parameter_space or {}) + base_config = _build_base_config(experiment) + + # Generate configs based on sweep type + if sweep_config.sweep_type == "grid": + configs = generate_grid_configs(base_config, sweep_config.params) + elif sweep_config.sweep_type == "random": + configs = generate_random_configs( + base_config, sweep_config.params, sweep_config.n_trials + ) + elif sweep_config.sweep_type == "guided": + # For guided, first gather any previous completed results + previous = _get_previous_results(db, experiment.id) + configs = generate_guided_configs( + base_config, + sweep_config.params, + previous, + sweep_config.n_trials, + sweep_config.top_k, + sweep_config.explore_ratio, + ) + else: + raise ValueError(f"Unknown sweep type: {sweep_config.sweep_type}") + + # Create Run records + runs = _create_run_records(db, experiment, configs) + + # Update experiment status + experiment.status = ExperimentStatus.running + db.commit() + + if event_bus: + event_bus.publish({ + "type": "sweep.started", + "experiment_id": str(experiment.id), + "total_runs": len(runs), + "sweep_type": sweep_config.sweep_type, + }) + + # Execute runs with bounded parallelism + result = await _execute_runs( + db=db, + experiment=experiment, + runs=runs, + adapter=adapter, + cache=cache, + scorers=scorers, + event_bus=event_bus, + redis_client=redis_client, + max_concurrent=max_concurrent, + max_tokens=max_tokens, + ) + + # Update experiment status + if result.stopped_reason == "paused": + experiment.status = ExperimentStatus.paused + else: + experiment.status = ExperimentStatus.completed + db.commit() + + if event_bus: + event_bus.publish({ + "type": "sweep.completed", + "experiment_id": str(experiment.id), + "total_runs": result.total_runs, + "completed_runs": result.completed_runs, + "failed_runs": result.failed_runs, + "stopped_reason": result.stopped_reason, + }) + + return result + + +async def _execute_runs( + db: Session, + experiment: Experiment, + runs: list[Run], + adapter: BaseAdapter, + cache: ResponseCacheLayer, + scorers: list[Any], + event_bus: EventBus | None, + redis_client: Any | None, + max_concurrent: int, + max_tokens: int, +) -> SweepResult: + """Execute runs with bounded concurrency, respecting stop/pause/token budget.""" + semaphore = asyncio.Semaphore(max_concurrent) + completed = 0 + failed = 0 + total_tokens_in = 0 + total_tokens_out = 0 + stopped_reason: str | None = "completed" + + async def _run_one(run: Run) -> None: + nonlocal completed, failed, total_tokens_in, total_tokens_out + async with semaphore: + try: + result = await run_single( + db, run, adapter, cache, scorers=scorers, event_bus=event_bus + ) + completed += 1 + total_tokens_in += result.tokens_in or 0 + total_tokens_out += result.tokens_out or 0 + + if event_bus: + event_bus.publish({ + "type": "sweep.run_completed", + "experiment_id": str(experiment.id), + "run_id": str(run.id), + "completed": completed, + "total": len(runs), + }) + except Exception: + failed += 1 + logger.warning( + "Sweep run %s failed", run.id, exc_info=True + ) + + pending_runs = list(runs) + while pending_runs: + # Check stop/pause flags + if redis_client is not None: + if redis_client.get(_stop_key(experiment.id)): + stopped_reason = "stopped" + # Mark remaining runs as failed + for r in pending_runs: + if r.status == RunStatus.pending: + r.status = RunStatus.failed + db.commit() + break + if redis_client.get(_pause_key(experiment.id)): + stopped_reason = "paused" + break + + # Check token budget + if max_tokens > 0 and (total_tokens_in + total_tokens_out) >= max_tokens: + stopped_reason = "token_budget" + break + + # Take a batch up to max_concurrent + batch_size = min(max_concurrent, len(pending_runs)) + batch = pending_runs[:batch_size] + pending_runs = pending_runs[batch_size:] + + # Run batch concurrently + await asyncio.gather(*[_run_one(run) for run in batch]) + + return SweepResult( + experiment_id=experiment.id, + total_runs=len(runs), + completed_runs=completed, + failed_runs=failed, + total_tokens_in=total_tokens_in, + total_tokens_out=total_tokens_out, + stopped_reason=stopped_reason, + ) + + +# --------------------------------------------------------------------------- +# Sweep control helpers +# --------------------------------------------------------------------------- + + +def request_stop(redis_client: Any, experiment_id: uuid.UUID) -> None: + """Set the stop flag for a running sweep.""" + redis_client.set(_stop_key(experiment_id), "1", ex=3600) + + +def request_pause(redis_client: Any, experiment_id: uuid.UUID) -> None: + """Set the pause flag for a running sweep.""" + redis_client.set(_pause_key(experiment_id), "1", ex=3600) + + +def request_resume(redis_client: Any, experiment_id: uuid.UUID) -> None: + """Clear the pause flag to allow a sweep to continue.""" + redis_client.delete(_pause_key(experiment_id)) + + +def clear_sweep_flags(redis_client: Any, experiment_id: uuid.UUID) -> None: + """Clear all control flags for an experiment.""" + redis_client.delete(_stop_key(experiment_id)) + redis_client.delete(_pause_key(experiment_id)) + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _parse_sweep_config(parameter_space: dict[str, Any]) -> SweepConfig: + """Parse experiment.parameter_space into a SweepConfig.""" + return SweepConfig( + sweep_type=parameter_space.get("type", "grid"), + params=parameter_space.get("params", {}), + n_trials=parameter_space.get("n_trials", 100), + top_k=parameter_space.get("top_k", 5), + explore_ratio=parameter_space.get("explore_ratio", 0.3), + ) + + +def _build_base_config(experiment: Experiment) -> dict[str, Any]: + """Build the base run config from experiment fields.""" + config: dict[str, Any] = {} + if experiment.pipeline_stages: + config["pipeline_stages"] = experiment.pipeline_stages + if experiment.sample_data: + config["input_data"] = experiment.sample_data + if experiment.scoring_config: + config["scoring_config"] = experiment.scoring_config + return config + + +def _create_run_records( + db: Session, + experiment: Experiment, + configs: list[dict[str, Any]], +) -> list[Run]: + """Create Run records for each config and persist them.""" + runs = [] + for cfg in configs: + config_hash = compute_config_hash( + cfg.get("prompt", cfg.get("pipeline_stages", [{}])[0].get("prompt_template", "")), + cfg.get("model", ""), + cfg.get("params", {}), + cfg.get("input_data"), + ) + run = Run( + experiment_id=experiment.id, + config_hash=config_hash, + config=cfg, + status=RunStatus.pending, + ) + db.add(run) + runs.append(run) + db.commit() + for run in runs: + db.refresh(run) + return runs + + +def _get_previous_results( + db: Session, experiment_id: uuid.UUID +) -> list[dict[str, Any]]: + """Get previous completed runs with their average scores, sorted by score desc.""" + runs = ( + db.query(Run) + .filter( + Run.experiment_id == experiment_id, + Run.status == RunStatus.completed, + ) + .all() + ) + + results = [] + for run in runs: + scores = db.query(Score).filter(Score.run_id == run.id).all() + avg_score = sum(s.value for s in scores) / len(scores) if scores else 0.0 + results.append({ + "config": run.config, + "score": avg_score, + "run_id": str(run.id), + }) + + results.sort(key=lambda r: r["score"], reverse=True) + return results + + +def _sample_param(spec: Any) -> Any: + """Sample a single parameter value from its specification. + + Spec can be: + - A list: uniform choice + - A dict with "min"/"max": uniform float + - A dict with "min"/"max"/"step": discrete steps + """ + if isinstance(spec, list): + return random.choice(spec) + if isinstance(spec, dict): + lo = spec.get("min", 0.0) + hi = spec.get("max", 1.0) + step = spec.get("step") + if step is not None: + # Discrete steps + steps = [] + v = lo + while v <= hi + 1e-9: + steps.append(round(v, 10)) + v += step + return random.choice(steps) if steps else lo + return random.uniform(lo, hi) + # Scalar fallback — return as-is + return spec + + +def _set_config_value(config: dict[str, Any], key: str, value: Any) -> None: + """Set a value in the config dict, supporting dotted keys like 'params.temperature'.""" + parts = key.split(".") + target = config + for part in parts[:-1]: + if part not in target: + target[part] = {} + target = target[part] + target[parts[-1]] = value + + +def _deep_copy_config(config: dict[str, Any]) -> dict[str, Any]: + """Deep-copy a config dict using JSON round-trip.""" + return json.loads(json.dumps(config, default=str)) diff --git a/backend/tests/test_sweep.py b/backend/tests/test_sweep.py new file mode 100644 index 0000000..a4f9cb2 --- /dev/null +++ b/backend/tests/test_sweep.py @@ -0,0 +1,636 @@ +"""Tests for sweep orchestration engine.""" + +import asyncio +import uuid +from typing import Any +from unittest.mock import MagicMock + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from engine.adapters.base import AdapterResponse, BaseAdapter +from engine.cache import ResponseCacheLayer +from engine.runner import EventBus +from engine.sweep import ( + SweepConfig, + SweepResult, + _parse_sweep_config, + _sample_param, + _set_config_value, + generate_grid_configs, + generate_guided_configs, + generate_random_configs, + run_sweep, +) +from models import Base, Experiment, ExperimentStatus, Project, Run, RunStatus, Score, User + + +# --------------------------------------------------------------------------- +# Fixtures / helpers +# --------------------------------------------------------------------------- + + +def _engine(): + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + return engine + + +def _session(engine): + return Session(engine) + + +def _make_experiment( + db: Session, + parameter_space: dict | None = None, + pipeline_stages: dict | None = None, + sample_data: dict | None = None, +) -> Experiment: + """Create a full User → Project → Experiment chain for testing.""" + user = User(username="testuser", password_hash="fakehash") + db.add(user) + db.flush() + + project = Project(name="Test Project", owner_id=user.id) + db.add(project) + db.flush() + + experiment = Experiment( + project_id=project.id, + name="Test Experiment", + parameter_space=parameter_space, + pipeline_stages=pipeline_stages, + sample_data=sample_data, + status=ExperimentStatus.draft, + ) + db.add(experiment) + db.commit() + db.refresh(experiment) + return experiment + + +class MockAdapter(BaseAdapter): + """Test adapter that returns a fixed response.""" + + def __init__(self, response_text: str = "mock response", tokens_in: int = 10, tokens_out: int = 5): + self.response_text = response_text + self.tokens_in = tokens_in + self.tokens_out = tokens_out + self.calls: list[tuple[str, str, dict]] = [] + + async def complete(self, prompt: str, model: str, params: dict[str, Any]) -> AdapterResponse: + self.calls.append((prompt, model, params)) + return AdapterResponse( + text=self.response_text, + tokens_in=self.tokens_in, + tokens_out=self.tokens_out, + latency_ms=42.0, + model=model, + ) + + async def list_models(self) -> list[str]: + return ["mock-model"] + + async def test_connection(self) -> bool: + return True + + +# --------------------------------------------------------------------------- +# Grid config generation +# --------------------------------------------------------------------------- + + +class TestGenerateGridConfigs: + def test_single_param(self): + base = {"prompt": "test", "model": "m", "params": {}} + configs = generate_grid_configs(base, {"params.temperature": [0.1, 0.5, 0.9]}) + assert len(configs) == 3 + temps = [c["params"]["temperature"] for c in configs] + assert sorted(temps) == [0.1, 0.5, 0.9] + + def test_two_params(self): + base = {"prompt": "test", "model": "m", "params": {}} + configs = generate_grid_configs( + base, + { + "model": ["gpt-4", "claude"], + "params.temperature": [0.1, 0.9], + }, + ) + # 2 models × 2 temperatures = 4 combos + assert len(configs) == 4 + + def test_three_params(self): + base = {"prompt": "test", "model": "m", "params": {}} + configs = generate_grid_configs( + base, + { + "model": ["a", "b"], + "params.temperature": [0.1, 0.9], + "params.max_tokens": [100, 200], + }, + ) + assert len(configs) == 8 # 2 × 2 × 2 + + def test_empty_param_space(self): + base = {"prompt": "test"} + configs = generate_grid_configs(base, {}) + assert len(configs) == 1 + assert configs[0]["prompt"] == "test" + + def test_base_config_not_mutated(self): + base = {"prompt": "test", "params": {"temperature": 0.5}} + generate_grid_configs(base, {"params.temperature": [0.1, 0.9]}) + assert base["params"]["temperature"] == 0.5 + + def test_model_override(self): + base = {"prompt": "test", "model": "original"} + configs = generate_grid_configs(base, {"model": ["a", "b"]}) + models = {c["model"] for c in configs} + assert models == {"a", "b"} + + +# --------------------------------------------------------------------------- +# Random config generation +# --------------------------------------------------------------------------- + + +class TestGenerateRandomConfigs: + def test_correct_count(self): + base = {"prompt": "test", "params": {}} + configs = generate_random_configs( + base, {"params.temperature": [0.1, 0.5, 0.9]}, n_trials=10 + ) + assert len(configs) == 10 + + def test_list_spec(self): + base = {"prompt": "test", "params": {}} + configs = generate_random_configs( + base, {"params.temperature": [0.1, 0.5, 0.9]}, n_trials=50 + ) + temps = {c["params"]["temperature"] for c in configs} + assert temps.issubset({0.1, 0.5, 0.9}) + + def test_range_spec(self): + base = {"prompt": "test", "params": {}} + configs = generate_random_configs( + base, {"params.temperature": {"min": 0.0, "max": 1.0}}, n_trials=20 + ) + for c in configs: + assert 0.0 <= c["params"]["temperature"] <= 1.0 + + def test_step_spec(self): + base = {"prompt": "test", "params": {}} + configs = generate_random_configs( + base, {"params.temperature": {"min": 0.0, "max": 1.0, "step": 0.5}}, n_trials=50 + ) + valid_values = {0.0, 0.5, 1.0} + for c in configs: + assert c["params"]["temperature"] in valid_values + + +# --------------------------------------------------------------------------- +# Guided config generation +# --------------------------------------------------------------------------- + + +class TestGenerateGuidedConfigs: + def test_with_previous_results(self): + base = {"prompt": "test", "model": "m", "params": {"temperature": 0.5}} + previous = [ + {"config": {"prompt": "test", "model": "m", "params": {"temperature": 0.7}}, "score": 0.9}, + {"config": {"prompt": "test", "model": "m", "params": {"temperature": 0.3}}, "score": 0.8}, + ] + configs = generate_guided_configs( + base, + {"params.temperature": [0.1, 0.3, 0.5, 0.7, 0.9]}, + previous, + n_trials=10, + top_k=2, + explore_ratio=0.3, + ) + assert len(configs) == 10 + + def test_no_previous_results_falls_back_to_random(self): + base = {"prompt": "test", "params": {}} + configs = generate_guided_configs( + base, + {"params.temperature": [0.1, 0.5, 0.9]}, + [], + n_trials=5, + ) + assert len(configs) == 5 + + def test_respects_n_trials_limit(self): + base = {"prompt": "test", "params": {}} + previous = [{"config": base, "score": 0.9}] * 20 + configs = generate_guided_configs( + base, + {"params.temperature": [0.1, 0.5, 0.9]}, + previous, + n_trials=7, + ) + assert len(configs) == 7 + + +# --------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------- + + +class TestSampleParam: + def test_list(self): + for _ in range(20): + v = _sample_param([1, 2, 3]) + assert v in [1, 2, 3] + + def test_range(self): + for _ in range(20): + v = _sample_param({"min": 0.0, "max": 1.0}) + assert 0.0 <= v <= 1.0 + + def test_step(self): + for _ in range(20): + v = _sample_param({"min": 0.0, "max": 1.0, "step": 0.25}) + assert v in [0.0, 0.25, 0.5, 0.75, 1.0] + + def test_scalar(self): + assert _sample_param(42) == 42 + assert _sample_param("gpt-4") == "gpt-4" + + +class TestSetConfigValue: + def test_simple_key(self): + cfg = {} + _set_config_value(cfg, "model", "gpt-4") + assert cfg["model"] == "gpt-4" + + def test_dotted_key(self): + cfg = {"params": {}} + _set_config_value(cfg, "params.temperature", 0.7) + assert cfg["params"]["temperature"] == 0.7 + + def test_creates_intermediate_dicts(self): + cfg = {} + _set_config_value(cfg, "params.temperature", 0.7) + assert cfg["params"]["temperature"] == 0.7 + + def test_deeply_nested(self): + cfg = {} + _set_config_value(cfg, "a.b.c", "deep") + assert cfg["a"]["b"]["c"] == "deep" + + +class TestParseSweepConfig: + def test_grid(self): + sc = _parse_sweep_config({"type": "grid", "params": {"model": ["a", "b"]}}) + assert sc.sweep_type == "grid" + assert sc.params == {"model": ["a", "b"]} + + def test_random(self): + sc = _parse_sweep_config({"type": "random", "params": {}, "n_trials": 50}) + assert sc.sweep_type == "random" + assert sc.n_trials == 50 + + def test_guided(self): + sc = _parse_sweep_config({ + "type": "guided", + "params": {}, + "n_trials": 20, + "top_k": 3, + "explore_ratio": 0.4, + }) + assert sc.sweep_type == "guided" + assert sc.top_k == 3 + assert sc.explore_ratio == 0.4 + + def test_defaults(self): + sc = _parse_sweep_config({}) + assert sc.sweep_type == "grid" + assert sc.n_trials == 100 + assert sc.top_k == 5 + assert sc.explore_ratio == 0.3 + + +# --------------------------------------------------------------------------- +# Full sweep execution +# --------------------------------------------------------------------------- + + +class TestRunSweep: + def test_grid_sweep_creates_and_runs_all(self): + engine = _engine() + with _session(engine) as db: + experiment = _make_experiment( + db, + parameter_space={ + "type": "grid", + "params": { + "model": ["model-a", "model-b"], + "params.temperature": [0.1, 0.9], + }, + }, + pipeline_stages=[ + {"prompt_template": "Hello", "model": "default", "params": {}}, + ], + ) + adapter = MockAdapter() + cache = ResponseCacheLayer() + + result = asyncio.get_event_loop().run_until_complete( + run_sweep(db, experiment, adapter, cache, max_concurrent=2) + ) + + assert result.total_runs == 4 # 2 models × 2 temps + assert result.completed_runs == 4 + assert result.failed_runs == 0 + assert result.stopped_reason == "completed" + + # Verify experiment status updated + db.refresh(experiment) + assert experiment.status == ExperimentStatus.completed + + # Verify Run records exist + runs = db.query(Run).filter_by(experiment_id=experiment.id).all() + assert len(runs) == 4 + assert all(r.status == RunStatus.completed for r in runs) + + def test_random_sweep(self): + engine = _engine() + with _session(engine) as db: + experiment = _make_experiment( + db, + parameter_space={ + "type": "random", + "params": {"params.temperature": [0.1, 0.5, 0.9]}, + "n_trials": 3, + }, + pipeline_stages=[ + {"prompt_template": "Hello", "model": "m", "params": {}}, + ], + ) + adapter = MockAdapter() + cache = ResponseCacheLayer() + + result = asyncio.get_event_loop().run_until_complete( + run_sweep(db, experiment, adapter, cache) + ) + + assert result.total_runs == 3 + assert result.completed_runs == 3 + + def test_sweep_events_published(self): + engine = _engine() + with _session(engine) as db: + experiment = _make_experiment( + db, + parameter_space={ + "type": "grid", + "params": {"model": ["a"]}, + }, + pipeline_stages=[ + {"prompt_template": "Hello", "model": "m", "params": {}}, + ], + ) + adapter = MockAdapter() + cache = ResponseCacheLayer() + events: list[dict] = [] + bus = EventBus() + bus.add_listener(lambda e: events.append(e)) + + asyncio.get_event_loop().run_until_complete( + run_sweep(db, experiment, adapter, cache, event_bus=bus) + ) + + event_types = [e["type"] for e in events] + assert "sweep.started" in event_types + assert "sweep.completed" in event_types + assert "run.started" in event_types + assert "run.completed" in event_types + + def test_sweep_stop_flag(self): + engine = _engine() + with _session(engine) as db: + experiment = _make_experiment( + db, + parameter_space={ + "type": "grid", + "params": {"model": ["a", "b", "c", "d"]}, + }, + pipeline_stages=[ + {"prompt_template": "Hello", "model": "m", "params": {}}, + ], + ) + # Mock Redis that always returns stop flag + mock_redis = MagicMock() + mock_redis.get = MagicMock(side_effect=lambda key: "1" if "stop" in key else None) + adapter = MockAdapter() + cache = ResponseCacheLayer() + + result = asyncio.get_event_loop().run_until_complete( + run_sweep( + db, experiment, adapter, cache, + redis_client=mock_redis, max_concurrent=1, + ) + ) + + assert result.stopped_reason == "stopped" + # Not all runs should have completed + assert result.completed_runs < result.total_runs + + def test_sweep_pause_flag(self): + engine = _engine() + with _session(engine) as db: + experiment = _make_experiment( + db, + parameter_space={ + "type": "grid", + "params": {"model": ["a", "b", "c"]}, + }, + pipeline_stages=[ + {"prompt_template": "Hello", "model": "m", "params": {}}, + ], + ) + mock_redis = MagicMock() + mock_redis.get = MagicMock(side_effect=lambda key: "1" if "pause" in key else None) + adapter = MockAdapter() + cache = ResponseCacheLayer() + + result = asyncio.get_event_loop().run_until_complete( + run_sweep( + db, experiment, adapter, cache, + redis_client=mock_redis, max_concurrent=1, + ) + ) + + assert result.stopped_reason == "paused" + db.refresh(experiment) + assert experiment.status == ExperimentStatus.paused + + def test_sweep_token_budget(self): + engine = _engine() + with _session(engine) as db: + experiment = _make_experiment( + db, + parameter_space={ + "type": "grid", + "params": {"model": ["a", "b", "c", "d", "e"]}, + }, + pipeline_stages=[ + {"prompt_template": "Hello", "model": "m", "params": {}}, + ], + ) + # Each run uses 10 tokens_in + 5 tokens_out = 15 + # Budget of 20 should stop after first batch + adapter = MockAdapter(tokens_in=10, tokens_out=5) + cache = ResponseCacheLayer() + + result = asyncio.get_event_loop().run_until_complete( + run_sweep( + db, experiment, adapter, cache, + max_concurrent=1, max_tokens=20, + ) + ) + + assert result.stopped_reason == "token_budget" + assert result.completed_runs < result.total_runs + + def test_sweep_with_scorers(self): + engine = _engine() + with _session(engine) as db: + experiment = _make_experiment( + db, + parameter_space={ + "type": "grid", + "params": {"model": ["a"]}, + }, + pipeline_stages=[ + {"prompt_template": "Hello", "model": "m", "params": {}}, + ], + ) + adapter = MockAdapter() + cache = ResponseCacheLayer() + + class SimpleScorer: + name = "quality" + def score(self, input_data, output, context): + return 0.85 + + result = asyncio.get_event_loop().run_until_complete( + run_sweep( + db, experiment, adapter, cache, + scorers=[SimpleScorer()], + ) + ) + + assert result.completed_runs == 1 + scores = db.query(Score).all() + assert len(scores) == 1 + assert scores[0].value == pytest.approx(0.85) + + def test_sweep_failed_run_doesnt_stop_sweep(self): + engine = _engine() + with _session(engine) as db: + experiment = _make_experiment( + db, + parameter_space={ + "type": "grid", + "params": {"params.fail": [False, True]}, + }, + pipeline_stages=[ + {"prompt_template": "Hello", "model": "m"}, + ], + ) + cache = ResponseCacheLayer() + call_count = 0 + + class SelectiveAdapter(BaseAdapter): + async def complete(self, prompt, model, params): + nonlocal call_count + call_count += 1 + if params.get("fail"): + raise RuntimeError("Intentional failure") + return AdapterResponse( + text="ok", tokens_in=5, tokens_out=3, + latency_ms=10.0, model=model, + ) + + async def list_models(self): + return [] + + async def test_connection(self): + return True + + result = asyncio.get_event_loop().run_until_complete( + run_sweep(db, experiment, SelectiveAdapter(), cache, max_concurrent=1) + ) + + assert result.total_runs == 2 + assert result.completed_runs == 1 + assert result.failed_runs == 1 + assert result.stopped_reason == "completed" + + def test_guided_sweep(self): + engine = _engine() + with _session(engine) as db: + experiment = _make_experiment( + db, + parameter_space={ + "type": "guided", + "params": {"params.temperature": [0.1, 0.3, 0.5, 0.7, 0.9]}, + "n_trials": 5, + "top_k": 2, + "explore_ratio": 0.4, + }, + pipeline_stages=[ + {"prompt_template": "Hello", "model": "m", "params": {}}, + ], + ) + adapter = MockAdapter() + cache = ResponseCacheLayer() + + result = asyncio.get_event_loop().run_until_complete( + run_sweep(db, experiment, adapter, cache) + ) + + assert result.total_runs == 5 + assert result.completed_runs == 5 + + def test_unknown_sweep_type_raises(self): + engine = _engine() + with _session(engine) as db: + experiment = _make_experiment( + db, + parameter_space={"type": "bayesian", "params": {}}, + ) + adapter = MockAdapter() + cache = ResponseCacheLayer() + + with pytest.raises(ValueError, match="Unknown sweep type"): + asyncio.get_event_loop().run_until_complete( + run_sweep(db, experiment, adapter, cache) + ) + + def test_sweep_tokens_tracked(self): + engine = _engine() + with _session(engine) as db: + experiment = _make_experiment( + db, + parameter_space={ + "type": "grid", + "params": {"model": ["a", "b"]}, + }, + pipeline_stages=[ + {"prompt_template": "Hello", "model": "m", "params": {}}, + ], + ) + adapter = MockAdapter(tokens_in=100, tokens_out=50) + cache = ResponseCacheLayer() + + result = asyncio.get_event_loop().run_until_complete( + run_sweep(db, experiment, adapter, cache) + ) + + assert result.total_tokens_in == 200 # 100 × 2 + assert result.total_tokens_out == 100 # 50 × 2