"""Experiments router — CRUD plus sweep control. POST /experiments/{id}/sweep validates sweep config, creates Run records, and dispatches to Celery. Pause/resume/stop set Redis flags that the sweep runner checks between runs. """ import uuid from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy.orm import Session from auth import get_current_user from engine.sweep import clear_sweep_flags, request_pause, request_resume, request_stop from engine.tasks import dispatch_sweep from main import get_db, get_redis from models import Experiment, ExperimentStatus, Project, Run, RunStatus, User from schemas import ( ExperimentCreate, ExperimentListResponse, ExperimentResponse, ExperimentUpdate, SweepRequest, SweepStatusResponse, ) router = APIRouter() # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _get_experiment_or_404(db: Session, experiment_id: uuid.UUID) -> Experiment: experiment = db.query(Experiment).filter(Experiment.id == experiment_id).first() if experiment is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Experiment not found") return experiment def _to_response(experiment: Experiment) -> ExperimentResponse: return ExperimentResponse.model_validate(experiment) # --------------------------------------------------------------------------- # CRUD # --------------------------------------------------------------------------- @router.get("/", response_model=ExperimentListResponse) def list_experiments( project_id: uuid.UUID | None = Query(None), db: Session = Depends(get_db), _user: User = Depends(get_current_user), ) -> ExperimentListResponse: """List experiments, optionally filtered by project.""" query = db.query(Experiment) if project_id is not None: query = query.filter(Experiment.project_id == project_id) experiments = query.order_by(Experiment.created_at.desc()).all() return ExperimentListResponse( items=[_to_response(exp) for exp in experiments], total=len(experiments), ) @router.post("/", response_model=ExperimentResponse, status_code=status.HTTP_201_CREATED) def create_experiment( body: ExperimentCreate, project_id: uuid.UUID = Query(...), db: Session = Depends(get_db), _user: User = Depends(get_current_user), ) -> ExperimentResponse: """Create a new experiment within a project.""" # Verify project exists project = db.query(Project).filter(Project.id == project_id).first() if project is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Project not found") experiment = Experiment( project_id=project_id, name=body.name, description=body.description, sample_data=body.sample_data, pipeline_stages=body.pipeline_stages, scoring_config=body.scoring_config, parameter_space=body.parameter_space, ) db.add(experiment) db.commit() db.refresh(experiment) return _to_response(experiment) @router.get("/{experiment_id}", response_model=ExperimentResponse) def get_experiment( experiment_id: uuid.UUID, db: Session = Depends(get_db), _user: User = Depends(get_current_user), ) -> ExperimentResponse: """Get a single experiment.""" experiment = _get_experiment_or_404(db, experiment_id) return _to_response(experiment) @router.put("/{experiment_id}", response_model=ExperimentResponse) def update_experiment( experiment_id: uuid.UUID, body: ExperimentUpdate, db: Session = Depends(get_db), _user: User = Depends(get_current_user), ) -> ExperimentResponse: """Update an experiment's configuration.""" experiment = _get_experiment_or_404(db, experiment_id) if body.name is not None: experiment.name = body.name if body.description is not None: experiment.description = body.description if body.sample_data is not None: experiment.sample_data = body.sample_data if body.pipeline_stages is not None: experiment.pipeline_stages = body.pipeline_stages if body.scoring_config is not None: experiment.scoring_config = body.scoring_config if body.parameter_space is not None: experiment.parameter_space = body.parameter_space if body.status is not None: experiment.status = body.status db.commit() db.refresh(experiment) return _to_response(experiment) @router.delete("/{experiment_id}", status_code=status.HTTP_204_NO_CONTENT) def delete_experiment( experiment_id: uuid.UUID, db: Session = Depends(get_db), _user: User = Depends(get_current_user), ) -> None: """Delete an experiment and all its runs.""" experiment = _get_experiment_or_404(db, experiment_id) db.delete(experiment) db.commit() # --------------------------------------------------------------------------- # Sweep control # --------------------------------------------------------------------------- @router.post("/{experiment_id}/sweep", response_model=SweepStatusResponse) def start_sweep( experiment_id: uuid.UUID, body: SweepRequest, db: Session = Depends(get_db), _user: User = Depends(get_current_user), ) -> SweepStatusResponse: """Start a sweep — validate config, create Run records, dispatch to Celery. The sweep_config from the request body overrides the experiment's parameter_space for this sweep execution. """ experiment = _get_experiment_or_404(db, experiment_id) if experiment.status == ExperimentStatus.running: raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail="Experiment already has a running sweep", ) # Build the sweep config dict that the engine expects sweep_config = { "type": body.sweep_type, "params": body.params or experiment.parameter_space.get("params", {}) if experiment.parameter_space else {}, "n_trials": body.n_trials, "top_k": body.top_k, "explore_ratio": body.explore_ratio, } # Dispatch via Celery / sync fallback dispatch_sweep(str(experiment.id), sweep_config=sweep_config) # Refresh to get the updated status (dispatch_sweep may have set it) db.refresh(experiment) # Count runs run_counts = _get_run_counts(db, experiment.id) return SweepStatusResponse( experiment_id=experiment.id, status=experiment.status, **run_counts, ) @router.get("/{experiment_id}/sweep/status", response_model=SweepStatusResponse) def sweep_status( experiment_id: uuid.UUID, db: Session = Depends(get_db), _user: User = Depends(get_current_user), ) -> SweepStatusResponse: """Get the current sweep status for an experiment.""" experiment = _get_experiment_or_404(db, experiment_id) run_counts = _get_run_counts(db, experiment.id) return SweepStatusResponse( experiment_id=experiment.id, status=experiment.status, **run_counts, ) @router.post("/{experiment_id}/pause", status_code=status.HTTP_200_OK) def pause_sweep( experiment_id: uuid.UUID, db: Session = Depends(get_db), _user: User = Depends(get_current_user), ) -> dict: """Pause a running sweep by setting the Redis pause flag.""" experiment = _get_experiment_or_404(db, experiment_id) if experiment.status != ExperimentStatus.running: raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail="Experiment is not currently running", ) redis_client = get_redis() if redis_client is not None: request_pause(redis_client, experiment.id) else: # In single-container mode, directly set paused status experiment.status = ExperimentStatus.paused db.commit() return {"experiment_id": str(experiment.id), "action": "pause"} @router.post("/{experiment_id}/resume", status_code=status.HTTP_200_OK) def resume_sweep( experiment_id: uuid.UUID, db: Session = Depends(get_db), _user: User = Depends(get_current_user), ) -> dict: """Resume a paused sweep by clearing the Redis pause flag.""" experiment = _get_experiment_or_404(db, experiment_id) if experiment.status != ExperimentStatus.paused: raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail="Experiment is not currently paused", ) redis_client = get_redis() if redis_client is not None: request_resume(redis_client, experiment.id) # Re-dispatch the sweep to continue experiment.status = ExperimentStatus.running db.commit() dispatch_sweep(str(experiment.id)) return {"experiment_id": str(experiment.id), "action": "resume"} @router.post("/{experiment_id}/stop", status_code=status.HTTP_200_OK) def stop_sweep( experiment_id: uuid.UUID, db: Session = Depends(get_db), _user: User = Depends(get_current_user), ) -> dict: """Stop a running or paused sweep.""" experiment = _get_experiment_or_404(db, experiment_id) if experiment.status not in (ExperimentStatus.running, ExperimentStatus.paused): raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail="Experiment is not running or paused", ) redis_client = get_redis() if redis_client is not None: request_stop(redis_client, experiment.id) else: # Mark remaining pending runs as failed pending_runs = ( db.query(Run) .filter(Run.experiment_id == experiment.id, Run.status == RunStatus.pending) .all() ) for run in pending_runs: run.status = RunStatus.failed experiment.status = ExperimentStatus.completed db.commit() return {"experiment_id": str(experiment.id), "action": "stop"} # --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _get_run_counts(db: Session, experiment_id: uuid.UUID) -> dict: """Return run counts by status for a sweep status response.""" runs = db.query(Run).filter(Run.experiment_id == experiment_id).all() completed = sum(1 for r in runs if r.status == RunStatus.completed) failed = sum(1 for r in runs if r.status == RunStatus.failed) pending = sum(1 for r in runs if r.status == RunStatus.pending) return { "total_runs": len(runs), "completed_runs": completed, "failed_runs": failed, "pending_runs": pending, }