Created engine/tasks.py with: - execute_run and execute_sweep Celery tasks registered via autodiscover - SyncTaskResult class mimicking Celery AsyncResult for in-process mode - dispatch_run/dispatch_sweep helpers that route to Celery or sync based on config - Proper async-to-sync bridging for the async engine functions - 17 tests covering task execution, sync fallback, error handling, and Celery dispatch
387 lines
13 KiB
Python
387 lines
13 KiB
Python
"""Tests for engine/tasks.py — Celery task definitions and sync fallback."""
|
|
|
|
import asyncio
|
|
import sys
|
|
import uuid
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import Session
|
|
|
|
# Ensure backend is on sys.path
|
|
_backend_dir = str(Path(__file__).resolve().parents[1])
|
|
if _backend_dir not in sys.path:
|
|
sys.path.insert(0, _backend_dir)
|
|
|
|
from engine.adapters.base import AdapterResponse, BaseAdapter
|
|
from engine.cache import ResponseCacheLayer, compute_config_hash
|
|
from engine.runner import EventBus, run_single
|
|
from engine.tasks import (
|
|
SyncTaskResult,
|
|
_do_execute_run,
|
|
_do_execute_sweep,
|
|
dispatch_run,
|
|
dispatch_sweep,
|
|
)
|
|
from models import (
|
|
Base,
|
|
Experiment,
|
|
ExperimentStatus,
|
|
Project,
|
|
Run,
|
|
RunStatus,
|
|
User,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures / helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class MockAdapter(BaseAdapter):
|
|
"""Test adapter that returns fixed responses."""
|
|
|
|
def __init__(self, text: str = "mock output", tokens_in: int = 10, tokens_out: int = 5):
|
|
self._text = text
|
|
self._tokens_in = tokens_in
|
|
self._tokens_out = tokens_out
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return "mock"
|
|
|
|
async def complete(self, prompt: str, model: str, params: dict) -> AdapterResponse:
|
|
return AdapterResponse(
|
|
text=self._text,
|
|
tokens_in=self._tokens_in,
|
|
tokens_out=self._tokens_out,
|
|
latency_ms=42.0,
|
|
)
|
|
|
|
async def list_models(self) -> list[str]:
|
|
return ["mock-model"]
|
|
|
|
async def test_connection(self) -> bool:
|
|
return True
|
|
|
|
|
|
@pytest.fixture
|
|
def db_engine():
|
|
engine = create_engine("sqlite:///:memory:")
|
|
Base.metadata.create_all(engine)
|
|
return engine
|
|
|
|
|
|
@pytest.fixture
|
|
def db(db_engine):
|
|
session = Session(db_engine)
|
|
yield session
|
|
session.close()
|
|
|
|
|
|
@pytest.fixture
|
|
def user(db):
|
|
u = User(username="testuser", password_hash="fakehash")
|
|
db.add(u)
|
|
db.commit()
|
|
db.refresh(u)
|
|
return u
|
|
|
|
|
|
@pytest.fixture
|
|
def project(db, user):
|
|
p = Project(name="Test Project", owner_id=user.id)
|
|
db.add(p)
|
|
db.commit()
|
|
db.refresh(p)
|
|
return p
|
|
|
|
|
|
@pytest.fixture
|
|
def experiment(db, project):
|
|
exp = Experiment(
|
|
project_id=project.id,
|
|
name="Test Experiment",
|
|
pipeline_stages=[
|
|
{"prompt_template": "Say hello", "model": "test-model", "params": {}}
|
|
],
|
|
parameter_space={"type": "grid", "params": {}},
|
|
status=ExperimentStatus.draft,
|
|
)
|
|
db.add(exp)
|
|
db.commit()
|
|
db.refresh(exp)
|
|
return exp
|
|
|
|
|
|
def _make_run(db: Session, experiment_id: uuid.UUID, config: dict | None = None) -> Run:
|
|
config = config or {
|
|
"prompt": "Say hello",
|
|
"model": "test-model",
|
|
"params": {},
|
|
"pipeline_stages": [
|
|
{"prompt_template": "Say hello", "model": "test-model", "params": {}}
|
|
],
|
|
}
|
|
config_hash = compute_config_hash(
|
|
config.get("prompt", ""), config.get("model", ""), config.get("params", {})
|
|
)
|
|
run = Run(
|
|
experiment_id=experiment_id,
|
|
config_hash=config_hash,
|
|
config=config,
|
|
status=RunStatus.pending,
|
|
)
|
|
db.add(run)
|
|
db.commit()
|
|
db.refresh(run)
|
|
return run
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SyncTaskResult tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSyncTaskResult:
|
|
def test_successful_result(self):
|
|
r = SyncTaskResult(result={"status": "ok"}, task_id="test-123")
|
|
assert r.id == "test-123"
|
|
assert r.status == "SUCCESS"
|
|
assert r.state == "SUCCESS"
|
|
assert r.ready() is True
|
|
assert r.successful() is True
|
|
assert r.failed() is False
|
|
assert r.get() == {"status": "ok"}
|
|
|
|
def test_failed_result(self):
|
|
err = ValueError("boom")
|
|
r = SyncTaskResult(error=err)
|
|
assert r.status == "FAILURE"
|
|
assert r.ready() is True
|
|
assert r.successful() is False
|
|
assert r.failed() is True
|
|
with pytest.raises(ValueError, match="boom"):
|
|
r.get()
|
|
|
|
def test_auto_generates_task_id(self):
|
|
r = SyncTaskResult(result="data")
|
|
assert r.id is not None
|
|
assert len(r.id) > 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _do_execute_run tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestDoExecuteRun:
|
|
def test_executes_run_successfully(self, db, experiment):
|
|
run = _make_run(db, experiment.id)
|
|
adapter = MockAdapter()
|
|
cache = ResponseCacheLayer()
|
|
event_bus = EventBus()
|
|
|
|
# Patch the internals to use our fixtures
|
|
with patch("engine.tasks._get_db_session", return_value=db), \
|
|
patch("engine.tasks._get_adapter", return_value=adapter), \
|
|
patch("engine.tasks._get_cache", return_value=cache), \
|
|
patch("engine.tasks._get_event_bus", return_value=event_bus):
|
|
result = _do_execute_run(str(run.id))
|
|
|
|
assert result["run_id"] == str(run.id)
|
|
assert result["status"] == "completed"
|
|
assert result["tokens_in"] == 10
|
|
assert result["tokens_out"] == 5
|
|
assert result["duration_ms"] is not None
|
|
|
|
def test_raises_for_missing_run(self, db):
|
|
fake_id = str(uuid.uuid4())
|
|
with patch("engine.tasks._get_db_session", return_value=db), \
|
|
pytest.raises(ValueError, match="not found"):
|
|
_do_execute_run(fake_id)
|
|
|
|
def test_passes_endpoint_config(self, db, experiment):
|
|
run = _make_run(db, experiment.id)
|
|
adapter = MockAdapter()
|
|
cache = ResponseCacheLayer()
|
|
event_bus = EventBus()
|
|
endpoint_cfg = {"url": "http://test:8080/v1", "api_key": "sk-test"}
|
|
|
|
with patch("engine.tasks._get_db_session", return_value=db), \
|
|
patch("engine.tasks._get_adapter", return_value=adapter) as mock_get_adapter, \
|
|
patch("engine.tasks._get_cache", return_value=cache), \
|
|
patch("engine.tasks._get_event_bus", return_value=event_bus):
|
|
_do_execute_run(str(run.id), endpoint_config=endpoint_cfg)
|
|
mock_get_adapter.assert_called_once_with(endpoint_cfg)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _do_execute_sweep tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestDoExecuteSweep:
|
|
def test_executes_sweep_successfully(self, db, experiment):
|
|
adapter = MockAdapter()
|
|
cache = ResponseCacheLayer()
|
|
event_bus = EventBus()
|
|
|
|
with patch("engine.tasks._get_db_session", return_value=db), \
|
|
patch("engine.tasks._get_adapter", return_value=adapter), \
|
|
patch("engine.tasks._get_cache", return_value=cache), \
|
|
patch("engine.tasks._get_event_bus", return_value=event_bus), \
|
|
patch("engine.tasks._get_redis_client", return_value=None):
|
|
result = _do_execute_sweep(str(experiment.id))
|
|
|
|
assert result["experiment_id"] == str(experiment.id)
|
|
assert result["total_runs"] >= 1
|
|
assert result["completed_runs"] >= 1
|
|
assert result["stopped_reason"] == "completed"
|
|
|
|
def test_raises_for_missing_experiment(self, db):
|
|
fake_id = str(uuid.uuid4())
|
|
with patch("engine.tasks._get_db_session", return_value=db), \
|
|
pytest.raises(ValueError, match="not found"):
|
|
_do_execute_sweep(fake_id)
|
|
|
|
def test_override_sweep_config(self, db, experiment):
|
|
adapter = MockAdapter()
|
|
cache = ResponseCacheLayer()
|
|
event_bus = EventBus()
|
|
custom_sweep = {"type": "random", "params": {"model": ["a", "b"]}, "n_trials": 2}
|
|
|
|
with patch("engine.tasks._get_db_session", return_value=db), \
|
|
patch("engine.tasks._get_adapter", return_value=adapter), \
|
|
patch("engine.tasks._get_cache", return_value=cache), \
|
|
patch("engine.tasks._get_event_bus", return_value=event_bus), \
|
|
patch("engine.tasks._get_redis_client", return_value=None):
|
|
result = _do_execute_sweep(str(experiment.id), sweep_config=custom_sweep)
|
|
|
|
assert result["total_runs"] == 2
|
|
assert result["completed_runs"] == 2
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# dispatch_run / dispatch_sweep tests (sync fallback)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestDispatchRunSync:
|
|
def test_sync_fallback_success(self, db, experiment):
|
|
run = _make_run(db, experiment.id)
|
|
adapter = MockAdapter()
|
|
cache = ResponseCacheLayer()
|
|
event_bus = EventBus()
|
|
|
|
with patch("engine.tasks.settings") as mock_settings, \
|
|
patch("engine.tasks._get_db_session", return_value=db), \
|
|
patch("engine.tasks._get_adapter", return_value=adapter), \
|
|
patch("engine.tasks._get_cache", return_value=cache), \
|
|
patch("engine.tasks._get_event_bus", return_value=event_bus):
|
|
mock_settings.use_in_process_queue = True
|
|
task_result = dispatch_run(str(run.id))
|
|
|
|
assert task_result.successful()
|
|
result = task_result.get()
|
|
assert result["status"] == "completed"
|
|
|
|
def test_sync_fallback_error(self, db):
|
|
fake_id = str(uuid.uuid4())
|
|
|
|
with patch("engine.tasks.settings") as mock_settings, \
|
|
patch("engine.tasks._get_db_session", return_value=db):
|
|
mock_settings.use_in_process_queue = True
|
|
task_result = dispatch_run(fake_id)
|
|
|
|
assert task_result.failed()
|
|
with pytest.raises(ValueError):
|
|
task_result.get()
|
|
|
|
|
|
class TestDispatchSweepSync:
|
|
def test_sync_fallback_success(self, db, experiment):
|
|
adapter = MockAdapter()
|
|
cache = ResponseCacheLayer()
|
|
event_bus = EventBus()
|
|
|
|
with patch("engine.tasks.settings") as mock_settings, \
|
|
patch("engine.tasks._get_db_session", return_value=db), \
|
|
patch("engine.tasks._get_adapter", return_value=adapter), \
|
|
patch("engine.tasks._get_cache", return_value=cache), \
|
|
patch("engine.tasks._get_event_bus", return_value=event_bus), \
|
|
patch("engine.tasks._get_redis_client", return_value=None):
|
|
mock_settings.use_in_process_queue = True
|
|
mock_settings.max_concurrent_runs = 2
|
|
mock_settings.max_tokens_per_sweep = 0
|
|
task_result = dispatch_sweep(str(experiment.id))
|
|
|
|
assert task_result.successful()
|
|
result = task_result.get()
|
|
assert result["completed_runs"] >= 1
|
|
|
|
def test_sync_fallback_error(self, db):
|
|
fake_id = str(uuid.uuid4())
|
|
|
|
with patch("engine.tasks.settings") as mock_settings, \
|
|
patch("engine.tasks._get_db_session", return_value=db):
|
|
mock_settings.use_in_process_queue = True
|
|
task_result = dispatch_sweep(fake_id)
|
|
|
|
assert task_result.failed()
|
|
with pytest.raises(ValueError):
|
|
task_result.get()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Celery task registration tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCeleryTaskRegistration:
|
|
def test_execute_run_task_exists(self):
|
|
"""execute_run should be registered as a Celery task or None if Celery unavailable."""
|
|
from engine.tasks import execute_run
|
|
# In test environment, Celery may or may not be available
|
|
# If available, it should be a Celery task with the correct name
|
|
if execute_run is not None:
|
|
assert execute_run.name == "engine.execute_run"
|
|
|
|
def test_execute_sweep_task_exists(self):
|
|
from engine.tasks import execute_sweep
|
|
if execute_sweep is not None:
|
|
assert execute_sweep.name == "engine.execute_sweep"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# dispatch with Celery available
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestDispatchWithCelery:
|
|
def test_dispatch_run_uses_celery_when_redis_available(self):
|
|
"""When use_in_process_queue is False and Celery tasks exist, should use .delay()."""
|
|
mock_task = MagicMock()
|
|
mock_task.delay.return_value = MagicMock(id="celery-task-id")
|
|
|
|
with patch("engine.tasks.settings") as mock_settings, \
|
|
patch("engine.tasks.execute_run", mock_task):
|
|
mock_settings.use_in_process_queue = False
|
|
result = dispatch_run("some-run-id", endpoint_config={"url": "http://x"})
|
|
|
|
mock_task.delay.assert_called_once_with("some-run-id", {"url": "http://x"})
|
|
|
|
def test_dispatch_sweep_uses_celery_when_redis_available(self):
|
|
mock_task = MagicMock()
|
|
mock_task.delay.return_value = MagicMock(id="celery-task-id")
|
|
|
|
with patch("engine.tasks.settings") as mock_settings, \
|
|
patch("engine.tasks.execute_sweep", mock_task):
|
|
mock_settings.use_in_process_queue = False
|
|
result = dispatch_sweep("some-exp-id", sweep_config={"type": "grid"})
|
|
|
|
mock_task.delay.assert_called_once_with("some-exp-id", {"type": "grid"}, None)
|