From b16454994e202d26770abdf92ee833e0dcaa3143 Mon Sep 17 00:00:00 2001 From: John Lightner Date: Tue, 7 Apr 2026 03:08:41 -0500 Subject: [PATCH] MAESTRO: Implement Celery tasks (execute_run, execute_sweep) with synchronous fallback for single-container mode Created engine/tasks.py with: - execute_run and execute_sweep Celery tasks registered via autodiscover - SyncTaskResult class mimicking Celery AsyncResult for in-process mode - dispatch_run/dispatch_sweep helpers that route to Celery or sync based on config - Proper async-to-sync bridging for the async engine functions - 17 tests covering task execution, sync fallback, error handling, and Celery dispatch --- Auto Run Docs/02a-backend-engine.md | 3 +- backend/engine/tasks.py | 266 +++++++++++++++++++ backend/tests/test_tasks.py | 387 ++++++++++++++++++++++++++++ 3 files changed, 655 insertions(+), 1 deletion(-) create mode 100644 backend/engine/tasks.py create mode 100644 backend/tests/test_tasks.py diff --git a/Auto Run Docs/02a-backend-engine.md b/Auto Run Docs/02a-backend-engine.md index e371e41..9513ab0 100644 --- a/Auto Run Docs/02a-backend-engine.md +++ b/Auto Run Docs/02a-backend-engine.md @@ -29,7 +29,8 @@ Implement the core experiment execution engine: LLM adapters, response caching, - [x] Implement backend/engine/scorers/llm_judge.py — sends the output to a separate LLM with a configurable judge prompt and asks for a 1-10 rating. Parses the numeric score from the response. This scorer requires an LLM call so it should be clearly marked as "costs tokens" in the UI. Cache the judge's response too. -- [ ] Wire up the Celery worker in backend/worker.py. Define tasks: execute_run(run_id), execute_sweep(experiment_id, sweep_config). Configure Celery to use Redis as broker. In single-container mode (no Redis), implement a simple synchronous fallback that runs tasks in-process. +- [x] Wire up the Celery worker in backend/worker.py. Define tasks: execute_run(run_id), execute_sweep(experiment_id, sweep_config). Configure Celery to use Redis as broker. In single-container mode (no Redis), implement a simple synchronous fallback that runs tasks in-process. + - [ ] Implement backend/routers/endpoints.py fully — CRUD for LLM endpoint configurations. The test endpoint should call adapter.test_connection() and adapter.list_models() and return the results. Store endpoint configs in the database with encrypted API keys (Fernet symmetric encryption, key derived from JWT_SECRET). diff --git a/backend/engine/tasks.py b/backend/engine/tasks.py new file mode 100644 index 0000000..0e8bb4b --- /dev/null +++ b/backend/engine/tasks.py @@ -0,0 +1,266 @@ +"""Celery tasks for PromptLooper experiment execution. + +Defines execute_run and execute_sweep tasks that are dispatched via +Celery (Redis broker) or run synchronously in single-container mode. +""" + +import asyncio +import logging +import uuid +from typing import Any + +from sqlalchemy.orm import Session + +from config import settings + +logger = logging.getLogger(__name__) + + +def _get_db_session() -> Session: + """Create a database session for task execution.""" + from main import SessionLocal + if SessionLocal is None: + from main import _init_db + _init_db() + from main import SessionLocal as SL + return SL() + return SessionLocal() + + +def _get_redis_client(): + """Get the Redis client (or None).""" + from main import get_redis + return get_redis() + + +def _get_adapter(endpoint_config: dict[str, Any] | None = None): + """Create an LLM adapter from config or defaults.""" + from engine.adapters.openai_compat import OpenAICompatAdapter + + if endpoint_config: + return OpenAICompatAdapter( + base_url=endpoint_config.get("url", settings.default_endpoint_url or ""), + api_key=endpoint_config.get("api_key", settings.default_endpoint_key), + ) + + return OpenAICompatAdapter( + base_url=settings.default_endpoint_url or "http://localhost:11434/v1", + api_key=settings.default_endpoint_key, + ) + + +def _get_event_bus(): + """Create an EventBus with Redis if available.""" + from engine.runner import EventBus + redis_client = _get_redis_client() + return EventBus(redis_client=redis_client) + + +def _get_cache(): + """Create a ResponseCacheLayer.""" + from engine.cache import ResponseCacheLayer + return ResponseCacheLayer() + + +def _run_async(coro): + """Run an async coroutine, creating an event loop if needed.""" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + # Already in an async context — create a new loop in a thread + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + return pool.submit(asyncio.run, coro).result() + else: + return asyncio.run(coro) + + +def _do_execute_run(run_id: str, endpoint_config: dict[str, Any] | None = None) -> dict[str, Any]: + """Core logic for executing a single run (used by both Celery and sync paths).""" + from models import Run + + db = _get_db_session() + try: + run = db.get(Run, uuid.UUID(run_id)) + if run is None: + raise ValueError(f"Run {run_id} not found") + + adapter = _get_adapter(endpoint_config) + cache = _get_cache() + event_bus = _get_event_bus() + + result = _run_async( + run_single_import()(db, run, adapter, cache, event_bus=event_bus) + ) + + return { + "run_id": str(result.id), + "status": result.status.value, + "duration_ms": result.duration_ms, + "tokens_in": result.tokens_in, + "tokens_out": result.tokens_out, + } + finally: + db.close() + + +def run_single_import(): + """Lazy import of run_single to avoid circular imports.""" + from engine.runner import run_single + return run_single + + +def run_sweep_import(): + """Lazy import of run_sweep to avoid circular imports.""" + from engine.sweep import run_sweep + return run_sweep + + +def _do_execute_sweep( + experiment_id: str, + sweep_config: dict[str, Any] | None = None, + endpoint_config: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Core logic for executing a sweep (used by both Celery and sync paths).""" + from models import Experiment + + db = _get_db_session() + try: + experiment = db.get(Experiment, uuid.UUID(experiment_id)) + if experiment is None: + raise ValueError(f"Experiment {experiment_id} not found") + + # Override parameter_space if sweep_config provided + if sweep_config: + experiment.parameter_space = sweep_config + db.commit() + + adapter = _get_adapter(endpoint_config) + cache = _get_cache() + event_bus = _get_event_bus() + redis_client = _get_redis_client() + + result = _run_async( + run_sweep_import()( + db=db, + experiment=experiment, + adapter=adapter, + cache=cache, + event_bus=event_bus, + redis_client=redis_client, + max_concurrent=settings.max_concurrent_runs, + max_tokens=settings.max_tokens_per_sweep, + ) + ) + + return { + "experiment_id": str(result.experiment_id), + "total_runs": result.total_runs, + "completed_runs": result.completed_runs, + "failed_runs": result.failed_runs, + "total_tokens_in": result.total_tokens_in, + "total_tokens_out": result.total_tokens_out, + "stopped_reason": result.stopped_reason, + } + finally: + db.close() + + +# --------------------------------------------------------------------------- +# Celery tasks (registered via autodiscover) +# --------------------------------------------------------------------------- + +try: + from worker import celery_app + + @celery_app.task(name="engine.execute_run", bind=True, max_retries=0) + def execute_run(self, run_id: str, endpoint_config: dict[str, Any] | None = None) -> dict[str, Any]: + """Celery task: execute a single Run by ID.""" + logger.info("Celery task execute_run started: run_id=%s", run_id) + return _do_execute_run(run_id, endpoint_config) + + @celery_app.task(name="engine.execute_sweep", bind=True, max_retries=0) + def execute_sweep( + self, + experiment_id: str, + sweep_config: dict[str, Any] | None = None, + endpoint_config: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Celery task: execute a full sweep for an experiment.""" + logger.info("Celery task execute_sweep started: experiment_id=%s", experiment_id) + return _do_execute_sweep(experiment_id, sweep_config, endpoint_config) + +except ImportError: + # Celery not available — tasks will only be usable via synchronous fallback + execute_run = None + execute_sweep = None + + +# --------------------------------------------------------------------------- +# Synchronous fallback for single-container mode (no Redis/Celery) +# --------------------------------------------------------------------------- + + +class SyncTaskResult: + """Mimics Celery AsyncResult for the synchronous fallback.""" + + def __init__(self, result: Any = None, task_id: str | None = None, error: Exception | None = None): + self.result = result + self.id = task_id or str(uuid.uuid4()) + self.status = "SUCCESS" if error is None else "FAILURE" + self._error = error + + def get(self, timeout: float | None = None) -> Any: + if self._error: + raise self._error + return self.result + + @property + def state(self) -> str: + return self.status + + def ready(self) -> bool: + return True + + def successful(self) -> bool: + return self.status == "SUCCESS" + + def failed(self) -> bool: + return self.status == "FAILURE" + + +def dispatch_run(run_id: str, endpoint_config: dict[str, Any] | None = None) -> Any: + """Dispatch a run execution — Celery if available, synchronous otherwise.""" + if not settings.use_in_process_queue and execute_run is not None: + return execute_run.delay(run_id, endpoint_config) + + # Synchronous fallback + logger.info("Sync fallback: executing run %s in-process", run_id) + try: + result = _do_execute_run(run_id, endpoint_config) + return SyncTaskResult(result=result) + except Exception as exc: + logger.error("Sync run %s failed: %s", run_id, exc, exc_info=True) + return SyncTaskResult(error=exc) + + +def dispatch_sweep( + experiment_id: str, + sweep_config: dict[str, Any] | None = None, + endpoint_config: dict[str, Any] | None = None, +) -> Any: + """Dispatch a sweep execution — Celery if available, synchronous otherwise.""" + if not settings.use_in_process_queue and execute_sweep is not None: + return execute_sweep.delay(experiment_id, sweep_config, endpoint_config) + + # Synchronous fallback + logger.info("Sync fallback: executing sweep for experiment %s in-process", experiment_id) + try: + result = _do_execute_sweep(experiment_id, sweep_config, endpoint_config) + return SyncTaskResult(result=result) + except Exception as exc: + logger.error("Sync sweep %s failed: %s", experiment_id, exc, exc_info=True) + return SyncTaskResult(error=exc) diff --git a/backend/tests/test_tasks.py b/backend/tests/test_tasks.py new file mode 100644 index 0000000..9660974 --- /dev/null +++ b/backend/tests/test_tasks.py @@ -0,0 +1,387 @@ +"""Tests for engine/tasks.py — Celery task definitions and sync fallback.""" + +import asyncio +import sys +import uuid +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +# Ensure backend is on sys.path +_backend_dir = str(Path(__file__).resolve().parents[1]) +if _backend_dir not in sys.path: + sys.path.insert(0, _backend_dir) + +from engine.adapters.base import AdapterResponse, BaseAdapter +from engine.cache import ResponseCacheLayer, compute_config_hash +from engine.runner import EventBus, run_single +from engine.tasks import ( + SyncTaskResult, + _do_execute_run, + _do_execute_sweep, + dispatch_run, + dispatch_sweep, +) +from models import ( + Base, + Experiment, + ExperimentStatus, + Project, + Run, + RunStatus, + User, +) + + +# --------------------------------------------------------------------------- +# Fixtures / helpers +# --------------------------------------------------------------------------- + + +class MockAdapter(BaseAdapter): + """Test adapter that returns fixed responses.""" + + def __init__(self, text: str = "mock output", tokens_in: int = 10, tokens_out: int = 5): + self._text = text + self._tokens_in = tokens_in + self._tokens_out = tokens_out + + @property + def name(self) -> str: + return "mock" + + async def complete(self, prompt: str, model: str, params: dict) -> AdapterResponse: + return AdapterResponse( + text=self._text, + tokens_in=self._tokens_in, + tokens_out=self._tokens_out, + latency_ms=42.0, + ) + + async def list_models(self) -> list[str]: + return ["mock-model"] + + async def test_connection(self) -> bool: + return True + + +@pytest.fixture +def db_engine(): + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + return engine + + +@pytest.fixture +def db(db_engine): + session = Session(db_engine) + yield session + session.close() + + +@pytest.fixture +def user(db): + u = User(username="testuser", password_hash="fakehash") + db.add(u) + db.commit() + db.refresh(u) + return u + + +@pytest.fixture +def project(db, user): + p = Project(name="Test Project", owner_id=user.id) + db.add(p) + db.commit() + db.refresh(p) + return p + + +@pytest.fixture +def experiment(db, project): + exp = Experiment( + project_id=project.id, + name="Test Experiment", + pipeline_stages=[ + {"prompt_template": "Say hello", "model": "test-model", "params": {}} + ], + parameter_space={"type": "grid", "params": {}}, + status=ExperimentStatus.draft, + ) + db.add(exp) + db.commit() + db.refresh(exp) + return exp + + +def _make_run(db: Session, experiment_id: uuid.UUID, config: dict | None = None) -> Run: + config = config or { + "prompt": "Say hello", + "model": "test-model", + "params": {}, + "pipeline_stages": [ + {"prompt_template": "Say hello", "model": "test-model", "params": {}} + ], + } + config_hash = compute_config_hash( + config.get("prompt", ""), config.get("model", ""), config.get("params", {}) + ) + run = Run( + experiment_id=experiment_id, + config_hash=config_hash, + config=config, + status=RunStatus.pending, + ) + db.add(run) + db.commit() + db.refresh(run) + return run + + +# --------------------------------------------------------------------------- +# SyncTaskResult tests +# --------------------------------------------------------------------------- + + +class TestSyncTaskResult: + def test_successful_result(self): + r = SyncTaskResult(result={"status": "ok"}, task_id="test-123") + assert r.id == "test-123" + assert r.status == "SUCCESS" + assert r.state == "SUCCESS" + assert r.ready() is True + assert r.successful() is True + assert r.failed() is False + assert r.get() == {"status": "ok"} + + def test_failed_result(self): + err = ValueError("boom") + r = SyncTaskResult(error=err) + assert r.status == "FAILURE" + assert r.ready() is True + assert r.successful() is False + assert r.failed() is True + with pytest.raises(ValueError, match="boom"): + r.get() + + def test_auto_generates_task_id(self): + r = SyncTaskResult(result="data") + assert r.id is not None + assert len(r.id) > 0 + + +# --------------------------------------------------------------------------- +# _do_execute_run tests +# --------------------------------------------------------------------------- + + +class TestDoExecuteRun: + def test_executes_run_successfully(self, db, experiment): + run = _make_run(db, experiment.id) + adapter = MockAdapter() + cache = ResponseCacheLayer() + event_bus = EventBus() + + # Patch the internals to use our fixtures + with patch("engine.tasks._get_db_session", return_value=db), \ + patch("engine.tasks._get_adapter", return_value=adapter), \ + patch("engine.tasks._get_cache", return_value=cache), \ + patch("engine.tasks._get_event_bus", return_value=event_bus): + result = _do_execute_run(str(run.id)) + + assert result["run_id"] == str(run.id) + assert result["status"] == "completed" + assert result["tokens_in"] == 10 + assert result["tokens_out"] == 5 + assert result["duration_ms"] is not None + + def test_raises_for_missing_run(self, db): + fake_id = str(uuid.uuid4()) + with patch("engine.tasks._get_db_session", return_value=db), \ + pytest.raises(ValueError, match="not found"): + _do_execute_run(fake_id) + + def test_passes_endpoint_config(self, db, experiment): + run = _make_run(db, experiment.id) + adapter = MockAdapter() + cache = ResponseCacheLayer() + event_bus = EventBus() + endpoint_cfg = {"url": "http://test:8080/v1", "api_key": "sk-test"} + + with patch("engine.tasks._get_db_session", return_value=db), \ + patch("engine.tasks._get_adapter", return_value=adapter) as mock_get_adapter, \ + patch("engine.tasks._get_cache", return_value=cache), \ + patch("engine.tasks._get_event_bus", return_value=event_bus): + _do_execute_run(str(run.id), endpoint_config=endpoint_cfg) + mock_get_adapter.assert_called_once_with(endpoint_cfg) + + +# --------------------------------------------------------------------------- +# _do_execute_sweep tests +# --------------------------------------------------------------------------- + + +class TestDoExecuteSweep: + def test_executes_sweep_successfully(self, db, experiment): + adapter = MockAdapter() + cache = ResponseCacheLayer() + event_bus = EventBus() + + with patch("engine.tasks._get_db_session", return_value=db), \ + patch("engine.tasks._get_adapter", return_value=adapter), \ + patch("engine.tasks._get_cache", return_value=cache), \ + patch("engine.tasks._get_event_bus", return_value=event_bus), \ + patch("engine.tasks._get_redis_client", return_value=None): + result = _do_execute_sweep(str(experiment.id)) + + assert result["experiment_id"] == str(experiment.id) + assert result["total_runs"] >= 1 + assert result["completed_runs"] >= 1 + assert result["stopped_reason"] == "completed" + + def test_raises_for_missing_experiment(self, db): + fake_id = str(uuid.uuid4()) + with patch("engine.tasks._get_db_session", return_value=db), \ + pytest.raises(ValueError, match="not found"): + _do_execute_sweep(fake_id) + + def test_override_sweep_config(self, db, experiment): + adapter = MockAdapter() + cache = ResponseCacheLayer() + event_bus = EventBus() + custom_sweep = {"type": "random", "params": {"model": ["a", "b"]}, "n_trials": 2} + + with patch("engine.tasks._get_db_session", return_value=db), \ + patch("engine.tasks._get_adapter", return_value=adapter), \ + patch("engine.tasks._get_cache", return_value=cache), \ + patch("engine.tasks._get_event_bus", return_value=event_bus), \ + patch("engine.tasks._get_redis_client", return_value=None): + result = _do_execute_sweep(str(experiment.id), sweep_config=custom_sweep) + + assert result["total_runs"] == 2 + assert result["completed_runs"] == 2 + + +# --------------------------------------------------------------------------- +# dispatch_run / dispatch_sweep tests (sync fallback) +# --------------------------------------------------------------------------- + + +class TestDispatchRunSync: + def test_sync_fallback_success(self, db, experiment): + run = _make_run(db, experiment.id) + adapter = MockAdapter() + cache = ResponseCacheLayer() + event_bus = EventBus() + + with patch("engine.tasks.settings") as mock_settings, \ + patch("engine.tasks._get_db_session", return_value=db), \ + patch("engine.tasks._get_adapter", return_value=adapter), \ + patch("engine.tasks._get_cache", return_value=cache), \ + patch("engine.tasks._get_event_bus", return_value=event_bus): + mock_settings.use_in_process_queue = True + task_result = dispatch_run(str(run.id)) + + assert task_result.successful() + result = task_result.get() + assert result["status"] == "completed" + + def test_sync_fallback_error(self, db): + fake_id = str(uuid.uuid4()) + + with patch("engine.tasks.settings") as mock_settings, \ + patch("engine.tasks._get_db_session", return_value=db): + mock_settings.use_in_process_queue = True + task_result = dispatch_run(fake_id) + + assert task_result.failed() + with pytest.raises(ValueError): + task_result.get() + + +class TestDispatchSweepSync: + def test_sync_fallback_success(self, db, experiment): + adapter = MockAdapter() + cache = ResponseCacheLayer() + event_bus = EventBus() + + with patch("engine.tasks.settings") as mock_settings, \ + patch("engine.tasks._get_db_session", return_value=db), \ + patch("engine.tasks._get_adapter", return_value=adapter), \ + patch("engine.tasks._get_cache", return_value=cache), \ + patch("engine.tasks._get_event_bus", return_value=event_bus), \ + patch("engine.tasks._get_redis_client", return_value=None): + mock_settings.use_in_process_queue = True + mock_settings.max_concurrent_runs = 2 + mock_settings.max_tokens_per_sweep = 0 + task_result = dispatch_sweep(str(experiment.id)) + + assert task_result.successful() + result = task_result.get() + assert result["completed_runs"] >= 1 + + def test_sync_fallback_error(self, db): + fake_id = str(uuid.uuid4()) + + with patch("engine.tasks.settings") as mock_settings, \ + patch("engine.tasks._get_db_session", return_value=db): + mock_settings.use_in_process_queue = True + task_result = dispatch_sweep(fake_id) + + assert task_result.failed() + with pytest.raises(ValueError): + task_result.get() + + +# --------------------------------------------------------------------------- +# Celery task registration tests +# --------------------------------------------------------------------------- + + +class TestCeleryTaskRegistration: + def test_execute_run_task_exists(self): + """execute_run should be registered as a Celery task or None if Celery unavailable.""" + from engine.tasks import execute_run + # In test environment, Celery may or may not be available + # If available, it should be a Celery task with the correct name + if execute_run is not None: + assert execute_run.name == "engine.execute_run" + + def test_execute_sweep_task_exists(self): + from engine.tasks import execute_sweep + if execute_sweep is not None: + assert execute_sweep.name == "engine.execute_sweep" + + +# --------------------------------------------------------------------------- +# dispatch with Celery available +# --------------------------------------------------------------------------- + + +class TestDispatchWithCelery: + def test_dispatch_run_uses_celery_when_redis_available(self): + """When use_in_process_queue is False and Celery tasks exist, should use .delay().""" + mock_task = MagicMock() + mock_task.delay.return_value = MagicMock(id="celery-task-id") + + with patch("engine.tasks.settings") as mock_settings, \ + patch("engine.tasks.execute_run", mock_task): + mock_settings.use_in_process_queue = False + result = dispatch_run("some-run-id", endpoint_config={"url": "http://x"}) + + mock_task.delay.assert_called_once_with("some-run-id", {"url": "http://x"}) + + def test_dispatch_sweep_uses_celery_when_redis_available(self): + mock_task = MagicMock() + mock_task.delay.return_value = MagicMock(id="celery-task-id") + + with patch("engine.tasks.settings") as mock_settings, \ + patch("engine.tasks.execute_sweep", mock_task): + mock_settings.use_in_process_queue = False + result = dispatch_sweep("some-exp-id", sweep_config={"type": "grid"}) + + mock_task.delay.assert_called_once_with("some-exp-id", {"type": "grid"}, None)