MAESTRO: Implement Celery tasks (execute_run, execute_sweep) with synchronous fallback for single-container mode
Created engine/tasks.py with: - execute_run and execute_sweep Celery tasks registered via autodiscover - SyncTaskResult class mimicking Celery AsyncResult for in-process mode - dispatch_run/dispatch_sweep helpers that route to Celery or sync based on config - Proper async-to-sync bridging for the async engine functions - 17 tests covering task execution, sync fallback, error handling, and Celery dispatch
This commit is contained in:
parent
16c56b13f2
commit
b16454994e
3 changed files with 655 additions and 1 deletions
|
|
@ -29,7 +29,8 @@ Implement the core experiment execution engine: LLM adapters, response caching,
|
|||
- [x] Implement backend/engine/scorers/llm_judge.py — sends the output to a separate LLM with a configurable judge prompt and asks for a 1-10 rating. Parses the numeric score from the response. This scorer requires an LLM call so it should be clearly marked as "costs tokens" in the UI. Cache the judge's response too.
|
||||
<!-- Completed: LLMJudgeScorer with configurable judge prompt, 1-10 rating parsing via regex, normalized to 0.0-1.0. COSTS_TOKENS class marker for UI. Optional ResponseCacheLayer integration for caching judge responses. Retries with exponential backoff. 36 tests in test_scorer_llm_judge.py, all passing. -->
|
||||
|
||||
- [ ] Wire up the Celery worker in backend/worker.py. Define tasks: execute_run(run_id), execute_sweep(experiment_id, sweep_config). Configure Celery to use Redis as broker. In single-container mode (no Redis), implement a simple synchronous fallback that runs tasks in-process.
|
||||
- [x] Wire up the Celery worker in backend/worker.py. Define tasks: execute_run(run_id), execute_sweep(experiment_id, sweep_config). Configure Celery to use Redis as broker. In single-container mode (no Redis), implement a simple synchronous fallback that runs tasks in-process.
|
||||
<!-- Completed: Created engine/tasks.py with execute_run and execute_sweep Celery tasks (autodiscovered via worker.py). SyncTaskResult class mimics AsyncResult for fallback. dispatch_run/dispatch_sweep helpers route to Celery or sync execution based on settings.use_in_process_queue. 17 tests in test_tasks.py, all passing. -->
|
||||
|
||||
- [ ] Implement backend/routers/endpoints.py fully — CRUD for LLM endpoint configurations. The test endpoint should call adapter.test_connection() and adapter.list_models() and return the results. Store endpoint configs in the database with encrypted API keys (Fernet symmetric encryption, key derived from JWT_SECRET).
|
||||
|
||||
|
|
|
|||
266
backend/engine/tasks.py
Normal file
266
backend/engine/tasks.py
Normal file
|
|
@ -0,0 +1,266 @@
|
|||
"""Celery tasks for PromptLooper experiment execution.
|
||||
|
||||
Defines execute_run and execute_sweep tasks that are dispatched via
|
||||
Celery (Redis broker) or run synchronously in single-container mode.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_db_session() -> Session:
|
||||
"""Create a database session for task execution."""
|
||||
from main import SessionLocal
|
||||
if SessionLocal is None:
|
||||
from main import _init_db
|
||||
_init_db()
|
||||
from main import SessionLocal as SL
|
||||
return SL()
|
||||
return SessionLocal()
|
||||
|
||||
|
||||
def _get_redis_client():
|
||||
"""Get the Redis client (or None)."""
|
||||
from main import get_redis
|
||||
return get_redis()
|
||||
|
||||
|
||||
def _get_adapter(endpoint_config: dict[str, Any] | None = None):
|
||||
"""Create an LLM adapter from config or defaults."""
|
||||
from engine.adapters.openai_compat import OpenAICompatAdapter
|
||||
|
||||
if endpoint_config:
|
||||
return OpenAICompatAdapter(
|
||||
base_url=endpoint_config.get("url", settings.default_endpoint_url or ""),
|
||||
api_key=endpoint_config.get("api_key", settings.default_endpoint_key),
|
||||
)
|
||||
|
||||
return OpenAICompatAdapter(
|
||||
base_url=settings.default_endpoint_url or "http://localhost:11434/v1",
|
||||
api_key=settings.default_endpoint_key,
|
||||
)
|
||||
|
||||
|
||||
def _get_event_bus():
|
||||
"""Create an EventBus with Redis if available."""
|
||||
from engine.runner import EventBus
|
||||
redis_client = _get_redis_client()
|
||||
return EventBus(redis_client=redis_client)
|
||||
|
||||
|
||||
def _get_cache():
|
||||
"""Create a ResponseCacheLayer."""
|
||||
from engine.cache import ResponseCacheLayer
|
||||
return ResponseCacheLayer()
|
||||
|
||||
|
||||
def _run_async(coro):
|
||||
"""Run an async coroutine, creating an event loop if needed."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop and loop.is_running():
|
||||
# Already in an async context — create a new loop in a thread
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
return pool.submit(asyncio.run, coro).result()
|
||||
else:
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _do_execute_run(run_id: str, endpoint_config: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
"""Core logic for executing a single run (used by both Celery and sync paths)."""
|
||||
from models import Run
|
||||
|
||||
db = _get_db_session()
|
||||
try:
|
||||
run = db.get(Run, uuid.UUID(run_id))
|
||||
if run is None:
|
||||
raise ValueError(f"Run {run_id} not found")
|
||||
|
||||
adapter = _get_adapter(endpoint_config)
|
||||
cache = _get_cache()
|
||||
event_bus = _get_event_bus()
|
||||
|
||||
result = _run_async(
|
||||
run_single_import()(db, run, adapter, cache, event_bus=event_bus)
|
||||
)
|
||||
|
||||
return {
|
||||
"run_id": str(result.id),
|
||||
"status": result.status.value,
|
||||
"duration_ms": result.duration_ms,
|
||||
"tokens_in": result.tokens_in,
|
||||
"tokens_out": result.tokens_out,
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def run_single_import():
|
||||
"""Lazy import of run_single to avoid circular imports."""
|
||||
from engine.runner import run_single
|
||||
return run_single
|
||||
|
||||
|
||||
def run_sweep_import():
|
||||
"""Lazy import of run_sweep to avoid circular imports."""
|
||||
from engine.sweep import run_sweep
|
||||
return run_sweep
|
||||
|
||||
|
||||
def _do_execute_sweep(
|
||||
experiment_id: str,
|
||||
sweep_config: dict[str, Any] | None = None,
|
||||
endpoint_config: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Core logic for executing a sweep (used by both Celery and sync paths)."""
|
||||
from models import Experiment
|
||||
|
||||
db = _get_db_session()
|
||||
try:
|
||||
experiment = db.get(Experiment, uuid.UUID(experiment_id))
|
||||
if experiment is None:
|
||||
raise ValueError(f"Experiment {experiment_id} not found")
|
||||
|
||||
# Override parameter_space if sweep_config provided
|
||||
if sweep_config:
|
||||
experiment.parameter_space = sweep_config
|
||||
db.commit()
|
||||
|
||||
adapter = _get_adapter(endpoint_config)
|
||||
cache = _get_cache()
|
||||
event_bus = _get_event_bus()
|
||||
redis_client = _get_redis_client()
|
||||
|
||||
result = _run_async(
|
||||
run_sweep_import()(
|
||||
db=db,
|
||||
experiment=experiment,
|
||||
adapter=adapter,
|
||||
cache=cache,
|
||||
event_bus=event_bus,
|
||||
redis_client=redis_client,
|
||||
max_concurrent=settings.max_concurrent_runs,
|
||||
max_tokens=settings.max_tokens_per_sweep,
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"experiment_id": str(result.experiment_id),
|
||||
"total_runs": result.total_runs,
|
||||
"completed_runs": result.completed_runs,
|
||||
"failed_runs": result.failed_runs,
|
||||
"total_tokens_in": result.total_tokens_in,
|
||||
"total_tokens_out": result.total_tokens_out,
|
||||
"stopped_reason": result.stopped_reason,
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Celery tasks (registered via autodiscover)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
try:
|
||||
from worker import celery_app
|
||||
|
||||
@celery_app.task(name="engine.execute_run", bind=True, max_retries=0)
|
||||
def execute_run(self, run_id: str, endpoint_config: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
"""Celery task: execute a single Run by ID."""
|
||||
logger.info("Celery task execute_run started: run_id=%s", run_id)
|
||||
return _do_execute_run(run_id, endpoint_config)
|
||||
|
||||
@celery_app.task(name="engine.execute_sweep", bind=True, max_retries=0)
|
||||
def execute_sweep(
|
||||
self,
|
||||
experiment_id: str,
|
||||
sweep_config: dict[str, Any] | None = None,
|
||||
endpoint_config: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Celery task: execute a full sweep for an experiment."""
|
||||
logger.info("Celery task execute_sweep started: experiment_id=%s", experiment_id)
|
||||
return _do_execute_sweep(experiment_id, sweep_config, endpoint_config)
|
||||
|
||||
except ImportError:
|
||||
# Celery not available — tasks will only be usable via synchronous fallback
|
||||
execute_run = None
|
||||
execute_sweep = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Synchronous fallback for single-container mode (no Redis/Celery)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SyncTaskResult:
|
||||
"""Mimics Celery AsyncResult for the synchronous fallback."""
|
||||
|
||||
def __init__(self, result: Any = None, task_id: str | None = None, error: Exception | None = None):
|
||||
self.result = result
|
||||
self.id = task_id or str(uuid.uuid4())
|
||||
self.status = "SUCCESS" if error is None else "FAILURE"
|
||||
self._error = error
|
||||
|
||||
def get(self, timeout: float | None = None) -> Any:
|
||||
if self._error:
|
||||
raise self._error
|
||||
return self.result
|
||||
|
||||
@property
|
||||
def state(self) -> str:
|
||||
return self.status
|
||||
|
||||
def ready(self) -> bool:
|
||||
return True
|
||||
|
||||
def successful(self) -> bool:
|
||||
return self.status == "SUCCESS"
|
||||
|
||||
def failed(self) -> bool:
|
||||
return self.status == "FAILURE"
|
||||
|
||||
|
||||
def dispatch_run(run_id: str, endpoint_config: dict[str, Any] | None = None) -> Any:
|
||||
"""Dispatch a run execution — Celery if available, synchronous otherwise."""
|
||||
if not settings.use_in_process_queue and execute_run is not None:
|
||||
return execute_run.delay(run_id, endpoint_config)
|
||||
|
||||
# Synchronous fallback
|
||||
logger.info("Sync fallback: executing run %s in-process", run_id)
|
||||
try:
|
||||
result = _do_execute_run(run_id, endpoint_config)
|
||||
return SyncTaskResult(result=result)
|
||||
except Exception as exc:
|
||||
logger.error("Sync run %s failed: %s", run_id, exc, exc_info=True)
|
||||
return SyncTaskResult(error=exc)
|
||||
|
||||
|
||||
def dispatch_sweep(
|
||||
experiment_id: str,
|
||||
sweep_config: dict[str, Any] | None = None,
|
||||
endpoint_config: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
"""Dispatch a sweep execution — Celery if available, synchronous otherwise."""
|
||||
if not settings.use_in_process_queue and execute_sweep is not None:
|
||||
return execute_sweep.delay(experiment_id, sweep_config, endpoint_config)
|
||||
|
||||
# Synchronous fallback
|
||||
logger.info("Sync fallback: executing sweep for experiment %s in-process", experiment_id)
|
||||
try:
|
||||
result = _do_execute_sweep(experiment_id, sweep_config, endpoint_config)
|
||||
return SyncTaskResult(result=result)
|
||||
except Exception as exc:
|
||||
logger.error("Sync sweep %s failed: %s", experiment_id, exc, exc_info=True)
|
||||
return SyncTaskResult(error=exc)
|
||||
387
backend/tests/test_tasks.py
Normal file
387
backend/tests/test_tasks.py
Normal file
|
|
@ -0,0 +1,387 @@
|
|||
"""Tests for engine/tasks.py — Celery task definitions and sync fallback."""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Ensure backend is on sys.path
|
||||
_backend_dir = str(Path(__file__).resolve().parents[1])
|
||||
if _backend_dir not in sys.path:
|
||||
sys.path.insert(0, _backend_dir)
|
||||
|
||||
from engine.adapters.base import AdapterResponse, BaseAdapter
|
||||
from engine.cache import ResponseCacheLayer, compute_config_hash
|
||||
from engine.runner import EventBus, run_single
|
||||
from engine.tasks import (
|
||||
SyncTaskResult,
|
||||
_do_execute_run,
|
||||
_do_execute_sweep,
|
||||
dispatch_run,
|
||||
dispatch_sweep,
|
||||
)
|
||||
from models import (
|
||||
Base,
|
||||
Experiment,
|
||||
ExperimentStatus,
|
||||
Project,
|
||||
Run,
|
||||
RunStatus,
|
||||
User,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures / helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockAdapter(BaseAdapter):
|
||||
"""Test adapter that returns fixed responses."""
|
||||
|
||||
def __init__(self, text: str = "mock output", tokens_in: int = 10, tokens_out: int = 5):
|
||||
self._text = text
|
||||
self._tokens_in = tokens_in
|
||||
self._tokens_out = tokens_out
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "mock"
|
||||
|
||||
async def complete(self, prompt: str, model: str, params: dict) -> AdapterResponse:
|
||||
return AdapterResponse(
|
||||
text=self._text,
|
||||
tokens_in=self._tokens_in,
|
||||
tokens_out=self._tokens_out,
|
||||
latency_ms=42.0,
|
||||
)
|
||||
|
||||
async def list_models(self) -> list[str]:
|
||||
return ["mock-model"]
|
||||
|
||||
async def test_connection(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_engine():
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
Base.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db(db_engine):
|
||||
session = Session(db_engine)
|
||||
yield session
|
||||
session.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user(db):
|
||||
u = User(username="testuser", password_hash="fakehash")
|
||||
db.add(u)
|
||||
db.commit()
|
||||
db.refresh(u)
|
||||
return u
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def project(db, user):
|
||||
p = Project(name="Test Project", owner_id=user.id)
|
||||
db.add(p)
|
||||
db.commit()
|
||||
db.refresh(p)
|
||||
return p
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def experiment(db, project):
|
||||
exp = Experiment(
|
||||
project_id=project.id,
|
||||
name="Test Experiment",
|
||||
pipeline_stages=[
|
||||
{"prompt_template": "Say hello", "model": "test-model", "params": {}}
|
||||
],
|
||||
parameter_space={"type": "grid", "params": {}},
|
||||
status=ExperimentStatus.draft,
|
||||
)
|
||||
db.add(exp)
|
||||
db.commit()
|
||||
db.refresh(exp)
|
||||
return exp
|
||||
|
||||
|
||||
def _make_run(db: Session, experiment_id: uuid.UUID, config: dict | None = None) -> Run:
|
||||
config = config or {
|
||||
"prompt": "Say hello",
|
||||
"model": "test-model",
|
||||
"params": {},
|
||||
"pipeline_stages": [
|
||||
{"prompt_template": "Say hello", "model": "test-model", "params": {}}
|
||||
],
|
||||
}
|
||||
config_hash = compute_config_hash(
|
||||
config.get("prompt", ""), config.get("model", ""), config.get("params", {})
|
||||
)
|
||||
run = Run(
|
||||
experiment_id=experiment_id,
|
||||
config_hash=config_hash,
|
||||
config=config,
|
||||
status=RunStatus.pending,
|
||||
)
|
||||
db.add(run)
|
||||
db.commit()
|
||||
db.refresh(run)
|
||||
return run
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SyncTaskResult tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSyncTaskResult:
|
||||
def test_successful_result(self):
|
||||
r = SyncTaskResult(result={"status": "ok"}, task_id="test-123")
|
||||
assert r.id == "test-123"
|
||||
assert r.status == "SUCCESS"
|
||||
assert r.state == "SUCCESS"
|
||||
assert r.ready() is True
|
||||
assert r.successful() is True
|
||||
assert r.failed() is False
|
||||
assert r.get() == {"status": "ok"}
|
||||
|
||||
def test_failed_result(self):
|
||||
err = ValueError("boom")
|
||||
r = SyncTaskResult(error=err)
|
||||
assert r.status == "FAILURE"
|
||||
assert r.ready() is True
|
||||
assert r.successful() is False
|
||||
assert r.failed() is True
|
||||
with pytest.raises(ValueError, match="boom"):
|
||||
r.get()
|
||||
|
||||
def test_auto_generates_task_id(self):
|
||||
r = SyncTaskResult(result="data")
|
||||
assert r.id is not None
|
||||
assert len(r.id) > 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _do_execute_run tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDoExecuteRun:
|
||||
def test_executes_run_successfully(self, db, experiment):
|
||||
run = _make_run(db, experiment.id)
|
||||
adapter = MockAdapter()
|
||||
cache = ResponseCacheLayer()
|
||||
event_bus = EventBus()
|
||||
|
||||
# Patch the internals to use our fixtures
|
||||
with patch("engine.tasks._get_db_session", return_value=db), \
|
||||
patch("engine.tasks._get_adapter", return_value=adapter), \
|
||||
patch("engine.tasks._get_cache", return_value=cache), \
|
||||
patch("engine.tasks._get_event_bus", return_value=event_bus):
|
||||
result = _do_execute_run(str(run.id))
|
||||
|
||||
assert result["run_id"] == str(run.id)
|
||||
assert result["status"] == "completed"
|
||||
assert result["tokens_in"] == 10
|
||||
assert result["tokens_out"] == 5
|
||||
assert result["duration_ms"] is not None
|
||||
|
||||
def test_raises_for_missing_run(self, db):
|
||||
fake_id = str(uuid.uuid4())
|
||||
with patch("engine.tasks._get_db_session", return_value=db), \
|
||||
pytest.raises(ValueError, match="not found"):
|
||||
_do_execute_run(fake_id)
|
||||
|
||||
def test_passes_endpoint_config(self, db, experiment):
|
||||
run = _make_run(db, experiment.id)
|
||||
adapter = MockAdapter()
|
||||
cache = ResponseCacheLayer()
|
||||
event_bus = EventBus()
|
||||
endpoint_cfg = {"url": "http://test:8080/v1", "api_key": "sk-test"}
|
||||
|
||||
with patch("engine.tasks._get_db_session", return_value=db), \
|
||||
patch("engine.tasks._get_adapter", return_value=adapter) as mock_get_adapter, \
|
||||
patch("engine.tasks._get_cache", return_value=cache), \
|
||||
patch("engine.tasks._get_event_bus", return_value=event_bus):
|
||||
_do_execute_run(str(run.id), endpoint_config=endpoint_cfg)
|
||||
mock_get_adapter.assert_called_once_with(endpoint_cfg)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _do_execute_sweep tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDoExecuteSweep:
|
||||
def test_executes_sweep_successfully(self, db, experiment):
|
||||
adapter = MockAdapter()
|
||||
cache = ResponseCacheLayer()
|
||||
event_bus = EventBus()
|
||||
|
||||
with patch("engine.tasks._get_db_session", return_value=db), \
|
||||
patch("engine.tasks._get_adapter", return_value=adapter), \
|
||||
patch("engine.tasks._get_cache", return_value=cache), \
|
||||
patch("engine.tasks._get_event_bus", return_value=event_bus), \
|
||||
patch("engine.tasks._get_redis_client", return_value=None):
|
||||
result = _do_execute_sweep(str(experiment.id))
|
||||
|
||||
assert result["experiment_id"] == str(experiment.id)
|
||||
assert result["total_runs"] >= 1
|
||||
assert result["completed_runs"] >= 1
|
||||
assert result["stopped_reason"] == "completed"
|
||||
|
||||
def test_raises_for_missing_experiment(self, db):
|
||||
fake_id = str(uuid.uuid4())
|
||||
with patch("engine.tasks._get_db_session", return_value=db), \
|
||||
pytest.raises(ValueError, match="not found"):
|
||||
_do_execute_sweep(fake_id)
|
||||
|
||||
def test_override_sweep_config(self, db, experiment):
|
||||
adapter = MockAdapter()
|
||||
cache = ResponseCacheLayer()
|
||||
event_bus = EventBus()
|
||||
custom_sweep = {"type": "random", "params": {"model": ["a", "b"]}, "n_trials": 2}
|
||||
|
||||
with patch("engine.tasks._get_db_session", return_value=db), \
|
||||
patch("engine.tasks._get_adapter", return_value=adapter), \
|
||||
patch("engine.tasks._get_cache", return_value=cache), \
|
||||
patch("engine.tasks._get_event_bus", return_value=event_bus), \
|
||||
patch("engine.tasks._get_redis_client", return_value=None):
|
||||
result = _do_execute_sweep(str(experiment.id), sweep_config=custom_sweep)
|
||||
|
||||
assert result["total_runs"] == 2
|
||||
assert result["completed_runs"] == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# dispatch_run / dispatch_sweep tests (sync fallback)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDispatchRunSync:
|
||||
def test_sync_fallback_success(self, db, experiment):
|
||||
run = _make_run(db, experiment.id)
|
||||
adapter = MockAdapter()
|
||||
cache = ResponseCacheLayer()
|
||||
event_bus = EventBus()
|
||||
|
||||
with patch("engine.tasks.settings") as mock_settings, \
|
||||
patch("engine.tasks._get_db_session", return_value=db), \
|
||||
patch("engine.tasks._get_adapter", return_value=adapter), \
|
||||
patch("engine.tasks._get_cache", return_value=cache), \
|
||||
patch("engine.tasks._get_event_bus", return_value=event_bus):
|
||||
mock_settings.use_in_process_queue = True
|
||||
task_result = dispatch_run(str(run.id))
|
||||
|
||||
assert task_result.successful()
|
||||
result = task_result.get()
|
||||
assert result["status"] == "completed"
|
||||
|
||||
def test_sync_fallback_error(self, db):
|
||||
fake_id = str(uuid.uuid4())
|
||||
|
||||
with patch("engine.tasks.settings") as mock_settings, \
|
||||
patch("engine.tasks._get_db_session", return_value=db):
|
||||
mock_settings.use_in_process_queue = True
|
||||
task_result = dispatch_run(fake_id)
|
||||
|
||||
assert task_result.failed()
|
||||
with pytest.raises(ValueError):
|
||||
task_result.get()
|
||||
|
||||
|
||||
class TestDispatchSweepSync:
|
||||
def test_sync_fallback_success(self, db, experiment):
|
||||
adapter = MockAdapter()
|
||||
cache = ResponseCacheLayer()
|
||||
event_bus = EventBus()
|
||||
|
||||
with patch("engine.tasks.settings") as mock_settings, \
|
||||
patch("engine.tasks._get_db_session", return_value=db), \
|
||||
patch("engine.tasks._get_adapter", return_value=adapter), \
|
||||
patch("engine.tasks._get_cache", return_value=cache), \
|
||||
patch("engine.tasks._get_event_bus", return_value=event_bus), \
|
||||
patch("engine.tasks._get_redis_client", return_value=None):
|
||||
mock_settings.use_in_process_queue = True
|
||||
mock_settings.max_concurrent_runs = 2
|
||||
mock_settings.max_tokens_per_sweep = 0
|
||||
task_result = dispatch_sweep(str(experiment.id))
|
||||
|
||||
assert task_result.successful()
|
||||
result = task_result.get()
|
||||
assert result["completed_runs"] >= 1
|
||||
|
||||
def test_sync_fallback_error(self, db):
|
||||
fake_id = str(uuid.uuid4())
|
||||
|
||||
with patch("engine.tasks.settings") as mock_settings, \
|
||||
patch("engine.tasks._get_db_session", return_value=db):
|
||||
mock_settings.use_in_process_queue = True
|
||||
task_result = dispatch_sweep(fake_id)
|
||||
|
||||
assert task_result.failed()
|
||||
with pytest.raises(ValueError):
|
||||
task_result.get()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Celery task registration tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCeleryTaskRegistration:
|
||||
def test_execute_run_task_exists(self):
|
||||
"""execute_run should be registered as a Celery task or None if Celery unavailable."""
|
||||
from engine.tasks import execute_run
|
||||
# In test environment, Celery may or may not be available
|
||||
# If available, it should be a Celery task with the correct name
|
||||
if execute_run is not None:
|
||||
assert execute_run.name == "engine.execute_run"
|
||||
|
||||
def test_execute_sweep_task_exists(self):
|
||||
from engine.tasks import execute_sweep
|
||||
if execute_sweep is not None:
|
||||
assert execute_sweep.name == "engine.execute_sweep"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# dispatch with Celery available
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDispatchWithCelery:
|
||||
def test_dispatch_run_uses_celery_when_redis_available(self):
|
||||
"""When use_in_process_queue is False and Celery tasks exist, should use .delay()."""
|
||||
mock_task = MagicMock()
|
||||
mock_task.delay.return_value = MagicMock(id="celery-task-id")
|
||||
|
||||
with patch("engine.tasks.settings") as mock_settings, \
|
||||
patch("engine.tasks.execute_run", mock_task):
|
||||
mock_settings.use_in_process_queue = False
|
||||
result = dispatch_run("some-run-id", endpoint_config={"url": "http://x"})
|
||||
|
||||
mock_task.delay.assert_called_once_with("some-run-id", {"url": "http://x"})
|
||||
|
||||
def test_dispatch_sweep_uses_celery_when_redis_available(self):
|
||||
mock_task = MagicMock()
|
||||
mock_task.delay.return_value = MagicMock(id="celery-task-id")
|
||||
|
||||
with patch("engine.tasks.settings") as mock_settings, \
|
||||
patch("engine.tasks.execute_sweep", mock_task):
|
||||
mock_settings.use_in_process_queue = False
|
||||
result = dispatch_sweep("some-exp-id", sweep_config={"type": "grid"})
|
||||
|
||||
mock_task.delay.assert_called_once_with("some-exp-id", {"type": "grid"}, None)
|
||||
Loading…
Add table
Reference in a new issue