promptlooper/backend/engine/runner.py
John Lightner d607970f0c MAESTRO: Implement run execution engine with Jinja2 templating, caching, scoring, and event bus
Adds backend/engine/runner.py with run_single() that iterates pipeline stages,
renders Jinja2 prompt templates with stage history context, checks/stores response
cache, calls LLM adapters, runs configured scorers, creates StageResult and Score
records, and publishes progress events via Redis pub/sub or in-process EventBus.
Includes 21 passing tests covering all execution paths.
2026-04-07 02:48:20 -05:00

285 lines
9.5 KiB
Python

"""Individual run execution for PromptLooper.
Executes a single Run by iterating through pipeline stages, rendering
prompts via Jinja2, checking/storing cache, calling the LLM adapter,
scoring results, and publishing progress events.
"""
import asyncio
import json
import logging
import time
from datetime import datetime, timezone
from typing import Any
from jinja2 import BaseLoader, Environment
from sqlalchemy.orm import Session
from engine.adapters.base import AdapterResponse, BaseAdapter
from engine.cache import ResponseCacheLayer, compute_config_hash
from models import Run, RunStatus, Score, StageResult
logger = logging.getLogger(__name__)
# Jinja2 environment with sandboxed string loader (no filesystem access)
_jinja_env = Environment(loader=BaseLoader(), autoescape=False)
class EventBus:
"""Simple event publisher that uses Redis pub/sub when available,
or falls back to in-process callbacks for single-container mode."""
def __init__(self, redis_client: Any | None = None, channel: str = "promptlooper:events"):
self._redis = redis_client
self._channel = channel
self._listeners: list[Any] = []
def add_listener(self, callback: Any) -> None:
self._listeners.append(callback)
def publish(self, event: dict[str, Any]) -> None:
payload = json.dumps(event, default=str)
if self._redis is not None:
try:
self._redis.publish(self._channel, payload)
except Exception:
logger.warning("Failed to publish event to Redis", exc_info=True)
for listener in self._listeners:
try:
listener(event)
except Exception:
logger.warning("Event listener error", exc_info=True)
def render_prompt(template_str: str, context: dict[str, Any]) -> str:
"""Render a Jinja2 template string with the given context."""
template = _jinja_env.from_string(template_str)
return template.render(**context)
def _build_stage_context(
stage_index: int,
stage_results: list[dict[str, Any]],
run_config: dict[str, Any],
input_data: Any = None,
) -> dict[str, Any]:
"""Build the Jinja2 template context for a given stage.
Available variables in templates:
- input: the original input data
- config: the run configuration
- stages: list of previous stage outputs (dicts with 'output', 'model', etc.)
- prev_output: the output text of the immediately preceding stage (or empty string)
- stage_index: current 0-based stage index
"""
prev_output = ""
if stage_results:
prev_output = stage_results[-1].get("output", "")
return {
"input": input_data or "",
"config": run_config,
"stages": stage_results,
"prev_output": prev_output,
"stage_index": stage_index,
}
async def run_single(
db: Session,
run: Run,
adapter: BaseAdapter,
cache: ResponseCacheLayer,
scorers: list[Any] | None = None,
event_bus: EventBus | None = None,
) -> Run:
"""Execute a single Run through all its pipeline stages.
Args:
db: SQLAlchemy session.
run: The Run ORM object (must already be persisted with config).
adapter: LLM adapter to use for completions.
cache: Response cache layer.
scorers: Optional list of scorer instances (must have .name and .score()).
event_bus: Optional event publisher for progress updates.
Returns:
The updated Run object.
"""
scorers = scorers or []
# Extract pipeline stages from config
config = run.config or {}
stages = config.get("pipeline_stages", [])
if not stages:
# Single-stage fallback: use prompt + model + params from config directly
stages = [
{
"prompt_template": config.get("prompt", config.get("prompt_template", "")),
"model": config.get("model", ""),
"params": config.get("params", {}),
}
]
input_data = config.get("input_data")
# Mark run as running
run.status = RunStatus.running
run.started_at = datetime.now(timezone.utc)
db.commit()
if event_bus:
event_bus.publish({
"type": "run.started",
"run_id": str(run.id),
"experiment_id": str(run.experiment_id),
})
total_tokens_in = 0
total_tokens_out = 0
completed_stages: list[dict[str, Any]] = []
t_start = time.perf_counter()
try:
for stage_index, stage_def in enumerate(stages):
prompt_template = stage_def.get("prompt_template", stage_def.get("prompt", ""))
model = stage_def.get("model", config.get("model", ""))
params = stage_def.get("params", config.get("params", {}))
# Build context and render prompt
template_ctx = _build_stage_context(
stage_index, completed_stages, config, input_data
)
rendered_prompt = render_prompt(prompt_template, template_ctx)
# Check cache
config_hash = compute_config_hash(rendered_prompt, model, params, input_data)
cached = cache.get(db, config_hash)
if cached is not None:
# Cache hit
response_text = cached.response
tokens_in = cached.tokens_in or 0
tokens_out = cached.tokens_out or 0
latency_ms = cached.latency_ms or 0
else:
# Cache miss — call LLM
adapter_resp: AdapterResponse = await adapter.complete(
rendered_prompt, model, params
)
response_text = adapter_resp.text
tokens_in = adapter_resp.tokens_in
tokens_out = adapter_resp.tokens_out
latency_ms = int(adapter_resp.latency_ms)
# Store in cache
cache.put(
db,
config_hash=config_hash,
response=response_text,
model=model,
tokens_in=tokens_in,
tokens_out=tokens_out,
latency_ms=latency_ms,
)
total_tokens_in += tokens_in
total_tokens_out += tokens_out
# Create StageResult record
stage_result = StageResult(
run_id=run.id,
stage_index=stage_index,
prompt_sent=rendered_prompt,
response_raw=response_text,
model_used=model,
parameters=params,
tokens_in=tokens_in,
tokens_out=tokens_out,
latency_ms=latency_ms,
)
db.add(stage_result)
db.flush()
completed_stages.append({
"output": response_text,
"model": model,
"tokens_in": tokens_in,
"tokens_out": tokens_out,
"latency_ms": latency_ms,
"stage_index": stage_index,
})
if event_bus:
event_bus.publish({
"type": "run.stage_completed",
"run_id": str(run.id),
"experiment_id": str(run.experiment_id),
"stage_index": stage_index,
"total_stages": len(stages),
})
# Run scorers on the final output
final_output = completed_stages[-1]["output"] if completed_stages else ""
scorer_context = {
"config": config,
"stages": completed_stages,
"input_data": input_data,
}
for scorer in scorers:
try:
if asyncio.iscoroutinefunction(getattr(scorer, "score", None)):
score_value = await scorer.score(input_data, final_output, scorer_context)
else:
score_value = scorer.score(input_data, final_output, scorer_context)
score_value = max(0.0, min(1.0, float(score_value)))
score_record = Score(
run_id=run.id,
scorer_name=scorer.name,
value=score_value,
)
db.add(score_record)
except Exception:
logger.warning(
"Scorer %s failed for run %s", scorer.name, run.id, exc_info=True
)
# Update run status and timing
duration_ms = int((time.perf_counter() - t_start) * 1000)
run.status = RunStatus.completed
run.completed_at = datetime.now(timezone.utc)
run.duration_ms = duration_ms
run.tokens_in = total_tokens_in
run.tokens_out = total_tokens_out
db.commit()
if event_bus:
event_bus.publish({
"type": "run.completed",
"run_id": str(run.id),
"experiment_id": str(run.experiment_id),
"duration_ms": duration_ms,
"tokens_in": total_tokens_in,
"tokens_out": total_tokens_out,
})
except Exception as exc:
logger.error("Run %s failed: %s", run.id, exc, exc_info=True)
run.status = RunStatus.failed
run.completed_at = datetime.now(timezone.utc)
run.duration_ms = int((time.perf_counter() - t_start) * 1000)
db.commit()
if event_bus:
event_bus.publish({
"type": "run.failed",
"run_id": str(run.id),
"experiment_id": str(run.experiment_id),
"error": str(exc),
})
raise
return run