promptlooper/backend/engine/sweep.py
John Lightner ba8cb7e2c6 MAESTRO: Implement sweep orchestration engine with grid, random, and guided sweep types
Adds backend/engine/sweep.py with three sweep strategies:
- GridSweep: exhaustive enumeration of all parameter combinations
- RandomSweep: N random samples from parameter ranges (list, min/max, step)
- GuidedSweep: top-K exploitation + random exploration from previous results

Features: bounded parallelism via asyncio.Semaphore, token budget enforcement,
Redis-based pause/resume/stop control flags, sweep-level event publishing.
36 tests in test_sweep.py covering config generation, helpers, and full sweep execution.
2026-04-07 02:53:30 -05:00

528 lines
17 KiB
Python

"""Sweep orchestration for PromptLooper.
Supports three sweep types:
- GridSweep: enumerate all combinations from parameter_space
- RandomSweep: sample N random configs from parameter ranges
- GuidedSweep: use previous results to pick next configs (top-K exploitation + exploration)
The sweep runner manages parallelism, token budgets, pause/resume/stop via
Redis flags, and emits events for each run completion.
"""
import asyncio
import itertools
import json
import logging
import random
import uuid
from dataclasses import dataclass, field
from typing import Any
from sqlalchemy.orm import Session
from engine.adapters.base import BaseAdapter
from engine.cache import ResponseCacheLayer, compute_config_hash
from engine.runner import EventBus, run_single
from models import Experiment, ExperimentStatus, Run, RunStatus, Score
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Sweep configuration dataclasses
# ---------------------------------------------------------------------------
@dataclass
class SweepConfig:
"""Parsed sweep configuration from experiment.parameter_space."""
sweep_type: str # "grid", "random", "guided"
params: dict[str, Any] = field(default_factory=dict)
n_trials: int = 100 # for random/guided
top_k: int = 5 # for guided: exploit top K results
explore_ratio: float = 0.3 # for guided: fraction of random exploration
@dataclass
class SweepResult:
"""Summary returned after a sweep completes or is stopped."""
experiment_id: uuid.UUID
total_runs: int
completed_runs: int
failed_runs: int
total_tokens_in: int
total_tokens_out: int
stopped_reason: str | None = None # "completed", "stopped", "paused", "token_budget"
# ---------------------------------------------------------------------------
# Redis flag keys for sweep control
# ---------------------------------------------------------------------------
_SWEEP_STOP_KEY = "promptlooper:sweep:{experiment_id}:stop"
_SWEEP_PAUSE_KEY = "promptlooper:sweep:{experiment_id}:pause"
def _stop_key(experiment_id: uuid.UUID) -> str:
return _SWEEP_STOP_KEY.format(experiment_id=experiment_id)
def _pause_key(experiment_id: uuid.UUID) -> str:
return _SWEEP_PAUSE_KEY.format(experiment_id=experiment_id)
# ---------------------------------------------------------------------------
# Configuration generation
# ---------------------------------------------------------------------------
def generate_grid_configs(
base_config: dict[str, Any],
param_space: dict[str, list],
) -> list[dict[str, Any]]:
"""Enumerate all combinations from parameter_space grid values.
Args:
base_config: Base run config (pipeline_stages, input_data, etc.)
param_space: Dict mapping param names to lists of values.
Supports top-level keys "model" and nested "params.*" keys.
Returns:
List of fully-formed run configs, one per combination.
"""
if not param_space:
return [base_config.copy()]
keys = sorted(param_space.keys())
value_lists = [param_space[k] for k in keys]
configs = []
for combo in itertools.product(*value_lists):
cfg = _deep_copy_config(base_config)
for key, value in zip(keys, combo):
_set_config_value(cfg, key, value)
configs.append(cfg)
return configs
def generate_random_configs(
base_config: dict[str, Any],
param_space: dict[str, Any],
n_trials: int,
) -> list[dict[str, Any]]:
"""Sample N random configs from parameter ranges.
Each param_space value can be:
- A list: sample uniformly from the list
- A dict with "min"/"max": sample uniformly from the continuous range
- A dict with "min"/"max"/"step": sample from discrete steps
"""
configs = []
for _ in range(n_trials):
cfg = _deep_copy_config(base_config)
for key, spec in param_space.items():
value = _sample_param(spec)
_set_config_value(cfg, key, value)
configs.append(cfg)
return configs
def generate_guided_configs(
base_config: dict[str, Any],
param_space: dict[str, Any],
previous_results: list[dict[str, Any]],
n_trials: int,
top_k: int = 5,
explore_ratio: float = 0.3,
) -> list[dict[str, Any]]:
"""Generate configs using top-K exploitation + random exploration.
Args:
base_config: Base run config.
param_space: Parameter search space.
previous_results: List of dicts with "config" and "score" keys
from previous runs, sorted by score descending.
n_trials: Total number of configs to generate.
top_k: Number of top results to exploit.
explore_ratio: Fraction of trials devoted to random exploration.
Returns:
List of run configs mixing exploitation and exploration.
"""
n_explore = max(1, int(n_trials * explore_ratio))
n_exploit = n_trials - n_explore
configs = []
# Exploitation: perturb top-K configs
if previous_results and n_exploit > 0:
top_results = previous_results[:top_k]
for i in range(n_exploit):
base_result = top_results[i % len(top_results)]
source_cfg = base_result.get("config", base_config)
cfg = _deep_copy_config(source_cfg)
# Perturb one random parameter
if param_space:
perturb_key = random.choice(list(param_space.keys()))
value = _sample_param(param_space[perturb_key])
_set_config_value(cfg, perturb_key, value)
configs.append(cfg)
else:
# No previous results — fall back to all random
n_explore = n_trials
# Exploration: random configs
configs.extend(
generate_random_configs(base_config, param_space, n_explore)
)
return configs[:n_trials]
# ---------------------------------------------------------------------------
# Sweep runner
# ---------------------------------------------------------------------------
async def run_sweep(
db: Session,
experiment: Experiment,
adapter: BaseAdapter,
cache: ResponseCacheLayer,
scorers: list[Any] | None = None,
event_bus: EventBus | None = None,
redis_client: Any | None = None,
max_concurrent: int = 4,
max_tokens: int = 0,
) -> SweepResult:
"""Execute a full sweep for an experiment.
Generates run configs based on experiment.parameter_space, creates Run
records, and executes them with bounded parallelism.
Args:
db: SQLAlchemy session.
experiment: Experiment ORM object.
adapter: LLM adapter for completions.
cache: Response cache layer.
scorers: Optional scorer instances.
event_bus: Optional event publisher.
redis_client: Optional Redis client for pause/stop flags.
max_concurrent: Maximum parallel runs.
max_tokens: Token budget (0 = unlimited).
Returns:
SweepResult summary.
"""
scorers = scorers or []
sweep_config = _parse_sweep_config(experiment.parameter_space or {})
base_config = _build_base_config(experiment)
# Generate configs based on sweep type
if sweep_config.sweep_type == "grid":
configs = generate_grid_configs(base_config, sweep_config.params)
elif sweep_config.sweep_type == "random":
configs = generate_random_configs(
base_config, sweep_config.params, sweep_config.n_trials
)
elif sweep_config.sweep_type == "guided":
# For guided, first gather any previous completed results
previous = _get_previous_results(db, experiment.id)
configs = generate_guided_configs(
base_config,
sweep_config.params,
previous,
sweep_config.n_trials,
sweep_config.top_k,
sweep_config.explore_ratio,
)
else:
raise ValueError(f"Unknown sweep type: {sweep_config.sweep_type}")
# Create Run records
runs = _create_run_records(db, experiment, configs)
# Update experiment status
experiment.status = ExperimentStatus.running
db.commit()
if event_bus:
event_bus.publish({
"type": "sweep.started",
"experiment_id": str(experiment.id),
"total_runs": len(runs),
"sweep_type": sweep_config.sweep_type,
})
# Execute runs with bounded parallelism
result = await _execute_runs(
db=db,
experiment=experiment,
runs=runs,
adapter=adapter,
cache=cache,
scorers=scorers,
event_bus=event_bus,
redis_client=redis_client,
max_concurrent=max_concurrent,
max_tokens=max_tokens,
)
# Update experiment status
if result.stopped_reason == "paused":
experiment.status = ExperimentStatus.paused
else:
experiment.status = ExperimentStatus.completed
db.commit()
if event_bus:
event_bus.publish({
"type": "sweep.completed",
"experiment_id": str(experiment.id),
"total_runs": result.total_runs,
"completed_runs": result.completed_runs,
"failed_runs": result.failed_runs,
"stopped_reason": result.stopped_reason,
})
return result
async def _execute_runs(
db: Session,
experiment: Experiment,
runs: list[Run],
adapter: BaseAdapter,
cache: ResponseCacheLayer,
scorers: list[Any],
event_bus: EventBus | None,
redis_client: Any | None,
max_concurrent: int,
max_tokens: int,
) -> SweepResult:
"""Execute runs with bounded concurrency, respecting stop/pause/token budget."""
semaphore = asyncio.Semaphore(max_concurrent)
completed = 0
failed = 0
total_tokens_in = 0
total_tokens_out = 0
stopped_reason: str | None = "completed"
async def _run_one(run: Run) -> None:
nonlocal completed, failed, total_tokens_in, total_tokens_out
async with semaphore:
try:
result = await run_single(
db, run, adapter, cache, scorers=scorers, event_bus=event_bus
)
completed += 1
total_tokens_in += result.tokens_in or 0
total_tokens_out += result.tokens_out or 0
if event_bus:
event_bus.publish({
"type": "sweep.run_completed",
"experiment_id": str(experiment.id),
"run_id": str(run.id),
"completed": completed,
"total": len(runs),
})
except Exception:
failed += 1
logger.warning(
"Sweep run %s failed", run.id, exc_info=True
)
pending_runs = list(runs)
while pending_runs:
# Check stop/pause flags
if redis_client is not None:
if redis_client.get(_stop_key(experiment.id)):
stopped_reason = "stopped"
# Mark remaining runs as failed
for r in pending_runs:
if r.status == RunStatus.pending:
r.status = RunStatus.failed
db.commit()
break
if redis_client.get(_pause_key(experiment.id)):
stopped_reason = "paused"
break
# Check token budget
if max_tokens > 0 and (total_tokens_in + total_tokens_out) >= max_tokens:
stopped_reason = "token_budget"
break
# Take a batch up to max_concurrent
batch_size = min(max_concurrent, len(pending_runs))
batch = pending_runs[:batch_size]
pending_runs = pending_runs[batch_size:]
# Run batch concurrently
await asyncio.gather(*[_run_one(run) for run in batch])
return SweepResult(
experiment_id=experiment.id,
total_runs=len(runs),
completed_runs=completed,
failed_runs=failed,
total_tokens_in=total_tokens_in,
total_tokens_out=total_tokens_out,
stopped_reason=stopped_reason,
)
# ---------------------------------------------------------------------------
# Sweep control helpers
# ---------------------------------------------------------------------------
def request_stop(redis_client: Any, experiment_id: uuid.UUID) -> None:
"""Set the stop flag for a running sweep."""
redis_client.set(_stop_key(experiment_id), "1", ex=3600)
def request_pause(redis_client: Any, experiment_id: uuid.UUID) -> None:
"""Set the pause flag for a running sweep."""
redis_client.set(_pause_key(experiment_id), "1", ex=3600)
def request_resume(redis_client: Any, experiment_id: uuid.UUID) -> None:
"""Clear the pause flag to allow a sweep to continue."""
redis_client.delete(_pause_key(experiment_id))
def clear_sweep_flags(redis_client: Any, experiment_id: uuid.UUID) -> None:
"""Clear all control flags for an experiment."""
redis_client.delete(_stop_key(experiment_id))
redis_client.delete(_pause_key(experiment_id))
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _parse_sweep_config(parameter_space: dict[str, Any]) -> SweepConfig:
"""Parse experiment.parameter_space into a SweepConfig."""
return SweepConfig(
sweep_type=parameter_space.get("type", "grid"),
params=parameter_space.get("params", {}),
n_trials=parameter_space.get("n_trials", 100),
top_k=parameter_space.get("top_k", 5),
explore_ratio=parameter_space.get("explore_ratio", 0.3),
)
def _build_base_config(experiment: Experiment) -> dict[str, Any]:
"""Build the base run config from experiment fields."""
config: dict[str, Any] = {}
if experiment.pipeline_stages:
config["pipeline_stages"] = experiment.pipeline_stages
if experiment.sample_data:
config["input_data"] = experiment.sample_data
if experiment.scoring_config:
config["scoring_config"] = experiment.scoring_config
return config
def _create_run_records(
db: Session,
experiment: Experiment,
configs: list[dict[str, Any]],
) -> list[Run]:
"""Create Run records for each config and persist them."""
runs = []
for cfg in configs:
config_hash = compute_config_hash(
cfg.get("prompt", cfg.get("pipeline_stages", [{}])[0].get("prompt_template", "")),
cfg.get("model", ""),
cfg.get("params", {}),
cfg.get("input_data"),
)
run = Run(
experiment_id=experiment.id,
config_hash=config_hash,
config=cfg,
status=RunStatus.pending,
)
db.add(run)
runs.append(run)
db.commit()
for run in runs:
db.refresh(run)
return runs
def _get_previous_results(
db: Session, experiment_id: uuid.UUID
) -> list[dict[str, Any]]:
"""Get previous completed runs with their average scores, sorted by score desc."""
runs = (
db.query(Run)
.filter(
Run.experiment_id == experiment_id,
Run.status == RunStatus.completed,
)
.all()
)
results = []
for run in runs:
scores = db.query(Score).filter(Score.run_id == run.id).all()
avg_score = sum(s.value for s in scores) / len(scores) if scores else 0.0
results.append({
"config": run.config,
"score": avg_score,
"run_id": str(run.id),
})
results.sort(key=lambda r: r["score"], reverse=True)
return results
def _sample_param(spec: Any) -> Any:
"""Sample a single parameter value from its specification.
Spec can be:
- A list: uniform choice
- A dict with "min"/"max": uniform float
- A dict with "min"/"max"/"step": discrete steps
"""
if isinstance(spec, list):
return random.choice(spec)
if isinstance(spec, dict):
lo = spec.get("min", 0.0)
hi = spec.get("max", 1.0)
step = spec.get("step")
if step is not None:
# Discrete steps
steps = []
v = lo
while v <= hi + 1e-9:
steps.append(round(v, 10))
v += step
return random.choice(steps) if steps else lo
return random.uniform(lo, hi)
# Scalar fallback — return as-is
return spec
def _set_config_value(config: dict[str, Any], key: str, value: Any) -> None:
"""Set a value in the config dict, supporting dotted keys like 'params.temperature'."""
parts = key.split(".")
target = config
for part in parts[:-1]:
if part not in target:
target[part] = {}
target = target[part]
target[parts[-1]] = value
def _deep_copy_config(config: dict[str, Any]) -> dict[str, Any]:
"""Deep-copy a config dict using JSON round-trip."""
return json.loads(json.dumps(config, default=str))