Created engine/tasks.py with: - execute_run and execute_sweep Celery tasks registered via autodiscover - SyncTaskResult class mimicking Celery AsyncResult for in-process mode - dispatch_run/dispatch_sweep helpers that route to Celery or sync based on config - Proper async-to-sync bridging for the async engine functions - 17 tests covering task execution, sync fallback, error handling, and Celery dispatch
266 lines
8.6 KiB
Python
266 lines
8.6 KiB
Python
"""Celery tasks for PromptLooper experiment execution.
|
|
|
|
Defines execute_run and execute_sweep tasks that are dispatched via
|
|
Celery (Redis broker) or run synchronously in single-container mode.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import uuid
|
|
from typing import Any
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _get_db_session() -> Session:
|
|
"""Create a database session for task execution."""
|
|
from main import SessionLocal
|
|
if SessionLocal is None:
|
|
from main import _init_db
|
|
_init_db()
|
|
from main import SessionLocal as SL
|
|
return SL()
|
|
return SessionLocal()
|
|
|
|
|
|
def _get_redis_client():
|
|
"""Get the Redis client (or None)."""
|
|
from main import get_redis
|
|
return get_redis()
|
|
|
|
|
|
def _get_adapter(endpoint_config: dict[str, Any] | None = None):
|
|
"""Create an LLM adapter from config or defaults."""
|
|
from engine.adapters.openai_compat import OpenAICompatAdapter
|
|
|
|
if endpoint_config:
|
|
return OpenAICompatAdapter(
|
|
base_url=endpoint_config.get("url", settings.default_endpoint_url or ""),
|
|
api_key=endpoint_config.get("api_key", settings.default_endpoint_key),
|
|
)
|
|
|
|
return OpenAICompatAdapter(
|
|
base_url=settings.default_endpoint_url or "http://localhost:11434/v1",
|
|
api_key=settings.default_endpoint_key,
|
|
)
|
|
|
|
|
|
def _get_event_bus():
|
|
"""Create an EventBus with Redis if available."""
|
|
from engine.runner import EventBus
|
|
redis_client = _get_redis_client()
|
|
return EventBus(redis_client=redis_client)
|
|
|
|
|
|
def _get_cache():
|
|
"""Create a ResponseCacheLayer."""
|
|
from engine.cache import ResponseCacheLayer
|
|
return ResponseCacheLayer()
|
|
|
|
|
|
def _run_async(coro):
|
|
"""Run an async coroutine, creating an event loop if needed."""
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
except RuntimeError:
|
|
loop = None
|
|
|
|
if loop and loop.is_running():
|
|
# Already in an async context — create a new loop in a thread
|
|
import concurrent.futures
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
|
return pool.submit(asyncio.run, coro).result()
|
|
else:
|
|
return asyncio.run(coro)
|
|
|
|
|
|
def _do_execute_run(run_id: str, endpoint_config: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
"""Core logic for executing a single run (used by both Celery and sync paths)."""
|
|
from models import Run
|
|
|
|
db = _get_db_session()
|
|
try:
|
|
run = db.get(Run, uuid.UUID(run_id))
|
|
if run is None:
|
|
raise ValueError(f"Run {run_id} not found")
|
|
|
|
adapter = _get_adapter(endpoint_config)
|
|
cache = _get_cache()
|
|
event_bus = _get_event_bus()
|
|
|
|
result = _run_async(
|
|
run_single_import()(db, run, adapter, cache, event_bus=event_bus)
|
|
)
|
|
|
|
return {
|
|
"run_id": str(result.id),
|
|
"status": result.status.value,
|
|
"duration_ms": result.duration_ms,
|
|
"tokens_in": result.tokens_in,
|
|
"tokens_out": result.tokens_out,
|
|
}
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def run_single_import():
|
|
"""Lazy import of run_single to avoid circular imports."""
|
|
from engine.runner import run_single
|
|
return run_single
|
|
|
|
|
|
def run_sweep_import():
|
|
"""Lazy import of run_sweep to avoid circular imports."""
|
|
from engine.sweep import run_sweep
|
|
return run_sweep
|
|
|
|
|
|
def _do_execute_sweep(
|
|
experiment_id: str,
|
|
sweep_config: dict[str, Any] | None = None,
|
|
endpoint_config: dict[str, Any] | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Core logic for executing a sweep (used by both Celery and sync paths)."""
|
|
from models import Experiment
|
|
|
|
db = _get_db_session()
|
|
try:
|
|
experiment = db.get(Experiment, uuid.UUID(experiment_id))
|
|
if experiment is None:
|
|
raise ValueError(f"Experiment {experiment_id} not found")
|
|
|
|
# Override parameter_space if sweep_config provided
|
|
if sweep_config:
|
|
experiment.parameter_space = sweep_config
|
|
db.commit()
|
|
|
|
adapter = _get_adapter(endpoint_config)
|
|
cache = _get_cache()
|
|
event_bus = _get_event_bus()
|
|
redis_client = _get_redis_client()
|
|
|
|
result = _run_async(
|
|
run_sweep_import()(
|
|
db=db,
|
|
experiment=experiment,
|
|
adapter=adapter,
|
|
cache=cache,
|
|
event_bus=event_bus,
|
|
redis_client=redis_client,
|
|
max_concurrent=settings.max_concurrent_runs,
|
|
max_tokens=settings.max_tokens_per_sweep,
|
|
)
|
|
)
|
|
|
|
return {
|
|
"experiment_id": str(result.experiment_id),
|
|
"total_runs": result.total_runs,
|
|
"completed_runs": result.completed_runs,
|
|
"failed_runs": result.failed_runs,
|
|
"total_tokens_in": result.total_tokens_in,
|
|
"total_tokens_out": result.total_tokens_out,
|
|
"stopped_reason": result.stopped_reason,
|
|
}
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Celery tasks (registered via autodiscover)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
try:
|
|
from worker import celery_app
|
|
|
|
@celery_app.task(name="engine.execute_run", bind=True, max_retries=0)
|
|
def execute_run(self, run_id: str, endpoint_config: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
"""Celery task: execute a single Run by ID."""
|
|
logger.info("Celery task execute_run started: run_id=%s", run_id)
|
|
return _do_execute_run(run_id, endpoint_config)
|
|
|
|
@celery_app.task(name="engine.execute_sweep", bind=True, max_retries=0)
|
|
def execute_sweep(
|
|
self,
|
|
experiment_id: str,
|
|
sweep_config: dict[str, Any] | None = None,
|
|
endpoint_config: dict[str, Any] | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Celery task: execute a full sweep for an experiment."""
|
|
logger.info("Celery task execute_sweep started: experiment_id=%s", experiment_id)
|
|
return _do_execute_sweep(experiment_id, sweep_config, endpoint_config)
|
|
|
|
except ImportError:
|
|
# Celery not available — tasks will only be usable via synchronous fallback
|
|
execute_run = None
|
|
execute_sweep = None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Synchronous fallback for single-container mode (no Redis/Celery)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class SyncTaskResult:
|
|
"""Mimics Celery AsyncResult for the synchronous fallback."""
|
|
|
|
def __init__(self, result: Any = None, task_id: str | None = None, error: Exception | None = None):
|
|
self.result = result
|
|
self.id = task_id or str(uuid.uuid4())
|
|
self.status = "SUCCESS" if error is None else "FAILURE"
|
|
self._error = error
|
|
|
|
def get(self, timeout: float | None = None) -> Any:
|
|
if self._error:
|
|
raise self._error
|
|
return self.result
|
|
|
|
@property
|
|
def state(self) -> str:
|
|
return self.status
|
|
|
|
def ready(self) -> bool:
|
|
return True
|
|
|
|
def successful(self) -> bool:
|
|
return self.status == "SUCCESS"
|
|
|
|
def failed(self) -> bool:
|
|
return self.status == "FAILURE"
|
|
|
|
|
|
def dispatch_run(run_id: str, endpoint_config: dict[str, Any] | None = None) -> Any:
|
|
"""Dispatch a run execution — Celery if available, synchronous otherwise."""
|
|
if not settings.use_in_process_queue and execute_run is not None:
|
|
return execute_run.delay(run_id, endpoint_config)
|
|
|
|
# Synchronous fallback
|
|
logger.info("Sync fallback: executing run %s in-process", run_id)
|
|
try:
|
|
result = _do_execute_run(run_id, endpoint_config)
|
|
return SyncTaskResult(result=result)
|
|
except Exception as exc:
|
|
logger.error("Sync run %s failed: %s", run_id, exc, exc_info=True)
|
|
return SyncTaskResult(error=exc)
|
|
|
|
|
|
def dispatch_sweep(
|
|
experiment_id: str,
|
|
sweep_config: dict[str, Any] | None = None,
|
|
endpoint_config: dict[str, Any] | None = None,
|
|
) -> Any:
|
|
"""Dispatch a sweep execution — Celery if available, synchronous otherwise."""
|
|
if not settings.use_in_process_queue and execute_sweep is not None:
|
|
return execute_sweep.delay(experiment_id, sweep_config, endpoint_config)
|
|
|
|
# Synchronous fallback
|
|
logger.info("Sync fallback: executing sweep for experiment %s in-process", experiment_id)
|
|
try:
|
|
result = _do_execute_sweep(experiment_id, sweep_config, endpoint_config)
|
|
return SyncTaskResult(result=result)
|
|
except Exception as exc:
|
|
logger.error("Sync sweep %s failed: %s", experiment_id, exc, exc_info=True)
|
|
return SyncTaskResult(error=exc)
|