Adds backend/engine/sweep.py with three sweep strategies: - GridSweep: exhaustive enumeration of all parameter combinations - RandomSweep: N random samples from parameter ranges (list, min/max, step) - GuidedSweep: top-K exploitation + random exploration from previous results Features: bounded parallelism via asyncio.Semaphore, token budget enforcement, Redis-based pause/resume/stop control flags, sweep-level event publishing. 36 tests in test_sweep.py covering config generation, helpers, and full sweep execution.
528 lines
17 KiB
Python
528 lines
17 KiB
Python
"""Sweep orchestration for PromptLooper.
|
|
|
|
Supports three sweep types:
|
|
- GridSweep: enumerate all combinations from parameter_space
|
|
- RandomSweep: sample N random configs from parameter ranges
|
|
- GuidedSweep: use previous results to pick next configs (top-K exploitation + exploration)
|
|
|
|
The sweep runner manages parallelism, token budgets, pause/resume/stop via
|
|
Redis flags, and emits events for each run completion.
|
|
"""
|
|
|
|
import asyncio
|
|
import itertools
|
|
import json
|
|
import logging
|
|
import random
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from engine.adapters.base import BaseAdapter
|
|
from engine.cache import ResponseCacheLayer, compute_config_hash
|
|
from engine.runner import EventBus, run_single
|
|
from models import Experiment, ExperimentStatus, Run, RunStatus, Score
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Sweep configuration dataclasses
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@dataclass
|
|
class SweepConfig:
|
|
"""Parsed sweep configuration from experiment.parameter_space."""
|
|
|
|
sweep_type: str # "grid", "random", "guided"
|
|
params: dict[str, Any] = field(default_factory=dict)
|
|
n_trials: int = 100 # for random/guided
|
|
top_k: int = 5 # for guided: exploit top K results
|
|
explore_ratio: float = 0.3 # for guided: fraction of random exploration
|
|
|
|
|
|
@dataclass
|
|
class SweepResult:
|
|
"""Summary returned after a sweep completes or is stopped."""
|
|
|
|
experiment_id: uuid.UUID
|
|
total_runs: int
|
|
completed_runs: int
|
|
failed_runs: int
|
|
total_tokens_in: int
|
|
total_tokens_out: int
|
|
stopped_reason: str | None = None # "completed", "stopped", "paused", "token_budget"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Redis flag keys for sweep control
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_SWEEP_STOP_KEY = "promptlooper:sweep:{experiment_id}:stop"
|
|
_SWEEP_PAUSE_KEY = "promptlooper:sweep:{experiment_id}:pause"
|
|
|
|
|
|
def _stop_key(experiment_id: uuid.UUID) -> str:
|
|
return _SWEEP_STOP_KEY.format(experiment_id=experiment_id)
|
|
|
|
|
|
def _pause_key(experiment_id: uuid.UUID) -> str:
|
|
return _SWEEP_PAUSE_KEY.format(experiment_id=experiment_id)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Configuration generation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def generate_grid_configs(
|
|
base_config: dict[str, Any],
|
|
param_space: dict[str, list],
|
|
) -> list[dict[str, Any]]:
|
|
"""Enumerate all combinations from parameter_space grid values.
|
|
|
|
Args:
|
|
base_config: Base run config (pipeline_stages, input_data, etc.)
|
|
param_space: Dict mapping param names to lists of values.
|
|
Supports top-level keys "model" and nested "params.*" keys.
|
|
|
|
Returns:
|
|
List of fully-formed run configs, one per combination.
|
|
"""
|
|
if not param_space:
|
|
return [base_config.copy()]
|
|
|
|
keys = sorted(param_space.keys())
|
|
value_lists = [param_space[k] for k in keys]
|
|
configs = []
|
|
|
|
for combo in itertools.product(*value_lists):
|
|
cfg = _deep_copy_config(base_config)
|
|
for key, value in zip(keys, combo):
|
|
_set_config_value(cfg, key, value)
|
|
configs.append(cfg)
|
|
|
|
return configs
|
|
|
|
|
|
def generate_random_configs(
|
|
base_config: dict[str, Any],
|
|
param_space: dict[str, Any],
|
|
n_trials: int,
|
|
) -> list[dict[str, Any]]:
|
|
"""Sample N random configs from parameter ranges.
|
|
|
|
Each param_space value can be:
|
|
- A list: sample uniformly from the list
|
|
- A dict with "min"/"max": sample uniformly from the continuous range
|
|
- A dict with "min"/"max"/"step": sample from discrete steps
|
|
"""
|
|
configs = []
|
|
for _ in range(n_trials):
|
|
cfg = _deep_copy_config(base_config)
|
|
for key, spec in param_space.items():
|
|
value = _sample_param(spec)
|
|
_set_config_value(cfg, key, value)
|
|
configs.append(cfg)
|
|
return configs
|
|
|
|
|
|
def generate_guided_configs(
|
|
base_config: dict[str, Any],
|
|
param_space: dict[str, Any],
|
|
previous_results: list[dict[str, Any]],
|
|
n_trials: int,
|
|
top_k: int = 5,
|
|
explore_ratio: float = 0.3,
|
|
) -> list[dict[str, Any]]:
|
|
"""Generate configs using top-K exploitation + random exploration.
|
|
|
|
Args:
|
|
base_config: Base run config.
|
|
param_space: Parameter search space.
|
|
previous_results: List of dicts with "config" and "score" keys
|
|
from previous runs, sorted by score descending.
|
|
n_trials: Total number of configs to generate.
|
|
top_k: Number of top results to exploit.
|
|
explore_ratio: Fraction of trials devoted to random exploration.
|
|
|
|
Returns:
|
|
List of run configs mixing exploitation and exploration.
|
|
"""
|
|
n_explore = max(1, int(n_trials * explore_ratio))
|
|
n_exploit = n_trials - n_explore
|
|
|
|
configs = []
|
|
|
|
# Exploitation: perturb top-K configs
|
|
if previous_results and n_exploit > 0:
|
|
top_results = previous_results[:top_k]
|
|
for i in range(n_exploit):
|
|
base_result = top_results[i % len(top_results)]
|
|
source_cfg = base_result.get("config", base_config)
|
|
cfg = _deep_copy_config(source_cfg)
|
|
# Perturb one random parameter
|
|
if param_space:
|
|
perturb_key = random.choice(list(param_space.keys()))
|
|
value = _sample_param(param_space[perturb_key])
|
|
_set_config_value(cfg, perturb_key, value)
|
|
configs.append(cfg)
|
|
else:
|
|
# No previous results — fall back to all random
|
|
n_explore = n_trials
|
|
|
|
# Exploration: random configs
|
|
configs.extend(
|
|
generate_random_configs(base_config, param_space, n_explore)
|
|
)
|
|
|
|
return configs[:n_trials]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Sweep runner
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def run_sweep(
|
|
db: Session,
|
|
experiment: Experiment,
|
|
adapter: BaseAdapter,
|
|
cache: ResponseCacheLayer,
|
|
scorers: list[Any] | None = None,
|
|
event_bus: EventBus | None = None,
|
|
redis_client: Any | None = None,
|
|
max_concurrent: int = 4,
|
|
max_tokens: int = 0,
|
|
) -> SweepResult:
|
|
"""Execute a full sweep for an experiment.
|
|
|
|
Generates run configs based on experiment.parameter_space, creates Run
|
|
records, and executes them with bounded parallelism.
|
|
|
|
Args:
|
|
db: SQLAlchemy session.
|
|
experiment: Experiment ORM object.
|
|
adapter: LLM adapter for completions.
|
|
cache: Response cache layer.
|
|
scorers: Optional scorer instances.
|
|
event_bus: Optional event publisher.
|
|
redis_client: Optional Redis client for pause/stop flags.
|
|
max_concurrent: Maximum parallel runs.
|
|
max_tokens: Token budget (0 = unlimited).
|
|
|
|
Returns:
|
|
SweepResult summary.
|
|
"""
|
|
scorers = scorers or []
|
|
sweep_config = _parse_sweep_config(experiment.parameter_space or {})
|
|
base_config = _build_base_config(experiment)
|
|
|
|
# Generate configs based on sweep type
|
|
if sweep_config.sweep_type == "grid":
|
|
configs = generate_grid_configs(base_config, sweep_config.params)
|
|
elif sweep_config.sweep_type == "random":
|
|
configs = generate_random_configs(
|
|
base_config, sweep_config.params, sweep_config.n_trials
|
|
)
|
|
elif sweep_config.sweep_type == "guided":
|
|
# For guided, first gather any previous completed results
|
|
previous = _get_previous_results(db, experiment.id)
|
|
configs = generate_guided_configs(
|
|
base_config,
|
|
sweep_config.params,
|
|
previous,
|
|
sweep_config.n_trials,
|
|
sweep_config.top_k,
|
|
sweep_config.explore_ratio,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown sweep type: {sweep_config.sweep_type}")
|
|
|
|
# Create Run records
|
|
runs = _create_run_records(db, experiment, configs)
|
|
|
|
# Update experiment status
|
|
experiment.status = ExperimentStatus.running
|
|
db.commit()
|
|
|
|
if event_bus:
|
|
event_bus.publish({
|
|
"type": "sweep.started",
|
|
"experiment_id": str(experiment.id),
|
|
"total_runs": len(runs),
|
|
"sweep_type": sweep_config.sweep_type,
|
|
})
|
|
|
|
# Execute runs with bounded parallelism
|
|
result = await _execute_runs(
|
|
db=db,
|
|
experiment=experiment,
|
|
runs=runs,
|
|
adapter=adapter,
|
|
cache=cache,
|
|
scorers=scorers,
|
|
event_bus=event_bus,
|
|
redis_client=redis_client,
|
|
max_concurrent=max_concurrent,
|
|
max_tokens=max_tokens,
|
|
)
|
|
|
|
# Update experiment status
|
|
if result.stopped_reason == "paused":
|
|
experiment.status = ExperimentStatus.paused
|
|
else:
|
|
experiment.status = ExperimentStatus.completed
|
|
db.commit()
|
|
|
|
if event_bus:
|
|
event_bus.publish({
|
|
"type": "sweep.completed",
|
|
"experiment_id": str(experiment.id),
|
|
"total_runs": result.total_runs,
|
|
"completed_runs": result.completed_runs,
|
|
"failed_runs": result.failed_runs,
|
|
"stopped_reason": result.stopped_reason,
|
|
})
|
|
|
|
return result
|
|
|
|
|
|
async def _execute_runs(
|
|
db: Session,
|
|
experiment: Experiment,
|
|
runs: list[Run],
|
|
adapter: BaseAdapter,
|
|
cache: ResponseCacheLayer,
|
|
scorers: list[Any],
|
|
event_bus: EventBus | None,
|
|
redis_client: Any | None,
|
|
max_concurrent: int,
|
|
max_tokens: int,
|
|
) -> SweepResult:
|
|
"""Execute runs with bounded concurrency, respecting stop/pause/token budget."""
|
|
semaphore = asyncio.Semaphore(max_concurrent)
|
|
completed = 0
|
|
failed = 0
|
|
total_tokens_in = 0
|
|
total_tokens_out = 0
|
|
stopped_reason: str | None = "completed"
|
|
|
|
async def _run_one(run: Run) -> None:
|
|
nonlocal completed, failed, total_tokens_in, total_tokens_out
|
|
async with semaphore:
|
|
try:
|
|
result = await run_single(
|
|
db, run, adapter, cache, scorers=scorers, event_bus=event_bus
|
|
)
|
|
completed += 1
|
|
total_tokens_in += result.tokens_in or 0
|
|
total_tokens_out += result.tokens_out or 0
|
|
|
|
if event_bus:
|
|
event_bus.publish({
|
|
"type": "sweep.run_completed",
|
|
"experiment_id": str(experiment.id),
|
|
"run_id": str(run.id),
|
|
"completed": completed,
|
|
"total": len(runs),
|
|
})
|
|
except Exception:
|
|
failed += 1
|
|
logger.warning(
|
|
"Sweep run %s failed", run.id, exc_info=True
|
|
)
|
|
|
|
pending_runs = list(runs)
|
|
while pending_runs:
|
|
# Check stop/pause flags
|
|
if redis_client is not None:
|
|
if redis_client.get(_stop_key(experiment.id)):
|
|
stopped_reason = "stopped"
|
|
# Mark remaining runs as failed
|
|
for r in pending_runs:
|
|
if r.status == RunStatus.pending:
|
|
r.status = RunStatus.failed
|
|
db.commit()
|
|
break
|
|
if redis_client.get(_pause_key(experiment.id)):
|
|
stopped_reason = "paused"
|
|
break
|
|
|
|
# Check token budget
|
|
if max_tokens > 0 and (total_tokens_in + total_tokens_out) >= max_tokens:
|
|
stopped_reason = "token_budget"
|
|
break
|
|
|
|
# Take a batch up to max_concurrent
|
|
batch_size = min(max_concurrent, len(pending_runs))
|
|
batch = pending_runs[:batch_size]
|
|
pending_runs = pending_runs[batch_size:]
|
|
|
|
# Run batch concurrently
|
|
await asyncio.gather(*[_run_one(run) for run in batch])
|
|
|
|
return SweepResult(
|
|
experiment_id=experiment.id,
|
|
total_runs=len(runs),
|
|
completed_runs=completed,
|
|
failed_runs=failed,
|
|
total_tokens_in=total_tokens_in,
|
|
total_tokens_out=total_tokens_out,
|
|
stopped_reason=stopped_reason,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Sweep control helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def request_stop(redis_client: Any, experiment_id: uuid.UUID) -> None:
|
|
"""Set the stop flag for a running sweep."""
|
|
redis_client.set(_stop_key(experiment_id), "1", ex=3600)
|
|
|
|
|
|
def request_pause(redis_client: Any, experiment_id: uuid.UUID) -> None:
|
|
"""Set the pause flag for a running sweep."""
|
|
redis_client.set(_pause_key(experiment_id), "1", ex=3600)
|
|
|
|
|
|
def request_resume(redis_client: Any, experiment_id: uuid.UUID) -> None:
|
|
"""Clear the pause flag to allow a sweep to continue."""
|
|
redis_client.delete(_pause_key(experiment_id))
|
|
|
|
|
|
def clear_sweep_flags(redis_client: Any, experiment_id: uuid.UUID) -> None:
|
|
"""Clear all control flags for an experiment."""
|
|
redis_client.delete(_stop_key(experiment_id))
|
|
redis_client.delete(_pause_key(experiment_id))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Internal helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _parse_sweep_config(parameter_space: dict[str, Any]) -> SweepConfig:
|
|
"""Parse experiment.parameter_space into a SweepConfig."""
|
|
return SweepConfig(
|
|
sweep_type=parameter_space.get("type", "grid"),
|
|
params=parameter_space.get("params", {}),
|
|
n_trials=parameter_space.get("n_trials", 100),
|
|
top_k=parameter_space.get("top_k", 5),
|
|
explore_ratio=parameter_space.get("explore_ratio", 0.3),
|
|
)
|
|
|
|
|
|
def _build_base_config(experiment: Experiment) -> dict[str, Any]:
|
|
"""Build the base run config from experiment fields."""
|
|
config: dict[str, Any] = {}
|
|
if experiment.pipeline_stages:
|
|
config["pipeline_stages"] = experiment.pipeline_stages
|
|
if experiment.sample_data:
|
|
config["input_data"] = experiment.sample_data
|
|
if experiment.scoring_config:
|
|
config["scoring_config"] = experiment.scoring_config
|
|
return config
|
|
|
|
|
|
def _create_run_records(
|
|
db: Session,
|
|
experiment: Experiment,
|
|
configs: list[dict[str, Any]],
|
|
) -> list[Run]:
|
|
"""Create Run records for each config and persist them."""
|
|
runs = []
|
|
for cfg in configs:
|
|
config_hash = compute_config_hash(
|
|
cfg.get("prompt", cfg.get("pipeline_stages", [{}])[0].get("prompt_template", "")),
|
|
cfg.get("model", ""),
|
|
cfg.get("params", {}),
|
|
cfg.get("input_data"),
|
|
)
|
|
run = Run(
|
|
experiment_id=experiment.id,
|
|
config_hash=config_hash,
|
|
config=cfg,
|
|
status=RunStatus.pending,
|
|
)
|
|
db.add(run)
|
|
runs.append(run)
|
|
db.commit()
|
|
for run in runs:
|
|
db.refresh(run)
|
|
return runs
|
|
|
|
|
|
def _get_previous_results(
|
|
db: Session, experiment_id: uuid.UUID
|
|
) -> list[dict[str, Any]]:
|
|
"""Get previous completed runs with their average scores, sorted by score desc."""
|
|
runs = (
|
|
db.query(Run)
|
|
.filter(
|
|
Run.experiment_id == experiment_id,
|
|
Run.status == RunStatus.completed,
|
|
)
|
|
.all()
|
|
)
|
|
|
|
results = []
|
|
for run in runs:
|
|
scores = db.query(Score).filter(Score.run_id == run.id).all()
|
|
avg_score = sum(s.value for s in scores) / len(scores) if scores else 0.0
|
|
results.append({
|
|
"config": run.config,
|
|
"score": avg_score,
|
|
"run_id": str(run.id),
|
|
})
|
|
|
|
results.sort(key=lambda r: r["score"], reverse=True)
|
|
return results
|
|
|
|
|
|
def _sample_param(spec: Any) -> Any:
|
|
"""Sample a single parameter value from its specification.
|
|
|
|
Spec can be:
|
|
- A list: uniform choice
|
|
- A dict with "min"/"max": uniform float
|
|
- A dict with "min"/"max"/"step": discrete steps
|
|
"""
|
|
if isinstance(spec, list):
|
|
return random.choice(spec)
|
|
if isinstance(spec, dict):
|
|
lo = spec.get("min", 0.0)
|
|
hi = spec.get("max", 1.0)
|
|
step = spec.get("step")
|
|
if step is not None:
|
|
# Discrete steps
|
|
steps = []
|
|
v = lo
|
|
while v <= hi + 1e-9:
|
|
steps.append(round(v, 10))
|
|
v += step
|
|
return random.choice(steps) if steps else lo
|
|
return random.uniform(lo, hi)
|
|
# Scalar fallback — return as-is
|
|
return spec
|
|
|
|
|
|
def _set_config_value(config: dict[str, Any], key: str, value: Any) -> None:
|
|
"""Set a value in the config dict, supporting dotted keys like 'params.temperature'."""
|
|
parts = key.split(".")
|
|
target = config
|
|
for part in parts[:-1]:
|
|
if part not in target:
|
|
target[part] = {}
|
|
target = target[part]
|
|
target[parts[-1]] = value
|
|
|
|
|
|
def _deep_copy_config(config: dict[str, Any]) -> dict[str, Any]:
|
|
"""Deep-copy a config dict using JSON round-trip."""
|
|
return json.loads(json.dumps(config, default=str))
|