promptlooper/backend/routers/experiments.py
John Lightner 82e97e9dba MAESTRO: Implement experiments router with full CRUD and sweep control endpoints
Add complete experiments API: list (with project filter), get, create, update,
delete, plus sweep lifecycle (start/pause/resume/stop/status). Adds
SweepRequest and SweepStatusResponse schemas. Sweep dispatch routes through
Celery with synchronous fallback for single-container mode. Redis flags control
pause/resume/stop; direct DB updates used when Redis unavailable. 34 tests.
2026-04-07 03:19:43 -05:00

315 lines
10 KiB
Python

"""Experiments router — CRUD plus sweep control.
POST /experiments/{id}/sweep validates sweep config, creates Run records,
and dispatches to Celery. Pause/resume/stop set Redis flags that the sweep
runner checks between runs.
"""
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session
from auth import get_current_user
from engine.sweep import clear_sweep_flags, request_pause, request_resume, request_stop
from engine.tasks import dispatch_sweep
from main import get_db, get_redis
from models import Experiment, ExperimentStatus, Project, Run, RunStatus, User
from schemas import (
ExperimentCreate,
ExperimentListResponse,
ExperimentResponse,
ExperimentUpdate,
SweepRequest,
SweepStatusResponse,
)
router = APIRouter()
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _get_experiment_or_404(db: Session, experiment_id: uuid.UUID) -> Experiment:
experiment = db.query(Experiment).filter(Experiment.id == experiment_id).first()
if experiment is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Experiment not found")
return experiment
def _to_response(experiment: Experiment) -> ExperimentResponse:
return ExperimentResponse.model_validate(experiment)
# ---------------------------------------------------------------------------
# CRUD
# ---------------------------------------------------------------------------
@router.get("/", response_model=ExperimentListResponse)
def list_experiments(
project_id: uuid.UUID | None = Query(None),
db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
) -> ExperimentListResponse:
"""List experiments, optionally filtered by project."""
query = db.query(Experiment)
if project_id is not None:
query = query.filter(Experiment.project_id == project_id)
experiments = query.order_by(Experiment.created_at.desc()).all()
return ExperimentListResponse(
items=[_to_response(exp) for exp in experiments],
total=len(experiments),
)
@router.post("/", response_model=ExperimentResponse, status_code=status.HTTP_201_CREATED)
def create_experiment(
body: ExperimentCreate,
project_id: uuid.UUID = Query(...),
db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
) -> ExperimentResponse:
"""Create a new experiment within a project."""
# Verify project exists
project = db.query(Project).filter(Project.id == project_id).first()
if project is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Project not found")
experiment = Experiment(
project_id=project_id,
name=body.name,
description=body.description,
sample_data=body.sample_data,
pipeline_stages=body.pipeline_stages,
scoring_config=body.scoring_config,
parameter_space=body.parameter_space,
)
db.add(experiment)
db.commit()
db.refresh(experiment)
return _to_response(experiment)
@router.get("/{experiment_id}", response_model=ExperimentResponse)
def get_experiment(
experiment_id: uuid.UUID,
db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
) -> ExperimentResponse:
"""Get a single experiment."""
experiment = _get_experiment_or_404(db, experiment_id)
return _to_response(experiment)
@router.put("/{experiment_id}", response_model=ExperimentResponse)
def update_experiment(
experiment_id: uuid.UUID,
body: ExperimentUpdate,
db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
) -> ExperimentResponse:
"""Update an experiment's configuration."""
experiment = _get_experiment_or_404(db, experiment_id)
if body.name is not None:
experiment.name = body.name
if body.description is not None:
experiment.description = body.description
if body.sample_data is not None:
experiment.sample_data = body.sample_data
if body.pipeline_stages is not None:
experiment.pipeline_stages = body.pipeline_stages
if body.scoring_config is not None:
experiment.scoring_config = body.scoring_config
if body.parameter_space is not None:
experiment.parameter_space = body.parameter_space
if body.status is not None:
experiment.status = body.status
db.commit()
db.refresh(experiment)
return _to_response(experiment)
@router.delete("/{experiment_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_experiment(
experiment_id: uuid.UUID,
db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
) -> None:
"""Delete an experiment and all its runs."""
experiment = _get_experiment_or_404(db, experiment_id)
db.delete(experiment)
db.commit()
# ---------------------------------------------------------------------------
# Sweep control
# ---------------------------------------------------------------------------
@router.post("/{experiment_id}/sweep", response_model=SweepStatusResponse)
def start_sweep(
experiment_id: uuid.UUID,
body: SweepRequest,
db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
) -> SweepStatusResponse:
"""Start a sweep — validate config, create Run records, dispatch to Celery.
The sweep_config from the request body overrides the experiment's
parameter_space for this sweep execution.
"""
experiment = _get_experiment_or_404(db, experiment_id)
if experiment.status == ExperimentStatus.running:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Experiment already has a running sweep",
)
# Build the sweep config dict that the engine expects
sweep_config = {
"type": body.sweep_type,
"params": body.params or experiment.parameter_space.get("params", {}) if experiment.parameter_space else {},
"n_trials": body.n_trials,
"top_k": body.top_k,
"explore_ratio": body.explore_ratio,
}
# Dispatch via Celery / sync fallback
dispatch_sweep(str(experiment.id), sweep_config=sweep_config)
# Refresh to get the updated status (dispatch_sweep may have set it)
db.refresh(experiment)
# Count runs
run_counts = _get_run_counts(db, experiment.id)
return SweepStatusResponse(
experiment_id=experiment.id,
status=experiment.status,
**run_counts,
)
@router.get("/{experiment_id}/sweep/status", response_model=SweepStatusResponse)
def sweep_status(
experiment_id: uuid.UUID,
db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
) -> SweepStatusResponse:
"""Get the current sweep status for an experiment."""
experiment = _get_experiment_or_404(db, experiment_id)
run_counts = _get_run_counts(db, experiment.id)
return SweepStatusResponse(
experiment_id=experiment.id,
status=experiment.status,
**run_counts,
)
@router.post("/{experiment_id}/pause", status_code=status.HTTP_200_OK)
def pause_sweep(
experiment_id: uuid.UUID,
db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
) -> dict:
"""Pause a running sweep by setting the Redis pause flag."""
experiment = _get_experiment_or_404(db, experiment_id)
if experiment.status != ExperimentStatus.running:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Experiment is not currently running",
)
redis_client = get_redis()
if redis_client is not None:
request_pause(redis_client, experiment.id)
else:
# In single-container mode, directly set paused status
experiment.status = ExperimentStatus.paused
db.commit()
return {"experiment_id": str(experiment.id), "action": "pause"}
@router.post("/{experiment_id}/resume", status_code=status.HTTP_200_OK)
def resume_sweep(
experiment_id: uuid.UUID,
db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
) -> dict:
"""Resume a paused sweep by clearing the Redis pause flag."""
experiment = _get_experiment_or_404(db, experiment_id)
if experiment.status != ExperimentStatus.paused:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Experiment is not currently paused",
)
redis_client = get_redis()
if redis_client is not None:
request_resume(redis_client, experiment.id)
# Re-dispatch the sweep to continue
experiment.status = ExperimentStatus.running
db.commit()
dispatch_sweep(str(experiment.id))
return {"experiment_id": str(experiment.id), "action": "resume"}
@router.post("/{experiment_id}/stop", status_code=status.HTTP_200_OK)
def stop_sweep(
experiment_id: uuid.UUID,
db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
) -> dict:
"""Stop a running or paused sweep."""
experiment = _get_experiment_or_404(db, experiment_id)
if experiment.status not in (ExperimentStatus.running, ExperimentStatus.paused):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Experiment is not running or paused",
)
redis_client = get_redis()
if redis_client is not None:
request_stop(redis_client, experiment.id)
else:
# Mark remaining pending runs as failed
pending_runs = (
db.query(Run)
.filter(Run.experiment_id == experiment.id, Run.status == RunStatus.pending)
.all()
)
for run in pending_runs:
run.status = RunStatus.failed
experiment.status = ExperimentStatus.completed
db.commit()
return {"experiment_id": str(experiment.id), "action": "stop"}
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _get_run_counts(db: Session, experiment_id: uuid.UUID) -> dict:
"""Return run counts by status for a sweep status response."""
runs = db.query(Run).filter(Run.experiment_id == experiment_id).all()
completed = sum(1 for r in runs if r.status == RunStatus.completed)
failed = sum(1 for r in runs if r.status == RunStatus.failed)
pending = sum(1 for r in runs if r.status == RunStatus.pending)
return {
"total_runs": len(runs),
"completed_runs": completed,
"failed_runs": failed,
"pending_runs": pending,
}