MAESTRO: Implement experiments router with full CRUD and sweep control endpoints

Add complete experiments API: list (with project filter), get, create, update,
delete, plus sweep lifecycle (start/pause/resume/stop/status). Adds
SweepRequest and SweepStatusResponse schemas. Sweep dispatch routes through
Celery with synchronous fallback for single-container mode. Redis flags control
pause/resume/stop; direct DB updates used when Redis unavailable. 34 tests.
This commit is contained in:
John Lightner 2026-04-07 03:19:43 -05:00
parent 59f18a11c3
commit 82e97e9dba
5 changed files with 881 additions and 49 deletions

View file

@ -35,7 +35,8 @@ Implement the core experiment execution engine: LLM adapters, response caching,
- [x] Implement backend/routers/endpoints.py fully — CRUD for LLM endpoint configurations. The test endpoint should call adapter.test_connection() and adapter.list_models() and return the results. Store endpoint configs in the database with encrypted API keys (Fernet symmetric encryption, key derived from JWT_SECRET). - [x] Implement backend/routers/endpoints.py fully — CRUD for LLM endpoint configurations. The test endpoint should call adapter.test_connection() and adapter.list_models() and return the results. Store endpoint configs in the database with encrypted API keys (Fernet symmetric encryption, key derived from JWT_SECRET).
<!-- Completed: Full CRUD (list/get/create/update/delete) + test_connection endpoint. LLMEndpoint model added to models.py. Fernet encryption via encryption.py (PBKDF2 key derivation from JWT_SECRET). API keys never exposed in responses; has_api_key boolean flag added to EndpointResponse. 25 tests in test_endpoints.py, all passing. --> <!-- Completed: Full CRUD (list/get/create/update/delete) + test_connection endpoint. LLMEndpoint model added to models.py. Fernet encryption via encryption.py (PBKDF2 key derivation from JWT_SECRET). API keys never exposed in responses; has_api_key boolean flag added to EndpointResponse. 25 tests in test_endpoints.py, all passing. -->
- [ ] Implement backend/routers/experiments.py fully — CRUD plus sweep control. POST /experiments/{id}/sweep should validate the sweep config, create Run records for all configurations, and dispatch to Celery. Pause/resume/stop should set Redis flags that the sweep runner checks between runs. - [x] Implement backend/routers/experiments.py fully — CRUD plus sweep control. POST /experiments/{id}/sweep should validate the sweep config, create Run records for all configurations, and dispatch to Celery. Pause/resume/stop should set Redis flags that the sweep runner checks between runs.
<!-- Completed: Full CRUD (list with project filter, get, create, update, delete) + sweep control (start/pause/resume/stop + status). SweepRequest/SweepStatusResponse schemas added. Sweep dispatch via Celery/sync fallback. Redis flags for pause/resume/stop, with single-container mode fallback. 34 tests in test_experiments.py, all passing. -->
- [ ] Implement backend/routers/runs.py fully — list runs with filtering (by experiment, status, score range), get run detail with stage results and scores, POST for ad-hoc single runs, and POST /{id}/score for human ratings. Include the leaderboard endpoint that returns top N runs ranked by weighted score. - [ ] Implement backend/routers/runs.py fully — list runs with filtering (by experiment, status, score range), get run detail with stage results and scores, POST for ad-hoc single runs, and POST /{id}/score for human ratings. Include the leaderboard endpoint that returns top N runs ranked by weighted score.

View file

@ -1,61 +1,315 @@
"""Experiments router — CRUD and sweep controls.""" """Experiments router — CRUD plus sweep control.
POST /experiments/{id}/sweep validates sweep config, creates Run records,
and dispatches to Celery. Pause/resume/stop set Redis flags that the sweep
runner checks between runs.
"""
import uuid import uuid
from fastapi import APIRouter, Response from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session
from auth import get_current_user
from engine.sweep import clear_sweep_flags, request_pause, request_resume, request_stop
from engine.tasks import dispatch_sweep
from main import get_db, get_redis
from models import Experiment, ExperimentStatus, Project, Run, RunStatus, User
from schemas import (
ExperimentCreate,
ExperimentListResponse,
ExperimentResponse,
ExperimentUpdate,
SweepRequest,
SweepStatusResponse,
)
router = APIRouter() router = APIRouter()
@router.get("/", status_code=501) # ---------------------------------------------------------------------------
def list_experiments(): # Helpers
"""List experiments (filter by project).""" # ---------------------------------------------------------------------------
return Response(status_code=501, content="Not Implemented")
@router.post("/", status_code=501) def _get_experiment_or_404(db: Session, experiment_id: uuid.UUID) -> Experiment:
def create_experiment(): experiment = db.query(Experiment).filter(Experiment.id == experiment_id).first()
"""Create experiment.""" if experiment is None:
return Response(status_code=501, content="Not Implemented") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Experiment not found")
return experiment
@router.get("/{experiment_id}", status_code=501) def _to_response(experiment: Experiment) -> ExperimentResponse:
def get_experiment(experiment_id: uuid.UUID): return ExperimentResponse.model_validate(experiment)
"""Experiment detail with run summaries."""
return Response(status_code=501, content="Not Implemented")
@router.put("/{experiment_id}", status_code=501) # ---------------------------------------------------------------------------
def update_experiment(experiment_id: uuid.UUID): # CRUD
"""Update experiment config.""" # ---------------------------------------------------------------------------
return Response(status_code=501, content="Not Implemented")
@router.delete("/{experiment_id}", status_code=501) @router.get("/", response_model=ExperimentListResponse)
def delete_experiment(experiment_id: uuid.UUID): def list_experiments(
"""Delete experiment.""" project_id: uuid.UUID | None = Query(None),
return Response(status_code=501, content="Not Implemented") db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
) -> ExperimentListResponse:
"""List experiments, optionally filtered by project."""
query = db.query(Experiment)
if project_id is not None:
query = query.filter(Experiment.project_id == project_id)
experiments = query.order_by(Experiment.created_at.desc()).all()
return ExperimentListResponse(
items=[_to_response(exp) for exp in experiments],
total=len(experiments),
)
@router.post("/{experiment_id}/sweep", status_code=501) @router.post("/", response_model=ExperimentResponse, status_code=status.HTTP_201_CREATED)
def start_sweep(experiment_id: uuid.UUID): def create_experiment(
"""Start a sweep (grid, random, or guided).""" body: ExperimentCreate,
return Response(status_code=501, content="Not Implemented") project_id: uuid.UUID = Query(...),
db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
) -> ExperimentResponse:
"""Create a new experiment within a project."""
# Verify project exists
project = db.query(Project).filter(Project.id == project_id).first()
if project is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Project not found")
experiment = Experiment(
project_id=project_id,
name=body.name,
description=body.description,
sample_data=body.sample_data,
pipeline_stages=body.pipeline_stages,
scoring_config=body.scoring_config,
parameter_space=body.parameter_space,
)
db.add(experiment)
db.commit()
db.refresh(experiment)
return _to_response(experiment)
@router.post("/{experiment_id}/pause", status_code=501) @router.get("/{experiment_id}", response_model=ExperimentResponse)
def pause_sweep(experiment_id: uuid.UUID): def get_experiment(
"""Pause running sweep.""" experiment_id: uuid.UUID,
return Response(status_code=501, content="Not Implemented") db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
) -> ExperimentResponse:
"""Get a single experiment."""
experiment = _get_experiment_or_404(db, experiment_id)
return _to_response(experiment)
@router.post("/{experiment_id}/resume", status_code=501) @router.put("/{experiment_id}", response_model=ExperimentResponse)
def resume_sweep(experiment_id: uuid.UUID): def update_experiment(
"""Resume paused sweep.""" experiment_id: uuid.UUID,
return Response(status_code=501, content="Not Implemented") body: ExperimentUpdate,
db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
) -> ExperimentResponse:
"""Update an experiment's configuration."""
experiment = _get_experiment_or_404(db, experiment_id)
if body.name is not None:
experiment.name = body.name
if body.description is not None:
experiment.description = body.description
if body.sample_data is not None:
experiment.sample_data = body.sample_data
if body.pipeline_stages is not None:
experiment.pipeline_stages = body.pipeline_stages
if body.scoring_config is not None:
experiment.scoring_config = body.scoring_config
if body.parameter_space is not None:
experiment.parameter_space = body.parameter_space
if body.status is not None:
experiment.status = body.status
db.commit()
db.refresh(experiment)
return _to_response(experiment)
@router.post("/{experiment_id}/stop", status_code=501) @router.delete("/{experiment_id}", status_code=status.HTTP_204_NO_CONTENT)
def stop_sweep(experiment_id: uuid.UUID): def delete_experiment(
"""Stop sweep.""" experiment_id: uuid.UUID,
return Response(status_code=501, content="Not Implemented") db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
) -> None:
"""Delete an experiment and all its runs."""
experiment = _get_experiment_or_404(db, experiment_id)
db.delete(experiment)
db.commit()
# ---------------------------------------------------------------------------
# Sweep control
# ---------------------------------------------------------------------------
@router.post("/{experiment_id}/sweep", response_model=SweepStatusResponse)
def start_sweep(
experiment_id: uuid.UUID,
body: SweepRequest,
db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
) -> SweepStatusResponse:
"""Start a sweep — validate config, create Run records, dispatch to Celery.
The sweep_config from the request body overrides the experiment's
parameter_space for this sweep execution.
"""
experiment = _get_experiment_or_404(db, experiment_id)
if experiment.status == ExperimentStatus.running:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Experiment already has a running sweep",
)
# Build the sweep config dict that the engine expects
sweep_config = {
"type": body.sweep_type,
"params": body.params or experiment.parameter_space.get("params", {}) if experiment.parameter_space else {},
"n_trials": body.n_trials,
"top_k": body.top_k,
"explore_ratio": body.explore_ratio,
}
# Dispatch via Celery / sync fallback
dispatch_sweep(str(experiment.id), sweep_config=sweep_config)
# Refresh to get the updated status (dispatch_sweep may have set it)
db.refresh(experiment)
# Count runs
run_counts = _get_run_counts(db, experiment.id)
return SweepStatusResponse(
experiment_id=experiment.id,
status=experiment.status,
**run_counts,
)
@router.get("/{experiment_id}/sweep/status", response_model=SweepStatusResponse)
def sweep_status(
experiment_id: uuid.UUID,
db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
) -> SweepStatusResponse:
"""Get the current sweep status for an experiment."""
experiment = _get_experiment_or_404(db, experiment_id)
run_counts = _get_run_counts(db, experiment.id)
return SweepStatusResponse(
experiment_id=experiment.id,
status=experiment.status,
**run_counts,
)
@router.post("/{experiment_id}/pause", status_code=status.HTTP_200_OK)
def pause_sweep(
experiment_id: uuid.UUID,
db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
) -> dict:
"""Pause a running sweep by setting the Redis pause flag."""
experiment = _get_experiment_or_404(db, experiment_id)
if experiment.status != ExperimentStatus.running:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Experiment is not currently running",
)
redis_client = get_redis()
if redis_client is not None:
request_pause(redis_client, experiment.id)
else:
# In single-container mode, directly set paused status
experiment.status = ExperimentStatus.paused
db.commit()
return {"experiment_id": str(experiment.id), "action": "pause"}
@router.post("/{experiment_id}/resume", status_code=status.HTTP_200_OK)
def resume_sweep(
experiment_id: uuid.UUID,
db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
) -> dict:
"""Resume a paused sweep by clearing the Redis pause flag."""
experiment = _get_experiment_or_404(db, experiment_id)
if experiment.status != ExperimentStatus.paused:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Experiment is not currently paused",
)
redis_client = get_redis()
if redis_client is not None:
request_resume(redis_client, experiment.id)
# Re-dispatch the sweep to continue
experiment.status = ExperimentStatus.running
db.commit()
dispatch_sweep(str(experiment.id))
return {"experiment_id": str(experiment.id), "action": "resume"}
@router.post("/{experiment_id}/stop", status_code=status.HTTP_200_OK)
def stop_sweep(
experiment_id: uuid.UUID,
db: Session = Depends(get_db),
_user: User = Depends(get_current_user),
) -> dict:
"""Stop a running or paused sweep."""
experiment = _get_experiment_or_404(db, experiment_id)
if experiment.status not in (ExperimentStatus.running, ExperimentStatus.paused):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Experiment is not running or paused",
)
redis_client = get_redis()
if redis_client is not None:
request_stop(redis_client, experiment.id)
else:
# Mark remaining pending runs as failed
pending_runs = (
db.query(Run)
.filter(Run.experiment_id == experiment.id, Run.status == RunStatus.pending)
.all()
)
for run in pending_runs:
run.status = RunStatus.failed
experiment.status = ExperimentStatus.completed
db.commit()
return {"experiment_id": str(experiment.id), "action": "stop"}
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _get_run_counts(db: Session, experiment_id: uuid.UUID) -> dict:
"""Return run counts by status for a sweep status response."""
runs = db.query(Run).filter(Run.experiment_id == experiment_id).all()
completed = sum(1 for r in runs if r.status == RunStatus.completed)
failed = sum(1 for r in runs if r.status == RunStatus.failed)
pending = sum(1 for r in runs if r.status == RunStatus.pending)
return {
"total_runs": len(runs),
"completed_runs": completed,
"failed_runs": failed,
"pending_runs": pending,
}

View file

@ -91,6 +91,29 @@ class ExperimentListResponse(BaseModel):
total: int total: int
# ---------------------------------------------------------------------------
# Sweep
# ---------------------------------------------------------------------------
class SweepRequest(BaseModel):
"""Request body for starting a sweep on an experiment."""
sweep_type: str = Field("grid", pattern="^(grid|random|guided)$")
params: dict | None = None
n_trials: int = Field(100, ge=1, le=100000)
top_k: int = Field(5, ge=1)
explore_ratio: float = Field(0.3, ge=0.0, le=1.0)
class SweepStatusResponse(BaseModel):
experiment_id: uuid.UUID
status: ExperimentStatus
total_runs: int
completed_runs: int
failed_runs: int
pending_runs: int
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Run # Run
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View file

@ -0,0 +1,554 @@
"""Tests for backend/routers/experiments.py — Experiment CRUD + sweep control."""
import os
import uuid
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
JWT_SECRET = "test-secret-key-for-jwt-signing"
API_KEY = "test-api-key-12345"
@pytest.fixture(autouse=True)
def _isolate_settings(tmp_path):
"""Ensure tests use a temp SQLite DB and no Redis."""
env = {
"DATABASE_URL": f"sqlite:///{tmp_path / 'test.db'}",
"REDIS_URL": "",
"DATA_DIR": str(tmp_path),
"JWT_SECRET": JWT_SECRET,
"API_KEY": API_KEY,
}
with patch.dict(os.environ, env, clear=False):
import config
new_settings = config.Settings(_env_file=None)
config.settings = new_settings
import main
main.settings = new_settings
main._init_db()
main._init_redis()
from models import Base
Base.metadata.create_all(bind=main.engine)
import auth
auth.settings = new_settings
yield
@pytest.fixture
def db_session():
from main import get_db
gen = get_db()
session = next(gen)
yield session
try:
next(gen)
except StopIteration:
pass
@pytest.fixture
def admin_user(db_session):
from auth import hash_password
from models import User
user = User(username="admin", password_hash=hash_password("adminpass"), is_admin=True)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
return user
@pytest.fixture
def project(db_session, admin_user):
from models import Project
proj = Project(name="Test Project", description="A test project", owner_id=admin_user.id)
db_session.add(proj)
db_session.commit()
db_session.refresh(proj)
return proj
@pytest.fixture
def auth_headers():
return {"X-Api-Key": API_KEY}
@pytest.fixture
def client():
from main import app
return TestClient(app)
def _create_experiment(client, project_id, auth_headers, **overrides):
"""Helper to create an experiment via the API."""
body = {
"name": overrides.get("name", "Test Experiment"),
"description": overrides.get("description", "A test experiment"),
}
for key in ("sample_data", "pipeline_stages", "scoring_config", "parameter_space"):
if key in overrides:
body[key] = overrides[key]
resp = client.post(
f"/api/experiments/?project_id={project_id}",
json=body,
headers=auth_headers,
)
return resp
# ---------------------------------------------------------------------------
# CRUD tests
# ---------------------------------------------------------------------------
class TestCreateExperiment:
def test_create_minimal(self, client, project, auth_headers):
resp = _create_experiment(client, project.id, auth_headers)
assert resp.status_code == 201
data = resp.json()
assert data["name"] == "Test Experiment"
assert data["project_id"] == str(project.id)
assert data["status"] == "draft"
assert "id" in data
def test_create_with_all_fields(self, client, project, auth_headers):
resp = _create_experiment(
client, project.id, auth_headers,
name="Full Experiment",
description="Detailed description",
sample_data={"key": "value"},
pipeline_stages={"stages": [{"prompt_template": "Hello {{input}}"}]},
scoring_config={"scorers": ["keyword"]},
parameter_space={"type": "grid", "params": {"model": ["a", "b"]}},
)
assert resp.status_code == 201
data = resp.json()
assert data["name"] == "Full Experiment"
assert data["sample_data"] == {"key": "value"}
assert data["parameter_space"]["type"] == "grid"
def test_create_requires_auth(self, client, project):
resp = client.post(
f"/api/experiments/?project_id={project.id}",
json={"name": "Test"},
)
assert resp.status_code == 401
def test_create_requires_project_id(self, client, admin_user, auth_headers):
resp = client.post(
"/api/experiments/",
json={"name": "Test"},
headers=auth_headers,
)
assert resp.status_code == 422
def test_create_validates_project_exists(self, client, admin_user, auth_headers):
fake_id = str(uuid.uuid4())
resp = client.post(
f"/api/experiments/?project_id={fake_id}",
json={"name": "Test"},
headers=auth_headers,
)
assert resp.status_code == 404
def test_create_validates_name(self, client, project, auth_headers):
resp = client.post(
f"/api/experiments/?project_id={project.id}",
json={"name": ""},
headers=auth_headers,
)
assert resp.status_code == 422
class TestListExperiments:
def test_list_empty(self, client, admin_user, auth_headers):
resp = client.get("/api/experiments/", headers=auth_headers)
assert resp.status_code == 200
data = resp.json()
assert data["items"] == []
assert data["total"] == 0
def test_list_multiple(self, client, project, auth_headers):
for name in ["Alpha", "Beta", "Gamma"]:
_create_experiment(client, project.id, auth_headers, name=name)
resp = client.get("/api/experiments/", headers=auth_headers)
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 3
def test_list_filter_by_project(self, client, project, admin_user, auth_headers, db_session):
_create_experiment(client, project.id, auth_headers, name="In Project")
# Create another project + experiment
from models import Project
proj2 = Project(name="Other", owner_id=admin_user.id)
db_session.add(proj2)
db_session.commit()
db_session.refresh(proj2)
_create_experiment(client, proj2.id, auth_headers, name="Other Experiment")
resp = client.get(
f"/api/experiments/?project_id={project.id}",
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 1
assert data["items"][0]["name"] == "In Project"
def test_list_requires_auth(self, client):
resp = client.get("/api/experiments/")
assert resp.status_code == 401
class TestGetExperiment:
def test_get_existing(self, client, project, auth_headers):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
resp = client.get(f"/api/experiments/{exp_id}", headers=auth_headers)
assert resp.status_code == 200
assert resp.json()["name"] == "Test Experiment"
def test_get_not_found(self, client, admin_user, auth_headers):
fake_id = str(uuid.uuid4())
resp = client.get(f"/api/experiments/{fake_id}", headers=auth_headers)
assert resp.status_code == 404
class TestUpdateExperiment:
def test_update_name(self, client, project, auth_headers):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
resp = client.put(
f"/api/experiments/{exp_id}",
json={"name": "Updated Name"},
headers=auth_headers,
)
assert resp.status_code == 200
assert resp.json()["name"] == "Updated Name"
def test_update_status(self, client, project, auth_headers):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
resp = client.put(
f"/api/experiments/{exp_id}",
json={"status": "completed"},
headers=auth_headers,
)
assert resp.status_code == 200
assert resp.json()["status"] == "completed"
def test_update_parameter_space(self, client, project, auth_headers):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
new_space = {"type": "random", "params": {"temperature": [0.1, 0.5, 0.9]}}
resp = client.put(
f"/api/experiments/{exp_id}",
json={"parameter_space": new_space},
headers=auth_headers,
)
assert resp.status_code == 200
assert resp.json()["parameter_space"] == new_space
def test_update_not_found(self, client, admin_user, auth_headers):
fake_id = str(uuid.uuid4())
resp = client.put(
f"/api/experiments/{fake_id}",
json={"name": "X"},
headers=auth_headers,
)
assert resp.status_code == 404
class TestDeleteExperiment:
def test_delete_existing(self, client, project, auth_headers):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
resp = client.delete(f"/api/experiments/{exp_id}", headers=auth_headers)
assert resp.status_code == 204
resp = client.get(f"/api/experiments/{exp_id}", headers=auth_headers)
assert resp.status_code == 404
def test_delete_not_found(self, client, admin_user, auth_headers):
fake_id = str(uuid.uuid4())
resp = client.delete(f"/api/experiments/{fake_id}", headers=auth_headers)
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# Sweep control tests
# ---------------------------------------------------------------------------
class TestStartSweep:
def test_start_sweep_dispatches(self, client, project, auth_headers):
create_resp = _create_experiment(
client, project.id, auth_headers,
parameter_space={"type": "grid", "params": {"model": ["a", "b"]}},
)
exp_id = create_resp.json()["id"]
with patch("routers.experiments.dispatch_sweep") as mock_dispatch:
mock_dispatch.return_value = MagicMock()
resp = client.post(
f"/api/experiments/{exp_id}/sweep",
json={"sweep_type": "grid"},
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["experiment_id"] == exp_id
mock_dispatch.assert_called_once()
call_args = mock_dispatch.call_args
assert call_args[0][0] == exp_id
assert call_args[1]["sweep_config"]["type"] == "grid"
def test_start_sweep_with_custom_config(self, client, project, auth_headers):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
with patch("routers.experiments.dispatch_sweep") as mock_dispatch:
mock_dispatch.return_value = MagicMock()
resp = client.post(
f"/api/experiments/{exp_id}/sweep",
json={
"sweep_type": "random",
"params": {"temperature": [0.1, 0.5, 0.9]},
"n_trials": 50,
},
headers=auth_headers,
)
assert resp.status_code == 200
call_args = mock_dispatch.call_args
assert call_args[1]["sweep_config"]["type"] == "random"
assert call_args[1]["sweep_config"]["n_trials"] == 50
def test_start_sweep_conflict_if_running(self, client, project, auth_headers, db_session):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
# Set experiment to running
from models import Experiment, ExperimentStatus
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
exp.status = ExperimentStatus.running
db_session.commit()
resp = client.post(
f"/api/experiments/{exp_id}/sweep",
json={"sweep_type": "grid"},
headers=auth_headers,
)
assert resp.status_code == 409
def test_start_sweep_validates_type(self, client, project, auth_headers):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
resp = client.post(
f"/api/experiments/{exp_id}/sweep",
json={"sweep_type": "invalid"},
headers=auth_headers,
)
assert resp.status_code == 422
def test_start_sweep_not_found(self, client, admin_user, auth_headers):
fake_id = str(uuid.uuid4())
resp = client.post(
f"/api/experiments/{fake_id}/sweep",
json={"sweep_type": "grid"},
headers=auth_headers,
)
assert resp.status_code == 404
def test_start_sweep_requires_auth(self, client, project):
create_resp = None
resp = client.post(
f"/api/experiments/{uuid.uuid4()}/sweep",
json={"sweep_type": "grid"},
)
assert resp.status_code == 401
class TestSweepStatus:
def test_sweep_status(self, client, project, auth_headers, db_session):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
# Create some run records
from models import Run, RunStatus
for s in [RunStatus.completed, RunStatus.completed, RunStatus.failed, RunStatus.pending]:
run = Run(
experiment_id=uuid.UUID(exp_id),
config_hash="abc123",
config={"test": True},
status=s,
)
db_session.add(run)
db_session.commit()
resp = client.get(
f"/api/experiments/{exp_id}/sweep/status",
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["total_runs"] == 4
assert data["completed_runs"] == 2
assert data["failed_runs"] == 1
assert data["pending_runs"] == 1
assert data["status"] == "draft"
class TestPauseSweep:
def test_pause_running_no_redis(self, client, project, auth_headers, db_session):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
from models import Experiment, ExperimentStatus
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
exp.status = ExperimentStatus.running
db_session.commit()
resp = client.post(f"/api/experiments/{exp_id}/pause", headers=auth_headers)
assert resp.status_code == 200
assert resp.json()["action"] == "pause"
# Verify status changed
db_session.refresh(exp)
assert exp.status == ExperimentStatus.paused
def test_pause_not_running(self, client, project, auth_headers):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
resp = client.post(f"/api/experiments/{exp_id}/pause", headers=auth_headers)
assert resp.status_code == 409
def test_pause_with_redis(self, client, project, auth_headers, db_session):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
from models import Experiment, ExperimentStatus
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
exp.status = ExperimentStatus.running
db_session.commit()
mock_redis = MagicMock()
with patch("routers.experiments.get_redis", return_value=mock_redis):
with patch("routers.experiments.request_pause") as mock_pause:
resp = client.post(f"/api/experiments/{exp_id}/pause", headers=auth_headers)
assert resp.status_code == 200
mock_pause.assert_called_once_with(mock_redis, uuid.UUID(exp_id))
class TestResumeSweep:
def test_resume_paused(self, client, project, auth_headers, db_session):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
from models import Experiment, ExperimentStatus
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
exp.status = ExperimentStatus.paused
db_session.commit()
with patch("routers.experiments.dispatch_sweep") as mock_dispatch:
mock_dispatch.return_value = MagicMock()
resp = client.post(f"/api/experiments/{exp_id}/resume", headers=auth_headers)
assert resp.status_code == 200
assert resp.json()["action"] == "resume"
mock_dispatch.assert_called_once()
def test_resume_not_paused(self, client, project, auth_headers):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
resp = client.post(f"/api/experiments/{exp_id}/resume", headers=auth_headers)
assert resp.status_code == 409
class TestStopSweep:
def test_stop_running_no_redis(self, client, project, auth_headers, db_session):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
from models import Experiment, ExperimentStatus, Run, RunStatus
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
exp.status = ExperimentStatus.running
db_session.commit()
# Create a pending run
run = Run(
experiment_id=uuid.UUID(exp_id),
config_hash="abc",
config={"test": True},
status=RunStatus.pending,
)
db_session.add(run)
db_session.commit()
db_session.refresh(run)
resp = client.post(f"/api/experiments/{exp_id}/stop", headers=auth_headers)
assert resp.status_code == 200
assert resp.json()["action"] == "stop"
# Verify pending run is now failed
db_session.refresh(run)
assert run.status == RunStatus.failed
# Verify experiment is completed
db_session.refresh(exp)
assert exp.status == ExperimentStatus.completed
def test_stop_not_running(self, client, project, auth_headers):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
resp = client.post(f"/api/experiments/{exp_id}/stop", headers=auth_headers)
assert resp.status_code == 409
def test_stop_paused(self, client, project, auth_headers, db_session):
"""Can stop a paused experiment too."""
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
from models import Experiment, ExperimentStatus
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
exp.status = ExperimentStatus.paused
db_session.commit()
resp = client.post(f"/api/experiments/{exp_id}/stop", headers=auth_headers)
assert resp.status_code == 200
def test_stop_with_redis(self, client, project, auth_headers, db_session):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
from models import Experiment, ExperimentStatus
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
exp.status = ExperimentStatus.running
db_session.commit()
mock_redis = MagicMock()
with patch("routers.experiments.get_redis", return_value=mock_redis):
with patch("routers.experiments.request_stop") as mock_stop:
resp = client.post(f"/api/experiments/{exp_id}/stop", headers=auth_headers)
assert resp.status_code == 200
mock_stop.assert_called_once_with(mock_redis, uuid.UUID(exp_id))

View file

@ -67,51 +67,51 @@ def test_projects_delete(client):
assert resp.status_code == 501 assert resp.status_code == 501
# ---- Experiments router (/api/experiments) ---- # ---- Experiments router (/api/experiments) — now implemented, requires auth ----
def test_experiments_list(client): def test_experiments_list(client):
resp = client.get("/api/experiments/") resp = client.get("/api/experiments/")
assert resp.status_code == 501 assert resp.status_code == 401
def test_experiments_create(client): def test_experiments_create(client):
resp = client.post("/api/experiments/") resp = client.post("/api/experiments/")
assert resp.status_code == 501 assert resp.status_code == 401
def test_experiments_get(client): def test_experiments_get(client):
resp = client.get("/api/experiments/00000000-0000-0000-0000-000000000001") resp = client.get("/api/experiments/00000000-0000-0000-0000-000000000001")
assert resp.status_code == 501 assert resp.status_code == 401
def test_experiments_update(client): def test_experiments_update(client):
resp = client.put("/api/experiments/00000000-0000-0000-0000-000000000001") resp = client.put("/api/experiments/00000000-0000-0000-0000-000000000001")
assert resp.status_code == 501 assert resp.status_code == 401
def test_experiments_delete(client): def test_experiments_delete(client):
resp = client.delete("/api/experiments/00000000-0000-0000-0000-000000000001") resp = client.delete("/api/experiments/00000000-0000-0000-0000-000000000001")
assert resp.status_code == 501 assert resp.status_code == 401
def test_experiments_sweep(client): def test_experiments_sweep(client):
resp = client.post("/api/experiments/00000000-0000-0000-0000-000000000001/sweep") resp = client.post("/api/experiments/00000000-0000-0000-0000-000000000001/sweep")
assert resp.status_code == 501 assert resp.status_code == 401
def test_experiments_pause(client): def test_experiments_pause(client):
resp = client.post("/api/experiments/00000000-0000-0000-0000-000000000001/pause") resp = client.post("/api/experiments/00000000-0000-0000-0000-000000000001/pause")
assert resp.status_code == 501 assert resp.status_code == 401
def test_experiments_resume(client): def test_experiments_resume(client):
resp = client.post("/api/experiments/00000000-0000-0000-0000-000000000001/resume") resp = client.post("/api/experiments/00000000-0000-0000-0000-000000000001/resume")
assert resp.status_code == 501 assert resp.status_code == 401
def test_experiments_stop(client): def test_experiments_stop(client):
resp = client.post("/api/experiments/00000000-0000-0000-0000-000000000001/stop") resp = client.post("/api/experiments/00000000-0000-0000-0000-000000000001/stop")
assert resp.status_code == 501 assert resp.status_code == 401
# ---- Runs router (/api/runs) ---- # ---- Runs router (/api/runs) ----