MAESTRO: Implement experiments router with full CRUD and sweep control endpoints
Add complete experiments API: list (with project filter), get, create, update, delete, plus sweep lifecycle (start/pause/resume/stop/status). Adds SweepRequest and SweepStatusResponse schemas. Sweep dispatch routes through Celery with synchronous fallback for single-container mode. Redis flags control pause/resume/stop; direct DB updates used when Redis unavailable. 34 tests.
This commit is contained in:
parent
59f18a11c3
commit
82e97e9dba
5 changed files with 881 additions and 49 deletions
|
|
@ -35,7 +35,8 @@ Implement the core experiment execution engine: LLM adapters, response caching,
|
||||||
- [x] Implement backend/routers/endpoints.py fully — CRUD for LLM endpoint configurations. The test endpoint should call adapter.test_connection() and adapter.list_models() and return the results. Store endpoint configs in the database with encrypted API keys (Fernet symmetric encryption, key derived from JWT_SECRET).
|
- [x] Implement backend/routers/endpoints.py fully — CRUD for LLM endpoint configurations. The test endpoint should call adapter.test_connection() and adapter.list_models() and return the results. Store endpoint configs in the database with encrypted API keys (Fernet symmetric encryption, key derived from JWT_SECRET).
|
||||||
<!-- Completed: Full CRUD (list/get/create/update/delete) + test_connection endpoint. LLMEndpoint model added to models.py. Fernet encryption via encryption.py (PBKDF2 key derivation from JWT_SECRET). API keys never exposed in responses; has_api_key boolean flag added to EndpointResponse. 25 tests in test_endpoints.py, all passing. -->
|
<!-- Completed: Full CRUD (list/get/create/update/delete) + test_connection endpoint. LLMEndpoint model added to models.py. Fernet encryption via encryption.py (PBKDF2 key derivation from JWT_SECRET). API keys never exposed in responses; has_api_key boolean flag added to EndpointResponse. 25 tests in test_endpoints.py, all passing. -->
|
||||||
|
|
||||||
- [ ] Implement backend/routers/experiments.py fully — CRUD plus sweep control. POST /experiments/{id}/sweep should validate the sweep config, create Run records for all configurations, and dispatch to Celery. Pause/resume/stop should set Redis flags that the sweep runner checks between runs.
|
- [x] Implement backend/routers/experiments.py fully — CRUD plus sweep control. POST /experiments/{id}/sweep should validate the sweep config, create Run records for all configurations, and dispatch to Celery. Pause/resume/stop should set Redis flags that the sweep runner checks between runs.
|
||||||
|
<!-- Completed: Full CRUD (list with project filter, get, create, update, delete) + sweep control (start/pause/resume/stop + status). SweepRequest/SweepStatusResponse schemas added. Sweep dispatch via Celery/sync fallback. Redis flags for pause/resume/stop, with single-container mode fallback. 34 tests in test_experiments.py, all passing. -->
|
||||||
|
|
||||||
- [ ] Implement backend/routers/runs.py fully — list runs with filtering (by experiment, status, score range), get run detail with stage results and scores, POST for ad-hoc single runs, and POST /{id}/score for human ratings. Include the leaderboard endpoint that returns top N runs ranked by weighted score.
|
- [ ] Implement backend/routers/runs.py fully — list runs with filtering (by experiment, status, score range), get run detail with stage results and scores, POST for ad-hoc single runs, and POST /{id}/score for human ratings. Include the leaderboard endpoint that returns top N runs ranked by weighted score.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,61 +1,315 @@
|
||||||
"""Experiments router — CRUD and sweep controls."""
|
"""Experiments router — CRUD plus sweep control.
|
||||||
|
|
||||||
|
POST /experiments/{id}/sweep validates sweep config, creates Run records,
|
||||||
|
and dispatches to Celery. Pause/resume/stop set Redis flags that the sweep
|
||||||
|
runner checks between runs.
|
||||||
|
"""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from fastapi import APIRouter, Response
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from auth import get_current_user
|
||||||
|
from engine.sweep import clear_sweep_flags, request_pause, request_resume, request_stop
|
||||||
|
from engine.tasks import dispatch_sweep
|
||||||
|
from main import get_db, get_redis
|
||||||
|
from models import Experiment, ExperimentStatus, Project, Run, RunStatus, User
|
||||||
|
from schemas import (
|
||||||
|
ExperimentCreate,
|
||||||
|
ExperimentListResponse,
|
||||||
|
ExperimentResponse,
|
||||||
|
ExperimentUpdate,
|
||||||
|
SweepRequest,
|
||||||
|
SweepStatusResponse,
|
||||||
|
)
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", status_code=501)
|
# ---------------------------------------------------------------------------
|
||||||
def list_experiments():
|
# Helpers
|
||||||
"""List experiments (filter by project)."""
|
# ---------------------------------------------------------------------------
|
||||||
return Response(status_code=501, content="Not Implemented")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/", status_code=501)
|
def _get_experiment_or_404(db: Session, experiment_id: uuid.UUID) -> Experiment:
|
||||||
def create_experiment():
|
experiment = db.query(Experiment).filter(Experiment.id == experiment_id).first()
|
||||||
"""Create experiment."""
|
if experiment is None:
|
||||||
return Response(status_code=501, content="Not Implemented")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Experiment not found")
|
||||||
|
return experiment
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{experiment_id}", status_code=501)
|
def _to_response(experiment: Experiment) -> ExperimentResponse:
|
||||||
def get_experiment(experiment_id: uuid.UUID):
|
return ExperimentResponse.model_validate(experiment)
|
||||||
"""Experiment detail with run summaries."""
|
|
||||||
return Response(status_code=501, content="Not Implemented")
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{experiment_id}", status_code=501)
|
# ---------------------------------------------------------------------------
|
||||||
def update_experiment(experiment_id: uuid.UUID):
|
# CRUD
|
||||||
"""Update experiment config."""
|
# ---------------------------------------------------------------------------
|
||||||
return Response(status_code=501, content="Not Implemented")
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{experiment_id}", status_code=501)
|
@router.get("/", response_model=ExperimentListResponse)
|
||||||
def delete_experiment(experiment_id: uuid.UUID):
|
def list_experiments(
|
||||||
"""Delete experiment."""
|
project_id: uuid.UUID | None = Query(None),
|
||||||
return Response(status_code=501, content="Not Implemented")
|
db: Session = Depends(get_db),
|
||||||
|
_user: User = Depends(get_current_user),
|
||||||
|
) -> ExperimentListResponse:
|
||||||
|
"""List experiments, optionally filtered by project."""
|
||||||
|
query = db.query(Experiment)
|
||||||
|
if project_id is not None:
|
||||||
|
query = query.filter(Experiment.project_id == project_id)
|
||||||
|
experiments = query.order_by(Experiment.created_at.desc()).all()
|
||||||
|
return ExperimentListResponse(
|
||||||
|
items=[_to_response(exp) for exp in experiments],
|
||||||
|
total=len(experiments),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{experiment_id}/sweep", status_code=501)
|
@router.post("/", response_model=ExperimentResponse, status_code=status.HTTP_201_CREATED)
|
||||||
def start_sweep(experiment_id: uuid.UUID):
|
def create_experiment(
|
||||||
"""Start a sweep (grid, random, or guided)."""
|
body: ExperimentCreate,
|
||||||
return Response(status_code=501, content="Not Implemented")
|
project_id: uuid.UUID = Query(...),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
_user: User = Depends(get_current_user),
|
||||||
|
) -> ExperimentResponse:
|
||||||
|
"""Create a new experiment within a project."""
|
||||||
|
# Verify project exists
|
||||||
|
project = db.query(Project).filter(Project.id == project_id).first()
|
||||||
|
if project is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Project not found")
|
||||||
|
|
||||||
|
experiment = Experiment(
|
||||||
|
project_id=project_id,
|
||||||
|
name=body.name,
|
||||||
|
description=body.description,
|
||||||
|
sample_data=body.sample_data,
|
||||||
|
pipeline_stages=body.pipeline_stages,
|
||||||
|
scoring_config=body.scoring_config,
|
||||||
|
parameter_space=body.parameter_space,
|
||||||
|
)
|
||||||
|
db.add(experiment)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(experiment)
|
||||||
|
return _to_response(experiment)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{experiment_id}/pause", status_code=501)
|
@router.get("/{experiment_id}", response_model=ExperimentResponse)
|
||||||
def pause_sweep(experiment_id: uuid.UUID):
|
def get_experiment(
|
||||||
"""Pause running sweep."""
|
experiment_id: uuid.UUID,
|
||||||
return Response(status_code=501, content="Not Implemented")
|
db: Session = Depends(get_db),
|
||||||
|
_user: User = Depends(get_current_user),
|
||||||
|
) -> ExperimentResponse:
|
||||||
|
"""Get a single experiment."""
|
||||||
|
experiment = _get_experiment_or_404(db, experiment_id)
|
||||||
|
return _to_response(experiment)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{experiment_id}/resume", status_code=501)
|
@router.put("/{experiment_id}", response_model=ExperimentResponse)
|
||||||
def resume_sweep(experiment_id: uuid.UUID):
|
def update_experiment(
|
||||||
"""Resume paused sweep."""
|
experiment_id: uuid.UUID,
|
||||||
return Response(status_code=501, content="Not Implemented")
|
body: ExperimentUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
_user: User = Depends(get_current_user),
|
||||||
|
) -> ExperimentResponse:
|
||||||
|
"""Update an experiment's configuration."""
|
||||||
|
experiment = _get_experiment_or_404(db, experiment_id)
|
||||||
|
|
||||||
|
if body.name is not None:
|
||||||
|
experiment.name = body.name
|
||||||
|
if body.description is not None:
|
||||||
|
experiment.description = body.description
|
||||||
|
if body.sample_data is not None:
|
||||||
|
experiment.sample_data = body.sample_data
|
||||||
|
if body.pipeline_stages is not None:
|
||||||
|
experiment.pipeline_stages = body.pipeline_stages
|
||||||
|
if body.scoring_config is not None:
|
||||||
|
experiment.scoring_config = body.scoring_config
|
||||||
|
if body.parameter_space is not None:
|
||||||
|
experiment.parameter_space = body.parameter_space
|
||||||
|
if body.status is not None:
|
||||||
|
experiment.status = body.status
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(experiment)
|
||||||
|
return _to_response(experiment)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{experiment_id}/stop", status_code=501)
|
@router.delete("/{experiment_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
def stop_sweep(experiment_id: uuid.UUID):
|
def delete_experiment(
|
||||||
"""Stop sweep."""
|
experiment_id: uuid.UUID,
|
||||||
return Response(status_code=501, content="Not Implemented")
|
db: Session = Depends(get_db),
|
||||||
|
_user: User = Depends(get_current_user),
|
||||||
|
) -> None:
|
||||||
|
"""Delete an experiment and all its runs."""
|
||||||
|
experiment = _get_experiment_or_404(db, experiment_id)
|
||||||
|
db.delete(experiment)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Sweep control
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{experiment_id}/sweep", response_model=SweepStatusResponse)
|
||||||
|
def start_sweep(
|
||||||
|
experiment_id: uuid.UUID,
|
||||||
|
body: SweepRequest,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
_user: User = Depends(get_current_user),
|
||||||
|
) -> SweepStatusResponse:
|
||||||
|
"""Start a sweep — validate config, create Run records, dispatch to Celery.
|
||||||
|
|
||||||
|
The sweep_config from the request body overrides the experiment's
|
||||||
|
parameter_space for this sweep execution.
|
||||||
|
"""
|
||||||
|
experiment = _get_experiment_or_404(db, experiment_id)
|
||||||
|
|
||||||
|
if experiment.status == ExperimentStatus.running:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail="Experiment already has a running sweep",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build the sweep config dict that the engine expects
|
||||||
|
sweep_config = {
|
||||||
|
"type": body.sweep_type,
|
||||||
|
"params": body.params or experiment.parameter_space.get("params", {}) if experiment.parameter_space else {},
|
||||||
|
"n_trials": body.n_trials,
|
||||||
|
"top_k": body.top_k,
|
||||||
|
"explore_ratio": body.explore_ratio,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Dispatch via Celery / sync fallback
|
||||||
|
dispatch_sweep(str(experiment.id), sweep_config=sweep_config)
|
||||||
|
|
||||||
|
# Refresh to get the updated status (dispatch_sweep may have set it)
|
||||||
|
db.refresh(experiment)
|
||||||
|
|
||||||
|
# Count runs
|
||||||
|
run_counts = _get_run_counts(db, experiment.id)
|
||||||
|
|
||||||
|
return SweepStatusResponse(
|
||||||
|
experiment_id=experiment.id,
|
||||||
|
status=experiment.status,
|
||||||
|
**run_counts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{experiment_id}/sweep/status", response_model=SweepStatusResponse)
|
||||||
|
def sweep_status(
|
||||||
|
experiment_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
_user: User = Depends(get_current_user),
|
||||||
|
) -> SweepStatusResponse:
|
||||||
|
"""Get the current sweep status for an experiment."""
|
||||||
|
experiment = _get_experiment_or_404(db, experiment_id)
|
||||||
|
run_counts = _get_run_counts(db, experiment.id)
|
||||||
|
return SweepStatusResponse(
|
||||||
|
experiment_id=experiment.id,
|
||||||
|
status=experiment.status,
|
||||||
|
**run_counts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{experiment_id}/pause", status_code=status.HTTP_200_OK)
|
||||||
|
def pause_sweep(
|
||||||
|
experiment_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
"""Pause a running sweep by setting the Redis pause flag."""
|
||||||
|
experiment = _get_experiment_or_404(db, experiment_id)
|
||||||
|
if experiment.status != ExperimentStatus.running:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail="Experiment is not currently running",
|
||||||
|
)
|
||||||
|
|
||||||
|
redis_client = get_redis()
|
||||||
|
if redis_client is not None:
|
||||||
|
request_pause(redis_client, experiment.id)
|
||||||
|
else:
|
||||||
|
# In single-container mode, directly set paused status
|
||||||
|
experiment.status = ExperimentStatus.paused
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return {"experiment_id": str(experiment.id), "action": "pause"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{experiment_id}/resume", status_code=status.HTTP_200_OK)
|
||||||
|
def resume_sweep(
|
||||||
|
experiment_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
"""Resume a paused sweep by clearing the Redis pause flag."""
|
||||||
|
experiment = _get_experiment_or_404(db, experiment_id)
|
||||||
|
if experiment.status != ExperimentStatus.paused:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail="Experiment is not currently paused",
|
||||||
|
)
|
||||||
|
|
||||||
|
redis_client = get_redis()
|
||||||
|
if redis_client is not None:
|
||||||
|
request_resume(redis_client, experiment.id)
|
||||||
|
|
||||||
|
# Re-dispatch the sweep to continue
|
||||||
|
experiment.status = ExperimentStatus.running
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
dispatch_sweep(str(experiment.id))
|
||||||
|
|
||||||
|
return {"experiment_id": str(experiment.id), "action": "resume"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{experiment_id}/stop", status_code=status.HTTP_200_OK)
|
||||||
|
def stop_sweep(
|
||||||
|
experiment_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
"""Stop a running or paused sweep."""
|
||||||
|
experiment = _get_experiment_or_404(db, experiment_id)
|
||||||
|
if experiment.status not in (ExperimentStatus.running, ExperimentStatus.paused):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail="Experiment is not running or paused",
|
||||||
|
)
|
||||||
|
|
||||||
|
redis_client = get_redis()
|
||||||
|
if redis_client is not None:
|
||||||
|
request_stop(redis_client, experiment.id)
|
||||||
|
else:
|
||||||
|
# Mark remaining pending runs as failed
|
||||||
|
pending_runs = (
|
||||||
|
db.query(Run)
|
||||||
|
.filter(Run.experiment_id == experiment.id, Run.status == RunStatus.pending)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
for run in pending_runs:
|
||||||
|
run.status = RunStatus.failed
|
||||||
|
experiment.status = ExperimentStatus.completed
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return {"experiment_id": str(experiment.id), "action": "stop"}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Internal helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _get_run_counts(db: Session, experiment_id: uuid.UUID) -> dict:
|
||||||
|
"""Return run counts by status for a sweep status response."""
|
||||||
|
runs = db.query(Run).filter(Run.experiment_id == experiment_id).all()
|
||||||
|
completed = sum(1 for r in runs if r.status == RunStatus.completed)
|
||||||
|
failed = sum(1 for r in runs if r.status == RunStatus.failed)
|
||||||
|
pending = sum(1 for r in runs if r.status == RunStatus.pending)
|
||||||
|
return {
|
||||||
|
"total_runs": len(runs),
|
||||||
|
"completed_runs": completed,
|
||||||
|
"failed_runs": failed,
|
||||||
|
"pending_runs": pending,
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -91,6 +91,29 @@ class ExperimentListResponse(BaseModel):
|
||||||
total: int
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Sweep
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class SweepRequest(BaseModel):
|
||||||
|
"""Request body for starting a sweep on an experiment."""
|
||||||
|
|
||||||
|
sweep_type: str = Field("grid", pattern="^(grid|random|guided)$")
|
||||||
|
params: dict | None = None
|
||||||
|
n_trials: int = Field(100, ge=1, le=100000)
|
||||||
|
top_k: int = Field(5, ge=1)
|
||||||
|
explore_ratio: float = Field(0.3, ge=0.0, le=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
class SweepStatusResponse(BaseModel):
|
||||||
|
experiment_id: uuid.UUID
|
||||||
|
status: ExperimentStatus
|
||||||
|
total_runs: int
|
||||||
|
completed_runs: int
|
||||||
|
failed_runs: int
|
||||||
|
pending_runs: int
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Run
|
# Run
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
|
||||||
554
backend/tests/test_experiments.py
Normal file
554
backend/tests/test_experiments.py
Normal file
|
|
@ -0,0 +1,554 @@
|
||||||
|
"""Tests for backend/routers/experiments.py — Experiment CRUD + sweep control."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
|
||||||
|
JWT_SECRET = "test-secret-key-for-jwt-signing"
|
||||||
|
API_KEY = "test-api-key-12345"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _isolate_settings(tmp_path):
|
||||||
|
"""Ensure tests use a temp SQLite DB and no Redis."""
|
||||||
|
env = {
|
||||||
|
"DATABASE_URL": f"sqlite:///{tmp_path / 'test.db'}",
|
||||||
|
"REDIS_URL": "",
|
||||||
|
"DATA_DIR": str(tmp_path),
|
||||||
|
"JWT_SECRET": JWT_SECRET,
|
||||||
|
"API_KEY": API_KEY,
|
||||||
|
}
|
||||||
|
with patch.dict(os.environ, env, clear=False):
|
||||||
|
import config
|
||||||
|
new_settings = config.Settings(_env_file=None)
|
||||||
|
config.settings = new_settings
|
||||||
|
|
||||||
|
import main
|
||||||
|
main.settings = new_settings
|
||||||
|
main._init_db()
|
||||||
|
main._init_redis()
|
||||||
|
|
||||||
|
from models import Base
|
||||||
|
Base.metadata.create_all(bind=main.engine)
|
||||||
|
|
||||||
|
import auth
|
||||||
|
auth.settings = new_settings
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_session():
|
||||||
|
from main import get_db
|
||||||
|
gen = get_db()
|
||||||
|
session = next(gen)
|
||||||
|
yield session
|
||||||
|
try:
|
||||||
|
next(gen)
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def admin_user(db_session):
|
||||||
|
from auth import hash_password
|
||||||
|
from models import User
|
||||||
|
user = User(username="admin", password_hash=hash_password("adminpass"), is_admin=True)
|
||||||
|
db_session.add(user)
|
||||||
|
db_session.commit()
|
||||||
|
db_session.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def project(db_session, admin_user):
|
||||||
|
from models import Project
|
||||||
|
proj = Project(name="Test Project", description="A test project", owner_id=admin_user.id)
|
||||||
|
db_session.add(proj)
|
||||||
|
db_session.commit()
|
||||||
|
db_session.refresh(proj)
|
||||||
|
return proj
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def auth_headers():
|
||||||
|
return {"X-Api-Key": API_KEY}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
from main import app
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_experiment(client, project_id, auth_headers, **overrides):
|
||||||
|
"""Helper to create an experiment via the API."""
|
||||||
|
body = {
|
||||||
|
"name": overrides.get("name", "Test Experiment"),
|
||||||
|
"description": overrides.get("description", "A test experiment"),
|
||||||
|
}
|
||||||
|
for key in ("sample_data", "pipeline_stages", "scoring_config", "parameter_space"):
|
||||||
|
if key in overrides:
|
||||||
|
body[key] = overrides[key]
|
||||||
|
resp = client.post(
|
||||||
|
f"/api/experiments/?project_id={project_id}",
|
||||||
|
json=body,
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CRUD tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateExperiment:
|
||||||
|
def test_create_minimal(self, client, project, auth_headers):
|
||||||
|
resp = _create_experiment(client, project.id, auth_headers)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
data = resp.json()
|
||||||
|
assert data["name"] == "Test Experiment"
|
||||||
|
assert data["project_id"] == str(project.id)
|
||||||
|
assert data["status"] == "draft"
|
||||||
|
assert "id" in data
|
||||||
|
|
||||||
|
def test_create_with_all_fields(self, client, project, auth_headers):
|
||||||
|
resp = _create_experiment(
|
||||||
|
client, project.id, auth_headers,
|
||||||
|
name="Full Experiment",
|
||||||
|
description="Detailed description",
|
||||||
|
sample_data={"key": "value"},
|
||||||
|
pipeline_stages={"stages": [{"prompt_template": "Hello {{input}}"}]},
|
||||||
|
scoring_config={"scorers": ["keyword"]},
|
||||||
|
parameter_space={"type": "grid", "params": {"model": ["a", "b"]}},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
data = resp.json()
|
||||||
|
assert data["name"] == "Full Experiment"
|
||||||
|
assert data["sample_data"] == {"key": "value"}
|
||||||
|
assert data["parameter_space"]["type"] == "grid"
|
||||||
|
|
||||||
|
def test_create_requires_auth(self, client, project):
|
||||||
|
resp = client.post(
|
||||||
|
f"/api/experiments/?project_id={project.id}",
|
||||||
|
json={"name": "Test"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
def test_create_requires_project_id(self, client, admin_user, auth_headers):
|
||||||
|
resp = client.post(
|
||||||
|
"/api/experiments/",
|
||||||
|
json={"name": "Test"},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
def test_create_validates_project_exists(self, client, admin_user, auth_headers):
|
||||||
|
fake_id = str(uuid.uuid4())
|
||||||
|
resp = client.post(
|
||||||
|
f"/api/experiments/?project_id={fake_id}",
|
||||||
|
json={"name": "Test"},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_create_validates_name(self, client, project, auth_headers):
|
||||||
|
resp = client.post(
|
||||||
|
f"/api/experiments/?project_id={project.id}",
|
||||||
|
json={"name": ""},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
class TestListExperiments:
|
||||||
|
def test_list_empty(self, client, admin_user, auth_headers):
|
||||||
|
resp = client.get("/api/experiments/", headers=auth_headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["items"] == []
|
||||||
|
assert data["total"] == 0
|
||||||
|
|
||||||
|
def test_list_multiple(self, client, project, auth_headers):
|
||||||
|
for name in ["Alpha", "Beta", "Gamma"]:
|
||||||
|
_create_experiment(client, project.id, auth_headers, name=name)
|
||||||
|
|
||||||
|
resp = client.get("/api/experiments/", headers=auth_headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["total"] == 3
|
||||||
|
|
||||||
|
def test_list_filter_by_project(self, client, project, admin_user, auth_headers, db_session):
|
||||||
|
_create_experiment(client, project.id, auth_headers, name="In Project")
|
||||||
|
|
||||||
|
# Create another project + experiment
|
||||||
|
from models import Project
|
||||||
|
proj2 = Project(name="Other", owner_id=admin_user.id)
|
||||||
|
db_session.add(proj2)
|
||||||
|
db_session.commit()
|
||||||
|
db_session.refresh(proj2)
|
||||||
|
_create_experiment(client, proj2.id, auth_headers, name="Other Experiment")
|
||||||
|
|
||||||
|
resp = client.get(
|
||||||
|
f"/api/experiments/?project_id={project.id}",
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["total"] == 1
|
||||||
|
assert data["items"][0]["name"] == "In Project"
|
||||||
|
|
||||||
|
def test_list_requires_auth(self, client):
|
||||||
|
resp = client.get("/api/experiments/")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetExperiment:
|
||||||
|
def test_get_existing(self, client, project, auth_headers):
|
||||||
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
||||||
|
exp_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
resp = client.get(f"/api/experiments/{exp_id}", headers=auth_headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["name"] == "Test Experiment"
|
||||||
|
|
||||||
|
def test_get_not_found(self, client, admin_user, auth_headers):
|
||||||
|
fake_id = str(uuid.uuid4())
|
||||||
|
resp = client.get(f"/api/experiments/{fake_id}", headers=auth_headers)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpdateExperiment:
|
||||||
|
def test_update_name(self, client, project, auth_headers):
|
||||||
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
||||||
|
exp_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
resp = client.put(
|
||||||
|
f"/api/experiments/{exp_id}",
|
||||||
|
json={"name": "Updated Name"},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["name"] == "Updated Name"
|
||||||
|
|
||||||
|
def test_update_status(self, client, project, auth_headers):
|
||||||
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
||||||
|
exp_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
resp = client.put(
|
||||||
|
f"/api/experiments/{exp_id}",
|
||||||
|
json={"status": "completed"},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["status"] == "completed"
|
||||||
|
|
||||||
|
def test_update_parameter_space(self, client, project, auth_headers):
|
||||||
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
||||||
|
exp_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
new_space = {"type": "random", "params": {"temperature": [0.1, 0.5, 0.9]}}
|
||||||
|
resp = client.put(
|
||||||
|
f"/api/experiments/{exp_id}",
|
||||||
|
json={"parameter_space": new_space},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["parameter_space"] == new_space
|
||||||
|
|
||||||
|
def test_update_not_found(self, client, admin_user, auth_headers):
|
||||||
|
fake_id = str(uuid.uuid4())
|
||||||
|
resp = client.put(
|
||||||
|
f"/api/experiments/{fake_id}",
|
||||||
|
json={"name": "X"},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeleteExperiment:
|
||||||
|
def test_delete_existing(self, client, project, auth_headers):
|
||||||
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
||||||
|
exp_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
resp = client.delete(f"/api/experiments/{exp_id}", headers=auth_headers)
|
||||||
|
assert resp.status_code == 204
|
||||||
|
|
||||||
|
resp = client.get(f"/api/experiments/{exp_id}", headers=auth_headers)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_delete_not_found(self, client, admin_user, auth_headers):
|
||||||
|
fake_id = str(uuid.uuid4())
|
||||||
|
resp = client.delete(f"/api/experiments/{fake_id}", headers=auth_headers)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Sweep control tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestStartSweep:
|
||||||
|
def test_start_sweep_dispatches(self, client, project, auth_headers):
|
||||||
|
create_resp = _create_experiment(
|
||||||
|
client, project.id, auth_headers,
|
||||||
|
parameter_space={"type": "grid", "params": {"model": ["a", "b"]}},
|
||||||
|
)
|
||||||
|
exp_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
with patch("routers.experiments.dispatch_sweep") as mock_dispatch:
|
||||||
|
mock_dispatch.return_value = MagicMock()
|
||||||
|
resp = client.post(
|
||||||
|
f"/api/experiments/{exp_id}/sweep",
|
||||||
|
json={"sweep_type": "grid"},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["experiment_id"] == exp_id
|
||||||
|
mock_dispatch.assert_called_once()
|
||||||
|
call_args = mock_dispatch.call_args
|
||||||
|
assert call_args[0][0] == exp_id
|
||||||
|
assert call_args[1]["sweep_config"]["type"] == "grid"
|
||||||
|
|
||||||
|
def test_start_sweep_with_custom_config(self, client, project, auth_headers):
|
||||||
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
||||||
|
exp_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
with patch("routers.experiments.dispatch_sweep") as mock_dispatch:
|
||||||
|
mock_dispatch.return_value = MagicMock()
|
||||||
|
resp = client.post(
|
||||||
|
f"/api/experiments/{exp_id}/sweep",
|
||||||
|
json={
|
||||||
|
"sweep_type": "random",
|
||||||
|
"params": {"temperature": [0.1, 0.5, 0.9]},
|
||||||
|
"n_trials": 50,
|
||||||
|
},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
call_args = mock_dispatch.call_args
|
||||||
|
assert call_args[1]["sweep_config"]["type"] == "random"
|
||||||
|
assert call_args[1]["sweep_config"]["n_trials"] == 50
|
||||||
|
|
||||||
|
def test_start_sweep_conflict_if_running(self, client, project, auth_headers, db_session):
|
||||||
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
||||||
|
exp_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
# Set experiment to running
|
||||||
|
from models import Experiment, ExperimentStatus
|
||||||
|
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
|
||||||
|
exp.status = ExperimentStatus.running
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
resp = client.post(
|
||||||
|
f"/api/experiments/{exp_id}/sweep",
|
||||||
|
json={"sweep_type": "grid"},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert resp.status_code == 409
|
||||||
|
|
||||||
|
def test_start_sweep_validates_type(self, client, project, auth_headers):
|
||||||
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
||||||
|
exp_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
resp = client.post(
|
||||||
|
f"/api/experiments/{exp_id}/sweep",
|
||||||
|
json={"sweep_type": "invalid"},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
def test_start_sweep_not_found(self, client, admin_user, auth_headers):
|
||||||
|
fake_id = str(uuid.uuid4())
|
||||||
|
resp = client.post(
|
||||||
|
f"/api/experiments/{fake_id}/sweep",
|
||||||
|
json={"sweep_type": "grid"},
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_start_sweep_requires_auth(self, client, project):
|
||||||
|
create_resp = None
|
||||||
|
resp = client.post(
|
||||||
|
f"/api/experiments/{uuid.uuid4()}/sweep",
|
||||||
|
json={"sweep_type": "grid"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
class TestSweepStatus:
|
||||||
|
def test_sweep_status(self, client, project, auth_headers, db_session):
|
||||||
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
||||||
|
exp_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
# Create some run records
|
||||||
|
from models import Run, RunStatus
|
||||||
|
for s in [RunStatus.completed, RunStatus.completed, RunStatus.failed, RunStatus.pending]:
|
||||||
|
run = Run(
|
||||||
|
experiment_id=uuid.UUID(exp_id),
|
||||||
|
config_hash="abc123",
|
||||||
|
config={"test": True},
|
||||||
|
status=s,
|
||||||
|
)
|
||||||
|
db_session.add(run)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
resp = client.get(
|
||||||
|
f"/api/experiments/{exp_id}/sweep/status",
|
||||||
|
headers=auth_headers,
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["total_runs"] == 4
|
||||||
|
assert data["completed_runs"] == 2
|
||||||
|
assert data["failed_runs"] == 1
|
||||||
|
assert data["pending_runs"] == 1
|
||||||
|
assert data["status"] == "draft"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPauseSweep:
|
||||||
|
def test_pause_running_no_redis(self, client, project, auth_headers, db_session):
|
||||||
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
||||||
|
exp_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
from models import Experiment, ExperimentStatus
|
||||||
|
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
|
||||||
|
exp.status = ExperimentStatus.running
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
resp = client.post(f"/api/experiments/{exp_id}/pause", headers=auth_headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["action"] == "pause"
|
||||||
|
|
||||||
|
# Verify status changed
|
||||||
|
db_session.refresh(exp)
|
||||||
|
assert exp.status == ExperimentStatus.paused
|
||||||
|
|
||||||
|
def test_pause_not_running(self, client, project, auth_headers):
|
||||||
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
||||||
|
exp_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
resp = client.post(f"/api/experiments/{exp_id}/pause", headers=auth_headers)
|
||||||
|
assert resp.status_code == 409
|
||||||
|
|
||||||
|
def test_pause_with_redis(self, client, project, auth_headers, db_session):
|
||||||
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
||||||
|
exp_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
from models import Experiment, ExperimentStatus
|
||||||
|
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
|
||||||
|
exp.status = ExperimentStatus.running
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
mock_redis = MagicMock()
|
||||||
|
with patch("routers.experiments.get_redis", return_value=mock_redis):
|
||||||
|
with patch("routers.experiments.request_pause") as mock_pause:
|
||||||
|
resp = client.post(f"/api/experiments/{exp_id}/pause", headers=auth_headers)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
mock_pause.assert_called_once_with(mock_redis, uuid.UUID(exp_id))
|
||||||
|
|
||||||
|
|
||||||
|
class TestResumeSweep:
|
||||||
|
def test_resume_paused(self, client, project, auth_headers, db_session):
|
||||||
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
||||||
|
exp_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
from models import Experiment, ExperimentStatus
|
||||||
|
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
|
||||||
|
exp.status = ExperimentStatus.paused
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
with patch("routers.experiments.dispatch_sweep") as mock_dispatch:
|
||||||
|
mock_dispatch.return_value = MagicMock()
|
||||||
|
resp = client.post(f"/api/experiments/{exp_id}/resume", headers=auth_headers)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["action"] == "resume"
|
||||||
|
mock_dispatch.assert_called_once()
|
||||||
|
|
||||||
|
def test_resume_not_paused(self, client, project, auth_headers):
|
||||||
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
||||||
|
exp_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
resp = client.post(f"/api/experiments/{exp_id}/resume", headers=auth_headers)
|
||||||
|
assert resp.status_code == 409
|
||||||
|
|
||||||
|
|
||||||
|
class TestStopSweep:
|
||||||
|
def test_stop_running_no_redis(self, client, project, auth_headers, db_session):
|
||||||
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
||||||
|
exp_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
from models import Experiment, ExperimentStatus, Run, RunStatus
|
||||||
|
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
|
||||||
|
exp.status = ExperimentStatus.running
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Create a pending run
|
||||||
|
run = Run(
|
||||||
|
experiment_id=uuid.UUID(exp_id),
|
||||||
|
config_hash="abc",
|
||||||
|
config={"test": True},
|
||||||
|
status=RunStatus.pending,
|
||||||
|
)
|
||||||
|
db_session.add(run)
|
||||||
|
db_session.commit()
|
||||||
|
db_session.refresh(run)
|
||||||
|
|
||||||
|
resp = client.post(f"/api/experiments/{exp_id}/stop", headers=auth_headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["action"] == "stop"
|
||||||
|
|
||||||
|
# Verify pending run is now failed
|
||||||
|
db_session.refresh(run)
|
||||||
|
assert run.status == RunStatus.failed
|
||||||
|
|
||||||
|
# Verify experiment is completed
|
||||||
|
db_session.refresh(exp)
|
||||||
|
assert exp.status == ExperimentStatus.completed
|
||||||
|
|
||||||
|
def test_stop_not_running(self, client, project, auth_headers):
|
||||||
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
||||||
|
exp_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
resp = client.post(f"/api/experiments/{exp_id}/stop", headers=auth_headers)
|
||||||
|
assert resp.status_code == 409
|
||||||
|
|
||||||
|
def test_stop_paused(self, client, project, auth_headers, db_session):
|
||||||
|
"""Can stop a paused experiment too."""
|
||||||
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
||||||
|
exp_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
from models import Experiment, ExperimentStatus
|
||||||
|
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
|
||||||
|
exp.status = ExperimentStatus.paused
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
resp = client.post(f"/api/experiments/{exp_id}/stop", headers=auth_headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
def test_stop_with_redis(self, client, project, auth_headers, db_session):
|
||||||
|
create_resp = _create_experiment(client, project.id, auth_headers)
|
||||||
|
exp_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
from models import Experiment, ExperimentStatus
|
||||||
|
exp = db_session.query(Experiment).filter(Experiment.id == uuid.UUID(exp_id)).first()
|
||||||
|
exp.status = ExperimentStatus.running
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
mock_redis = MagicMock()
|
||||||
|
with patch("routers.experiments.get_redis", return_value=mock_redis):
|
||||||
|
with patch("routers.experiments.request_stop") as mock_stop:
|
||||||
|
resp = client.post(f"/api/experiments/{exp_id}/stop", headers=auth_headers)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
mock_stop.assert_called_once_with(mock_redis, uuid.UUID(exp_id))
|
||||||
|
|
@ -67,51 +67,51 @@ def test_projects_delete(client):
|
||||||
assert resp.status_code == 501
|
assert resp.status_code == 501
|
||||||
|
|
||||||
|
|
||||||
# ---- Experiments router (/api/experiments) ----
|
# ---- Experiments router (/api/experiments) — now implemented, requires auth ----
|
||||||
|
|
||||||
def test_experiments_list(client):
|
def test_experiments_list(client):
|
||||||
resp = client.get("/api/experiments/")
|
resp = client.get("/api/experiments/")
|
||||||
assert resp.status_code == 501
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
def test_experiments_create(client):
|
def test_experiments_create(client):
|
||||||
resp = client.post("/api/experiments/")
|
resp = client.post("/api/experiments/")
|
||||||
assert resp.status_code == 501
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
def test_experiments_get(client):
|
def test_experiments_get(client):
|
||||||
resp = client.get("/api/experiments/00000000-0000-0000-0000-000000000001")
|
resp = client.get("/api/experiments/00000000-0000-0000-0000-000000000001")
|
||||||
assert resp.status_code == 501
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
def test_experiments_update(client):
|
def test_experiments_update(client):
|
||||||
resp = client.put("/api/experiments/00000000-0000-0000-0000-000000000001")
|
resp = client.put("/api/experiments/00000000-0000-0000-0000-000000000001")
|
||||||
assert resp.status_code == 501
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
def test_experiments_delete(client):
|
def test_experiments_delete(client):
|
||||||
resp = client.delete("/api/experiments/00000000-0000-0000-0000-000000000001")
|
resp = client.delete("/api/experiments/00000000-0000-0000-0000-000000000001")
|
||||||
assert resp.status_code == 501
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
def test_experiments_sweep(client):
|
def test_experiments_sweep(client):
|
||||||
resp = client.post("/api/experiments/00000000-0000-0000-0000-000000000001/sweep")
|
resp = client.post("/api/experiments/00000000-0000-0000-0000-000000000001/sweep")
|
||||||
assert resp.status_code == 501
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
def test_experiments_pause(client):
|
def test_experiments_pause(client):
|
||||||
resp = client.post("/api/experiments/00000000-0000-0000-0000-000000000001/pause")
|
resp = client.post("/api/experiments/00000000-0000-0000-0000-000000000001/pause")
|
||||||
assert resp.status_code == 501
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
def test_experiments_resume(client):
|
def test_experiments_resume(client):
|
||||||
resp = client.post("/api/experiments/00000000-0000-0000-0000-000000000001/resume")
|
resp = client.post("/api/experiments/00000000-0000-0000-0000-000000000001/resume")
|
||||||
assert resp.status_code == 501
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
def test_experiments_stop(client):
|
def test_experiments_stop(client):
|
||||||
resp = client.post("/api/experiments/00000000-0000-0000-0000-000000000001/stop")
|
resp = client.post("/api/experiments/00000000-0000-0000-0000-000000000001/stop")
|
||||||
assert resp.status_code == 501
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
# ---- Runs router (/api/runs) ----
|
# ---- Runs router (/api/runs) ----
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue