promptlooper/backend/tests/test_experiments.py
John Lightner 82e97e9dba MAESTRO: Implement experiments router with full CRUD and sweep control endpoints
Add complete experiments API: list (with project filter), get, create, update,
delete, plus sweep lifecycle (start/pause/resume/stop/status). Adds
SweepRequest and SweepStatusResponse schemas. Sweep dispatch routes through
Celery with synchronous fallback for single-container mode. Redis flags control
pause/resume/stop; direct DB updates used when Redis unavailable. 34 tests.
2026-04-07 03:19:43 -05:00

554 lines
20 KiB
Python

"""Tests for backend/routers/experiments.py — Experiment CRUD + sweep control."""
import os
import uuid
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
JWT_SECRET = "test-secret-key-for-jwt-signing"
API_KEY = "test-api-key-12345"
@pytest.fixture(autouse=True)
def _isolate_settings(tmp_path):
"""Ensure tests use a temp SQLite DB and no Redis."""
env = {
"DATABASE_URL": f"sqlite:///{tmp_path / 'test.db'}",
"REDIS_URL": "",
"DATA_DIR": str(tmp_path),
"JWT_SECRET": JWT_SECRET,
"API_KEY": API_KEY,
}
with patch.dict(os.environ, env, clear=False):
import config
new_settings = config.Settings(_env_file=None)
config.settings = new_settings
import main
main.settings = new_settings
main._init_db()
main._init_redis()
from models import Base
Base.metadata.create_all(bind=main.engine)
import auth
auth.settings = new_settings
yield
@pytest.fixture
def db_session():
from main import get_db
gen = get_db()
session = next(gen)
yield session
try:
next(gen)
except StopIteration:
pass
@pytest.fixture
def admin_user(db_session):
from auth import hash_password
from models import User
user = User(username="admin", password_hash=hash_password("adminpass"), is_admin=True)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
return user
@pytest.fixture
def project(db_session, admin_user):
from models import Project
proj = Project(name="Test Project", description="A test project", owner_id=admin_user.id)
db_session.add(proj)
db_session.commit()
db_session.refresh(proj)
return proj
@pytest.fixture
def auth_headers():
return {"X-Api-Key": API_KEY}
@pytest.fixture
def client():
from main import app
return TestClient(app)
def _create_experiment(client, project_id, auth_headers, **overrides):
"""Helper to create an experiment via the API."""
body = {
"name": overrides.get("name", "Test Experiment"),
"description": overrides.get("description", "A test experiment"),
}
for key in ("sample_data", "pipeline_stages", "scoring_config", "parameter_space"):
if key in overrides:
body[key] = overrides[key]
resp = client.post(
f"/api/experiments/?project_id={project_id}",
json=body,
headers=auth_headers,
)
return resp
# ---------------------------------------------------------------------------
# CRUD tests
# ---------------------------------------------------------------------------
class TestCreateExperiment:
def test_create_minimal(self, client, project, auth_headers):
resp = _create_experiment(client, project.id, auth_headers)
assert resp.status_code == 201
data = resp.json()
assert data["name"] == "Test Experiment"
assert data["project_id"] == str(project.id)
assert data["status"] == "draft"
assert "id" in data
def test_create_with_all_fields(self, client, project, auth_headers):
resp = _create_experiment(
client, project.id, auth_headers,
name="Full Experiment",
description="Detailed description",
sample_data={"key": "value"},
pipeline_stages={"stages": [{"prompt_template": "Hello {{input}}"}]},
scoring_config={"scorers": ["keyword"]},
parameter_space={"type": "grid", "params": {"model": ["a", "b"]}},
)
assert resp.status_code == 201
data = resp.json()
assert data["name"] == "Full Experiment"
assert data["sample_data"] == {"key": "value"}
assert data["parameter_space"]["type"] == "grid"
def test_create_requires_auth(self, client, project):
resp = client.post(
f"/api/experiments/?project_id={project.id}",
json={"name": "Test"},
)
assert resp.status_code == 401
def test_create_requires_project_id(self, client, admin_user, auth_headers):
resp = client.post(
"/api/experiments/",
json={"name": "Test"},
headers=auth_headers,
)
assert resp.status_code == 422
def test_create_validates_project_exists(self, client, admin_user, auth_headers):
fake_id = str(uuid.uuid4())
resp = client.post(
f"/api/experiments/?project_id={fake_id}",
json={"name": "Test"},
headers=auth_headers,
)
assert resp.status_code == 404
def test_create_validates_name(self, client, project, auth_headers):
resp = client.post(
f"/api/experiments/?project_id={project.id}",
json={"name": ""},
headers=auth_headers,
)
assert resp.status_code == 422
class TestListExperiments:
def test_list_empty(self, client, admin_user, auth_headers):
resp = client.get("/api/experiments/", headers=auth_headers)
assert resp.status_code == 200
data = resp.json()
assert data["items"] == []
assert data["total"] == 0
def test_list_multiple(self, client, project, auth_headers):
for name in ["Alpha", "Beta", "Gamma"]:
_create_experiment(client, project.id, auth_headers, name=name)
resp = client.get("/api/experiments/", headers=auth_headers)
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 3
def test_list_filter_by_project(self, client, project, admin_user, auth_headers, db_session):
_create_experiment(client, project.id, auth_headers, name="In Project")
# Create another project + experiment
from models import Project
proj2 = Project(name="Other", owner_id=admin_user.id)
db_session.add(proj2)
db_session.commit()
db_session.refresh(proj2)
_create_experiment(client, proj2.id, auth_headers, name="Other Experiment")
resp = client.get(
f"/api/experiments/?project_id={project.id}",
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 1
assert data["items"][0]["name"] == "In Project"
def test_list_requires_auth(self, client):
resp = client.get("/api/experiments/")
assert resp.status_code == 401
class TestGetExperiment:
def test_get_existing(self, client, project, auth_headers):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
resp = client.get(f"/api/experiments/{exp_id}", headers=auth_headers)
assert resp.status_code == 200
assert resp.json()["name"] == "Test Experiment"
def test_get_not_found(self, client, admin_user, auth_headers):
fake_id = str(uuid.uuid4())
resp = client.get(f"/api/experiments/{fake_id}", headers=auth_headers)
assert resp.status_code == 404
class TestUpdateExperiment:
def test_update_name(self, client, project, auth_headers):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
resp = client.put(
f"/api/experiments/{exp_id}",
json={"name": "Updated Name"},
headers=auth_headers,
)
assert resp.status_code == 200
assert resp.json()["name"] == "Updated Name"
def test_update_status(self, client, project, auth_headers):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
resp = client.put(
f"/api/experiments/{exp_id}",
json={"status": "completed"},
headers=auth_headers,
)
assert resp.status_code == 200
assert resp.json()["status"] == "completed"
def test_update_parameter_space(self, client, project, auth_headers):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
new_space = {"type": "random", "params": {"temperature": [0.1, 0.5, 0.9]}}
resp = client.put(
f"/api/experiments/{exp_id}",
json={"parameter_space": new_space},
headers=auth_headers,
)
assert resp.status_code == 200
assert resp.json()["parameter_space"] == new_space
def test_update_not_found(self, client, admin_user, auth_headers):
fake_id = str(uuid.uuid4())
resp = client.put(
f"/api/experiments/{fake_id}",
json={"name": "X"},
headers=auth_headers,
)
assert resp.status_code == 404
class TestDeleteExperiment:
def test_delete_existing(self, client, project, auth_headers):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
resp = client.delete(f"/api/experiments/{exp_id}", headers=auth_headers)
assert resp.status_code == 204
resp = client.get(f"/api/experiments/{exp_id}", headers=auth_headers)
assert resp.status_code == 404
def test_delete_not_found(self, client, admin_user, auth_headers):
fake_id = str(uuid.uuid4())
resp = client.delete(f"/api/experiments/{fake_id}", headers=auth_headers)
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# Sweep control tests
# ---------------------------------------------------------------------------
class TestStartSweep:
def test_start_sweep_dispatches(self, client, project, auth_headers):
create_resp = _create_experiment(
client, project.id, auth_headers,
parameter_space={"type": "grid", "params": {"model": ["a", "b"]}},
)
exp_id = create_resp.json()["id"]
with patch("routers.experiments.dispatch_sweep") as mock_dispatch:
mock_dispatch.return_value = MagicMock()
resp = client.post(
f"/api/experiments/{exp_id}/sweep",
json={"sweep_type": "grid"},
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["experiment_id"] == exp_id
mock_dispatch.assert_called_once()
call_args = mock_dispatch.call_args
assert call_args[0][0] == exp_id
assert call_args[1]["sweep_config"]["type"] == "grid"
def test_start_sweep_with_custom_config(self, client, project, auth_headers):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
with patch("routers.experiments.dispatch_sweep") as mock_dispatch:
mock_dispatch.return_value = MagicMock()
resp = client.post(
f"/api/experiments/{exp_id}/sweep",
json={
"sweep_type": "random",
"params": {"temperature": [0.1, 0.5, 0.9]},
"n_trials": 50,
},
headers=auth_headers,
)
assert resp.status_code == 200
call_args = mock_dispatch.call_args
assert call_args[1]["sweep_config"]["type"] == "random"
assert call_args[1]["sweep_config"]["n_trials"] == 50
def test_start_sweep_conflict_if_running(self, client, project, auth_headers, db_session):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
# Set experiment to running
from models import Experiment, ExperimentStatus
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
exp.status = ExperimentStatus.running
db_session.commit()
resp = client.post(
f"/api/experiments/{exp_id}/sweep",
json={"sweep_type": "grid"},
headers=auth_headers,
)
assert resp.status_code == 409
def test_start_sweep_validates_type(self, client, project, auth_headers):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
resp = client.post(
f"/api/experiments/{exp_id}/sweep",
json={"sweep_type": "invalid"},
headers=auth_headers,
)
assert resp.status_code == 422
def test_start_sweep_not_found(self, client, admin_user, auth_headers):
fake_id = str(uuid.uuid4())
resp = client.post(
f"/api/experiments/{fake_id}/sweep",
json={"sweep_type": "grid"},
headers=auth_headers,
)
assert resp.status_code == 404
def test_start_sweep_requires_auth(self, client, project):
create_resp = None
resp = client.post(
f"/api/experiments/{uuid.uuid4()}/sweep",
json={"sweep_type": "grid"},
)
assert resp.status_code == 401
class TestSweepStatus:
def test_sweep_status(self, client, project, auth_headers, db_session):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
# Create some run records
from models import Run, RunStatus
for s in [RunStatus.completed, RunStatus.completed, RunStatus.failed, RunStatus.pending]:
run = Run(
experiment_id=uuid.UUID(exp_id),
config_hash="abc123",
config={"test": True},
status=s,
)
db_session.add(run)
db_session.commit()
resp = client.get(
f"/api/experiments/{exp_id}/sweep/status",
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["total_runs"] == 4
assert data["completed_runs"] == 2
assert data["failed_runs"] == 1
assert data["pending_runs"] == 1
assert data["status"] == "draft"
class TestPauseSweep:
def test_pause_running_no_redis(self, client, project, auth_headers, db_session):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
from models import Experiment, ExperimentStatus
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
exp.status = ExperimentStatus.running
db_session.commit()
resp = client.post(f"/api/experiments/{exp_id}/pause", headers=auth_headers)
assert resp.status_code == 200
assert resp.json()["action"] == "pause"
# Verify status changed
db_session.refresh(exp)
assert exp.status == ExperimentStatus.paused
def test_pause_not_running(self, client, project, auth_headers):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
resp = client.post(f"/api/experiments/{exp_id}/pause", headers=auth_headers)
assert resp.status_code == 409
def test_pause_with_redis(self, client, project, auth_headers, db_session):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
from models import Experiment, ExperimentStatus
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
exp.status = ExperimentStatus.running
db_session.commit()
mock_redis = MagicMock()
with patch("routers.experiments.get_redis", return_value=mock_redis):
with patch("routers.experiments.request_pause") as mock_pause:
resp = client.post(f"/api/experiments/{exp_id}/pause", headers=auth_headers)
assert resp.status_code == 200
mock_pause.assert_called_once_with(mock_redis, uuid.UUID(exp_id))
class TestResumeSweep:
def test_resume_paused(self, client, project, auth_headers, db_session):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
from models import Experiment, ExperimentStatus
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
exp.status = ExperimentStatus.paused
db_session.commit()
with patch("routers.experiments.dispatch_sweep") as mock_dispatch:
mock_dispatch.return_value = MagicMock()
resp = client.post(f"/api/experiments/{exp_id}/resume", headers=auth_headers)
assert resp.status_code == 200
assert resp.json()["action"] == "resume"
mock_dispatch.assert_called_once()
def test_resume_not_paused(self, client, project, auth_headers):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
resp = client.post(f"/api/experiments/{exp_id}/resume", headers=auth_headers)
assert resp.status_code == 409
class TestStopSweep:
def test_stop_running_no_redis(self, client, project, auth_headers, db_session):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
from models import Experiment, ExperimentStatus, Run, RunStatus
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
exp.status = ExperimentStatus.running
db_session.commit()
# Create a pending run
run = Run(
experiment_id=uuid.UUID(exp_id),
config_hash="abc",
config={"test": True},
status=RunStatus.pending,
)
db_session.add(run)
db_session.commit()
db_session.refresh(run)
resp = client.post(f"/api/experiments/{exp_id}/stop", headers=auth_headers)
assert resp.status_code == 200
assert resp.json()["action"] == "stop"
# Verify pending run is now failed
db_session.refresh(run)
assert run.status == RunStatus.failed
# Verify experiment is completed
db_session.refresh(exp)
assert exp.status == ExperimentStatus.completed
def test_stop_not_running(self, client, project, auth_headers):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
resp = client.post(f"/api/experiments/{exp_id}/stop", headers=auth_headers)
assert resp.status_code == 409
def test_stop_paused(self, client, project, auth_headers, db_session):
"""Can stop a paused experiment too."""
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
from models import Experiment, ExperimentStatus
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
exp.status = ExperimentStatus.paused
db_session.commit()
resp = client.post(f"/api/experiments/{exp_id}/stop", headers=auth_headers)
assert resp.status_code == 200
def test_stop_with_redis(self, client, project, auth_headers, db_session):
create_resp = _create_experiment(client, project.id, auth_headers)
exp_id = create_resp.json()["id"]
from models import Experiment, ExperimentStatus
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
exp.status = ExperimentStatus.running
db_session.commit()
mock_redis = MagicMock()
with patch("routers.experiments.get_redis", return_value=mock_redis):
with patch("routers.experiments.request_stop") as mock_stop:
resp = client.post(f"/api/experiments/{exp_id}/stop", headers=auth_headers)
assert resp.status_code == 200
mock_stop.assert_called_once_with(mock_redis, uuid.UUID(exp_id))