"""Sweep orchestration for PromptLooper. Supports three sweep types: - GridSweep: enumerate all combinations from parameter_space - RandomSweep: sample N random configs from parameter ranges - GuidedSweep: use previous results to pick next configs (top-K exploitation + exploration) The sweep runner manages parallelism, token budgets, pause/resume/stop via Redis flags, and emits events for each run completion. """ import asyncio import itertools import json import logging import random import uuid from dataclasses import dataclass, field from typing import Any from sqlalchemy.orm import Session from engine.adapters.base import BaseAdapter from engine.cache import ResponseCacheLayer, compute_config_hash from engine.runner import EventBus, run_single from models import Experiment, ExperimentStatus, Run, RunStatus, Score logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Sweep configuration dataclasses # --------------------------------------------------------------------------- @dataclass class SweepConfig: """Parsed sweep configuration from experiment.parameter_space.""" sweep_type: str # "grid", "random", "guided" params: dict[str, Any] = field(default_factory=dict) n_trials: int = 100 # for random/guided top_k: int = 5 # for guided: exploit top K results explore_ratio: float = 0.3 # for guided: fraction of random exploration @dataclass class SweepResult: """Summary returned after a sweep completes or is stopped.""" experiment_id: uuid.UUID total_runs: int completed_runs: int failed_runs: int total_tokens_in: int total_tokens_out: int stopped_reason: str | None = None # "completed", "stopped", "paused", "token_budget" # --------------------------------------------------------------------------- # Redis flag keys for sweep control # --------------------------------------------------------------------------- _SWEEP_STOP_KEY = "promptlooper:sweep:{experiment_id}:stop" _SWEEP_PAUSE_KEY = "promptlooper:sweep:{experiment_id}:pause" def _stop_key(experiment_id: uuid.UUID) -> str: return _SWEEP_STOP_KEY.format(experiment_id=experiment_id) def _pause_key(experiment_id: uuid.UUID) -> str: return _SWEEP_PAUSE_KEY.format(experiment_id=experiment_id) # --------------------------------------------------------------------------- # Configuration generation # --------------------------------------------------------------------------- def generate_grid_configs( base_config: dict[str, Any], param_space: dict[str, list], ) -> list[dict[str, Any]]: """Enumerate all combinations from parameter_space grid values. Args: base_config: Base run config (pipeline_stages, input_data, etc.) param_space: Dict mapping param names to lists of values. Supports top-level keys "model" and nested "params.*" keys. Returns: List of fully-formed run configs, one per combination. """ if not param_space: return [base_config.copy()] keys = sorted(param_space.keys()) value_lists = [param_space[k] for k in keys] configs = [] for combo in itertools.product(*value_lists): cfg = _deep_copy_config(base_config) for key, value in zip(keys, combo): _set_config_value(cfg, key, value) configs.append(cfg) return configs def generate_random_configs( base_config: dict[str, Any], param_space: dict[str, Any], n_trials: int, ) -> list[dict[str, Any]]: """Sample N random configs from parameter ranges. Each param_space value can be: - A list: sample uniformly from the list - A dict with "min"/"max": sample uniformly from the continuous range - A dict with "min"/"max"/"step": sample from discrete steps """ configs = [] for _ in range(n_trials): cfg = _deep_copy_config(base_config) for key, spec in param_space.items(): value = _sample_param(spec) _set_config_value(cfg, key, value) configs.append(cfg) return configs def generate_guided_configs( base_config: dict[str, Any], param_space: dict[str, Any], previous_results: list[dict[str, Any]], n_trials: int, top_k: int = 5, explore_ratio: float = 0.3, ) -> list[dict[str, Any]]: """Generate configs using top-K exploitation + random exploration. Args: base_config: Base run config. param_space: Parameter search space. previous_results: List of dicts with "config" and "score" keys from previous runs, sorted by score descending. n_trials: Total number of configs to generate. top_k: Number of top results to exploit. explore_ratio: Fraction of trials devoted to random exploration. Returns: List of run configs mixing exploitation and exploration. """ n_explore = max(1, int(n_trials * explore_ratio)) n_exploit = n_trials - n_explore configs = [] # Exploitation: perturb top-K configs if previous_results and n_exploit > 0: top_results = previous_results[:top_k] for i in range(n_exploit): base_result = top_results[i % len(top_results)] source_cfg = base_result.get("config", base_config) cfg = _deep_copy_config(source_cfg) # Perturb one random parameter if param_space: perturb_key = random.choice(list(param_space.keys())) value = _sample_param(param_space[perturb_key]) _set_config_value(cfg, perturb_key, value) configs.append(cfg) else: # No previous results — fall back to all random n_explore = n_trials # Exploration: random configs configs.extend( generate_random_configs(base_config, param_space, n_explore) ) return configs[:n_trials] # --------------------------------------------------------------------------- # Sweep runner # --------------------------------------------------------------------------- async def run_sweep( db: Session, experiment: Experiment, adapter: BaseAdapter, cache: ResponseCacheLayer, scorers: list[Any] | None = None, event_bus: EventBus | None = None, redis_client: Any | None = None, max_concurrent: int = 4, max_tokens: int = 0, ) -> SweepResult: """Execute a full sweep for an experiment. Generates run configs based on experiment.parameter_space, creates Run records, and executes them with bounded parallelism. Args: db: SQLAlchemy session. experiment: Experiment ORM object. adapter: LLM adapter for completions. cache: Response cache layer. scorers: Optional scorer instances. event_bus: Optional event publisher. redis_client: Optional Redis client for pause/stop flags. max_concurrent: Maximum parallel runs. max_tokens: Token budget (0 = unlimited). Returns: SweepResult summary. """ scorers = scorers or [] sweep_config = _parse_sweep_config(experiment.parameter_space or {}) base_config = _build_base_config(experiment) # Generate configs based on sweep type if sweep_config.sweep_type == "grid": configs = generate_grid_configs(base_config, sweep_config.params) elif sweep_config.sweep_type == "random": configs = generate_random_configs( base_config, sweep_config.params, sweep_config.n_trials ) elif sweep_config.sweep_type == "guided": # For guided, first gather any previous completed results previous = _get_previous_results(db, experiment.id) configs = generate_guided_configs( base_config, sweep_config.params, previous, sweep_config.n_trials, sweep_config.top_k, sweep_config.explore_ratio, ) else: raise ValueError(f"Unknown sweep type: {sweep_config.sweep_type}") # Create Run records runs = _create_run_records(db, experiment, configs) # Update experiment status experiment.status = ExperimentStatus.running db.commit() if event_bus: event_bus.publish({ "type": "sweep.started", "experiment_id": str(experiment.id), "total_runs": len(runs), "sweep_type": sweep_config.sweep_type, }) # Execute runs with bounded parallelism result = await _execute_runs( db=db, experiment=experiment, runs=runs, adapter=adapter, cache=cache, scorers=scorers, event_bus=event_bus, redis_client=redis_client, max_concurrent=max_concurrent, max_tokens=max_tokens, ) # Update experiment status if result.stopped_reason == "paused": experiment.status = ExperimentStatus.paused else: experiment.status = ExperimentStatus.completed db.commit() sweep_completed_event = { "type": "sweep.completed", "experiment_id": str(experiment.id), "total_runs": result.total_runs, "completed_runs": result.completed_runs, "failed_runs": result.failed_runs, "stopped_reason": result.stopped_reason, } if event_bus: event_bus.publish(sweep_completed_event) # Fire webhooks asynchronously try: from engine.tasks import fire_webhooks fire_webhooks("sweep.completed", sweep_completed_event) except Exception: logger.warning("Failed to dispatch sweep.completed webhooks", exc_info=True) return result async def _execute_runs( db: Session, experiment: Experiment, runs: list[Run], adapter: BaseAdapter, cache: ResponseCacheLayer, scorers: list[Any], event_bus: EventBus | None, redis_client: Any | None, max_concurrent: int, max_tokens: int, ) -> SweepResult: """Execute runs with bounded concurrency, respecting stop/pause/token budget.""" semaphore = asyncio.Semaphore(max_concurrent) completed = 0 failed = 0 total_tokens_in = 0 total_tokens_out = 0 stopped_reason: str | None = "completed" async def _run_one(run: Run) -> None: nonlocal completed, failed, total_tokens_in, total_tokens_out async with semaphore: try: result = await run_single( db, run, adapter, cache, scorers=scorers, event_bus=event_bus ) completed += 1 total_tokens_in += result.tokens_in or 0 total_tokens_out += result.tokens_out or 0 if event_bus: event_bus.publish({ "type": "sweep.run_completed", "experiment_id": str(experiment.id), "run_id": str(run.id), "completed": completed, "total": len(runs), }) except Exception: failed += 1 logger.warning( "Sweep run %s failed", run.id, exc_info=True ) pending_runs = list(runs) while pending_runs: # Check stop/pause flags if redis_client is not None: if redis_client.get(_stop_key(experiment.id)): stopped_reason = "stopped" # Mark remaining runs as failed for r in pending_runs: if r.status == RunStatus.pending: r.status = RunStatus.failed db.commit() break if redis_client.get(_pause_key(experiment.id)): stopped_reason = "paused" break # Check token budget if max_tokens > 0 and (total_tokens_in + total_tokens_out) >= max_tokens: stopped_reason = "token_budget" break # Take a batch up to max_concurrent batch_size = min(max_concurrent, len(pending_runs)) batch = pending_runs[:batch_size] pending_runs = pending_runs[batch_size:] # Run batch concurrently await asyncio.gather(*[_run_one(run) for run in batch]) return SweepResult( experiment_id=experiment.id, total_runs=len(runs), completed_runs=completed, failed_runs=failed, total_tokens_in=total_tokens_in, total_tokens_out=total_tokens_out, stopped_reason=stopped_reason, ) # --------------------------------------------------------------------------- # Sweep control helpers # --------------------------------------------------------------------------- def request_stop(redis_client: Any, experiment_id: uuid.UUID) -> None: """Set the stop flag for a running sweep.""" redis_client.set(_stop_key(experiment_id), "1", ex=3600) def request_pause(redis_client: Any, experiment_id: uuid.UUID) -> None: """Set the pause flag for a running sweep.""" redis_client.set(_pause_key(experiment_id), "1", ex=3600) def request_resume(redis_client: Any, experiment_id: uuid.UUID) -> None: """Clear the pause flag to allow a sweep to continue.""" redis_client.delete(_pause_key(experiment_id)) def clear_sweep_flags(redis_client: Any, experiment_id: uuid.UUID) -> None: """Clear all control flags for an experiment.""" redis_client.delete(_stop_key(experiment_id)) redis_client.delete(_pause_key(experiment_id)) # --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _parse_sweep_config(parameter_space: dict[str, Any]) -> SweepConfig: """Parse experiment.parameter_space into a SweepConfig.""" return SweepConfig( sweep_type=parameter_space.get("type", "grid"), params=parameter_space.get("params", {}), n_trials=parameter_space.get("n_trials", 100), top_k=parameter_space.get("top_k", 5), explore_ratio=parameter_space.get("explore_ratio", 0.3), ) def _build_base_config(experiment: Experiment) -> dict[str, Any]: """Build the base run config from experiment fields.""" config: dict[str, Any] = {} if experiment.pipeline_stages: config["pipeline_stages"] = experiment.pipeline_stages if experiment.sample_data: config["input_data"] = experiment.sample_data if experiment.scoring_config: config["scoring_config"] = experiment.scoring_config return config def _create_run_records( db: Session, experiment: Experiment, configs: list[dict[str, Any]], ) -> list[Run]: """Create Run records for each config and persist them.""" runs = [] for cfg in configs: config_hash = compute_config_hash( cfg.get("prompt", cfg.get("pipeline_stages", [{}])[0].get("prompt_template", "")), cfg.get("model", ""), cfg.get("params", {}), cfg.get("input_data"), ) run = Run( experiment_id=experiment.id, config_hash=config_hash, config=cfg, status=RunStatus.pending, ) db.add(run) runs.append(run) db.commit() for run in runs: db.refresh(run) return runs def _get_previous_results( db: Session, experiment_id: uuid.UUID ) -> list[dict[str, Any]]: """Get previous completed runs with their average scores, sorted by score desc.""" runs = ( db.query(Run) .filter( Run.experiment_id == experiment_id, Run.status == RunStatus.completed, ) .all() ) results = [] for run in runs: scores = db.query(Score).filter(Score.run_id == run.id).all() avg_score = sum(s.value for s in scores) / len(scores) if scores else 0.0 results.append({ "config": run.config, "score": avg_score, "run_id": str(run.id), }) results.sort(key=lambda r: r["score"], reverse=True) return results def _sample_param(spec: Any) -> Any: """Sample a single parameter value from its specification. Spec can be: - A list: uniform choice - A dict with "min"/"max": uniform float - A dict with "min"/"max"/"step": discrete steps """ if isinstance(spec, list): return random.choice(spec) if isinstance(spec, dict): lo = spec.get("min", 0.0) hi = spec.get("max", 1.0) step = spec.get("step") if step is not None: # Discrete steps steps = [] v = lo while v <= hi + 1e-9: steps.append(round(v, 10)) v += step return random.choice(steps) if steps else lo return random.uniform(lo, hi) # Scalar fallback — return as-is return spec def _set_config_value(config: dict[str, Any], key: str, value: Any) -> None: """Set a value in the config dict, supporting dotted keys like 'params.temperature'.""" parts = key.split(".") target = config for part in parts[:-1]: if part not in target: target[part] = {} target = target[part] target[parts[-1]] = value def _deep_copy_config(config: dict[str, Any]) -> dict[str, Any]: """Deep-copy a config dict using JSON round-trip.""" return json.loads(json.dumps(config, default=str))