Add complete experiments API: list (with project filter), get, create, update, delete, plus sweep lifecycle (start/pause/resume/stop/status). Adds SweepRequest and SweepStatusResponse schemas. Sweep dispatch routes through Celery with synchronous fallback for single-container mode. Redis flags control pause/resume/stop; direct DB updates used when Redis unavailable. 34 tests.
554 lines
20 KiB
Python
554 lines
20 KiB
Python
"""Tests for backend/routers/experiments.py — Experiment CRUD + sweep control."""
|
|
|
|
import os
|
|
import uuid
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
|
|
JWT_SECRET = "test-secret-key-for-jwt-signing"
|
|
API_KEY = "test-api-key-12345"
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _isolate_settings(tmp_path):
|
|
"""Ensure tests use a temp SQLite DB and no Redis."""
|
|
env = {
|
|
"DATABASE_URL": f"sqlite:///{tmp_path / 'test.db'}",
|
|
"REDIS_URL": "",
|
|
"DATA_DIR": str(tmp_path),
|
|
"JWT_SECRET": JWT_SECRET,
|
|
"API_KEY": API_KEY,
|
|
}
|
|
with patch.dict(os.environ, env, clear=False):
|
|
import config
|
|
new_settings = config.Settings(_env_file=None)
|
|
config.settings = new_settings
|
|
|
|
import main
|
|
main.settings = new_settings
|
|
main._init_db()
|
|
main._init_redis()
|
|
|
|
from models import Base
|
|
Base.metadata.create_all(bind=main.engine)
|
|
|
|
import auth
|
|
auth.settings = new_settings
|
|
|
|
yield
|
|
|
|
|
|
@pytest.fixture
|
|
def db_session():
|
|
from main import get_db
|
|
gen = get_db()
|
|
session = next(gen)
|
|
yield session
|
|
try:
|
|
next(gen)
|
|
except StopIteration:
|
|
pass
|
|
|
|
|
|
@pytest.fixture
|
|
def admin_user(db_session):
|
|
from auth import hash_password
|
|
from models import User
|
|
user = User(username="admin", password_hash=hash_password("adminpass"), is_admin=True)
|
|
db_session.add(user)
|
|
db_session.commit()
|
|
db_session.refresh(user)
|
|
return user
|
|
|
|
|
|
@pytest.fixture
|
|
def project(db_session, admin_user):
|
|
from models import Project
|
|
proj = Project(name="Test Project", description="A test project", owner_id=admin_user.id)
|
|
db_session.add(proj)
|
|
db_session.commit()
|
|
db_session.refresh(proj)
|
|
return proj
|
|
|
|
|
|
@pytest.fixture
|
|
def auth_headers():
|
|
return {"X-Api-Key": API_KEY}
|
|
|
|
|
|
@pytest.fixture
|
|
def client():
|
|
from main import app
|
|
return TestClient(app)
|
|
|
|
|
|
def _create_experiment(client, project_id, auth_headers, **overrides):
|
|
"""Helper to create an experiment via the API."""
|
|
body = {
|
|
"name": overrides.get("name", "Test Experiment"),
|
|
"description": overrides.get("description", "A test experiment"),
|
|
}
|
|
for key in ("sample_data", "pipeline_stages", "scoring_config", "parameter_space"):
|
|
if key in overrides:
|
|
body[key] = overrides[key]
|
|
resp = client.post(
|
|
f"/api/experiments/?project_id={project_id}",
|
|
json=body,
|
|
headers=auth_headers,
|
|
)
|
|
return resp
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CRUD tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCreateExperiment:
|
|
def test_create_minimal(self, client, project, auth_headers):
|
|
resp = _create_experiment(client, project.id, auth_headers)
|
|
assert resp.status_code == 201
|
|
data = resp.json()
|
|
assert data["name"] == "Test Experiment"
|
|
assert data["project_id"] == str(project.id)
|
|
assert data["status"] == "draft"
|
|
assert "id" in data
|
|
|
|
def test_create_with_all_fields(self, client, project, auth_headers):
|
|
resp = _create_experiment(
|
|
client, project.id, auth_headers,
|
|
name="Full Experiment",
|
|
description="Detailed description",
|
|
sample_data={"key": "value"},
|
|
pipeline_stages={"stages": [{"prompt_template": "Hello {{input}}"}]},
|
|
scoring_config={"scorers": ["keyword"]},
|
|
parameter_space={"type": "grid", "params": {"model": ["a", "b"]}},
|
|
)
|
|
assert resp.status_code == 201
|
|
data = resp.json()
|
|
assert data["name"] == "Full Experiment"
|
|
assert data["sample_data"] == {"key": "value"}
|
|
assert data["parameter_space"]["type"] == "grid"
|
|
|
|
def test_create_requires_auth(self, client, project):
|
|
resp = client.post(
|
|
f"/api/experiments/?project_id={project.id}",
|
|
json={"name": "Test"},
|
|
)
|
|
assert resp.status_code == 401
|
|
|
|
def test_create_requires_project_id(self, client, admin_user, auth_headers):
|
|
resp = client.post(
|
|
"/api/experiments/",
|
|
json={"name": "Test"},
|
|
headers=auth_headers,
|
|
)
|
|
assert resp.status_code == 422
|
|
|
|
def test_create_validates_project_exists(self, client, admin_user, auth_headers):
|
|
fake_id = str(uuid.uuid4())
|
|
resp = client.post(
|
|
f"/api/experiments/?project_id={fake_id}",
|
|
json={"name": "Test"},
|
|
headers=auth_headers,
|
|
)
|
|
assert resp.status_code == 404
|
|
|
|
def test_create_validates_name(self, client, project, auth_headers):
|
|
resp = client.post(
|
|
f"/api/experiments/?project_id={project.id}",
|
|
json={"name": ""},
|
|
headers=auth_headers,
|
|
)
|
|
assert resp.status_code == 422
|
|
|
|
|
|
class TestListExperiments:
|
|
def test_list_empty(self, client, admin_user, auth_headers):
|
|
resp = client.get("/api/experiments/", headers=auth_headers)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["items"] == []
|
|
assert data["total"] == 0
|
|
|
|
def test_list_multiple(self, client, project, auth_headers):
|
|
for name in ["Alpha", "Beta", "Gamma"]:
|
|
_create_experiment(client, project.id, auth_headers, name=name)
|
|
|
|
resp = client.get("/api/experiments/", headers=auth_headers)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["total"] == 3
|
|
|
|
def test_list_filter_by_project(self, client, project, admin_user, auth_headers, db_session):
|
|
_create_experiment(client, project.id, auth_headers, name="In Project")
|
|
|
|
# Create another project + experiment
|
|
from models import Project
|
|
proj2 = Project(name="Other", owner_id=admin_user.id)
|
|
db_session.add(proj2)
|
|
db_session.commit()
|
|
db_session.refresh(proj2)
|
|
_create_experiment(client, proj2.id, auth_headers, name="Other Experiment")
|
|
|
|
resp = client.get(
|
|
f"/api/experiments/?project_id={project.id}",
|
|
headers=auth_headers,
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["total"] == 1
|
|
assert data["items"][0]["name"] == "In Project"
|
|
|
|
def test_list_requires_auth(self, client):
|
|
resp = client.get("/api/experiments/")
|
|
assert resp.status_code == 401
|
|
|
|
|
|
class TestGetExperiment:
|
|
def test_get_existing(self, client, project, auth_headers):
|
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
|
exp_id = create_resp.json()["id"]
|
|
|
|
resp = client.get(f"/api/experiments/{exp_id}", headers=auth_headers)
|
|
assert resp.status_code == 200
|
|
assert resp.json()["name"] == "Test Experiment"
|
|
|
|
def test_get_not_found(self, client, admin_user, auth_headers):
|
|
fake_id = str(uuid.uuid4())
|
|
resp = client.get(f"/api/experiments/{fake_id}", headers=auth_headers)
|
|
assert resp.status_code == 404
|
|
|
|
|
|
class TestUpdateExperiment:
|
|
def test_update_name(self, client, project, auth_headers):
|
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
|
exp_id = create_resp.json()["id"]
|
|
|
|
resp = client.put(
|
|
f"/api/experiments/{exp_id}",
|
|
json={"name": "Updated Name"},
|
|
headers=auth_headers,
|
|
)
|
|
assert resp.status_code == 200
|
|
assert resp.json()["name"] == "Updated Name"
|
|
|
|
def test_update_status(self, client, project, auth_headers):
|
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
|
exp_id = create_resp.json()["id"]
|
|
|
|
resp = client.put(
|
|
f"/api/experiments/{exp_id}",
|
|
json={"status": "completed"},
|
|
headers=auth_headers,
|
|
)
|
|
assert resp.status_code == 200
|
|
assert resp.json()["status"] == "completed"
|
|
|
|
def test_update_parameter_space(self, client, project, auth_headers):
|
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
|
exp_id = create_resp.json()["id"]
|
|
|
|
new_space = {"type": "random", "params": {"temperature": [0.1, 0.5, 0.9]}}
|
|
resp = client.put(
|
|
f"/api/experiments/{exp_id}",
|
|
json={"parameter_space": new_space},
|
|
headers=auth_headers,
|
|
)
|
|
assert resp.status_code == 200
|
|
assert resp.json()["parameter_space"] == new_space
|
|
|
|
def test_update_not_found(self, client, admin_user, auth_headers):
|
|
fake_id = str(uuid.uuid4())
|
|
resp = client.put(
|
|
f"/api/experiments/{fake_id}",
|
|
json={"name": "X"},
|
|
headers=auth_headers,
|
|
)
|
|
assert resp.status_code == 404
|
|
|
|
|
|
class TestDeleteExperiment:
|
|
def test_delete_existing(self, client, project, auth_headers):
|
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
|
exp_id = create_resp.json()["id"]
|
|
|
|
resp = client.delete(f"/api/experiments/{exp_id}", headers=auth_headers)
|
|
assert resp.status_code == 204
|
|
|
|
resp = client.get(f"/api/experiments/{exp_id}", headers=auth_headers)
|
|
assert resp.status_code == 404
|
|
|
|
def test_delete_not_found(self, client, admin_user, auth_headers):
|
|
fake_id = str(uuid.uuid4())
|
|
resp = client.delete(f"/api/experiments/{fake_id}", headers=auth_headers)
|
|
assert resp.status_code == 404
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Sweep control tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestStartSweep:
|
|
def test_start_sweep_dispatches(self, client, project, auth_headers):
|
|
create_resp = _create_experiment(
|
|
client, project.id, auth_headers,
|
|
parameter_space={"type": "grid", "params": {"model": ["a", "b"]}},
|
|
)
|
|
exp_id = create_resp.json()["id"]
|
|
|
|
with patch("routers.experiments.dispatch_sweep") as mock_dispatch:
|
|
mock_dispatch.return_value = MagicMock()
|
|
resp = client.post(
|
|
f"/api/experiments/{exp_id}/sweep",
|
|
json={"sweep_type": "grid"},
|
|
headers=auth_headers,
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["experiment_id"] == exp_id
|
|
mock_dispatch.assert_called_once()
|
|
call_args = mock_dispatch.call_args
|
|
assert call_args[0][0] == exp_id
|
|
assert call_args[1]["sweep_config"]["type"] == "grid"
|
|
|
|
def test_start_sweep_with_custom_config(self, client, project, auth_headers):
|
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
|
exp_id = create_resp.json()["id"]
|
|
|
|
with patch("routers.experiments.dispatch_sweep") as mock_dispatch:
|
|
mock_dispatch.return_value = MagicMock()
|
|
resp = client.post(
|
|
f"/api/experiments/{exp_id}/sweep",
|
|
json={
|
|
"sweep_type": "random",
|
|
"params": {"temperature": [0.1, 0.5, 0.9]},
|
|
"n_trials": 50,
|
|
},
|
|
headers=auth_headers,
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
call_args = mock_dispatch.call_args
|
|
assert call_args[1]["sweep_config"]["type"] == "random"
|
|
assert call_args[1]["sweep_config"]["n_trials"] == 50
|
|
|
|
def test_start_sweep_conflict_if_running(self, client, project, auth_headers, db_session):
|
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
|
exp_id = create_resp.json()["id"]
|
|
|
|
# Set experiment to running
|
|
from models import Experiment, ExperimentStatus
|
|
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
|
|
exp.status = ExperimentStatus.running
|
|
db_session.commit()
|
|
|
|
resp = client.post(
|
|
f"/api/experiments/{exp_id}/sweep",
|
|
json={"sweep_type": "grid"},
|
|
headers=auth_headers,
|
|
)
|
|
assert resp.status_code == 409
|
|
|
|
def test_start_sweep_validates_type(self, client, project, auth_headers):
|
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
|
exp_id = create_resp.json()["id"]
|
|
|
|
resp = client.post(
|
|
f"/api/experiments/{exp_id}/sweep",
|
|
json={"sweep_type": "invalid"},
|
|
headers=auth_headers,
|
|
)
|
|
assert resp.status_code == 422
|
|
|
|
def test_start_sweep_not_found(self, client, admin_user, auth_headers):
|
|
fake_id = str(uuid.uuid4())
|
|
resp = client.post(
|
|
f"/api/experiments/{fake_id}/sweep",
|
|
json={"sweep_type": "grid"},
|
|
headers=auth_headers,
|
|
)
|
|
assert resp.status_code == 404
|
|
|
|
def test_start_sweep_requires_auth(self, client, project):
|
|
create_resp = None
|
|
resp = client.post(
|
|
f"/api/experiments/{uuid.uuid4()}/sweep",
|
|
json={"sweep_type": "grid"},
|
|
)
|
|
assert resp.status_code == 401
|
|
|
|
|
|
class TestSweepStatus:
|
|
def test_sweep_status(self, client, project, auth_headers, db_session):
|
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
|
exp_id = create_resp.json()["id"]
|
|
|
|
# Create some run records
|
|
from models import Run, RunStatus
|
|
for s in [RunStatus.completed, RunStatus.completed, RunStatus.failed, RunStatus.pending]:
|
|
run = Run(
|
|
experiment_id=uuid.UUID(exp_id),
|
|
config_hash="abc123",
|
|
config={"test": True},
|
|
status=s,
|
|
)
|
|
db_session.add(run)
|
|
db_session.commit()
|
|
|
|
resp = client.get(
|
|
f"/api/experiments/{exp_id}/sweep/status",
|
|
headers=auth_headers,
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["total_runs"] == 4
|
|
assert data["completed_runs"] == 2
|
|
assert data["failed_runs"] == 1
|
|
assert data["pending_runs"] == 1
|
|
assert data["status"] == "draft"
|
|
|
|
|
|
class TestPauseSweep:
|
|
def test_pause_running_no_redis(self, client, project, auth_headers, db_session):
|
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
|
exp_id = create_resp.json()["id"]
|
|
|
|
from models import Experiment, ExperimentStatus
|
|
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
|
|
exp.status = ExperimentStatus.running
|
|
db_session.commit()
|
|
|
|
resp = client.post(f"/api/experiments/{exp_id}/pause", headers=auth_headers)
|
|
assert resp.status_code == 200
|
|
assert resp.json()["action"] == "pause"
|
|
|
|
# Verify status changed
|
|
db_session.refresh(exp)
|
|
assert exp.status == ExperimentStatus.paused
|
|
|
|
def test_pause_not_running(self, client, project, auth_headers):
|
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
|
exp_id = create_resp.json()["id"]
|
|
|
|
resp = client.post(f"/api/experiments/{exp_id}/pause", headers=auth_headers)
|
|
assert resp.status_code == 409
|
|
|
|
def test_pause_with_redis(self, client, project, auth_headers, db_session):
|
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
|
exp_id = create_resp.json()["id"]
|
|
|
|
from models import Experiment, ExperimentStatus
|
|
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
|
|
exp.status = ExperimentStatus.running
|
|
db_session.commit()
|
|
|
|
mock_redis = MagicMock()
|
|
with patch("routers.experiments.get_redis", return_value=mock_redis):
|
|
with patch("routers.experiments.request_pause") as mock_pause:
|
|
resp = client.post(f"/api/experiments/{exp_id}/pause", headers=auth_headers)
|
|
|
|
assert resp.status_code == 200
|
|
mock_pause.assert_called_once_with(mock_redis, uuid.UUID(exp_id))
|
|
|
|
|
|
class TestResumeSweep:
|
|
def test_resume_paused(self, client, project, auth_headers, db_session):
|
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
|
exp_id = create_resp.json()["id"]
|
|
|
|
from models import Experiment, ExperimentStatus
|
|
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
|
|
exp.status = ExperimentStatus.paused
|
|
db_session.commit()
|
|
|
|
with patch("routers.experiments.dispatch_sweep") as mock_dispatch:
|
|
mock_dispatch.return_value = MagicMock()
|
|
resp = client.post(f"/api/experiments/{exp_id}/resume", headers=auth_headers)
|
|
|
|
assert resp.status_code == 200
|
|
assert resp.json()["action"] == "resume"
|
|
mock_dispatch.assert_called_once()
|
|
|
|
def test_resume_not_paused(self, client, project, auth_headers):
|
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
|
exp_id = create_resp.json()["id"]
|
|
|
|
resp = client.post(f"/api/experiments/{exp_id}/resume", headers=auth_headers)
|
|
assert resp.status_code == 409
|
|
|
|
|
|
class TestStopSweep:
|
|
def test_stop_running_no_redis(self, client, project, auth_headers, db_session):
|
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
|
exp_id = create_resp.json()["id"]
|
|
|
|
from models import Experiment, ExperimentStatus, Run, RunStatus
|
|
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
|
|
exp.status = ExperimentStatus.running
|
|
db_session.commit()
|
|
|
|
# Create a pending run
|
|
run = Run(
|
|
experiment_id=uuid.UUID(exp_id),
|
|
config_hash="abc",
|
|
config={"test": True},
|
|
status=RunStatus.pending,
|
|
)
|
|
db_session.add(run)
|
|
db_session.commit()
|
|
db_session.refresh(run)
|
|
|
|
resp = client.post(f"/api/experiments/{exp_id}/stop", headers=auth_headers)
|
|
assert resp.status_code == 200
|
|
assert resp.json()["action"] == "stop"
|
|
|
|
# Verify pending run is now failed
|
|
db_session.refresh(run)
|
|
assert run.status == RunStatus.failed
|
|
|
|
# Verify experiment is completed
|
|
db_session.refresh(exp)
|
|
assert exp.status == ExperimentStatus.completed
|
|
|
|
def test_stop_not_running(self, client, project, auth_headers):
|
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
|
exp_id = create_resp.json()["id"]
|
|
|
|
resp = client.post(f"/api/experiments/{exp_id}/stop", headers=auth_headers)
|
|
assert resp.status_code == 409
|
|
|
|
def test_stop_paused(self, client, project, auth_headers, db_session):
|
|
"""Can stop a paused experiment too."""
|
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
|
exp_id = create_resp.json()["id"]
|
|
|
|
from models import Experiment, ExperimentStatus
|
|
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
|
|
exp.status = ExperimentStatus.paused
|
|
db_session.commit()
|
|
|
|
resp = client.post(f"/api/experiments/{exp_id}/stop", headers=auth_headers)
|
|
assert resp.status_code == 200
|
|
|
|
def test_stop_with_redis(self, client, project, auth_headers, db_session):
|
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
|
exp_id = create_resp.json()["id"]
|
|
|
|
from models import Experiment, ExperimentStatus
|
|
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
|
|
exp.status = ExperimentStatus.running
|
|
db_session.commit()
|
|
|
|
mock_redis = MagicMock()
|
|
with patch("routers.experiments.get_redis", return_value=mock_redis):
|
|
with patch("routers.experiments.request_stop") as mock_stop:
|
|
resp = client.post(f"/api/experiments/{exp_id}/stop", headers=auth_headers)
|
|
|
|
assert resp.status_code == 200
|
|
mock_stop.assert_called_once_with(mock_redis, uuid.UUID(exp_id))
|