Add complete experiments API: list (with project filter), get, create, update, delete, plus sweep lifecycle (start/pause/resume/stop/status). Adds SweepRequest and SweepStatusResponse schemas. Sweep dispatch routes through Celery with synchronous fallback for single-container mode. Redis flags control pause/resume/stop; direct DB updates used when Redis unavailable. 34 tests.
315 lines
10 KiB
Python
315 lines
10 KiB
Python
"""Experiments router — CRUD plus sweep control.
|
|
|
|
POST /experiments/{id}/sweep validates sweep config, creates Run records,
|
|
and dispatches to Celery. Pause/resume/stop set Redis flags that the sweep
|
|
runner checks between runs.
|
|
"""
|
|
|
|
import uuid
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|
from sqlalchemy.orm import Session
|
|
|
|
from auth import get_current_user
|
|
from engine.sweep import clear_sweep_flags, request_pause, request_resume, request_stop
|
|
from engine.tasks import dispatch_sweep
|
|
from main import get_db, get_redis
|
|
from models import Experiment, ExperimentStatus, Project, Run, RunStatus, User
|
|
from schemas import (
|
|
ExperimentCreate,
|
|
ExperimentListResponse,
|
|
ExperimentResponse,
|
|
ExperimentUpdate,
|
|
SweepRequest,
|
|
SweepStatusResponse,
|
|
)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _get_experiment_or_404(db: Session, experiment_id: uuid.UUID) -> Experiment:
|
|
experiment = db.query(Experiment).filter(Experiment.id == experiment_id).first()
|
|
if experiment is None:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Experiment not found")
|
|
return experiment
|
|
|
|
|
|
def _to_response(experiment: Experiment) -> ExperimentResponse:
|
|
return ExperimentResponse.model_validate(experiment)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CRUD
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@router.get("/", response_model=ExperimentListResponse)
|
|
def list_experiments(
|
|
project_id: uuid.UUID | None = Query(None),
|
|
db: Session = Depends(get_db),
|
|
_user: User = Depends(get_current_user),
|
|
) -> ExperimentListResponse:
|
|
"""List experiments, optionally filtered by project."""
|
|
query = db.query(Experiment)
|
|
if project_id is not None:
|
|
query = query.filter(Experiment.project_id == project_id)
|
|
experiments = query.order_by(Experiment.created_at.desc()).all()
|
|
return ExperimentListResponse(
|
|
items=[_to_response(exp) for exp in experiments],
|
|
total=len(experiments),
|
|
)
|
|
|
|
|
|
@router.post("/", response_model=ExperimentResponse, status_code=status.HTTP_201_CREATED)
|
|
def create_experiment(
|
|
body: ExperimentCreate,
|
|
project_id: uuid.UUID = Query(...),
|
|
db: Session = Depends(get_db),
|
|
_user: User = Depends(get_current_user),
|
|
) -> ExperimentResponse:
|
|
"""Create a new experiment within a project."""
|
|
# Verify project exists
|
|
project = db.query(Project).filter(Project.id == project_id).first()
|
|
if project is None:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Project not found")
|
|
|
|
experiment = Experiment(
|
|
project_id=project_id,
|
|
name=body.name,
|
|
description=body.description,
|
|
sample_data=body.sample_data,
|
|
pipeline_stages=body.pipeline_stages,
|
|
scoring_config=body.scoring_config,
|
|
parameter_space=body.parameter_space,
|
|
)
|
|
db.add(experiment)
|
|
db.commit()
|
|
db.refresh(experiment)
|
|
return _to_response(experiment)
|
|
|
|
|
|
@router.get("/{experiment_id}", response_model=ExperimentResponse)
|
|
def get_experiment(
|
|
experiment_id: uuid.UUID,
|
|
db: Session = Depends(get_db),
|
|
_user: User = Depends(get_current_user),
|
|
) -> ExperimentResponse:
|
|
"""Get a single experiment."""
|
|
experiment = _get_experiment_or_404(db, experiment_id)
|
|
return _to_response(experiment)
|
|
|
|
|
|
@router.put("/{experiment_id}", response_model=ExperimentResponse)
|
|
def update_experiment(
|
|
experiment_id: uuid.UUID,
|
|
body: ExperimentUpdate,
|
|
db: Session = Depends(get_db),
|
|
_user: User = Depends(get_current_user),
|
|
) -> ExperimentResponse:
|
|
"""Update an experiment's configuration."""
|
|
experiment = _get_experiment_or_404(db, experiment_id)
|
|
|
|
if body.name is not None:
|
|
experiment.name = body.name
|
|
if body.description is not None:
|
|
experiment.description = body.description
|
|
if body.sample_data is not None:
|
|
experiment.sample_data = body.sample_data
|
|
if body.pipeline_stages is not None:
|
|
experiment.pipeline_stages = body.pipeline_stages
|
|
if body.scoring_config is not None:
|
|
experiment.scoring_config = body.scoring_config
|
|
if body.parameter_space is not None:
|
|
experiment.parameter_space = body.parameter_space
|
|
if body.status is not None:
|
|
experiment.status = body.status
|
|
|
|
db.commit()
|
|
db.refresh(experiment)
|
|
return _to_response(experiment)
|
|
|
|
|
|
@router.delete("/{experiment_id}", status_code=status.HTTP_204_NO_CONTENT)
|
|
def delete_experiment(
|
|
experiment_id: uuid.UUID,
|
|
db: Session = Depends(get_db),
|
|
_user: User = Depends(get_current_user),
|
|
) -> None:
|
|
"""Delete an experiment and all its runs."""
|
|
experiment = _get_experiment_or_404(db, experiment_id)
|
|
db.delete(experiment)
|
|
db.commit()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Sweep control
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@router.post("/{experiment_id}/sweep", response_model=SweepStatusResponse)
|
|
def start_sweep(
|
|
experiment_id: uuid.UUID,
|
|
body: SweepRequest,
|
|
db: Session = Depends(get_db),
|
|
_user: User = Depends(get_current_user),
|
|
) -> SweepStatusResponse:
|
|
"""Start a sweep — validate config, create Run records, dispatch to Celery.
|
|
|
|
The sweep_config from the request body overrides the experiment's
|
|
parameter_space for this sweep execution.
|
|
"""
|
|
experiment = _get_experiment_or_404(db, experiment_id)
|
|
|
|
if experiment.status == ExperimentStatus.running:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_409_CONFLICT,
|
|
detail="Experiment already has a running sweep",
|
|
)
|
|
|
|
# Build the sweep config dict that the engine expects
|
|
sweep_config = {
|
|
"type": body.sweep_type,
|
|
"params": body.params or experiment.parameter_space.get("params", {}) if experiment.parameter_space else {},
|
|
"n_trials": body.n_trials,
|
|
"top_k": body.top_k,
|
|
"explore_ratio": body.explore_ratio,
|
|
}
|
|
|
|
# Dispatch via Celery / sync fallback
|
|
dispatch_sweep(str(experiment.id), sweep_config=sweep_config)
|
|
|
|
# Refresh to get the updated status (dispatch_sweep may have set it)
|
|
db.refresh(experiment)
|
|
|
|
# Count runs
|
|
run_counts = _get_run_counts(db, experiment.id)
|
|
|
|
return SweepStatusResponse(
|
|
experiment_id=experiment.id,
|
|
status=experiment.status,
|
|
**run_counts,
|
|
)
|
|
|
|
|
|
@router.get("/{experiment_id}/sweep/status", response_model=SweepStatusResponse)
|
|
def sweep_status(
|
|
experiment_id: uuid.UUID,
|
|
db: Session = Depends(get_db),
|
|
_user: User = Depends(get_current_user),
|
|
) -> SweepStatusResponse:
|
|
"""Get the current sweep status for an experiment."""
|
|
experiment = _get_experiment_or_404(db, experiment_id)
|
|
run_counts = _get_run_counts(db, experiment.id)
|
|
return SweepStatusResponse(
|
|
experiment_id=experiment.id,
|
|
status=experiment.status,
|
|
**run_counts,
|
|
)
|
|
|
|
|
|
@router.post("/{experiment_id}/pause", status_code=status.HTTP_200_OK)
|
|
def pause_sweep(
|
|
experiment_id: uuid.UUID,
|
|
db: Session = Depends(get_db),
|
|
_user: User = Depends(get_current_user),
|
|
) -> dict:
|
|
"""Pause a running sweep by setting the Redis pause flag."""
|
|
experiment = _get_experiment_or_404(db, experiment_id)
|
|
if experiment.status != ExperimentStatus.running:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_409_CONFLICT,
|
|
detail="Experiment is not currently running",
|
|
)
|
|
|
|
redis_client = get_redis()
|
|
if redis_client is not None:
|
|
request_pause(redis_client, experiment.id)
|
|
else:
|
|
# In single-container mode, directly set paused status
|
|
experiment.status = ExperimentStatus.paused
|
|
db.commit()
|
|
|
|
return {"experiment_id": str(experiment.id), "action": "pause"}
|
|
|
|
|
|
@router.post("/{experiment_id}/resume", status_code=status.HTTP_200_OK)
|
|
def resume_sweep(
|
|
experiment_id: uuid.UUID,
|
|
db: Session = Depends(get_db),
|
|
_user: User = Depends(get_current_user),
|
|
) -> dict:
|
|
"""Resume a paused sweep by clearing the Redis pause flag."""
|
|
experiment = _get_experiment_or_404(db, experiment_id)
|
|
if experiment.status != ExperimentStatus.paused:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_409_CONFLICT,
|
|
detail="Experiment is not currently paused",
|
|
)
|
|
|
|
redis_client = get_redis()
|
|
if redis_client is not None:
|
|
request_resume(redis_client, experiment.id)
|
|
|
|
# Re-dispatch the sweep to continue
|
|
experiment.status = ExperimentStatus.running
|
|
db.commit()
|
|
|
|
dispatch_sweep(str(experiment.id))
|
|
|
|
return {"experiment_id": str(experiment.id), "action": "resume"}
|
|
|
|
|
|
@router.post("/{experiment_id}/stop", status_code=status.HTTP_200_OK)
|
|
def stop_sweep(
|
|
experiment_id: uuid.UUID,
|
|
db: Session = Depends(get_db),
|
|
_user: User = Depends(get_current_user),
|
|
) -> dict:
|
|
"""Stop a running or paused sweep."""
|
|
experiment = _get_experiment_or_404(db, experiment_id)
|
|
if experiment.status not in (ExperimentStatus.running, ExperimentStatus.paused):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_409_CONFLICT,
|
|
detail="Experiment is not running or paused",
|
|
)
|
|
|
|
redis_client = get_redis()
|
|
if redis_client is not None:
|
|
request_stop(redis_client, experiment.id)
|
|
else:
|
|
# Mark remaining pending runs as failed
|
|
pending_runs = (
|
|
db.query(Run)
|
|
.filter(Run.experiment_id == experiment.id, Run.status == RunStatus.pending)
|
|
.all()
|
|
)
|
|
for run in pending_runs:
|
|
run.status = RunStatus.failed
|
|
experiment.status = ExperimentStatus.completed
|
|
db.commit()
|
|
|
|
return {"experiment_id": str(experiment.id), "action": "stop"}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Internal helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _get_run_counts(db: Session, experiment_id: uuid.UUID) -> dict:
|
|
"""Return run counts by status for a sweep status response."""
|
|
runs = db.query(Run).filter(Run.experiment_id == experiment_id).all()
|
|
completed = sum(1 for r in runs if r.status == RunStatus.completed)
|
|
failed = sum(1 for r in runs if r.status == RunStatus.failed)
|
|
pending = sum(1 for r in runs if r.status == RunStatus.pending)
|
|
return {
|
|
"total_runs": len(runs),
|
|
"completed_runs": completed,
|
|
"failed_runs": failed,
|
|
"pending_runs": pending,
|
|
}
|