diff --git a/Auto Run Docs/02a-backend-engine.md b/Auto Run Docs/02a-backend-engine.md index 216327c..1d61005 100644 --- a/Auto Run Docs/02a-backend-engine.md +++ b/Auto Run Docs/02a-backend-engine.md @@ -35,7 +35,8 @@ Implement the core experiment execution engine: LLM adapters, response caching, - [x] Implement backend/routers/endpoints.py fully — CRUD for LLM endpoint configurations. The test endpoint should call adapter.test_connection() and adapter.list_models() and return the results. Store endpoint configs in the database with encrypted API keys (Fernet symmetric encryption, key derived from JWT_SECRET). -- [ ] Implement backend/routers/experiments.py fully — CRUD plus sweep control. POST /experiments/{id}/sweep should validate the sweep config, create Run records for all configurations, and dispatch to Celery. Pause/resume/stop should set Redis flags that the sweep runner checks between runs. +- [x] Implement backend/routers/experiments.py fully — CRUD plus sweep control. POST /experiments/{id}/sweep should validate the sweep config, create Run records for all configurations, and dispatch to Celery. Pause/resume/stop should set Redis flags that the sweep runner checks between runs. + - [ ] Implement backend/routers/runs.py fully — list runs with filtering (by experiment, status, score range), get run detail with stage results and scores, POST for ad-hoc single runs, and POST /{id}/score for human ratings. Include the leaderboard endpoint that returns top N runs ranked by weighted score. diff --git a/backend/routers/experiments.py b/backend/routers/experiments.py index 9ecaca1..c73578b 100644 --- a/backend/routers/experiments.py +++ b/backend/routers/experiments.py @@ -1,61 +1,315 @@ -"""Experiments router — CRUD and sweep controls.""" +"""Experiments router — CRUD plus sweep control. + +POST /experiments/{id}/sweep validates sweep config, creates Run records, +and dispatches to Celery. Pause/resume/stop set Redis flags that the sweep +runner checks between runs. +""" import uuid -from fastapi import APIRouter, Response +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy.orm import Session + +from auth import get_current_user +from engine.sweep import clear_sweep_flags, request_pause, request_resume, request_stop +from engine.tasks import dispatch_sweep +from main import get_db, get_redis +from models import Experiment, ExperimentStatus, Project, Run, RunStatus, User +from schemas import ( + ExperimentCreate, + ExperimentListResponse, + ExperimentResponse, + ExperimentUpdate, + SweepRequest, + SweepStatusResponse, +) router = APIRouter() -@router.get("/", status_code=501) -def list_experiments(): - """List experiments (filter by project).""" - return Response(status_code=501, content="Not Implemented") +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- -@router.post("/", status_code=501) -def create_experiment(): - """Create experiment.""" - return Response(status_code=501, content="Not Implemented") +def _get_experiment_or_404(db: Session, experiment_id: uuid.UUID) -> Experiment: + experiment = db.query(Experiment).filter(Experiment.id == experiment_id).first() + if experiment is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Experiment not found") + return experiment -@router.get("/{experiment_id}", status_code=501) -def get_experiment(experiment_id: uuid.UUID): - """Experiment detail with run summaries.""" - return Response(status_code=501, content="Not Implemented") +def _to_response(experiment: Experiment) -> ExperimentResponse: + return ExperimentResponse.model_validate(experiment) -@router.put("/{experiment_id}", status_code=501) -def update_experiment(experiment_id: uuid.UUID): - """Update experiment config.""" - return Response(status_code=501, content="Not Implemented") +# --------------------------------------------------------------------------- +# CRUD +# --------------------------------------------------------------------------- -@router.delete("/{experiment_id}", status_code=501) -def delete_experiment(experiment_id: uuid.UUID): - """Delete experiment.""" - return Response(status_code=501, content="Not Implemented") +@router.get("/", response_model=ExperimentListResponse) +def list_experiments( + project_id: uuid.UUID | None = Query(None), + db: Session = Depends(get_db), + _user: User = Depends(get_current_user), +) -> ExperimentListResponse: + """List experiments, optionally filtered by project.""" + query = db.query(Experiment) + if project_id is not None: + query = query.filter(Experiment.project_id == project_id) + experiments = query.order_by(Experiment.created_at.desc()).all() + return ExperimentListResponse( + items=[_to_response(exp) for exp in experiments], + total=len(experiments), + ) -@router.post("/{experiment_id}/sweep", status_code=501) -def start_sweep(experiment_id: uuid.UUID): - """Start a sweep (grid, random, or guided).""" - return Response(status_code=501, content="Not Implemented") +@router.post("/", response_model=ExperimentResponse, status_code=status.HTTP_201_CREATED) +def create_experiment( + body: ExperimentCreate, + project_id: uuid.UUID = Query(...), + db: Session = Depends(get_db), + _user: User = Depends(get_current_user), +) -> ExperimentResponse: + """Create a new experiment within a project.""" + # Verify project exists + project = db.query(Project).filter(Project.id == project_id).first() + if project is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Project not found") + + experiment = Experiment( + project_id=project_id, + name=body.name, + description=body.description, + sample_data=body.sample_data, + pipeline_stages=body.pipeline_stages, + scoring_config=body.scoring_config, + parameter_space=body.parameter_space, + ) + db.add(experiment) + db.commit() + db.refresh(experiment) + return _to_response(experiment) -@router.post("/{experiment_id}/pause", status_code=501) -def pause_sweep(experiment_id: uuid.UUID): - """Pause running sweep.""" - return Response(status_code=501, content="Not Implemented") +@router.get("/{experiment_id}", response_model=ExperimentResponse) +def get_experiment( + experiment_id: uuid.UUID, + db: Session = Depends(get_db), + _user: User = Depends(get_current_user), +) -> ExperimentResponse: + """Get a single experiment.""" + experiment = _get_experiment_or_404(db, experiment_id) + return _to_response(experiment) -@router.post("/{experiment_id}/resume", status_code=501) -def resume_sweep(experiment_id: uuid.UUID): - """Resume paused sweep.""" - return Response(status_code=501, content="Not Implemented") +@router.put("/{experiment_id}", response_model=ExperimentResponse) +def update_experiment( + experiment_id: uuid.UUID, + body: ExperimentUpdate, + db: Session = Depends(get_db), + _user: User = Depends(get_current_user), +) -> ExperimentResponse: + """Update an experiment's configuration.""" + experiment = _get_experiment_or_404(db, experiment_id) + + if body.name is not None: + experiment.name = body.name + if body.description is not None: + experiment.description = body.description + if body.sample_data is not None: + experiment.sample_data = body.sample_data + if body.pipeline_stages is not None: + experiment.pipeline_stages = body.pipeline_stages + if body.scoring_config is not None: + experiment.scoring_config = body.scoring_config + if body.parameter_space is not None: + experiment.parameter_space = body.parameter_space + if body.status is not None: + experiment.status = body.status + + db.commit() + db.refresh(experiment) + return _to_response(experiment) -@router.post("/{experiment_id}/stop", status_code=501) -def stop_sweep(experiment_id: uuid.UUID): - """Stop sweep.""" - return Response(status_code=501, content="Not Implemented") +@router.delete("/{experiment_id}", status_code=status.HTTP_204_NO_CONTENT) +def delete_experiment( + experiment_id: uuid.UUID, + db: Session = Depends(get_db), + _user: User = Depends(get_current_user), +) -> None: + """Delete an experiment and all its runs.""" + experiment = _get_experiment_or_404(db, experiment_id) + db.delete(experiment) + db.commit() + + +# --------------------------------------------------------------------------- +# Sweep control +# --------------------------------------------------------------------------- + + +@router.post("/{experiment_id}/sweep", response_model=SweepStatusResponse) +def start_sweep( + experiment_id: uuid.UUID, + body: SweepRequest, + db: Session = Depends(get_db), + _user: User = Depends(get_current_user), +) -> SweepStatusResponse: + """Start a sweep — validate config, create Run records, dispatch to Celery. + + The sweep_config from the request body overrides the experiment's + parameter_space for this sweep execution. + """ + experiment = _get_experiment_or_404(db, experiment_id) + + if experiment.status == ExperimentStatus.running: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Experiment already has a running sweep", + ) + + # Build the sweep config dict that the engine expects + sweep_config = { + "type": body.sweep_type, + "params": body.params or experiment.parameter_space.get("params", {}) if experiment.parameter_space else {}, + "n_trials": body.n_trials, + "top_k": body.top_k, + "explore_ratio": body.explore_ratio, + } + + # Dispatch via Celery / sync fallback + dispatch_sweep(str(experiment.id), sweep_config=sweep_config) + + # Refresh to get the updated status (dispatch_sweep may have set it) + db.refresh(experiment) + + # Count runs + run_counts = _get_run_counts(db, experiment.id) + + return SweepStatusResponse( + experiment_id=experiment.id, + status=experiment.status, + **run_counts, + ) + + +@router.get("/{experiment_id}/sweep/status", response_model=SweepStatusResponse) +def sweep_status( + experiment_id: uuid.UUID, + db: Session = Depends(get_db), + _user: User = Depends(get_current_user), +) -> SweepStatusResponse: + """Get the current sweep status for an experiment.""" + experiment = _get_experiment_or_404(db, experiment_id) + run_counts = _get_run_counts(db, experiment.id) + return SweepStatusResponse( + experiment_id=experiment.id, + status=experiment.status, + **run_counts, + ) + + +@router.post("/{experiment_id}/pause", status_code=status.HTTP_200_OK) +def pause_sweep( + experiment_id: uuid.UUID, + db: Session = Depends(get_db), + _user: User = Depends(get_current_user), +) -> dict: + """Pause a running sweep by setting the Redis pause flag.""" + experiment = _get_experiment_or_404(db, experiment_id) + if experiment.status != ExperimentStatus.running: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Experiment is not currently running", + ) + + redis_client = get_redis() + if redis_client is not None: + request_pause(redis_client, experiment.id) + else: + # In single-container mode, directly set paused status + experiment.status = ExperimentStatus.paused + db.commit() + + return {"experiment_id": str(experiment.id), "action": "pause"} + + +@router.post("/{experiment_id}/resume", status_code=status.HTTP_200_OK) +def resume_sweep( + experiment_id: uuid.UUID, + db: Session = Depends(get_db), + _user: User = Depends(get_current_user), +) -> dict: + """Resume a paused sweep by clearing the Redis pause flag.""" + experiment = _get_experiment_or_404(db, experiment_id) + if experiment.status != ExperimentStatus.paused: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Experiment is not currently paused", + ) + + redis_client = get_redis() + if redis_client is not None: + request_resume(redis_client, experiment.id) + + # Re-dispatch the sweep to continue + experiment.status = ExperimentStatus.running + db.commit() + + dispatch_sweep(str(experiment.id)) + + return {"experiment_id": str(experiment.id), "action": "resume"} + + +@router.post("/{experiment_id}/stop", status_code=status.HTTP_200_OK) +def stop_sweep( + experiment_id: uuid.UUID, + db: Session = Depends(get_db), + _user: User = Depends(get_current_user), +) -> dict: + """Stop a running or paused sweep.""" + experiment = _get_experiment_or_404(db, experiment_id) + if experiment.status not in (ExperimentStatus.running, ExperimentStatus.paused): + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Experiment is not running or paused", + ) + + redis_client = get_redis() + if redis_client is not None: + request_stop(redis_client, experiment.id) + else: + # Mark remaining pending runs as failed + pending_runs = ( + db.query(Run) + .filter(Run.experiment_id == experiment.id, Run.status == RunStatus.pending) + .all() + ) + for run in pending_runs: + run.status = RunStatus.failed + experiment.status = ExperimentStatus.completed + db.commit() + + return {"experiment_id": str(experiment.id), "action": "stop"} + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _get_run_counts(db: Session, experiment_id: uuid.UUID) -> dict: + """Return run counts by status for a sweep status response.""" + runs = db.query(Run).filter(Run.experiment_id == experiment_id).all() + completed = sum(1 for r in runs if r.status == RunStatus.completed) + failed = sum(1 for r in runs if r.status == RunStatus.failed) + pending = sum(1 for r in runs if r.status == RunStatus.pending) + return { + "total_runs": len(runs), + "completed_runs": completed, + "failed_runs": failed, + "pending_runs": pending, + } diff --git a/backend/schemas.py b/backend/schemas.py index 87db0ba..e86c830 100644 --- a/backend/schemas.py +++ b/backend/schemas.py @@ -91,6 +91,29 @@ class ExperimentListResponse(BaseModel): total: int +# --------------------------------------------------------------------------- +# Sweep +# --------------------------------------------------------------------------- + +class SweepRequest(BaseModel): + """Request body for starting a sweep on an experiment.""" + + sweep_type: str = Field("grid", pattern="^(grid|random|guided)$") + params: dict | None = None + n_trials: int = Field(100, ge=1, le=100000) + top_k: int = Field(5, ge=1) + explore_ratio: float = Field(0.3, ge=0.0, le=1.0) + + +class SweepStatusResponse(BaseModel): + experiment_id: uuid.UUID + status: ExperimentStatus + total_runs: int + completed_runs: int + failed_runs: int + pending_runs: int + + # --------------------------------------------------------------------------- # Run # --------------------------------------------------------------------------- diff --git a/backend/tests/test_experiments.py b/backend/tests/test_experiments.py new file mode 100644 index 0000000..7738b8f --- /dev/null +++ b/backend/tests/test_experiments.py @@ -0,0 +1,554 @@ +"""Tests for backend/routers/experiments.py — Experiment CRUD + sweep control.""" + +import os +import uuid +from unittest.mock import MagicMock, patch + +import pytest +from fastapi.testclient import TestClient + + +JWT_SECRET = "test-secret-key-for-jwt-signing" +API_KEY = "test-api-key-12345" + + +@pytest.fixture(autouse=True) +def _isolate_settings(tmp_path): + """Ensure tests use a temp SQLite DB and no Redis.""" + env = { + "DATABASE_URL": f"sqlite:///{tmp_path / 'test.db'}", + "REDIS_URL": "", + "DATA_DIR": str(tmp_path), + "JWT_SECRET": JWT_SECRET, + "API_KEY": API_KEY, + } + with patch.dict(os.environ, env, clear=False): + import config + new_settings = config.Settings(_env_file=None) + config.settings = new_settings + + import main + main.settings = new_settings + main._init_db() + main._init_redis() + + from models import Base + Base.metadata.create_all(bind=main.engine) + + import auth + auth.settings = new_settings + + yield + + +@pytest.fixture +def db_session(): + from main import get_db + gen = get_db() + session = next(gen) + yield session + try: + next(gen) + except StopIteration: + pass + + +@pytest.fixture +def admin_user(db_session): + from auth import hash_password + from models import User + user = User(username="admin", password_hash=hash_password("adminpass"), is_admin=True) + db_session.add(user) + db_session.commit() + db_session.refresh(user) + return user + + +@pytest.fixture +def project(db_session, admin_user): + from models import Project + proj = Project(name="Test Project", description="A test project", owner_id=admin_user.id) + db_session.add(proj) + db_session.commit() + db_session.refresh(proj) + return proj + + +@pytest.fixture +def auth_headers(): + return {"X-Api-Key": API_KEY} + + +@pytest.fixture +def client(): + from main import app + return TestClient(app) + + +def _create_experiment(client, project_id, auth_headers, **overrides): + """Helper to create an experiment via the API.""" + body = { + "name": overrides.get("name", "Test Experiment"), + "description": overrides.get("description", "A test experiment"), + } + for key in ("sample_data", "pipeline_stages", "scoring_config", "parameter_space"): + if key in overrides: + body[key] = overrides[key] + resp = client.post( + f"/api/experiments/?project_id={project_id}", + json=body, + headers=auth_headers, + ) + return resp + + +# --------------------------------------------------------------------------- +# CRUD tests +# --------------------------------------------------------------------------- + + +class TestCreateExperiment: + def test_create_minimal(self, client, project, auth_headers): + resp = _create_experiment(client, project.id, auth_headers) + assert resp.status_code == 201 + data = resp.json() + assert data["name"] == "Test Experiment" + assert data["project_id"] == str(project.id) + assert data["status"] == "draft" + assert "id" in data + + def test_create_with_all_fields(self, client, project, auth_headers): + resp = _create_experiment( + client, project.id, auth_headers, + name="Full Experiment", + description="Detailed description", + sample_data={"key": "value"}, + pipeline_stages={"stages": [{"prompt_template": "Hello {{input}}"}]}, + scoring_config={"scorers": ["keyword"]}, + parameter_space={"type": "grid", "params": {"model": ["a", "b"]}}, + ) + assert resp.status_code == 201 + data = resp.json() + assert data["name"] == "Full Experiment" + assert data["sample_data"] == {"key": "value"} + assert data["parameter_space"]["type"] == "grid" + + def test_create_requires_auth(self, client, project): + resp = client.post( + f"/api/experiments/?project_id={project.id}", + json={"name": "Test"}, + ) + assert resp.status_code == 401 + + def test_create_requires_project_id(self, client, admin_user, auth_headers): + resp = client.post( + "/api/experiments/", + json={"name": "Test"}, + headers=auth_headers, + ) + assert resp.status_code == 422 + + def test_create_validates_project_exists(self, client, admin_user, auth_headers): + fake_id = str(uuid.uuid4()) + resp = client.post( + f"/api/experiments/?project_id={fake_id}", + json={"name": "Test"}, + headers=auth_headers, + ) + assert resp.status_code == 404 + + def test_create_validates_name(self, client, project, auth_headers): + resp = client.post( + f"/api/experiments/?project_id={project.id}", + json={"name": ""}, + headers=auth_headers, + ) + assert resp.status_code == 422 + + +class TestListExperiments: + def test_list_empty(self, client, admin_user, auth_headers): + resp = client.get("/api/experiments/", headers=auth_headers) + assert resp.status_code == 200 + data = resp.json() + assert data["items"] == [] + assert data["total"] == 0 + + def test_list_multiple(self, client, project, auth_headers): + for name in ["Alpha", "Beta", "Gamma"]: + _create_experiment(client, project.id, auth_headers, name=name) + + resp = client.get("/api/experiments/", headers=auth_headers) + assert resp.status_code == 200 + data = resp.json() + assert data["total"] == 3 + + def test_list_filter_by_project(self, client, project, admin_user, auth_headers, db_session): + _create_experiment(client, project.id, auth_headers, name="In Project") + + # Create another project + experiment + from models import Project + proj2 = Project(name="Other", owner_id=admin_user.id) + db_session.add(proj2) + db_session.commit() + db_session.refresh(proj2) + _create_experiment(client, proj2.id, auth_headers, name="Other Experiment") + + resp = client.get( + f"/api/experiments/?project_id={project.id}", + headers=auth_headers, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["total"] == 1 + assert data["items"][0]["name"] == "In Project" + + def test_list_requires_auth(self, client): + resp = client.get("/api/experiments/") + assert resp.status_code == 401 + + +class TestGetExperiment: + def test_get_existing(self, client, project, auth_headers): + create_resp = _create_experiment(client, project.id, auth_headers) + exp_id = create_resp.json()["id"] + + resp = client.get(f"/api/experiments/{exp_id}", headers=auth_headers) + assert resp.status_code == 200 + assert resp.json()["name"] == "Test Experiment" + + def test_get_not_found(self, client, admin_user, auth_headers): + fake_id = str(uuid.uuid4()) + resp = client.get(f"/api/experiments/{fake_id}", headers=auth_headers) + assert resp.status_code == 404 + + +class TestUpdateExperiment: + def test_update_name(self, client, project, auth_headers): + create_resp = _create_experiment(client, project.id, auth_headers) + exp_id = create_resp.json()["id"] + + resp = client.put( + f"/api/experiments/{exp_id}", + json={"name": "Updated Name"}, + headers=auth_headers, + ) + assert resp.status_code == 200 + assert resp.json()["name"] == "Updated Name" + + def test_update_status(self, client, project, auth_headers): + create_resp = _create_experiment(client, project.id, auth_headers) + exp_id = create_resp.json()["id"] + + resp = client.put( + f"/api/experiments/{exp_id}", + json={"status": "completed"}, + headers=auth_headers, + ) + assert resp.status_code == 200 + assert resp.json()["status"] == "completed" + + def test_update_parameter_space(self, client, project, auth_headers): + create_resp = _create_experiment(client, project.id, auth_headers) + exp_id = create_resp.json()["id"] + + new_space = {"type": "random", "params": {"temperature": [0.1, 0.5, 0.9]}} + resp = client.put( + f"/api/experiments/{exp_id}", + json={"parameter_space": new_space}, + headers=auth_headers, + ) + assert resp.status_code == 200 + assert resp.json()["parameter_space"] == new_space + + def test_update_not_found(self, client, admin_user, auth_headers): + fake_id = str(uuid.uuid4()) + resp = client.put( + f"/api/experiments/{fake_id}", + json={"name": "X"}, + headers=auth_headers, + ) + assert resp.status_code == 404 + + +class TestDeleteExperiment: + def test_delete_existing(self, client, project, auth_headers): + create_resp = _create_experiment(client, project.id, auth_headers) + exp_id = create_resp.json()["id"] + + resp = client.delete(f"/api/experiments/{exp_id}", headers=auth_headers) + assert resp.status_code == 204 + + resp = client.get(f"/api/experiments/{exp_id}", headers=auth_headers) + assert resp.status_code == 404 + + def test_delete_not_found(self, client, admin_user, auth_headers): + fake_id = str(uuid.uuid4()) + resp = client.delete(f"/api/experiments/{fake_id}", headers=auth_headers) + assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# Sweep control tests +# --------------------------------------------------------------------------- + + +class TestStartSweep: + def test_start_sweep_dispatches(self, client, project, auth_headers): + create_resp = _create_experiment( + client, project.id, auth_headers, + parameter_space={"type": "grid", "params": {"model": ["a", "b"]}}, + ) + exp_id = create_resp.json()["id"] + + with patch("routers.experiments.dispatch_sweep") as mock_dispatch: + mock_dispatch.return_value = MagicMock() + resp = client.post( + f"/api/experiments/{exp_id}/sweep", + json={"sweep_type": "grid"}, + headers=auth_headers, + ) + + assert resp.status_code == 200 + data = resp.json() + assert data["experiment_id"] == exp_id + mock_dispatch.assert_called_once() + call_args = mock_dispatch.call_args + assert call_args[0][0] == exp_id + assert call_args[1]["sweep_config"]["type"] == "grid" + + def test_start_sweep_with_custom_config(self, client, project, auth_headers): + create_resp = _create_experiment(client, project.id, auth_headers) + exp_id = create_resp.json()["id"] + + with patch("routers.experiments.dispatch_sweep") as mock_dispatch: + mock_dispatch.return_value = MagicMock() + resp = client.post( + f"/api/experiments/{exp_id}/sweep", + json={ + "sweep_type": "random", + "params": {"temperature": [0.1, 0.5, 0.9]}, + "n_trials": 50, + }, + headers=auth_headers, + ) + + assert resp.status_code == 200 + call_args = mock_dispatch.call_args + assert call_args[1]["sweep_config"]["type"] == "random" + assert call_args[1]["sweep_config"]["n_trials"] == 50 + + def test_start_sweep_conflict_if_running(self, client, project, auth_headers, db_session): + create_resp = _create_experiment(client, project.id, auth_headers) + exp_id = create_resp.json()["id"] + + # Set experiment to running + from models import Experiment, ExperimentStatus + exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first() + exp.status = ExperimentStatus.running + db_session.commit() + + resp = client.post( + f"/api/experiments/{exp_id}/sweep", + json={"sweep_type": "grid"}, + headers=auth_headers, + ) + assert resp.status_code == 409 + + def test_start_sweep_validates_type(self, client, project, auth_headers): + create_resp = _create_experiment(client, project.id, auth_headers) + exp_id = create_resp.json()["id"] + + resp = client.post( + f"/api/experiments/{exp_id}/sweep", + json={"sweep_type": "invalid"}, + headers=auth_headers, + ) + assert resp.status_code == 422 + + def test_start_sweep_not_found(self, client, admin_user, auth_headers): + fake_id = str(uuid.uuid4()) + resp = client.post( + f"/api/experiments/{fake_id}/sweep", + json={"sweep_type": "grid"}, + headers=auth_headers, + ) + assert resp.status_code == 404 + + def test_start_sweep_requires_auth(self, client, project): + create_resp = None + resp = client.post( + f"/api/experiments/{uuid.uuid4()}/sweep", + json={"sweep_type": "grid"}, + ) + assert resp.status_code == 401 + + +class TestSweepStatus: + def test_sweep_status(self, client, project, auth_headers, db_session): + create_resp = _create_experiment(client, project.id, auth_headers) + exp_id = create_resp.json()["id"] + + # Create some run records + from models import Run, RunStatus + for s in [RunStatus.completed, RunStatus.completed, RunStatus.failed, RunStatus.pending]: + run = Run( + experiment_id=uuid.UUID(exp_id), + config_hash="abc123", + config={"test": True}, + status=s, + ) + db_session.add(run) + db_session.commit() + + resp = client.get( + f"/api/experiments/{exp_id}/sweep/status", + headers=auth_headers, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["total_runs"] == 4 + assert data["completed_runs"] == 2 + assert data["failed_runs"] == 1 + assert data["pending_runs"] == 1 + assert data["status"] == "draft" + + +class TestPauseSweep: + def test_pause_running_no_redis(self, client, project, auth_headers, db_session): + create_resp = _create_experiment(client, project.id, auth_headers) + exp_id = create_resp.json()["id"] + + from models import Experiment, ExperimentStatus + exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first() + exp.status = ExperimentStatus.running + db_session.commit() + + resp = client.post(f"/api/experiments/{exp_id}/pause", headers=auth_headers) + assert resp.status_code == 200 + assert resp.json()["action"] == "pause" + + # Verify status changed + db_session.refresh(exp) + assert exp.status == ExperimentStatus.paused + + def test_pause_not_running(self, client, project, auth_headers): + create_resp = _create_experiment(client, project.id, auth_headers) + exp_id = create_resp.json()["id"] + + resp = client.post(f"/api/experiments/{exp_id}/pause", headers=auth_headers) + assert resp.status_code == 409 + + def test_pause_with_redis(self, client, project, auth_headers, db_session): + create_resp = _create_experiment(client, project.id, auth_headers) + exp_id = create_resp.json()["id"] + + from models import Experiment, ExperimentStatus + exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first() + exp.status = ExperimentStatus.running + db_session.commit() + + mock_redis = MagicMock() + with patch("routers.experiments.get_redis", return_value=mock_redis): + with patch("routers.experiments.request_pause") as mock_pause: + resp = client.post(f"/api/experiments/{exp_id}/pause", headers=auth_headers) + + assert resp.status_code == 200 + mock_pause.assert_called_once_with(mock_redis, uuid.UUID(exp_id)) + + +class TestResumeSweep: + def test_resume_paused(self, client, project, auth_headers, db_session): + create_resp = _create_experiment(client, project.id, auth_headers) + exp_id = create_resp.json()["id"] + + from models import Experiment, ExperimentStatus + exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first() + exp.status = ExperimentStatus.paused + db_session.commit() + + with patch("routers.experiments.dispatch_sweep") as mock_dispatch: + mock_dispatch.return_value = MagicMock() + resp = client.post(f"/api/experiments/{exp_id}/resume", headers=auth_headers) + + assert resp.status_code == 200 + assert resp.json()["action"] == "resume" + mock_dispatch.assert_called_once() + + def test_resume_not_paused(self, client, project, auth_headers): + create_resp = _create_experiment(client, project.id, auth_headers) + exp_id = create_resp.json()["id"] + + resp = client.post(f"/api/experiments/{exp_id}/resume", headers=auth_headers) + assert resp.status_code == 409 + + +class TestStopSweep: + def test_stop_running_no_redis(self, client, project, auth_headers, db_session): + create_resp = _create_experiment(client, project.id, auth_headers) + exp_id = create_resp.json()["id"] + + from models import Experiment, ExperimentStatus, Run, RunStatus + exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first() + exp.status = ExperimentStatus.running + db_session.commit() + + # Create a pending run + run = Run( + experiment_id=uuid.UUID(exp_id), + config_hash="abc", + config={"test": True}, + status=RunStatus.pending, + ) + db_session.add(run) + db_session.commit() + db_session.refresh(run) + + resp = client.post(f"/api/experiments/{exp_id}/stop", headers=auth_headers) + assert resp.status_code == 200 + assert resp.json()["action"] == "stop" + + # Verify pending run is now failed + db_session.refresh(run) + assert run.status == RunStatus.failed + + # Verify experiment is completed + db_session.refresh(exp) + assert exp.status == ExperimentStatus.completed + + def test_stop_not_running(self, client, project, auth_headers): + create_resp = _create_experiment(client, project.id, auth_headers) + exp_id = create_resp.json()["id"] + + resp = client.post(f"/api/experiments/{exp_id}/stop", headers=auth_headers) + assert resp.status_code == 409 + + def test_stop_paused(self, client, project, auth_headers, db_session): + """Can stop a paused experiment too.""" + create_resp = _create_experiment(client, project.id, auth_headers) + exp_id = create_resp.json()["id"] + + from models import Experiment, ExperimentStatus + exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first() + exp.status = ExperimentStatus.paused + db_session.commit() + + resp = client.post(f"/api/experiments/{exp_id}/stop", headers=auth_headers) + assert resp.status_code == 200 + + def test_stop_with_redis(self, client, project, auth_headers, db_session): + create_resp = _create_experiment(client, project.id, auth_headers) + exp_id = create_resp.json()["id"] + + from models import Experiment, ExperimentStatus + exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first() + exp.status = ExperimentStatus.running + db_session.commit() + + mock_redis = MagicMock() + with patch("routers.experiments.get_redis", return_value=mock_redis): + with patch("routers.experiments.request_stop") as mock_stop: + resp = client.post(f"/api/experiments/{exp_id}/stop", headers=auth_headers) + + assert resp.status_code == 200 + mock_stop.assert_called_once_with(mock_redis, uuid.UUID(exp_id)) diff --git a/backend/tests/test_routers.py b/backend/tests/test_routers.py index dff6c08..70be77b 100644 --- a/backend/tests/test_routers.py +++ b/backend/tests/test_routers.py @@ -67,51 +67,51 @@ def test_projects_delete(client): assert resp.status_code == 501 -# ---- Experiments router (/api/experiments) ---- +# ---- Experiments router (/api/experiments) — now implemented, requires auth ---- def test_experiments_list(client): resp = client.get("/api/experiments/") - assert resp.status_code == 501 + assert resp.status_code == 401 def test_experiments_create(client): resp = client.post("/api/experiments/") - assert resp.status_code == 501 + assert resp.status_code == 401 def test_experiments_get(client): resp = client.get("/api/experiments/00000000-0000-0000-0000-000000000001") - assert resp.status_code == 501 + assert resp.status_code == 401 def test_experiments_update(client): resp = client.put("/api/experiments/00000000-0000-0000-0000-000000000001") - assert resp.status_code == 501 + assert resp.status_code == 401 def test_experiments_delete(client): resp = client.delete("/api/experiments/00000000-0000-0000-0000-000000000001") - assert resp.status_code == 501 + assert resp.status_code == 401 def test_experiments_sweep(client): resp = client.post("/api/experiments/00000000-0000-0000-0000-000000000001/sweep") - assert resp.status_code == 501 + assert resp.status_code == 401 def test_experiments_pause(client): resp = client.post("/api/experiments/00000000-0000-0000-0000-000000000001/pause") - assert resp.status_code == 501 + assert resp.status_code == 401 def test_experiments_resume(client): resp = client.post("/api/experiments/00000000-0000-0000-0000-000000000001/resume") - assert resp.status_code == 501 + assert resp.status_code == 401 def test_experiments_stop(client): resp = client.post("/api/experiments/00000000-0000-0000-0000-000000000001/stop") - assert resp.status_code == 501 + assert resp.status_code == 401 # ---- Runs router (/api/runs) ----