Full webhook system: CRUD endpoints (list/filter/get/create/update/delete), WebhookDelivery model for delivery audit trail, dispatch engine with 3-attempt retry and exponential backoff, Celery task integration with sync fallback, and webhook firing hooks in runner.py and sweep.py event paths.
304 lines
10 KiB
Python
304 lines
10 KiB
Python
"""Celery tasks for PromptLooper experiment execution.
|
|
|
|
Defines execute_run and execute_sweep tasks that are dispatched via
|
|
Celery (Redis broker) or run synchronously in single-container mode.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import uuid
|
|
from typing import Any
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _get_db_session() -> Session:
|
|
"""Create a database session for task execution."""
|
|
from main import SessionLocal
|
|
if SessionLocal is None:
|
|
from main import _init_db
|
|
_init_db()
|
|
from main import SessionLocal as SL
|
|
return SL()
|
|
return SessionLocal()
|
|
|
|
|
|
def _get_redis_client():
|
|
"""Get the Redis client (or None)."""
|
|
from main import get_redis
|
|
return get_redis()
|
|
|
|
|
|
def _get_adapter(endpoint_config: dict[str, Any] | None = None):
|
|
"""Create an LLM adapter from config or defaults."""
|
|
from engine.adapters.openai_compat import OpenAICompatAdapter
|
|
|
|
if endpoint_config:
|
|
return OpenAICompatAdapter(
|
|
base_url=endpoint_config.get("url", settings.default_endpoint_url or ""),
|
|
api_key=endpoint_config.get("api_key", settings.default_endpoint_key),
|
|
)
|
|
|
|
return OpenAICompatAdapter(
|
|
base_url=settings.default_endpoint_url or "http://localhost:11434/v1",
|
|
api_key=settings.default_endpoint_key,
|
|
)
|
|
|
|
|
|
def _get_event_bus():
|
|
"""Create an EventBus with Redis if available."""
|
|
from engine.runner import EventBus
|
|
redis_client = _get_redis_client()
|
|
return EventBus(redis_client=redis_client)
|
|
|
|
|
|
def _do_dispatch_webhooks(event_type: str, payload: dict[str, Any]) -> dict[str, Any]:
|
|
"""Core logic for dispatching webhooks (used by both Celery and sync paths)."""
|
|
from engine.webhooks import dispatch_webhooks
|
|
|
|
db = _get_db_session()
|
|
try:
|
|
successes = dispatch_webhooks(db, event_type, payload)
|
|
return {"event_type": event_type, "dispatched": successes}
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def _get_cache():
|
|
"""Create a ResponseCacheLayer."""
|
|
from engine.cache import ResponseCacheLayer
|
|
return ResponseCacheLayer()
|
|
|
|
|
|
def _run_async(coro):
|
|
"""Run an async coroutine, creating an event loop if needed."""
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
except RuntimeError:
|
|
loop = None
|
|
|
|
if loop and loop.is_running():
|
|
# Already in an async context — create a new loop in a thread
|
|
import concurrent.futures
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
|
return pool.submit(asyncio.run, coro).result()
|
|
else:
|
|
return asyncio.run(coro)
|
|
|
|
|
|
def _do_execute_run(run_id: str, endpoint_config: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
"""Core logic for executing a single run (used by both Celery and sync paths)."""
|
|
from models import Run
|
|
|
|
db = _get_db_session()
|
|
try:
|
|
run = db.get(Run, uuid.UUID(run_id))
|
|
if run is None:
|
|
raise ValueError(f"Run {run_id} not found")
|
|
|
|
adapter = _get_adapter(endpoint_config)
|
|
cache = _get_cache()
|
|
event_bus = _get_event_bus()
|
|
|
|
result = _run_async(
|
|
run_single_import()(db, run, adapter, cache, event_bus=event_bus)
|
|
)
|
|
|
|
return {
|
|
"run_id": str(result.id),
|
|
"status": result.status.value,
|
|
"duration_ms": result.duration_ms,
|
|
"tokens_in": result.tokens_in,
|
|
"tokens_out": result.tokens_out,
|
|
}
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def run_single_import():
|
|
"""Lazy import of run_single to avoid circular imports."""
|
|
from engine.runner import run_single
|
|
return run_single
|
|
|
|
|
|
def run_sweep_import():
|
|
"""Lazy import of run_sweep to avoid circular imports."""
|
|
from engine.sweep import run_sweep
|
|
return run_sweep
|
|
|
|
|
|
def _do_execute_sweep(
|
|
experiment_id: str,
|
|
sweep_config: dict[str, Any] | None = None,
|
|
endpoint_config: dict[str, Any] | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Core logic for executing a sweep (used by both Celery and sync paths)."""
|
|
from models import Experiment
|
|
|
|
db = _get_db_session()
|
|
try:
|
|
experiment = db.get(Experiment, uuid.UUID(experiment_id))
|
|
if experiment is None:
|
|
raise ValueError(f"Experiment {experiment_id} not found")
|
|
|
|
# Override parameter_space if sweep_config provided
|
|
if sweep_config:
|
|
experiment.parameter_space = sweep_config
|
|
db.commit()
|
|
|
|
adapter = _get_adapter(endpoint_config)
|
|
cache = _get_cache()
|
|
event_bus = _get_event_bus()
|
|
redis_client = _get_redis_client()
|
|
|
|
result = _run_async(
|
|
run_sweep_import()(
|
|
db=db,
|
|
experiment=experiment,
|
|
adapter=adapter,
|
|
cache=cache,
|
|
event_bus=event_bus,
|
|
redis_client=redis_client,
|
|
max_concurrent=settings.max_concurrent_runs,
|
|
max_tokens=settings.max_tokens_per_sweep,
|
|
)
|
|
)
|
|
|
|
return {
|
|
"experiment_id": str(result.experiment_id),
|
|
"total_runs": result.total_runs,
|
|
"completed_runs": result.completed_runs,
|
|
"failed_runs": result.failed_runs,
|
|
"total_tokens_in": result.total_tokens_in,
|
|
"total_tokens_out": result.total_tokens_out,
|
|
"stopped_reason": result.stopped_reason,
|
|
}
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Celery tasks (registered via autodiscover)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
try:
|
|
from worker import celery_app
|
|
|
|
@celery_app.task(name="engine.execute_run", bind=True, max_retries=0)
|
|
def execute_run(self, run_id: str, endpoint_config: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
"""Celery task: execute a single Run by ID."""
|
|
logger.info("Celery task execute_run started: run_id=%s", run_id)
|
|
return _do_execute_run(run_id, endpoint_config)
|
|
|
|
@celery_app.task(name="engine.execute_sweep", bind=True, max_retries=0)
|
|
def execute_sweep(
|
|
self,
|
|
experiment_id: str,
|
|
sweep_config: dict[str, Any] | None = None,
|
|
endpoint_config: dict[str, Any] | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Celery task: execute a full sweep for an experiment."""
|
|
logger.info("Celery task execute_sweep started: experiment_id=%s", experiment_id)
|
|
return _do_execute_sweep(experiment_id, sweep_config, endpoint_config)
|
|
|
|
@celery_app.task(name="engine.dispatch_webhooks", bind=True, max_retries=0)
|
|
def dispatch_webhooks_task(self, event_type: str, payload: dict[str, Any]) -> dict[str, Any]:
|
|
"""Celery task: dispatch webhooks for an event."""
|
|
logger.info("Celery task dispatch_webhooks started: event_type=%s", event_type)
|
|
return _do_dispatch_webhooks(event_type, payload)
|
|
|
|
except ImportError:
|
|
# Celery not available — tasks will only be usable via synchronous fallback
|
|
execute_run = None
|
|
execute_sweep = None
|
|
dispatch_webhooks_task = None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Synchronous fallback for single-container mode (no Redis/Celery)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class SyncTaskResult:
|
|
"""Mimics Celery AsyncResult for the synchronous fallback."""
|
|
|
|
def __init__(self, result: Any = None, task_id: str | None = None, error: Exception | None = None):
|
|
self.result = result
|
|
self.id = task_id or str(uuid.uuid4())
|
|
self.status = "SUCCESS" if error is None else "FAILURE"
|
|
self._error = error
|
|
|
|
def get(self, timeout: float | None = None) -> Any:
|
|
if self._error:
|
|
raise self._error
|
|
return self.result
|
|
|
|
@property
|
|
def state(self) -> str:
|
|
return self.status
|
|
|
|
def ready(self) -> bool:
|
|
return True
|
|
|
|
def successful(self) -> bool:
|
|
return self.status == "SUCCESS"
|
|
|
|
def failed(self) -> bool:
|
|
return self.status == "FAILURE"
|
|
|
|
|
|
def dispatch_run(run_id: str, endpoint_config: dict[str, Any] | None = None) -> Any:
|
|
"""Dispatch a run execution — Celery if available, synchronous otherwise."""
|
|
if not settings.use_in_process_queue and execute_run is not None:
|
|
return execute_run.delay(run_id, endpoint_config)
|
|
|
|
# Synchronous fallback
|
|
logger.info("Sync fallback: executing run %s in-process", run_id)
|
|
try:
|
|
result = _do_execute_run(run_id, endpoint_config)
|
|
return SyncTaskResult(result=result)
|
|
except Exception as exc:
|
|
logger.error("Sync run %s failed: %s", run_id, exc, exc_info=True)
|
|
return SyncTaskResult(error=exc)
|
|
|
|
|
|
def dispatch_sweep(
|
|
experiment_id: str,
|
|
sweep_config: dict[str, Any] | None = None,
|
|
endpoint_config: dict[str, Any] | None = None,
|
|
) -> Any:
|
|
"""Dispatch a sweep execution — Celery if available, synchronous otherwise."""
|
|
if not settings.use_in_process_queue and execute_sweep is not None:
|
|
return execute_sweep.delay(experiment_id, sweep_config, endpoint_config)
|
|
|
|
# Synchronous fallback
|
|
logger.info("Sync fallback: executing sweep for experiment %s in-process", experiment_id)
|
|
try:
|
|
result = _do_execute_sweep(experiment_id, sweep_config, endpoint_config)
|
|
return SyncTaskResult(result=result)
|
|
except Exception as exc:
|
|
logger.error("Sync sweep %s failed: %s", experiment_id, exc, exc_info=True)
|
|
return SyncTaskResult(error=exc)
|
|
|
|
|
|
def fire_webhooks(event_type: str, payload: dict[str, Any]) -> Any:
|
|
"""Dispatch webhook delivery — Celery if available, synchronous otherwise.
|
|
|
|
Call this from runner.py / sweep.py whenever an event occurs that
|
|
should trigger webhooks.
|
|
"""
|
|
if not settings.use_in_process_queue and dispatch_webhooks_task is not None:
|
|
return dispatch_webhooks_task.delay(event_type, payload)
|
|
|
|
# Synchronous fallback
|
|
logger.info("Sync fallback: dispatching webhooks for event_type=%s", event_type)
|
|
try:
|
|
result = _do_dispatch_webhooks(event_type, payload)
|
|
return SyncTaskResult(result=result)
|
|
except Exception as exc:
|
|
logger.error("Sync webhook dispatch failed: %s", exc, exc_info=True)
|
|
return SyncTaskResult(error=exc)
|