MAESTRO: Implement Celery tasks (execute_run, execute_sweep) with synchronous fallback for single-container mode

Created engine/tasks.py with:
- execute_run and execute_sweep Celery tasks registered via autodiscover
- SyncTaskResult class mimicking Celery AsyncResult for in-process mode
- dispatch_run/dispatch_sweep helpers that route to Celery or sync based on config
- Proper async-to-sync bridging for the async engine functions
- 17 tests covering task execution, sync fallback, error handling, and Celery dispatch
This commit is contained in:
John Lightner 2026-04-07 03:08:41 -05:00
parent 16c56b13f2
commit b16454994e
3 changed files with 655 additions and 1 deletions

View file

@ -29,7 +29,8 @@ Implement the core experiment execution engine: LLM adapters, response caching,
- [x] Implement backend/engine/scorers/llm_judge.py — sends the output to a separate LLM with a configurable judge prompt and asks for a 1-10 rating. Parses the numeric score from the response. This scorer requires an LLM call so it should be clearly marked as "costs tokens" in the UI. Cache the judge's response too. - [x] Implement backend/engine/scorers/llm_judge.py — sends the output to a separate LLM with a configurable judge prompt and asks for a 1-10 rating. Parses the numeric score from the response. This scorer requires an LLM call so it should be clearly marked as "costs tokens" in the UI. Cache the judge's response too.
<!-- Completed: LLMJudgeScorer with configurable judge prompt, 1-10 rating parsing via regex, normalized to 0.0-1.0. COSTS_TOKENS class marker for UI. Optional ResponseCacheLayer integration for caching judge responses. Retries with exponential backoff. 36 tests in test_scorer_llm_judge.py, all passing. --> <!-- Completed: LLMJudgeScorer with configurable judge prompt, 1-10 rating parsing via regex, normalized to 0.0-1.0. COSTS_TOKENS class marker for UI. Optional ResponseCacheLayer integration for caching judge responses. Retries with exponential backoff. 36 tests in test_scorer_llm_judge.py, all passing. -->
- [ ] Wire up the Celery worker in backend/worker.py. Define tasks: execute_run(run_id), execute_sweep(experiment_id, sweep_config). Configure Celery to use Redis as broker. In single-container mode (no Redis), implement a simple synchronous fallback that runs tasks in-process. - [x] Wire up the Celery worker in backend/worker.py. Define tasks: execute_run(run_id), execute_sweep(experiment_id, sweep_config). Configure Celery to use Redis as broker. In single-container mode (no Redis), implement a simple synchronous fallback that runs tasks in-process.
<!-- Completed: Created engine/tasks.py with execute_run and execute_sweep Celery tasks (autodiscovered via worker.py). SyncTaskResult class mimics AsyncResult for fallback. dispatch_run/dispatch_sweep helpers route to Celery or sync execution based on settings.use_in_process_queue. 17 tests in test_tasks.py, all passing. -->
- [ ] Implement backend/routers/endpoints.py fully — CRUD for LLM endpoint configurations. The test endpoint should call adapter.test_connection() and adapter.list_models() and return the results. Store endpoint configs in the database with encrypted API keys (Fernet symmetric encryption, key derived from JWT_SECRET). - [ ] Implement backend/routers/endpoints.py fully — CRUD for LLM endpoint configurations. The test endpoint should call adapter.test_connection() and adapter.list_models() and return the results. Store endpoint configs in the database with encrypted API keys (Fernet symmetric encryption, key derived from JWT_SECRET).

266
backend/engine/tasks.py Normal file
View file

@ -0,0 +1,266 @@
"""Celery tasks for PromptLooper experiment execution.
Defines execute_run and execute_sweep tasks that are dispatched via
Celery (Redis broker) or run synchronously in single-container mode.
"""
import asyncio
import logging
import uuid
from typing import Any
from sqlalchemy.orm import Session
from config import settings
logger = logging.getLogger(__name__)
def _get_db_session() -> Session:
"""Create a database session for task execution."""
from main import SessionLocal
if SessionLocal is None:
from main import _init_db
_init_db()
from main import SessionLocal as SL
return SL()
return SessionLocal()
def _get_redis_client():
"""Get the Redis client (or None)."""
from main import get_redis
return get_redis()
def _get_adapter(endpoint_config: dict[str, Any] | None = None):
"""Create an LLM adapter from config or defaults."""
from engine.adapters.openai_compat import OpenAICompatAdapter
if endpoint_config:
return OpenAICompatAdapter(
base_url=endpoint_config.get("url", settings.default_endpoint_url or ""),
api_key=endpoint_config.get("api_key", settings.default_endpoint_key),
)
return OpenAICompatAdapter(
base_url=settings.default_endpoint_url or "http://localhost:11434/v1",
api_key=settings.default_endpoint_key,
)
def _get_event_bus():
"""Create an EventBus with Redis if available."""
from engine.runner import EventBus
redis_client = _get_redis_client()
return EventBus(redis_client=redis_client)
def _get_cache():
"""Create a ResponseCacheLayer."""
from engine.cache import ResponseCacheLayer
return ResponseCacheLayer()
def _run_async(coro):
"""Run an async coroutine, creating an event loop if needed."""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
# Already in an async context — create a new loop in a thread
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
return pool.submit(asyncio.run, coro).result()
else:
return asyncio.run(coro)
def _do_execute_run(run_id: str, endpoint_config: dict[str, Any] | None = None) -> dict[str, Any]:
"""Core logic for executing a single run (used by both Celery and sync paths)."""
from models import Run
db = _get_db_session()
try:
run = db.get(Run, uuid.UUID(run_id))
if run is None:
raise ValueError(f"Run {run_id} not found")
adapter = _get_adapter(endpoint_config)
cache = _get_cache()
event_bus = _get_event_bus()
result = _run_async(
run_single_import()(db, run, adapter, cache, event_bus=event_bus)
)
return {
"run_id": str(result.id),
"status": result.status.value,
"duration_ms": result.duration_ms,
"tokens_in": result.tokens_in,
"tokens_out": result.tokens_out,
}
finally:
db.close()
def run_single_import():
"""Lazy import of run_single to avoid circular imports."""
from engine.runner import run_single
return run_single
def run_sweep_import():
"""Lazy import of run_sweep to avoid circular imports."""
from engine.sweep import run_sweep
return run_sweep
def _do_execute_sweep(
experiment_id: str,
sweep_config: dict[str, Any] | None = None,
endpoint_config: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Core logic for executing a sweep (used by both Celery and sync paths)."""
from models import Experiment
db = _get_db_session()
try:
experiment = db.get(Experiment, uuid.UUID(experiment_id))
if experiment is None:
raise ValueError(f"Experiment {experiment_id} not found")
# Override parameter_space if sweep_config provided
if sweep_config:
experiment.parameter_space = sweep_config
db.commit()
adapter = _get_adapter(endpoint_config)
cache = _get_cache()
event_bus = _get_event_bus()
redis_client = _get_redis_client()
result = _run_async(
run_sweep_import()(
db=db,
experiment=experiment,
adapter=adapter,
cache=cache,
event_bus=event_bus,
redis_client=redis_client,
max_concurrent=settings.max_concurrent_runs,
max_tokens=settings.max_tokens_per_sweep,
)
)
return {
"experiment_id": str(result.experiment_id),
"total_runs": result.total_runs,
"completed_runs": result.completed_runs,
"failed_runs": result.failed_runs,
"total_tokens_in": result.total_tokens_in,
"total_tokens_out": result.total_tokens_out,
"stopped_reason": result.stopped_reason,
}
finally:
db.close()
# ---------------------------------------------------------------------------
# Celery tasks (registered via autodiscover)
# ---------------------------------------------------------------------------
try:
from worker import celery_app
@celery_app.task(name="engine.execute_run", bind=True, max_retries=0)
def execute_run(self, run_id: str, endpoint_config: dict[str, Any] | None = None) -> dict[str, Any]:
"""Celery task: execute a single Run by ID."""
logger.info("Celery task execute_run started: run_id=%s", run_id)
return _do_execute_run(run_id, endpoint_config)
@celery_app.task(name="engine.execute_sweep", bind=True, max_retries=0)
def execute_sweep(
self,
experiment_id: str,
sweep_config: dict[str, Any] | None = None,
endpoint_config: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Celery task: execute a full sweep for an experiment."""
logger.info("Celery task execute_sweep started: experiment_id=%s", experiment_id)
return _do_execute_sweep(experiment_id, sweep_config, endpoint_config)
except ImportError:
# Celery not available — tasks will only be usable via synchronous fallback
execute_run = None
execute_sweep = None
# ---------------------------------------------------------------------------
# Synchronous fallback for single-container mode (no Redis/Celery)
# ---------------------------------------------------------------------------
class SyncTaskResult:
"""Mimics Celery AsyncResult for the synchronous fallback."""
def __init__(self, result: Any = None, task_id: str | None = None, error: Exception | None = None):
self.result = result
self.id = task_id or str(uuid.uuid4())
self.status = "SUCCESS" if error is None else "FAILURE"
self._error = error
def get(self, timeout: float | None = None) -> Any:
if self._error:
raise self._error
return self.result
@property
def state(self) -> str:
return self.status
def ready(self) -> bool:
return True
def successful(self) -> bool:
return self.status == "SUCCESS"
def failed(self) -> bool:
return self.status == "FAILURE"
def dispatch_run(run_id: str, endpoint_config: dict[str, Any] | None = None) -> Any:
"""Dispatch a run execution — Celery if available, synchronous otherwise."""
if not settings.use_in_process_queue and execute_run is not None:
return execute_run.delay(run_id, endpoint_config)
# Synchronous fallback
logger.info("Sync fallback: executing run %s in-process", run_id)
try:
result = _do_execute_run(run_id, endpoint_config)
return SyncTaskResult(result=result)
except Exception as exc:
logger.error("Sync run %s failed: %s", run_id, exc, exc_info=True)
return SyncTaskResult(error=exc)
def dispatch_sweep(
experiment_id: str,
sweep_config: dict[str, Any] | None = None,
endpoint_config: dict[str, Any] | None = None,
) -> Any:
"""Dispatch a sweep execution — Celery if available, synchronous otherwise."""
if not settings.use_in_process_queue and execute_sweep is not None:
return execute_sweep.delay(experiment_id, sweep_config, endpoint_config)
# Synchronous fallback
logger.info("Sync fallback: executing sweep for experiment %s in-process", experiment_id)
try:
result = _do_execute_sweep(experiment_id, sweep_config, endpoint_config)
return SyncTaskResult(result=result)
except Exception as exc:
logger.error("Sync sweep %s failed: %s", experiment_id, exc, exc_info=True)
return SyncTaskResult(error=exc)

387
backend/tests/test_tasks.py Normal file
View file

@ -0,0 +1,387 @@
"""Tests for engine/tasks.py — Celery task definitions and sync fallback."""
import asyncio
import sys
import uuid
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
# Ensure backend is on sys.path
_backend_dir = str(Path(__file__).resolve().parents[1])
if _backend_dir not in sys.path:
sys.path.insert(0, _backend_dir)
from engine.adapters.base import AdapterResponse, BaseAdapter
from engine.cache import ResponseCacheLayer, compute_config_hash
from engine.runner import EventBus, run_single
from engine.tasks import (
SyncTaskResult,
_do_execute_run,
_do_execute_sweep,
dispatch_run,
dispatch_sweep,
)
from models import (
Base,
Experiment,
ExperimentStatus,
Project,
Run,
RunStatus,
User,
)
# ---------------------------------------------------------------------------
# Fixtures / helpers
# ---------------------------------------------------------------------------
class MockAdapter(BaseAdapter):
"""Test adapter that returns fixed responses."""
def __init__(self, text: str = "mock output", tokens_in: int = 10, tokens_out: int = 5):
self._text = text
self._tokens_in = tokens_in
self._tokens_out = tokens_out
@property
def name(self) -> str:
return "mock"
async def complete(self, prompt: str, model: str, params: dict) -> AdapterResponse:
return AdapterResponse(
text=self._text,
tokens_in=self._tokens_in,
tokens_out=self._tokens_out,
latency_ms=42.0,
)
async def list_models(self) -> list[str]:
return ["mock-model"]
async def test_connection(self) -> bool:
return True
@pytest.fixture
def db_engine():
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
return engine
@pytest.fixture
def db(db_engine):
session = Session(db_engine)
yield session
session.close()
@pytest.fixture
def user(db):
u = User(username="testuser", password_hash="fakehash")
db.add(u)
db.commit()
db.refresh(u)
return u
@pytest.fixture
def project(db, user):
p = Project(name="Test Project", owner_id=user.id)
db.add(p)
db.commit()
db.refresh(p)
return p
@pytest.fixture
def experiment(db, project):
exp = Experiment(
project_id=project.id,
name="Test Experiment",
pipeline_stages=[
{"prompt_template": "Say hello", "model": "test-model", "params": {}}
],
parameter_space={"type": "grid", "params": {}},
status=ExperimentStatus.draft,
)
db.add(exp)
db.commit()
db.refresh(exp)
return exp
def _make_run(db: Session, experiment_id: uuid.UUID, config: dict | None = None) -> Run:
config = config or {
"prompt": "Say hello",
"model": "test-model",
"params": {},
"pipeline_stages": [
{"prompt_template": "Say hello", "model": "test-model", "params": {}}
],
}
config_hash = compute_config_hash(
config.get("prompt", ""), config.get("model", ""), config.get("params", {})
)
run = Run(
experiment_id=experiment_id,
config_hash=config_hash,
config=config,
status=RunStatus.pending,
)
db.add(run)
db.commit()
db.refresh(run)
return run
# ---------------------------------------------------------------------------
# SyncTaskResult tests
# ---------------------------------------------------------------------------
class TestSyncTaskResult:
def test_successful_result(self):
r = SyncTaskResult(result={"status": "ok"}, task_id="test-123")
assert r.id == "test-123"
assert r.status == "SUCCESS"
assert r.state == "SUCCESS"
assert r.ready() is True
assert r.successful() is True
assert r.failed() is False
assert r.get() == {"status": "ok"}
def test_failed_result(self):
err = ValueError("boom")
r = SyncTaskResult(error=err)
assert r.status == "FAILURE"
assert r.ready() is True
assert r.successful() is False
assert r.failed() is True
with pytest.raises(ValueError, match="boom"):
r.get()
def test_auto_generates_task_id(self):
r = SyncTaskResult(result="data")
assert r.id is not None
assert len(r.id) > 0
# ---------------------------------------------------------------------------
# _do_execute_run tests
# ---------------------------------------------------------------------------
class TestDoExecuteRun:
def test_executes_run_successfully(self, db, experiment):
run = _make_run(db, experiment.id)
adapter = MockAdapter()
cache = ResponseCacheLayer()
event_bus = EventBus()
# Patch the internals to use our fixtures
with patch("engine.tasks._get_db_session", return_value=db), \
patch("engine.tasks._get_adapter", return_value=adapter), \
patch("engine.tasks._get_cache", return_value=cache), \
patch("engine.tasks._get_event_bus", return_value=event_bus):
result = _do_execute_run(str(run.id))
assert result["run_id"] == str(run.id)
assert result["status"] == "completed"
assert result["tokens_in"] == 10
assert result["tokens_out"] == 5
assert result["duration_ms"] is not None
def test_raises_for_missing_run(self, db):
fake_id = str(uuid.uuid4())
with patch("engine.tasks._get_db_session", return_value=db), \
pytest.raises(ValueError, match="not found"):
_do_execute_run(fake_id)
def test_passes_endpoint_config(self, db, experiment):
run = _make_run(db, experiment.id)
adapter = MockAdapter()
cache = ResponseCacheLayer()
event_bus = EventBus()
endpoint_cfg = {"url": "http://test:8080/v1", "api_key": "sk-test"}
with patch("engine.tasks._get_db_session", return_value=db), \
patch("engine.tasks._get_adapter", return_value=adapter) as mock_get_adapter, \
patch("engine.tasks._get_cache", return_value=cache), \
patch("engine.tasks._get_event_bus", return_value=event_bus):
_do_execute_run(str(run.id), endpoint_config=endpoint_cfg)
mock_get_adapter.assert_called_once_with(endpoint_cfg)
# ---------------------------------------------------------------------------
# _do_execute_sweep tests
# ---------------------------------------------------------------------------
class TestDoExecuteSweep:
def test_executes_sweep_successfully(self, db, experiment):
adapter = MockAdapter()
cache = ResponseCacheLayer()
event_bus = EventBus()
with patch("engine.tasks._get_db_session", return_value=db), \
patch("engine.tasks._get_adapter", return_value=adapter), \
patch("engine.tasks._get_cache", return_value=cache), \
patch("engine.tasks._get_event_bus", return_value=event_bus), \
patch("engine.tasks._get_redis_client", return_value=None):
result = _do_execute_sweep(str(experiment.id))
assert result["experiment_id"] == str(experiment.id)
assert result["total_runs"] >= 1
assert result["completed_runs"] >= 1
assert result["stopped_reason"] == "completed"
def test_raises_for_missing_experiment(self, db):
fake_id = str(uuid.uuid4())
with patch("engine.tasks._get_db_session", return_value=db), \
pytest.raises(ValueError, match="not found"):
_do_execute_sweep(fake_id)
def test_override_sweep_config(self, db, experiment):
adapter = MockAdapter()
cache = ResponseCacheLayer()
event_bus = EventBus()
custom_sweep = {"type": "random", "params": {"model": ["a", "b"]}, "n_trials": 2}
with patch("engine.tasks._get_db_session", return_value=db), \
patch("engine.tasks._get_adapter", return_value=adapter), \
patch("engine.tasks._get_cache", return_value=cache), \
patch("engine.tasks._get_event_bus", return_value=event_bus), \
patch("engine.tasks._get_redis_client", return_value=None):
result = _do_execute_sweep(str(experiment.id), sweep_config=custom_sweep)
assert result["total_runs"] == 2
assert result["completed_runs"] == 2
# ---------------------------------------------------------------------------
# dispatch_run / dispatch_sweep tests (sync fallback)
# ---------------------------------------------------------------------------
class TestDispatchRunSync:
def test_sync_fallback_success(self, db, experiment):
run = _make_run(db, experiment.id)
adapter = MockAdapter()
cache = ResponseCacheLayer()
event_bus = EventBus()
with patch("engine.tasks.settings") as mock_settings, \
patch("engine.tasks._get_db_session", return_value=db), \
patch("engine.tasks._get_adapter", return_value=adapter), \
patch("engine.tasks._get_cache", return_value=cache), \
patch("engine.tasks._get_event_bus", return_value=event_bus):
mock_settings.use_in_process_queue = True
task_result = dispatch_run(str(run.id))
assert task_result.successful()
result = task_result.get()
assert result["status"] == "completed"
def test_sync_fallback_error(self, db):
fake_id = str(uuid.uuid4())
with patch("engine.tasks.settings") as mock_settings, \
patch("engine.tasks._get_db_session", return_value=db):
mock_settings.use_in_process_queue = True
task_result = dispatch_run(fake_id)
assert task_result.failed()
with pytest.raises(ValueError):
task_result.get()
class TestDispatchSweepSync:
def test_sync_fallback_success(self, db, experiment):
adapter = MockAdapter()
cache = ResponseCacheLayer()
event_bus = EventBus()
with patch("engine.tasks.settings") as mock_settings, \
patch("engine.tasks._get_db_session", return_value=db), \
patch("engine.tasks._get_adapter", return_value=adapter), \
patch("engine.tasks._get_cache", return_value=cache), \
patch("engine.tasks._get_event_bus", return_value=event_bus), \
patch("engine.tasks._get_redis_client", return_value=None):
mock_settings.use_in_process_queue = True
mock_settings.max_concurrent_runs = 2
mock_settings.max_tokens_per_sweep = 0
task_result = dispatch_sweep(str(experiment.id))
assert task_result.successful()
result = task_result.get()
assert result["completed_runs"] >= 1
def test_sync_fallback_error(self, db):
fake_id = str(uuid.uuid4())
with patch("engine.tasks.settings") as mock_settings, \
patch("engine.tasks._get_db_session", return_value=db):
mock_settings.use_in_process_queue = True
task_result = dispatch_sweep(fake_id)
assert task_result.failed()
with pytest.raises(ValueError):
task_result.get()
# ---------------------------------------------------------------------------
# Celery task registration tests
# ---------------------------------------------------------------------------
class TestCeleryTaskRegistration:
def test_execute_run_task_exists(self):
"""execute_run should be registered as a Celery task or None if Celery unavailable."""
from engine.tasks import execute_run
# In test environment, Celery may or may not be available
# If available, it should be a Celery task with the correct name
if execute_run is not None:
assert execute_run.name == "engine.execute_run"
def test_execute_sweep_task_exists(self):
from engine.tasks import execute_sweep
if execute_sweep is not None:
assert execute_sweep.name == "engine.execute_sweep"
# ---------------------------------------------------------------------------
# dispatch with Celery available
# ---------------------------------------------------------------------------
class TestDispatchWithCelery:
def test_dispatch_run_uses_celery_when_redis_available(self):
"""When use_in_process_queue is False and Celery tasks exist, should use .delay()."""
mock_task = MagicMock()
mock_task.delay.return_value = MagicMock(id="celery-task-id")
with patch("engine.tasks.settings") as mock_settings, \
patch("engine.tasks.execute_run", mock_task):
mock_settings.use_in_process_queue = False
result = dispatch_run("some-run-id", endpoint_config={"url": "http://x"})
mock_task.delay.assert_called_once_with("some-run-id", {"url": "http://x"})
def test_dispatch_sweep_uses_celery_when_redis_available(self):
mock_task = MagicMock()
mock_task.delay.return_value = MagicMock(id="celery-task-id")
with patch("engine.tasks.settings") as mock_settings, \
patch("engine.tasks.execute_sweep", mock_task):
mock_settings.use_in_process_queue = False
result = dispatch_sweep("some-exp-id", sweep_config={"type": "grid"})
mock_task.delay.assert_called_once_with("some-exp-id", {"type": "grid"}, None)