Full webhook system: CRUD endpoints (list/filter/get/create/update/delete), WebhookDelivery model for delivery audit trail, dispatch engine with 3-attempt retry and exponential backoff, Celery task integration with sync fallback, and webhook firing hooks in runner.py and sweep.py event paths.
536 lines
17 KiB
Python
536 lines
17 KiB
Python
"""Sweep orchestration for PromptLooper.
|
|
|
|
Supports three sweep types:
|
|
- GridSweep: enumerate all combinations from parameter_space
|
|
- RandomSweep: sample N random configs from parameter ranges
|
|
- GuidedSweep: use previous results to pick next configs (top-K exploitation + exploration)
|
|
|
|
The sweep runner manages parallelism, token budgets, pause/resume/stop via
|
|
Redis flags, and emits events for each run completion.
|
|
"""
|
|
|
|
import asyncio
|
|
import itertools
|
|
import json
|
|
import logging
|
|
import random
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from engine.adapters.base import BaseAdapter
|
|
from engine.cache import ResponseCacheLayer, compute_config_hash
|
|
from engine.runner import EventBus, run_single
|
|
from models import Experiment, ExperimentStatus, Run, RunStatus, Score
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Sweep configuration dataclasses
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@dataclass
|
|
class SweepConfig:
|
|
"""Parsed sweep configuration from experiment.parameter_space."""
|
|
|
|
sweep_type: str # "grid", "random", "guided"
|
|
params: dict[str, Any] = field(default_factory=dict)
|
|
n_trials: int = 100 # for random/guided
|
|
top_k: int = 5 # for guided: exploit top K results
|
|
explore_ratio: float = 0.3 # for guided: fraction of random exploration
|
|
|
|
|
|
@dataclass
|
|
class SweepResult:
|
|
"""Summary returned after a sweep completes or is stopped."""
|
|
|
|
experiment_id: uuid.UUID
|
|
total_runs: int
|
|
completed_runs: int
|
|
failed_runs: int
|
|
total_tokens_in: int
|
|
total_tokens_out: int
|
|
stopped_reason: str | None = None # "completed", "stopped", "paused", "token_budget"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Redis flag keys for sweep control
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_SWEEP_STOP_KEY = "promptlooper:sweep:{experiment_id}:stop"
|
|
_SWEEP_PAUSE_KEY = "promptlooper:sweep:{experiment_id}:pause"
|
|
|
|
|
|
def _stop_key(experiment_id: uuid.UUID) -> str:
|
|
return _SWEEP_STOP_KEY.format(experiment_id=experiment_id)
|
|
|
|
|
|
def _pause_key(experiment_id: uuid.UUID) -> str:
|
|
return _SWEEP_PAUSE_KEY.format(experiment_id=experiment_id)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Configuration generation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def generate_grid_configs(
|
|
base_config: dict[str, Any],
|
|
param_space: dict[str, list],
|
|
) -> list[dict[str, Any]]:
|
|
"""Enumerate all combinations from parameter_space grid values.
|
|
|
|
Args:
|
|
base_config: Base run config (pipeline_stages, input_data, etc.)
|
|
param_space: Dict mapping param names to lists of values.
|
|
Supports top-level keys "model" and nested "params.*" keys.
|
|
|
|
Returns:
|
|
List of fully-formed run configs, one per combination.
|
|
"""
|
|
if not param_space:
|
|
return [base_config.copy()]
|
|
|
|
keys = sorted(param_space.keys())
|
|
value_lists = [param_space[k] for k in keys]
|
|
configs = []
|
|
|
|
for combo in itertools.product(*value_lists):
|
|
cfg = _deep_copy_config(base_config)
|
|
for key, value in zip(keys, combo):
|
|
_set_config_value(cfg, key, value)
|
|
configs.append(cfg)
|
|
|
|
return configs
|
|
|
|
|
|
def generate_random_configs(
|
|
base_config: dict[str, Any],
|
|
param_space: dict[str, Any],
|
|
n_trials: int,
|
|
) -> list[dict[str, Any]]:
|
|
"""Sample N random configs from parameter ranges.
|
|
|
|
Each param_space value can be:
|
|
- A list: sample uniformly from the list
|
|
- A dict with "min"/"max": sample uniformly from the continuous range
|
|
- A dict with "min"/"max"/"step": sample from discrete steps
|
|
"""
|
|
configs = []
|
|
for _ in range(n_trials):
|
|
cfg = _deep_copy_config(base_config)
|
|
for key, spec in param_space.items():
|
|
value = _sample_param(spec)
|
|
_set_config_value(cfg, key, value)
|
|
configs.append(cfg)
|
|
return configs
|
|
|
|
|
|
def generate_guided_configs(
|
|
base_config: dict[str, Any],
|
|
param_space: dict[str, Any],
|
|
previous_results: list[dict[str, Any]],
|
|
n_trials: int,
|
|
top_k: int = 5,
|
|
explore_ratio: float = 0.3,
|
|
) -> list[dict[str, Any]]:
|
|
"""Generate configs using top-K exploitation + random exploration.
|
|
|
|
Args:
|
|
base_config: Base run config.
|
|
param_space: Parameter search space.
|
|
previous_results: List of dicts with "config" and "score" keys
|
|
from previous runs, sorted by score descending.
|
|
n_trials: Total number of configs to generate.
|
|
top_k: Number of top results to exploit.
|
|
explore_ratio: Fraction of trials devoted to random exploration.
|
|
|
|
Returns:
|
|
List of run configs mixing exploitation and exploration.
|
|
"""
|
|
n_explore = max(1, int(n_trials * explore_ratio))
|
|
n_exploit = n_trials - n_explore
|
|
|
|
configs = []
|
|
|
|
# Exploitation: perturb top-K configs
|
|
if previous_results and n_exploit > 0:
|
|
top_results = previous_results[:top_k]
|
|
for i in range(n_exploit):
|
|
base_result = top_results[i % len(top_results)]
|
|
source_cfg = base_result.get("config", base_config)
|
|
cfg = _deep_copy_config(source_cfg)
|
|
# Perturb one random parameter
|
|
if param_space:
|
|
perturb_key = random.choice(list(param_space.keys()))
|
|
value = _sample_param(param_space[perturb_key])
|
|
_set_config_value(cfg, perturb_key, value)
|
|
configs.append(cfg)
|
|
else:
|
|
# No previous results — fall back to all random
|
|
n_explore = n_trials
|
|
|
|
# Exploration: random configs
|
|
configs.extend(
|
|
generate_random_configs(base_config, param_space, n_explore)
|
|
)
|
|
|
|
return configs[:n_trials]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Sweep runner
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def run_sweep(
|
|
db: Session,
|
|
experiment: Experiment,
|
|
adapter: BaseAdapter,
|
|
cache: ResponseCacheLayer,
|
|
scorers: list[Any] | None = None,
|
|
event_bus: EventBus | None = None,
|
|
redis_client: Any | None = None,
|
|
max_concurrent: int = 4,
|
|
max_tokens: int = 0,
|
|
) -> SweepResult:
|
|
"""Execute a full sweep for an experiment.
|
|
|
|
Generates run configs based on experiment.parameter_space, creates Run
|
|
records, and executes them with bounded parallelism.
|
|
|
|
Args:
|
|
db: SQLAlchemy session.
|
|
experiment: Experiment ORM object.
|
|
adapter: LLM adapter for completions.
|
|
cache: Response cache layer.
|
|
scorers: Optional scorer instances.
|
|
event_bus: Optional event publisher.
|
|
redis_client: Optional Redis client for pause/stop flags.
|
|
max_concurrent: Maximum parallel runs.
|
|
max_tokens: Token budget (0 = unlimited).
|
|
|
|
Returns:
|
|
SweepResult summary.
|
|
"""
|
|
scorers = scorers or []
|
|
sweep_config = _parse_sweep_config(experiment.parameter_space or {})
|
|
base_config = _build_base_config(experiment)
|
|
|
|
# Generate configs based on sweep type
|
|
if sweep_config.sweep_type == "grid":
|
|
configs = generate_grid_configs(base_config, sweep_config.params)
|
|
elif sweep_config.sweep_type == "random":
|
|
configs = generate_random_configs(
|
|
base_config, sweep_config.params, sweep_config.n_trials
|
|
)
|
|
elif sweep_config.sweep_type == "guided":
|
|
# For guided, first gather any previous completed results
|
|
previous = _get_previous_results(db, experiment.id)
|
|
configs = generate_guided_configs(
|
|
base_config,
|
|
sweep_config.params,
|
|
previous,
|
|
sweep_config.n_trials,
|
|
sweep_config.top_k,
|
|
sweep_config.explore_ratio,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown sweep type: {sweep_config.sweep_type}")
|
|
|
|
# Create Run records
|
|
runs = _create_run_records(db, experiment, configs)
|
|
|
|
# Update experiment status
|
|
experiment.status = ExperimentStatus.running
|
|
db.commit()
|
|
|
|
if event_bus:
|
|
event_bus.publish({
|
|
"type": "sweep.started",
|
|
"experiment_id": str(experiment.id),
|
|
"total_runs": len(runs),
|
|
"sweep_type": sweep_config.sweep_type,
|
|
})
|
|
|
|
# Execute runs with bounded parallelism
|
|
result = await _execute_runs(
|
|
db=db,
|
|
experiment=experiment,
|
|
runs=runs,
|
|
adapter=adapter,
|
|
cache=cache,
|
|
scorers=scorers,
|
|
event_bus=event_bus,
|
|
redis_client=redis_client,
|
|
max_concurrent=max_concurrent,
|
|
max_tokens=max_tokens,
|
|
)
|
|
|
|
# Update experiment status
|
|
if result.stopped_reason == "paused":
|
|
experiment.status = ExperimentStatus.paused
|
|
else:
|
|
experiment.status = ExperimentStatus.completed
|
|
db.commit()
|
|
|
|
sweep_completed_event = {
|
|
"type": "sweep.completed",
|
|
"experiment_id": str(experiment.id),
|
|
"total_runs": result.total_runs,
|
|
"completed_runs": result.completed_runs,
|
|
"failed_runs": result.failed_runs,
|
|
"stopped_reason": result.stopped_reason,
|
|
}
|
|
if event_bus:
|
|
event_bus.publish(sweep_completed_event)
|
|
|
|
# Fire webhooks asynchronously
|
|
try:
|
|
from engine.tasks import fire_webhooks
|
|
fire_webhooks("sweep.completed", sweep_completed_event)
|
|
except Exception:
|
|
logger.warning("Failed to dispatch sweep.completed webhooks", exc_info=True)
|
|
|
|
return result
|
|
|
|
|
|
async def _execute_runs(
|
|
db: Session,
|
|
experiment: Experiment,
|
|
runs: list[Run],
|
|
adapter: BaseAdapter,
|
|
cache: ResponseCacheLayer,
|
|
scorers: list[Any],
|
|
event_bus: EventBus | None,
|
|
redis_client: Any | None,
|
|
max_concurrent: int,
|
|
max_tokens: int,
|
|
) -> SweepResult:
|
|
"""Execute runs with bounded concurrency, respecting stop/pause/token budget."""
|
|
semaphore = asyncio.Semaphore(max_concurrent)
|
|
completed = 0
|
|
failed = 0
|
|
total_tokens_in = 0
|
|
total_tokens_out = 0
|
|
stopped_reason: str | None = "completed"
|
|
|
|
async def _run_one(run: Run) -> None:
|
|
nonlocal completed, failed, total_tokens_in, total_tokens_out
|
|
async with semaphore:
|
|
try:
|
|
result = await run_single(
|
|
db, run, adapter, cache, scorers=scorers, event_bus=event_bus
|
|
)
|
|
completed += 1
|
|
total_tokens_in += result.tokens_in or 0
|
|
total_tokens_out += result.tokens_out or 0
|
|
|
|
if event_bus:
|
|
event_bus.publish({
|
|
"type": "sweep.run_completed",
|
|
"experiment_id": str(experiment.id),
|
|
"run_id": str(run.id),
|
|
"completed": completed,
|
|
"total": len(runs),
|
|
})
|
|
except Exception:
|
|
failed += 1
|
|
logger.warning(
|
|
"Sweep run %s failed", run.id, exc_info=True
|
|
)
|
|
|
|
pending_runs = list(runs)
|
|
while pending_runs:
|
|
# Check stop/pause flags
|
|
if redis_client is not None:
|
|
if redis_client.get(_stop_key(experiment.id)):
|
|
stopped_reason = "stopped"
|
|
# Mark remaining runs as failed
|
|
for r in pending_runs:
|
|
if r.status == RunStatus.pending:
|
|
r.status = RunStatus.failed
|
|
db.commit()
|
|
break
|
|
if redis_client.get(_pause_key(experiment.id)):
|
|
stopped_reason = "paused"
|
|
break
|
|
|
|
# Check token budget
|
|
if max_tokens > 0 and (total_tokens_in + total_tokens_out) >= max_tokens:
|
|
stopped_reason = "token_budget"
|
|
break
|
|
|
|
# Take a batch up to max_concurrent
|
|
batch_size = min(max_concurrent, len(pending_runs))
|
|
batch = pending_runs[:batch_size]
|
|
pending_runs = pending_runs[batch_size:]
|
|
|
|
# Run batch concurrently
|
|
await asyncio.gather(*[_run_one(run) for run in batch])
|
|
|
|
return SweepResult(
|
|
experiment_id=experiment.id,
|
|
total_runs=len(runs),
|
|
completed_runs=completed,
|
|
failed_runs=failed,
|
|
total_tokens_in=total_tokens_in,
|
|
total_tokens_out=total_tokens_out,
|
|
stopped_reason=stopped_reason,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Sweep control helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def request_stop(redis_client: Any, experiment_id: uuid.UUID) -> None:
|
|
"""Set the stop flag for a running sweep."""
|
|
redis_client.set(_stop_key(experiment_id), "1", ex=3600)
|
|
|
|
|
|
def request_pause(redis_client: Any, experiment_id: uuid.UUID) -> None:
|
|
"""Set the pause flag for a running sweep."""
|
|
redis_client.set(_pause_key(experiment_id), "1", ex=3600)
|
|
|
|
|
|
def request_resume(redis_client: Any, experiment_id: uuid.UUID) -> None:
|
|
"""Clear the pause flag to allow a sweep to continue."""
|
|
redis_client.delete(_pause_key(experiment_id))
|
|
|
|
|
|
def clear_sweep_flags(redis_client: Any, experiment_id: uuid.UUID) -> None:
|
|
"""Clear all control flags for an experiment."""
|
|
redis_client.delete(_stop_key(experiment_id))
|
|
redis_client.delete(_pause_key(experiment_id))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Internal helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _parse_sweep_config(parameter_space: dict[str, Any]) -> SweepConfig:
|
|
"""Parse experiment.parameter_space into a SweepConfig."""
|
|
return SweepConfig(
|
|
sweep_type=parameter_space.get("type", "grid"),
|
|
params=parameter_space.get("params", {}),
|
|
n_trials=parameter_space.get("n_trials", 100),
|
|
top_k=parameter_space.get("top_k", 5),
|
|
explore_ratio=parameter_space.get("explore_ratio", 0.3),
|
|
)
|
|
|
|
|
|
def _build_base_config(experiment: Experiment) -> dict[str, Any]:
|
|
"""Build the base run config from experiment fields."""
|
|
config: dict[str, Any] = {}
|
|
if experiment.pipeline_stages:
|
|
config["pipeline_stages"] = experiment.pipeline_stages
|
|
if experiment.sample_data:
|
|
config["input_data"] = experiment.sample_data
|
|
if experiment.scoring_config:
|
|
config["scoring_config"] = experiment.scoring_config
|
|
return config
|
|
|
|
|
|
def _create_run_records(
|
|
db: Session,
|
|
experiment: Experiment,
|
|
configs: list[dict[str, Any]],
|
|
) -> list[Run]:
|
|
"""Create Run records for each config and persist them."""
|
|
runs = []
|
|
for cfg in configs:
|
|
config_hash = compute_config_hash(
|
|
cfg.get("prompt", cfg.get("pipeline_stages", [{}])[0].get("prompt_template", "")),
|
|
cfg.get("model", ""),
|
|
cfg.get("params", {}),
|
|
cfg.get("input_data"),
|
|
)
|
|
run = Run(
|
|
experiment_id=experiment.id,
|
|
config_hash=config_hash,
|
|
config=cfg,
|
|
status=RunStatus.pending,
|
|
)
|
|
db.add(run)
|
|
runs.append(run)
|
|
db.commit()
|
|
for run in runs:
|
|
db.refresh(run)
|
|
return runs
|
|
|
|
|
|
def _get_previous_results(
|
|
db: Session, experiment_id: uuid.UUID
|
|
) -> list[dict[str, Any]]:
|
|
"""Get previous completed runs with their average scores, sorted by score desc."""
|
|
runs = (
|
|
db.query(Run)
|
|
.filter(
|
|
Run.experiment_id == experiment_id,
|
|
Run.status == RunStatus.completed,
|
|
)
|
|
.all()
|
|
)
|
|
|
|
results = []
|
|
for run in runs:
|
|
scores = db.query(Score).filter(Score.run_id == run.id).all()
|
|
avg_score = sum(s.value for s in scores) / len(scores) if scores else 0.0
|
|
results.append({
|
|
"config": run.config,
|
|
"score": avg_score,
|
|
"run_id": str(run.id),
|
|
})
|
|
|
|
results.sort(key=lambda r: r["score"], reverse=True)
|
|
return results
|
|
|
|
|
|
def _sample_param(spec: Any) -> Any:
|
|
"""Sample a single parameter value from its specification.
|
|
|
|
Spec can be:
|
|
- A list: uniform choice
|
|
- A dict with "min"/"max": uniform float
|
|
- A dict with "min"/"max"/"step": discrete steps
|
|
"""
|
|
if isinstance(spec, list):
|
|
return random.choice(spec)
|
|
if isinstance(spec, dict):
|
|
lo = spec.get("min", 0.0)
|
|
hi = spec.get("max", 1.0)
|
|
step = spec.get("step")
|
|
if step is not None:
|
|
# Discrete steps
|
|
steps = []
|
|
v = lo
|
|
while v <= hi + 1e-9:
|
|
steps.append(round(v, 10))
|
|
v += step
|
|
return random.choice(steps) if steps else lo
|
|
return random.uniform(lo, hi)
|
|
# Scalar fallback — return as-is
|
|
return spec
|
|
|
|
|
|
def _set_config_value(config: dict[str, Any], key: str, value: Any) -> None:
|
|
"""Set a value in the config dict, supporting dotted keys like 'params.temperature'."""
|
|
parts = key.split(".")
|
|
target = config
|
|
for part in parts[:-1]:
|
|
if part not in target:
|
|
target[part] = {}
|
|
target = target[part]
|
|
target[parts[-1]] = value
|
|
|
|
|
|
def _deep_copy_config(config: dict[str, Any]) -> dict[str, Any]:
|
|
"""Deep-copy a config dict using JSON round-trip."""
|
|
return json.loads(json.dumps(config, default=str))
|