"""Celery tasks for PromptLooper experiment execution. Defines execute_run and execute_sweep tasks that are dispatched via Celery (Redis broker) or run synchronously in single-container mode. """ import asyncio import logging import uuid from typing import Any from sqlalchemy.orm import Session from config import settings logger = logging.getLogger(__name__) def _get_db_session() -> Session: """Create a database session for task execution.""" from main import SessionLocal if SessionLocal is None: from main import _init_db _init_db() from main import SessionLocal as SL return SL() return SessionLocal() def _get_redis_client(): """Get the Redis client (or None).""" from main import get_redis return get_redis() def _get_adapter(endpoint_config: dict[str, Any] | None = None): """Create an LLM adapter from config or defaults.""" from engine.adapters.openai_compat import OpenAICompatAdapter if endpoint_config: return OpenAICompatAdapter( base_url=endpoint_config.get("url", settings.default_endpoint_url or ""), api_key=endpoint_config.get("api_key", settings.default_endpoint_key), ) return OpenAICompatAdapter( base_url=settings.default_endpoint_url or "http://localhost:11434/v1", api_key=settings.default_endpoint_key, ) def _get_event_bus(): """Create an EventBus with Redis if available.""" from engine.runner import EventBus redis_client = _get_redis_client() return EventBus(redis_client=redis_client) def _get_cache(): """Create a ResponseCacheLayer.""" from engine.cache import ResponseCacheLayer return ResponseCacheLayer() def _run_async(coro): """Run an async coroutine, creating an event loop if needed.""" try: loop = asyncio.get_running_loop() except RuntimeError: loop = None if loop and loop.is_running(): # Already in an async context — create a new loop in a thread import concurrent.futures with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: return pool.submit(asyncio.run, coro).result() else: return asyncio.run(coro) def _do_execute_run(run_id: str, endpoint_config: dict[str, Any] | None = None) -> dict[str, Any]: """Core logic for executing a single run (used by both Celery and sync paths).""" from models import Run db = _get_db_session() try: run = db.get(Run, uuid.UUID(run_id)) if run is None: raise ValueError(f"Run {run_id} not found") adapter = _get_adapter(endpoint_config) cache = _get_cache() event_bus = _get_event_bus() result = _run_async( run_single_import()(db, run, adapter, cache, event_bus=event_bus) ) return { "run_id": str(result.id), "status": result.status.value, "duration_ms": result.duration_ms, "tokens_in": result.tokens_in, "tokens_out": result.tokens_out, } finally: db.close() def run_single_import(): """Lazy import of run_single to avoid circular imports.""" from engine.runner import run_single return run_single def run_sweep_import(): """Lazy import of run_sweep to avoid circular imports.""" from engine.sweep import run_sweep return run_sweep def _do_execute_sweep( experiment_id: str, sweep_config: dict[str, Any] | None = None, endpoint_config: dict[str, Any] | None = None, ) -> dict[str, Any]: """Core logic for executing a sweep (used by both Celery and sync paths).""" from models import Experiment db = _get_db_session() try: experiment = db.get(Experiment, uuid.UUID(experiment_id)) if experiment is None: raise ValueError(f"Experiment {experiment_id} not found") # Override parameter_space if sweep_config provided if sweep_config: experiment.parameter_space = sweep_config db.commit() adapter = _get_adapter(endpoint_config) cache = _get_cache() event_bus = _get_event_bus() redis_client = _get_redis_client() result = _run_async( run_sweep_import()( db=db, experiment=experiment, adapter=adapter, cache=cache, event_bus=event_bus, redis_client=redis_client, max_concurrent=settings.max_concurrent_runs, max_tokens=settings.max_tokens_per_sweep, ) ) return { "experiment_id": str(result.experiment_id), "total_runs": result.total_runs, "completed_runs": result.completed_runs, "failed_runs": result.failed_runs, "total_tokens_in": result.total_tokens_in, "total_tokens_out": result.total_tokens_out, "stopped_reason": result.stopped_reason, } finally: db.close() # --------------------------------------------------------------------------- # Celery tasks (registered via autodiscover) # --------------------------------------------------------------------------- try: from worker import celery_app @celery_app.task(name="engine.execute_run", bind=True, max_retries=0) def execute_run(self, run_id: str, endpoint_config: dict[str, Any] | None = None) -> dict[str, Any]: """Celery task: execute a single Run by ID.""" logger.info("Celery task execute_run started: run_id=%s", run_id) return _do_execute_run(run_id, endpoint_config) @celery_app.task(name="engine.execute_sweep", bind=True, max_retries=0) def execute_sweep( self, experiment_id: str, sweep_config: dict[str, Any] | None = None, endpoint_config: dict[str, Any] | None = None, ) -> dict[str, Any]: """Celery task: execute a full sweep for an experiment.""" logger.info("Celery task execute_sweep started: experiment_id=%s", experiment_id) return _do_execute_sweep(experiment_id, sweep_config, endpoint_config) except ImportError: # Celery not available — tasks will only be usable via synchronous fallback execute_run = None execute_sweep = None # --------------------------------------------------------------------------- # Synchronous fallback for single-container mode (no Redis/Celery) # --------------------------------------------------------------------------- class SyncTaskResult: """Mimics Celery AsyncResult for the synchronous fallback.""" def __init__(self, result: Any = None, task_id: str | None = None, error: Exception | None = None): self.result = result self.id = task_id or str(uuid.uuid4()) self.status = "SUCCESS" if error is None else "FAILURE" self._error = error def get(self, timeout: float | None = None) -> Any: if self._error: raise self._error return self.result @property def state(self) -> str: return self.status def ready(self) -> bool: return True def successful(self) -> bool: return self.status == "SUCCESS" def failed(self) -> bool: return self.status == "FAILURE" def dispatch_run(run_id: str, endpoint_config: dict[str, Any] | None = None) -> Any: """Dispatch a run execution — Celery if available, synchronous otherwise.""" if not settings.use_in_process_queue and execute_run is not None: return execute_run.delay(run_id, endpoint_config) # Synchronous fallback logger.info("Sync fallback: executing run %s in-process", run_id) try: result = _do_execute_run(run_id, endpoint_config) return SyncTaskResult(result=result) except Exception as exc: logger.error("Sync run %s failed: %s", run_id, exc, exc_info=True) return SyncTaskResult(error=exc) def dispatch_sweep( experiment_id: str, sweep_config: dict[str, Any] | None = None, endpoint_config: dict[str, Any] | None = None, ) -> Any: """Dispatch a sweep execution — Celery if available, synchronous otherwise.""" if not settings.use_in_process_queue and execute_sweep is not None: return execute_sweep.delay(experiment_id, sweep_config, endpoint_config) # Synchronous fallback logger.info("Sync fallback: executing sweep for experiment %s in-process", experiment_id) try: result = _do_execute_sweep(experiment_id, sweep_config, endpoint_config) return SyncTaskResult(result=result) except Exception as exc: logger.error("Sync sweep %s failed: %s", experiment_id, exc, exc_info=True) return SyncTaskResult(error=exc)