"""Tests for backend/routers/experiments.py — Experiment CRUD + sweep control.""" import os import uuid from unittest.mock import MagicMock, patch import pytest from fastapi.testclient import TestClient JWT_SECRET = "test-secret-key-for-jwt-signing" API_KEY = "test-api-key-12345" @pytest.fixture(autouse=True) def _isolate_settings(tmp_path): """Ensure tests use a temp SQLite DB and no Redis.""" env = { "DATABASE_URL": f"sqlite:///{tmp_path / 'test.db'}", "REDIS_URL": "", "DATA_DIR": str(tmp_path), "JWT_SECRET": JWT_SECRET, "API_KEY": API_KEY, } with patch.dict(os.environ, env, clear=False): import config new_settings = config.Settings(_env_file=None) config.settings = new_settings import main main.settings = new_settings main._init_db() main._init_redis() from models import Base Base.metadata.create_all(bind=main.engine) import auth auth.settings = new_settings yield @pytest.fixture def db_session(): from main import get_db gen = get_db() session = next(gen) yield session try: next(gen) except StopIteration: pass @pytest.fixture def admin_user(db_session): from auth import hash_password from models import User user = User(username="admin", password_hash=hash_password("adminpass"), is_admin=True) db_session.add(user) db_session.commit() db_session.refresh(user) return user @pytest.fixture def project(db_session, admin_user): from models import Project proj = Project(name="Test Project", description="A test project", owner_id=admin_user.id) db_session.add(proj) db_session.commit() db_session.refresh(proj) return proj @pytest.fixture def auth_headers(): return {"X-Api-Key": API_KEY} @pytest.fixture def client(): from main import app return TestClient(app) def _create_experiment(client, project_id, auth_headers, **overrides): """Helper to create an experiment via the API.""" body = { "name": overrides.get("name", "Test Experiment"), "description": overrides.get("description", "A test experiment"), } for key in ("sample_data", "pipeline_stages", "scoring_config", "parameter_space"): if key in overrides: body[key] = overrides[key] resp = client.post( f"/api/experiments/?project_id={project_id}", json=body, headers=auth_headers, ) return resp # --------------------------------------------------------------------------- # CRUD tests # --------------------------------------------------------------------------- class TestCreateExperiment: def test_create_minimal(self, client, project, auth_headers): resp = _create_experiment(client, project.id, auth_headers) assert resp.status_code == 201 data = resp.json() assert data["name"] == "Test Experiment" assert data["project_id"] == str(project.id) assert data["status"] == "draft" assert "id" in data def test_create_with_all_fields(self, client, project, auth_headers): resp = _create_experiment( client, project.id, auth_headers, name="Full Experiment", description="Detailed description", sample_data={"key": "value"}, pipeline_stages={"stages": [{"prompt_template": "Hello {{input}}"}]}, scoring_config={"scorers": ["keyword"]}, parameter_space={"type": "grid", "params": {"model": ["a", "b"]}}, ) assert resp.status_code == 201 data = resp.json() assert data["name"] == "Full Experiment" assert data["sample_data"] == {"key": "value"} assert data["parameter_space"]["type"] == "grid" def test_create_requires_auth(self, client, project): resp = client.post( f"/api/experiments/?project_id={project.id}", json={"name": "Test"}, ) assert resp.status_code == 401 def test_create_requires_project_id(self, client, admin_user, auth_headers): resp = client.post( "/api/experiments/", json={"name": "Test"}, headers=auth_headers, ) assert resp.status_code == 422 def test_create_validates_project_exists(self, client, admin_user, auth_headers): fake_id = str(uuid.uuid4()) resp = client.post( f"/api/experiments/?project_id={fake_id}", json={"name": "Test"}, headers=auth_headers, ) assert resp.status_code == 404 def test_create_validates_name(self, client, project, auth_headers): resp = client.post( f"/api/experiments/?project_id={project.id}", json={"name": ""}, headers=auth_headers, ) assert resp.status_code == 422 class TestListExperiments: def test_list_empty(self, client, admin_user, auth_headers): resp = client.get("/api/experiments/", headers=auth_headers) assert resp.status_code == 200 data = resp.json() assert data["items"] == [] assert data["total"] == 0 def test_list_multiple(self, client, project, auth_headers): for name in ["Alpha", "Beta", "Gamma"]: _create_experiment(client, project.id, auth_headers, name=name) resp = client.get("/api/experiments/", headers=auth_headers) assert resp.status_code == 200 data = resp.json() assert data["total"] == 3 def test_list_filter_by_project(self, client, project, admin_user, auth_headers, db_session): _create_experiment(client, project.id, auth_headers, name="In Project") # Create another project + experiment from models import Project proj2 = Project(name="Other", owner_id=admin_user.id) db_session.add(proj2) db_session.commit() db_session.refresh(proj2) _create_experiment(client, proj2.id, auth_headers, name="Other Experiment") resp = client.get( f"/api/experiments/?project_id={project.id}", headers=auth_headers, ) assert resp.status_code == 200 data = resp.json() assert data["total"] == 1 assert data["items"][0]["name"] == "In Project" def test_list_requires_auth(self, client): resp = client.get("/api/experiments/") assert resp.status_code == 401 class TestGetExperiment: def test_get_existing(self, client, project, auth_headers): create_resp = _create_experiment(client, project.id, auth_headers) exp_id = create_resp.json()["id"] resp = client.get(f"/api/experiments/{exp_id}", headers=auth_headers) assert resp.status_code == 200 assert resp.json()["name"] == "Test Experiment" def test_get_not_found(self, client, admin_user, auth_headers): fake_id = str(uuid.uuid4()) resp = client.get(f"/api/experiments/{fake_id}", headers=auth_headers) assert resp.status_code == 404 class TestUpdateExperiment: def test_update_name(self, client, project, auth_headers): create_resp = _create_experiment(client, project.id, auth_headers) exp_id = create_resp.json()["id"] resp = client.put( f"/api/experiments/{exp_id}", json={"name": "Updated Name"}, headers=auth_headers, ) assert resp.status_code == 200 assert resp.json()["name"] == "Updated Name" def test_update_status(self, client, project, auth_headers): create_resp = _create_experiment(client, project.id, auth_headers) exp_id = create_resp.json()["id"] resp = client.put( f"/api/experiments/{exp_id}", json={"status": "completed"}, headers=auth_headers, ) assert resp.status_code == 200 assert resp.json()["status"] == "completed" def test_update_parameter_space(self, client, project, auth_headers): create_resp = _create_experiment(client, project.id, auth_headers) exp_id = create_resp.json()["id"] new_space = {"type": "random", "params": {"temperature": [0.1, 0.5, 0.9]}} resp = client.put( f"/api/experiments/{exp_id}", json={"parameter_space": new_space}, headers=auth_headers, ) assert resp.status_code == 200 assert resp.json()["parameter_space"] == new_space def test_update_not_found(self, client, admin_user, auth_headers): fake_id = str(uuid.uuid4()) resp = client.put( f"/api/experiments/{fake_id}", json={"name": "X"}, headers=auth_headers, ) assert resp.status_code == 404 class TestDeleteExperiment: def test_delete_existing(self, client, project, auth_headers): create_resp = _create_experiment(client, project.id, auth_headers) exp_id = create_resp.json()["id"] resp = client.delete(f"/api/experiments/{exp_id}", headers=auth_headers) assert resp.status_code == 204 resp = client.get(f"/api/experiments/{exp_id}", headers=auth_headers) assert resp.status_code == 404 def test_delete_not_found(self, client, admin_user, auth_headers): fake_id = str(uuid.uuid4()) resp = client.delete(f"/api/experiments/{fake_id}", headers=auth_headers) assert resp.status_code == 404 # --------------------------------------------------------------------------- # Sweep control tests # --------------------------------------------------------------------------- class TestStartSweep: def test_start_sweep_dispatches(self, client, project, auth_headers): create_resp = _create_experiment( client, project.id, auth_headers, parameter_space={"type": "grid", "params": {"model": ["a", "b"]}}, ) exp_id = create_resp.json()["id"] with patch("routers.experiments.dispatch_sweep") as mock_dispatch: mock_dispatch.return_value = MagicMock() resp = client.post( f"/api/experiments/{exp_id}/sweep", json={"sweep_type": "grid"}, headers=auth_headers, ) assert resp.status_code == 200 data = resp.json() assert data["experiment_id"] == exp_id mock_dispatch.assert_called_once() call_args = mock_dispatch.call_args assert call_args[0][0] == exp_id assert call_args[1]["sweep_config"]["type"] == "grid" def test_start_sweep_with_custom_config(self, client, project, auth_headers): create_resp = _create_experiment(client, project.id, auth_headers) exp_id = create_resp.json()["id"] with patch("routers.experiments.dispatch_sweep") as mock_dispatch: mock_dispatch.return_value = MagicMock() resp = client.post( f"/api/experiments/{exp_id}/sweep", json={ "sweep_type": "random", "params": {"temperature": [0.1, 0.5, 0.9]}, "n_trials": 50, }, headers=auth_headers, ) assert resp.status_code == 200 call_args = mock_dispatch.call_args assert call_args[1]["sweep_config"]["type"] == "random" assert call_args[1]["sweep_config"]["n_trials"] == 50 def test_start_sweep_conflict_if_running(self, client, project, auth_headers, db_session): create_resp = _create_experiment(client, project.id, auth_headers) exp_id = create_resp.json()["id"] # Set experiment to running from models import Experiment, ExperimentStatus exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first() exp.status = ExperimentStatus.running db_session.commit() resp = client.post( f"/api/experiments/{exp_id}/sweep", json={"sweep_type": "grid"}, headers=auth_headers, ) assert resp.status_code == 409 def test_start_sweep_validates_type(self, client, project, auth_headers): create_resp = _create_experiment(client, project.id, auth_headers) exp_id = create_resp.json()["id"] resp = client.post( f"/api/experiments/{exp_id}/sweep", json={"sweep_type": "invalid"}, headers=auth_headers, ) assert resp.status_code == 422 def test_start_sweep_not_found(self, client, admin_user, auth_headers): fake_id = str(uuid.uuid4()) resp = client.post( f"/api/experiments/{fake_id}/sweep", json={"sweep_type": "grid"}, headers=auth_headers, ) assert resp.status_code == 404 def test_start_sweep_requires_auth(self, client, project): create_resp = None resp = client.post( f"/api/experiments/{uuid.uuid4()}/sweep", json={"sweep_type": "grid"}, ) assert resp.status_code == 401 class TestSweepStatus: def test_sweep_status(self, client, project, auth_headers, db_session): create_resp = _create_experiment(client, project.id, auth_headers) exp_id = create_resp.json()["id"] # Create some run records from models import Run, RunStatus for s in [RunStatus.completed, RunStatus.completed, RunStatus.failed, RunStatus.pending]: run = Run( experiment_id=uuid.UUID(exp_id), config_hash="abc123", config={"test": True}, status=s, ) db_session.add(run) db_session.commit() resp = client.get( f"/api/experiments/{exp_id}/sweep/status", headers=auth_headers, ) assert resp.status_code == 200 data = resp.json() assert data["total_runs"] == 4 assert data["completed_runs"] == 2 assert data["failed_runs"] == 1 assert data["pending_runs"] == 1 assert data["status"] == "draft" class TestPauseSweep: def test_pause_running_no_redis(self, client, project, auth_headers, db_session): create_resp = _create_experiment(client, project.id, auth_headers) exp_id = create_resp.json()["id"] from models import Experiment, ExperimentStatus exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first() exp.status = ExperimentStatus.running db_session.commit() resp = client.post(f"/api/experiments/{exp_id}/pause", headers=auth_headers) assert resp.status_code == 200 assert resp.json()["action"] == "pause" # Verify status changed db_session.refresh(exp) assert exp.status == ExperimentStatus.paused def test_pause_not_running(self, client, project, auth_headers): create_resp = _create_experiment(client, project.id, auth_headers) exp_id = create_resp.json()["id"] resp = client.post(f"/api/experiments/{exp_id}/pause", headers=auth_headers) assert resp.status_code == 409 def test_pause_with_redis(self, client, project, auth_headers, db_session): create_resp = _create_experiment(client, project.id, auth_headers) exp_id = create_resp.json()["id"] from models import Experiment, ExperimentStatus exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first() exp.status = ExperimentStatus.running db_session.commit() mock_redis = MagicMock() with patch("routers.experiments.get_redis", return_value=mock_redis): with patch("routers.experiments.request_pause") as mock_pause: resp = client.post(f"/api/experiments/{exp_id}/pause", headers=auth_headers) assert resp.status_code == 200 mock_pause.assert_called_once_with(mock_redis, uuid.UUID(exp_id)) class TestResumeSweep: def test_resume_paused(self, client, project, auth_headers, db_session): create_resp = _create_experiment(client, project.id, auth_headers) exp_id = create_resp.json()["id"] from models import Experiment, ExperimentStatus exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first() exp.status = ExperimentStatus.paused db_session.commit() with patch("routers.experiments.dispatch_sweep") as mock_dispatch: mock_dispatch.return_value = MagicMock() resp = client.post(f"/api/experiments/{exp_id}/resume", headers=auth_headers) assert resp.status_code == 200 assert resp.json()["action"] == "resume" mock_dispatch.assert_called_once() def test_resume_not_paused(self, client, project, auth_headers): create_resp = _create_experiment(client, project.id, auth_headers) exp_id = create_resp.json()["id"] resp = client.post(f"/api/experiments/{exp_id}/resume", headers=auth_headers) assert resp.status_code == 409 class TestStopSweep: def test_stop_running_no_redis(self, client, project, auth_headers, db_session): create_resp = _create_experiment(client, project.id, auth_headers) exp_id = create_resp.json()["id"] from models import Experiment, ExperimentStatus, Run, RunStatus exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first() exp.status = ExperimentStatus.running db_session.commit() # Create a pending run run = Run( experiment_id=uuid.UUID(exp_id), config_hash="abc", config={"test": True}, status=RunStatus.pending, ) db_session.add(run) db_session.commit() db_session.refresh(run) resp = client.post(f"/api/experiments/{exp_id}/stop", headers=auth_headers) assert resp.status_code == 200 assert resp.json()["action"] == "stop" # Verify pending run is now failed db_session.refresh(run) assert run.status == RunStatus.failed # Verify experiment is completed db_session.refresh(exp) assert exp.status == ExperimentStatus.completed def test_stop_not_running(self, client, project, auth_headers): create_resp = _create_experiment(client, project.id, auth_headers) exp_id = create_resp.json()["id"] resp = client.post(f"/api/experiments/{exp_id}/stop", headers=auth_headers) assert resp.status_code == 409 def test_stop_paused(self, client, project, auth_headers, db_session): """Can stop a paused experiment too.""" create_resp = _create_experiment(client, project.id, auth_headers) exp_id = create_resp.json()["id"] from models import Experiment, ExperimentStatus exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first() exp.status = ExperimentStatus.paused db_session.commit() resp = client.post(f"/api/experiments/{exp_id}/stop", headers=auth_headers) assert resp.status_code == 200 def test_stop_with_redis(self, client, project, auth_headers, db_session): create_resp = _create_experiment(client, project.id, auth_headers) exp_id = create_resp.json()["id"] from models import Experiment, ExperimentStatus exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first() exp.status = ExperimentStatus.running db_session.commit() mock_redis = MagicMock() with patch("routers.experiments.get_redis", return_value=mock_redis): with patch("routers.experiments.request_stop") as mock_stop: resp = client.post(f"/api/experiments/{exp_id}/stop", headers=auth_headers) assert resp.status_code == 200 mock_stop.assert_called_once_with(mock_redis, uuid.UUID(exp_id))