"""Tests for sweep orchestration engine.""" import asyncio import uuid from typing import Any from unittest.mock import MagicMock import pytest from sqlalchemy import create_engine from sqlalchemy.orm import Session from engine.adapters.base import AdapterResponse, BaseAdapter from engine.cache import ResponseCacheLayer from engine.runner import EventBus from engine.sweep import ( SweepConfig, SweepResult, _parse_sweep_config, _sample_param, _set_config_value, generate_grid_configs, generate_guided_configs, generate_random_configs, run_sweep, ) from models import Base, Experiment, ExperimentStatus, Project, Run, RunStatus, Score, User # --------------------------------------------------------------------------- # Fixtures / helpers # --------------------------------------------------------------------------- def _engine(): engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine) return engine def _session(engine): return Session(engine) def _make_experiment( db: Session, parameter_space: dict | None = None, pipeline_stages: dict | None = None, sample_data: dict | None = None, ) -> Experiment: """Create a full User → Project → Experiment chain for testing.""" user = User(username="testuser", password_hash="fakehash") db.add(user) db.flush() project = Project(name="Test Project", owner_id=user.id) db.add(project) db.flush() experiment = Experiment( project_id=project.id, name="Test Experiment", parameter_space=parameter_space, pipeline_stages=pipeline_stages, sample_data=sample_data, status=ExperimentStatus.draft, ) db.add(experiment) db.commit() db.refresh(experiment) return experiment class MockAdapter(BaseAdapter): """Test adapter that returns a fixed response.""" def __init__(self, response_text: str = "mock response", tokens_in: int = 10, tokens_out: int = 5): self.response_text = response_text self.tokens_in = tokens_in self.tokens_out = tokens_out self.calls: list[tuple[str, str, dict]] = [] async def complete(self, prompt: str, model: str, params: dict[str, Any]) -> AdapterResponse: self.calls.append((prompt, model, params)) return AdapterResponse( text=self.response_text, tokens_in=self.tokens_in, tokens_out=self.tokens_out, latency_ms=42.0, model=model, ) async def list_models(self) -> list[str]: return ["mock-model"] async def test_connection(self) -> bool: return True # --------------------------------------------------------------------------- # Grid config generation # --------------------------------------------------------------------------- class TestGenerateGridConfigs: def test_single_param(self): base = {"prompt": "test", "model": "m", "params": {}} configs = generate_grid_configs(base, {"params.temperature": [0.1, 0.5, 0.9]}) assert len(configs) == 3 temps = [c["params"]["temperature"] for c in configs] assert sorted(temps) == [0.1, 0.5, 0.9] def test_two_params(self): base = {"prompt": "test", "model": "m", "params": {}} configs = generate_grid_configs( base, { "model": ["gpt-4", "claude"], "params.temperature": [0.1, 0.9], }, ) # 2 models × 2 temperatures = 4 combos assert len(configs) == 4 def test_three_params(self): base = {"prompt": "test", "model": "m", "params": {}} configs = generate_grid_configs( base, { "model": ["a", "b"], "params.temperature": [0.1, 0.9], "params.max_tokens": [100, 200], }, ) assert len(configs) == 8 # 2 × 2 × 2 def test_empty_param_space(self): base = {"prompt": "test"} configs = generate_grid_configs(base, {}) assert len(configs) == 1 assert configs[0]["prompt"] == "test" def test_base_config_not_mutated(self): base = {"prompt": "test", "params": {"temperature": 0.5}} generate_grid_configs(base, {"params.temperature": [0.1, 0.9]}) assert base["params"]["temperature"] == 0.5 def test_model_override(self): base = {"prompt": "test", "model": "original"} configs = generate_grid_configs(base, {"model": ["a", "b"]}) models = {c["model"] for c in configs} assert models == {"a", "b"} # --------------------------------------------------------------------------- # Random config generation # --------------------------------------------------------------------------- class TestGenerateRandomConfigs: def test_correct_count(self): base = {"prompt": "test", "params": {}} configs = generate_random_configs( base, {"params.temperature": [0.1, 0.5, 0.9]}, n_trials=10 ) assert len(configs) == 10 def test_list_spec(self): base = {"prompt": "test", "params": {}} configs = generate_random_configs( base, {"params.temperature": [0.1, 0.5, 0.9]}, n_trials=50 ) temps = {c["params"]["temperature"] for c in configs} assert temps.issubset({0.1, 0.5, 0.9}) def test_range_spec(self): base = {"prompt": "test", "params": {}} configs = generate_random_configs( base, {"params.temperature": {"min": 0.0, "max": 1.0}}, n_trials=20 ) for c in configs: assert 0.0 <= c["params"]["temperature"] <= 1.0 def test_step_spec(self): base = {"prompt": "test", "params": {}} configs = generate_random_configs( base, {"params.temperature": {"min": 0.0, "max": 1.0, "step": 0.5}}, n_trials=50 ) valid_values = {0.0, 0.5, 1.0} for c in configs: assert c["params"]["temperature"] in valid_values # --------------------------------------------------------------------------- # Guided config generation # --------------------------------------------------------------------------- class TestGenerateGuidedConfigs: def test_with_previous_results(self): base = {"prompt": "test", "model": "m", "params": {"temperature": 0.5}} previous = [ {"config": {"prompt": "test", "model": "m", "params": {"temperature": 0.7}}, "score": 0.9}, {"config": {"prompt": "test", "model": "m", "params": {"temperature": 0.3}}, "score": 0.8}, ] configs = generate_guided_configs( base, {"params.temperature": [0.1, 0.3, 0.5, 0.7, 0.9]}, previous, n_trials=10, top_k=2, explore_ratio=0.3, ) assert len(configs) == 10 def test_no_previous_results_falls_back_to_random(self): base = {"prompt": "test", "params": {}} configs = generate_guided_configs( base, {"params.temperature": [0.1, 0.5, 0.9]}, [], n_trials=5, ) assert len(configs) == 5 def test_respects_n_trials_limit(self): base = {"prompt": "test", "params": {}} previous = [{"config": base, "score": 0.9}] * 20 configs = generate_guided_configs( base, {"params.temperature": [0.1, 0.5, 0.9]}, previous, n_trials=7, ) assert len(configs) == 7 # --------------------------------------------------------------------------- # Helper functions # --------------------------------------------------------------------------- class TestSampleParam: def test_list(self): for _ in range(20): v = _sample_param([1, 2, 3]) assert v in [1, 2, 3] def test_range(self): for _ in range(20): v = _sample_param({"min": 0.0, "max": 1.0}) assert 0.0 <= v <= 1.0 def test_step(self): for _ in range(20): v = _sample_param({"min": 0.0, "max": 1.0, "step": 0.25}) assert v in [0.0, 0.25, 0.5, 0.75, 1.0] def test_scalar(self): assert _sample_param(42) == 42 assert _sample_param("gpt-4") == "gpt-4" class TestSetConfigValue: def test_simple_key(self): cfg = {} _set_config_value(cfg, "model", "gpt-4") assert cfg["model"] == "gpt-4" def test_dotted_key(self): cfg = {"params": {}} _set_config_value(cfg, "params.temperature", 0.7) assert cfg["params"]["temperature"] == 0.7 def test_creates_intermediate_dicts(self): cfg = {} _set_config_value(cfg, "params.temperature", 0.7) assert cfg["params"]["temperature"] == 0.7 def test_deeply_nested(self): cfg = {} _set_config_value(cfg, "a.b.c", "deep") assert cfg["a"]["b"]["c"] == "deep" class TestParseSweepConfig: def test_grid(self): sc = _parse_sweep_config({"type": "grid", "params": {"model": ["a", "b"]}}) assert sc.sweep_type == "grid" assert sc.params == {"model": ["a", "b"]} def test_random(self): sc = _parse_sweep_config({"type": "random", "params": {}, "n_trials": 50}) assert sc.sweep_type == "random" assert sc.n_trials == 50 def test_guided(self): sc = _parse_sweep_config({ "type": "guided", "params": {}, "n_trials": 20, "top_k": 3, "explore_ratio": 0.4, }) assert sc.sweep_type == "guided" assert sc.top_k == 3 assert sc.explore_ratio == 0.4 def test_defaults(self): sc = _parse_sweep_config({}) assert sc.sweep_type == "grid" assert sc.n_trials == 100 assert sc.top_k == 5 assert sc.explore_ratio == 0.3 # --------------------------------------------------------------------------- # Full sweep execution # --------------------------------------------------------------------------- class TestRunSweep: def test_grid_sweep_creates_and_runs_all(self): engine = _engine() with _session(engine) as db: experiment = _make_experiment( db, parameter_space={ "type": "grid", "params": { "model": ["model-a", "model-b"], "params.temperature": [0.1, 0.9], }, }, pipeline_stages=[ {"prompt_template": "Hello", "model": "default", "params": {}}, ], ) adapter = MockAdapter() cache = ResponseCacheLayer() result = asyncio.get_event_loop().run_until_complete( run_sweep(db, experiment, adapter, cache, max_concurrent=2) ) assert result.total_runs == 4 # 2 models × 2 temps assert result.completed_runs == 4 assert result.failed_runs == 0 assert result.stopped_reason == "completed" # Verify experiment status updated db.refresh(experiment) assert experiment.status == ExperimentStatus.completed # Verify Run records exist runs = db.query(Run).filter_by(experiment_id=experiment.id).all() assert len(runs) == 4 assert all(r.status == RunStatus.completed for r in runs) def test_random_sweep(self): engine = _engine() with _session(engine) as db: experiment = _make_experiment( db, parameter_space={ "type": "random", "params": {"params.temperature": [0.1, 0.5, 0.9]}, "n_trials": 3, }, pipeline_stages=[ {"prompt_template": "Hello", "model": "m", "params": {}}, ], ) adapter = MockAdapter() cache = ResponseCacheLayer() result = asyncio.get_event_loop().run_until_complete( run_sweep(db, experiment, adapter, cache) ) assert result.total_runs == 3 assert result.completed_runs == 3 def test_sweep_events_published(self): engine = _engine() with _session(engine) as db: experiment = _make_experiment( db, parameter_space={ "type": "grid", "params": {"model": ["a"]}, }, pipeline_stages=[ {"prompt_template": "Hello", "model": "m", "params": {}}, ], ) adapter = MockAdapter() cache = ResponseCacheLayer() events: list[dict] = [] bus = EventBus() bus.add_listener(lambda e: events.append(e)) asyncio.get_event_loop().run_until_complete( run_sweep(db, experiment, adapter, cache, event_bus=bus) ) event_types = [e["type"] for e in events] assert "sweep.started" in event_types assert "sweep.completed" in event_types assert "run.started" in event_types assert "run.completed" in event_types def test_sweep_stop_flag(self): engine = _engine() with _session(engine) as db: experiment = _make_experiment( db, parameter_space={ "type": "grid", "params": {"model": ["a", "b", "c", "d"]}, }, pipeline_stages=[ {"prompt_template": "Hello", "model": "m", "params": {}}, ], ) # Mock Redis that always returns stop flag mock_redis = MagicMock() mock_redis.get = MagicMock(side_effect=lambda key: "1" if "stop" in key else None) adapter = MockAdapter() cache = ResponseCacheLayer() result = asyncio.get_event_loop().run_until_complete( run_sweep( db, experiment, adapter, cache, redis_client=mock_redis, max_concurrent=1, ) ) assert result.stopped_reason == "stopped" # Not all runs should have completed assert result.completed_runs < result.total_runs def test_sweep_pause_flag(self): engine = _engine() with _session(engine) as db: experiment = _make_experiment( db, parameter_space={ "type": "grid", "params": {"model": ["a", "b", "c"]}, }, pipeline_stages=[ {"prompt_template": "Hello", "model": "m", "params": {}}, ], ) mock_redis = MagicMock() mock_redis.get = MagicMock(side_effect=lambda key: "1" if "pause" in key else None) adapter = MockAdapter() cache = ResponseCacheLayer() result = asyncio.get_event_loop().run_until_complete( run_sweep( db, experiment, adapter, cache, redis_client=mock_redis, max_concurrent=1, ) ) assert result.stopped_reason == "paused" db.refresh(experiment) assert experiment.status == ExperimentStatus.paused def test_sweep_token_budget(self): engine = _engine() with _session(engine) as db: experiment = _make_experiment( db, parameter_space={ "type": "grid", "params": {"model": ["a", "b", "c", "d", "e"]}, }, pipeline_stages=[ {"prompt_template": "Hello", "model": "m", "params": {}}, ], ) # Each run uses 10 tokens_in + 5 tokens_out = 15 # Budget of 20 should stop after first batch adapter = MockAdapter(tokens_in=10, tokens_out=5) cache = ResponseCacheLayer() result = asyncio.get_event_loop().run_until_complete( run_sweep( db, experiment, adapter, cache, max_concurrent=1, max_tokens=20, ) ) assert result.stopped_reason == "token_budget" assert result.completed_runs < result.total_runs def test_sweep_with_scorers(self): engine = _engine() with _session(engine) as db: experiment = _make_experiment( db, parameter_space={ "type": "grid", "params": {"model": ["a"]}, }, pipeline_stages=[ {"prompt_template": "Hello", "model": "m", "params": {}}, ], ) adapter = MockAdapter() cache = ResponseCacheLayer() class SimpleScorer: name = "quality" def score(self, input_data, output, context): return 0.85 result = asyncio.get_event_loop().run_until_complete( run_sweep( db, experiment, adapter, cache, scorers=[SimpleScorer()], ) ) assert result.completed_runs == 1 scores = db.query(Score).all() assert len(scores) == 1 assert scores[0].value == pytest.approx(0.85) def test_sweep_failed_run_doesnt_stop_sweep(self): engine = _engine() with _session(engine) as db: experiment = _make_experiment( db, parameter_space={ "type": "grid", "params": {"params.fail": [False, True]}, }, pipeline_stages=[ {"prompt_template": "Hello", "model": "m"}, ], ) cache = ResponseCacheLayer() call_count = 0 class SelectiveAdapter(BaseAdapter): async def complete(self, prompt, model, params): nonlocal call_count call_count += 1 if params.get("fail"): raise RuntimeError("Intentional failure") return AdapterResponse( text="ok", tokens_in=5, tokens_out=3, latency_ms=10.0, model=model, ) async def list_models(self): return [] async def test_connection(self): return True result = asyncio.get_event_loop().run_until_complete( run_sweep(db, experiment, SelectiveAdapter(), cache, max_concurrent=1) ) assert result.total_runs == 2 assert result.completed_runs == 1 assert result.failed_runs == 1 assert result.stopped_reason == "completed" def test_guided_sweep(self): engine = _engine() with _session(engine) as db: experiment = _make_experiment( db, parameter_space={ "type": "guided", "params": {"params.temperature": [0.1, 0.3, 0.5, 0.7, 0.9]}, "n_trials": 5, "top_k": 2, "explore_ratio": 0.4, }, pipeline_stages=[ {"prompt_template": "Hello", "model": "m", "params": {}}, ], ) adapter = MockAdapter() cache = ResponseCacheLayer() result = asyncio.get_event_loop().run_until_complete( run_sweep(db, experiment, adapter, cache) ) assert result.total_runs == 5 assert result.completed_runs == 5 def test_unknown_sweep_type_raises(self): engine = _engine() with _session(engine) as db: experiment = _make_experiment( db, parameter_space={"type": "bayesian", "params": {}}, ) adapter = MockAdapter() cache = ResponseCacheLayer() with pytest.raises(ValueError, match="Unknown sweep type"): asyncio.get_event_loop().run_until_complete( run_sweep(db, experiment, adapter, cache) ) def test_sweep_tokens_tracked(self): engine = _engine() with _session(engine) as db: experiment = _make_experiment( db, parameter_space={ "type": "grid", "params": {"model": ["a", "b"]}, }, pipeline_stages=[ {"prompt_template": "Hello", "model": "m", "params": {}}, ], ) adapter = MockAdapter(tokens_in=100, tokens_out=50) cache = ResponseCacheLayer() result = asyncio.get_event_loop().run_until_complete( run_sweep(db, experiment, adapter, cache) ) assert result.total_tokens_in == 200 # 100 × 2 assert result.total_tokens_out == 100 # 50 × 2