promptlooper/backend/tests/test_tasks.py
John Lightner b16454994e MAESTRO: Implement Celery tasks (execute_run, execute_sweep) with synchronous fallback for single-container mode
Created engine/tasks.py with:
- execute_run and execute_sweep Celery tasks registered via autodiscover
- SyncTaskResult class mimicking Celery AsyncResult for in-process mode
- dispatch_run/dispatch_sweep helpers that route to Celery or sync based on config
- Proper async-to-sync bridging for the async engine functions
- 17 tests covering task execution, sync fallback, error handling, and Celery dispatch
2026-04-07 03:08:41 -05:00

387 lines
13 KiB
Python

"""Tests for engine/tasks.py — Celery task definitions and sync fallback."""
import asyncio
import sys
import uuid
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
# Ensure backend is on sys.path
_backend_dir = str(Path(__file__).resolve().parents[1])
if _backend_dir not in sys.path:
sys.path.insert(0, _backend_dir)
from engine.adapters.base import AdapterResponse, BaseAdapter
from engine.cache import ResponseCacheLayer, compute_config_hash
from engine.runner import EventBus, run_single
from engine.tasks import (
SyncTaskResult,
_do_execute_run,
_do_execute_sweep,
dispatch_run,
dispatch_sweep,
)
from models import (
Base,
Experiment,
ExperimentStatus,
Project,
Run,
RunStatus,
User,
)
# ---------------------------------------------------------------------------
# Fixtures / helpers
# ---------------------------------------------------------------------------
class MockAdapter(BaseAdapter):
"""Test adapter that returns fixed responses."""
def __init__(self, text: str = "mock output", tokens_in: int = 10, tokens_out: int = 5):
self._text = text
self._tokens_in = tokens_in
self._tokens_out = tokens_out
@property
def name(self) -> str:
return "mock"
async def complete(self, prompt: str, model: str, params: dict) -> AdapterResponse:
return AdapterResponse(
text=self._text,
tokens_in=self._tokens_in,
tokens_out=self._tokens_out,
latency_ms=42.0,
)
async def list_models(self) -> list[str]:
return ["mock-model"]
async def test_connection(self) -> bool:
return True
@pytest.fixture
def db_engine():
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
return engine
@pytest.fixture
def db(db_engine):
session = Session(db_engine)
yield session
session.close()
@pytest.fixture
def user(db):
u = User(username="testuser", password_hash="fakehash")
db.add(u)
db.commit()
db.refresh(u)
return u
@pytest.fixture
def project(db, user):
p = Project(name="Test Project", owner_id=user.id)
db.add(p)
db.commit()
db.refresh(p)
return p
@pytest.fixture
def experiment(db, project):
exp = Experiment(
project_id=project.id,
name="Test Experiment",
pipeline_stages=[
{"prompt_template": "Say hello", "model": "test-model", "params": {}}
],
parameter_space={"type": "grid", "params": {}},
status=ExperimentStatus.draft,
)
db.add(exp)
db.commit()
db.refresh(exp)
return exp
def _make_run(db: Session, experiment_id: uuid.UUID, config: dict | None = None) -> Run:
config = config or {
"prompt": "Say hello",
"model": "test-model",
"params": {},
"pipeline_stages": [
{"prompt_template": "Say hello", "model": "test-model", "params": {}}
],
}
config_hash = compute_config_hash(
config.get("prompt", ""), config.get("model", ""), config.get("params", {})
)
run = Run(
experiment_id=experiment_id,
config_hash=config_hash,
config=config,
status=RunStatus.pending,
)
db.add(run)
db.commit()
db.refresh(run)
return run
# ---------------------------------------------------------------------------
# SyncTaskResult tests
# ---------------------------------------------------------------------------
class TestSyncTaskResult:
def test_successful_result(self):
r = SyncTaskResult(result={"status": "ok"}, task_id="test-123")
assert r.id == "test-123"
assert r.status == "SUCCESS"
assert r.state == "SUCCESS"
assert r.ready() is True
assert r.successful() is True
assert r.failed() is False
assert r.get() == {"status": "ok"}
def test_failed_result(self):
err = ValueError("boom")
r = SyncTaskResult(error=err)
assert r.status == "FAILURE"
assert r.ready() is True
assert r.successful() is False
assert r.failed() is True
with pytest.raises(ValueError, match="boom"):
r.get()
def test_auto_generates_task_id(self):
r = SyncTaskResult(result="data")
assert r.id is not None
assert len(r.id) > 0
# ---------------------------------------------------------------------------
# _do_execute_run tests
# ---------------------------------------------------------------------------
class TestDoExecuteRun:
def test_executes_run_successfully(self, db, experiment):
run = _make_run(db, experiment.id)
adapter = MockAdapter()
cache = ResponseCacheLayer()
event_bus = EventBus()
# Patch the internals to use our fixtures
with patch("engine.tasks._get_db_session", return_value=db), \
patch("engine.tasks._get_adapter", return_value=adapter), \
patch("engine.tasks._get_cache", return_value=cache), \
patch("engine.tasks._get_event_bus", return_value=event_bus):
result = _do_execute_run(str(run.id))
assert result["run_id"] == str(run.id)
assert result["status"] == "completed"
assert result["tokens_in"] == 10
assert result["tokens_out"] == 5
assert result["duration_ms"] is not None
def test_raises_for_missing_run(self, db):
fake_id = str(uuid.uuid4())
with patch("engine.tasks._get_db_session", return_value=db), \
pytest.raises(ValueError, match="not found"):
_do_execute_run(fake_id)
def test_passes_endpoint_config(self, db, experiment):
run = _make_run(db, experiment.id)
adapter = MockAdapter()
cache = ResponseCacheLayer()
event_bus = EventBus()
endpoint_cfg = {"url": "http://test:8080/v1", "api_key": "sk-test"}
with patch("engine.tasks._get_db_session", return_value=db), \
patch("engine.tasks._get_adapter", return_value=adapter) as mock_get_adapter, \
patch("engine.tasks._get_cache", return_value=cache), \
patch("engine.tasks._get_event_bus", return_value=event_bus):
_do_execute_run(str(run.id), endpoint_config=endpoint_cfg)
mock_get_adapter.assert_called_once_with(endpoint_cfg)
# ---------------------------------------------------------------------------
# _do_execute_sweep tests
# ---------------------------------------------------------------------------
class TestDoExecuteSweep:
def test_executes_sweep_successfully(self, db, experiment):
adapter = MockAdapter()
cache = ResponseCacheLayer()
event_bus = EventBus()
with patch("engine.tasks._get_db_session", return_value=db), \
patch("engine.tasks._get_adapter", return_value=adapter), \
patch("engine.tasks._get_cache", return_value=cache), \
patch("engine.tasks._get_event_bus", return_value=event_bus), \
patch("engine.tasks._get_redis_client", return_value=None):
result = _do_execute_sweep(str(experiment.id))
assert result["experiment_id"] == str(experiment.id)
assert result["total_runs"] >= 1
assert result["completed_runs"] >= 1
assert result["stopped_reason"] == "completed"
def test_raises_for_missing_experiment(self, db):
fake_id = str(uuid.uuid4())
with patch("engine.tasks._get_db_session", return_value=db), \
pytest.raises(ValueError, match="not found"):
_do_execute_sweep(fake_id)
def test_override_sweep_config(self, db, experiment):
adapter = MockAdapter()
cache = ResponseCacheLayer()
event_bus = EventBus()
custom_sweep = {"type": "random", "params": {"model": ["a", "b"]}, "n_trials": 2}
with patch("engine.tasks._get_db_session", return_value=db), \
patch("engine.tasks._get_adapter", return_value=adapter), \
patch("engine.tasks._get_cache", return_value=cache), \
patch("engine.tasks._get_event_bus", return_value=event_bus), \
patch("engine.tasks._get_redis_client", return_value=None):
result = _do_execute_sweep(str(experiment.id), sweep_config=custom_sweep)
assert result["total_runs"] == 2
assert result["completed_runs"] == 2
# ---------------------------------------------------------------------------
# dispatch_run / dispatch_sweep tests (sync fallback)
# ---------------------------------------------------------------------------
class TestDispatchRunSync:
def test_sync_fallback_success(self, db, experiment):
run = _make_run(db, experiment.id)
adapter = MockAdapter()
cache = ResponseCacheLayer()
event_bus = EventBus()
with patch("engine.tasks.settings") as mock_settings, \
patch("engine.tasks._get_db_session", return_value=db), \
patch("engine.tasks._get_adapter", return_value=adapter), \
patch("engine.tasks._get_cache", return_value=cache), \
patch("engine.tasks._get_event_bus", return_value=event_bus):
mock_settings.use_in_process_queue = True
task_result = dispatch_run(str(run.id))
assert task_result.successful()
result = task_result.get()
assert result["status"] == "completed"
def test_sync_fallback_error(self, db):
fake_id = str(uuid.uuid4())
with patch("engine.tasks.settings") as mock_settings, \
patch("engine.tasks._get_db_session", return_value=db):
mock_settings.use_in_process_queue = True
task_result = dispatch_run(fake_id)
assert task_result.failed()
with pytest.raises(ValueError):
task_result.get()
class TestDispatchSweepSync:
def test_sync_fallback_success(self, db, experiment):
adapter = MockAdapter()
cache = ResponseCacheLayer()
event_bus = EventBus()
with patch("engine.tasks.settings") as mock_settings, \
patch("engine.tasks._get_db_session", return_value=db), \
patch("engine.tasks._get_adapter", return_value=adapter), \
patch("engine.tasks._get_cache", return_value=cache), \
patch("engine.tasks._get_event_bus", return_value=event_bus), \
patch("engine.tasks._get_redis_client", return_value=None):
mock_settings.use_in_process_queue = True
mock_settings.max_concurrent_runs = 2
mock_settings.max_tokens_per_sweep = 0
task_result = dispatch_sweep(str(experiment.id))
assert task_result.successful()
result = task_result.get()
assert result["completed_runs"] >= 1
def test_sync_fallback_error(self, db):
fake_id = str(uuid.uuid4())
with patch("engine.tasks.settings") as mock_settings, \
patch("engine.tasks._get_db_session", return_value=db):
mock_settings.use_in_process_queue = True
task_result = dispatch_sweep(fake_id)
assert task_result.failed()
with pytest.raises(ValueError):
task_result.get()
# ---------------------------------------------------------------------------
# Celery task registration tests
# ---------------------------------------------------------------------------
class TestCeleryTaskRegistration:
def test_execute_run_task_exists(self):
"""execute_run should be registered as a Celery task or None if Celery unavailable."""
from engine.tasks import execute_run
# In test environment, Celery may or may not be available
# If available, it should be a Celery task with the correct name
if execute_run is not None:
assert execute_run.name == "engine.execute_run"
def test_execute_sweep_task_exists(self):
from engine.tasks import execute_sweep
if execute_sweep is not None:
assert execute_sweep.name == "engine.execute_sweep"
# ---------------------------------------------------------------------------
# dispatch with Celery available
# ---------------------------------------------------------------------------
class TestDispatchWithCelery:
def test_dispatch_run_uses_celery_when_redis_available(self):
"""When use_in_process_queue is False and Celery tasks exist, should use .delay()."""
mock_task = MagicMock()
mock_task.delay.return_value = MagicMock(id="celery-task-id")
with patch("engine.tasks.settings") as mock_settings, \
patch("engine.tasks.execute_run", mock_task):
mock_settings.use_in_process_queue = False
result = dispatch_run("some-run-id", endpoint_config={"url": "http://x"})
mock_task.delay.assert_called_once_with("some-run-id", {"url": "http://x"})
def test_dispatch_sweep_uses_celery_when_redis_available(self):
mock_task = MagicMock()
mock_task.delay.return_value = MagicMock(id="celery-task-id")
with patch("engine.tasks.settings") as mock_settings, \
patch("engine.tasks.execute_sweep", mock_task):
mock_settings.use_in_process_queue = False
result = dispatch_sweep("some-exp-id", sweep_config={"type": "grid"})
mock_task.delay.assert_called_once_with("some-exp-id", {"type": "grid"}, None)