Adds backend/engine/runner.py with run_single() that iterates pipeline stages, renders Jinja2 prompt templates with stage history context, checks/stores response cache, calls LLM adapters, runs configured scorers, creates StageResult and Score records, and publishes progress events via Redis pub/sub or in-process EventBus. Includes 21 passing tests covering all execution paths.
285 lines
9.5 KiB
Python
285 lines
9.5 KiB
Python
"""Individual run execution for PromptLooper.
|
|
|
|
Executes a single Run by iterating through pipeline stages, rendering
|
|
prompts via Jinja2, checking/storing cache, calling the LLM adapter,
|
|
scoring results, and publishing progress events.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
from jinja2 import BaseLoader, Environment
|
|
from sqlalchemy.orm import Session
|
|
|
|
from engine.adapters.base import AdapterResponse, BaseAdapter
|
|
from engine.cache import ResponseCacheLayer, compute_config_hash
|
|
from models import Run, RunStatus, Score, StageResult
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Jinja2 environment with sandboxed string loader (no filesystem access)
|
|
_jinja_env = Environment(loader=BaseLoader(), autoescape=False)
|
|
|
|
|
|
class EventBus:
|
|
"""Simple event publisher that uses Redis pub/sub when available,
|
|
or falls back to in-process callbacks for single-container mode."""
|
|
|
|
def __init__(self, redis_client: Any | None = None, channel: str = "promptlooper:events"):
|
|
self._redis = redis_client
|
|
self._channel = channel
|
|
self._listeners: list[Any] = []
|
|
|
|
def add_listener(self, callback: Any) -> None:
|
|
self._listeners.append(callback)
|
|
|
|
def publish(self, event: dict[str, Any]) -> None:
|
|
payload = json.dumps(event, default=str)
|
|
if self._redis is not None:
|
|
try:
|
|
self._redis.publish(self._channel, payload)
|
|
except Exception:
|
|
logger.warning("Failed to publish event to Redis", exc_info=True)
|
|
for listener in self._listeners:
|
|
try:
|
|
listener(event)
|
|
except Exception:
|
|
logger.warning("Event listener error", exc_info=True)
|
|
|
|
|
|
def render_prompt(template_str: str, context: dict[str, Any]) -> str:
|
|
"""Render a Jinja2 template string with the given context."""
|
|
template = _jinja_env.from_string(template_str)
|
|
return template.render(**context)
|
|
|
|
|
|
def _build_stage_context(
|
|
stage_index: int,
|
|
stage_results: list[dict[str, Any]],
|
|
run_config: dict[str, Any],
|
|
input_data: Any = None,
|
|
) -> dict[str, Any]:
|
|
"""Build the Jinja2 template context for a given stage.
|
|
|
|
Available variables in templates:
|
|
- input: the original input data
|
|
- config: the run configuration
|
|
- stages: list of previous stage outputs (dicts with 'output', 'model', etc.)
|
|
- prev_output: the output text of the immediately preceding stage (or empty string)
|
|
- stage_index: current 0-based stage index
|
|
"""
|
|
prev_output = ""
|
|
if stage_results:
|
|
prev_output = stage_results[-1].get("output", "")
|
|
|
|
return {
|
|
"input": input_data or "",
|
|
"config": run_config,
|
|
"stages": stage_results,
|
|
"prev_output": prev_output,
|
|
"stage_index": stage_index,
|
|
}
|
|
|
|
|
|
async def run_single(
|
|
db: Session,
|
|
run: Run,
|
|
adapter: BaseAdapter,
|
|
cache: ResponseCacheLayer,
|
|
scorers: list[Any] | None = None,
|
|
event_bus: EventBus | None = None,
|
|
) -> Run:
|
|
"""Execute a single Run through all its pipeline stages.
|
|
|
|
Args:
|
|
db: SQLAlchemy session.
|
|
run: The Run ORM object (must already be persisted with config).
|
|
adapter: LLM adapter to use for completions.
|
|
cache: Response cache layer.
|
|
scorers: Optional list of scorer instances (must have .name and .score()).
|
|
event_bus: Optional event publisher for progress updates.
|
|
|
|
Returns:
|
|
The updated Run object.
|
|
"""
|
|
scorers = scorers or []
|
|
|
|
# Extract pipeline stages from config
|
|
config = run.config or {}
|
|
stages = config.get("pipeline_stages", [])
|
|
if not stages:
|
|
# Single-stage fallback: use prompt + model + params from config directly
|
|
stages = [
|
|
{
|
|
"prompt_template": config.get("prompt", config.get("prompt_template", "")),
|
|
"model": config.get("model", ""),
|
|
"params": config.get("params", {}),
|
|
}
|
|
]
|
|
|
|
input_data = config.get("input_data")
|
|
|
|
# Mark run as running
|
|
run.status = RunStatus.running
|
|
run.started_at = datetime.now(timezone.utc)
|
|
db.commit()
|
|
|
|
if event_bus:
|
|
event_bus.publish({
|
|
"type": "run.started",
|
|
"run_id": str(run.id),
|
|
"experiment_id": str(run.experiment_id),
|
|
})
|
|
|
|
total_tokens_in = 0
|
|
total_tokens_out = 0
|
|
completed_stages: list[dict[str, Any]] = []
|
|
t_start = time.perf_counter()
|
|
|
|
try:
|
|
for stage_index, stage_def in enumerate(stages):
|
|
prompt_template = stage_def.get("prompt_template", stage_def.get("prompt", ""))
|
|
model = stage_def.get("model", config.get("model", ""))
|
|
params = stage_def.get("params", config.get("params", {}))
|
|
|
|
# Build context and render prompt
|
|
template_ctx = _build_stage_context(
|
|
stage_index, completed_stages, config, input_data
|
|
)
|
|
rendered_prompt = render_prompt(prompt_template, template_ctx)
|
|
|
|
# Check cache
|
|
config_hash = compute_config_hash(rendered_prompt, model, params, input_data)
|
|
cached = cache.get(db, config_hash)
|
|
|
|
if cached is not None:
|
|
# Cache hit
|
|
response_text = cached.response
|
|
tokens_in = cached.tokens_in or 0
|
|
tokens_out = cached.tokens_out or 0
|
|
latency_ms = cached.latency_ms or 0
|
|
else:
|
|
# Cache miss — call LLM
|
|
adapter_resp: AdapterResponse = await adapter.complete(
|
|
rendered_prompt, model, params
|
|
)
|
|
response_text = adapter_resp.text
|
|
tokens_in = adapter_resp.tokens_in
|
|
tokens_out = adapter_resp.tokens_out
|
|
latency_ms = int(adapter_resp.latency_ms)
|
|
|
|
# Store in cache
|
|
cache.put(
|
|
db,
|
|
config_hash=config_hash,
|
|
response=response_text,
|
|
model=model,
|
|
tokens_in=tokens_in,
|
|
tokens_out=tokens_out,
|
|
latency_ms=latency_ms,
|
|
)
|
|
|
|
total_tokens_in += tokens_in
|
|
total_tokens_out += tokens_out
|
|
|
|
# Create StageResult record
|
|
stage_result = StageResult(
|
|
run_id=run.id,
|
|
stage_index=stage_index,
|
|
prompt_sent=rendered_prompt,
|
|
response_raw=response_text,
|
|
model_used=model,
|
|
parameters=params,
|
|
tokens_in=tokens_in,
|
|
tokens_out=tokens_out,
|
|
latency_ms=latency_ms,
|
|
)
|
|
db.add(stage_result)
|
|
db.flush()
|
|
|
|
completed_stages.append({
|
|
"output": response_text,
|
|
"model": model,
|
|
"tokens_in": tokens_in,
|
|
"tokens_out": tokens_out,
|
|
"latency_ms": latency_ms,
|
|
"stage_index": stage_index,
|
|
})
|
|
|
|
if event_bus:
|
|
event_bus.publish({
|
|
"type": "run.stage_completed",
|
|
"run_id": str(run.id),
|
|
"experiment_id": str(run.experiment_id),
|
|
"stage_index": stage_index,
|
|
"total_stages": len(stages),
|
|
})
|
|
|
|
# Run scorers on the final output
|
|
final_output = completed_stages[-1]["output"] if completed_stages else ""
|
|
scorer_context = {
|
|
"config": config,
|
|
"stages": completed_stages,
|
|
"input_data": input_data,
|
|
}
|
|
|
|
for scorer in scorers:
|
|
try:
|
|
if asyncio.iscoroutinefunction(getattr(scorer, "score", None)):
|
|
score_value = await scorer.score(input_data, final_output, scorer_context)
|
|
else:
|
|
score_value = scorer.score(input_data, final_output, scorer_context)
|
|
|
|
score_value = max(0.0, min(1.0, float(score_value)))
|
|
|
|
score_record = Score(
|
|
run_id=run.id,
|
|
scorer_name=scorer.name,
|
|
value=score_value,
|
|
)
|
|
db.add(score_record)
|
|
except Exception:
|
|
logger.warning(
|
|
"Scorer %s failed for run %s", scorer.name, run.id, exc_info=True
|
|
)
|
|
|
|
# Update run status and timing
|
|
duration_ms = int((time.perf_counter() - t_start) * 1000)
|
|
run.status = RunStatus.completed
|
|
run.completed_at = datetime.now(timezone.utc)
|
|
run.duration_ms = duration_ms
|
|
run.tokens_in = total_tokens_in
|
|
run.tokens_out = total_tokens_out
|
|
db.commit()
|
|
|
|
if event_bus:
|
|
event_bus.publish({
|
|
"type": "run.completed",
|
|
"run_id": str(run.id),
|
|
"experiment_id": str(run.experiment_id),
|
|
"duration_ms": duration_ms,
|
|
"tokens_in": total_tokens_in,
|
|
"tokens_out": total_tokens_out,
|
|
})
|
|
|
|
except Exception as exc:
|
|
logger.error("Run %s failed: %s", run.id, exc, exc_info=True)
|
|
run.status = RunStatus.failed
|
|
run.completed_at = datetime.now(timezone.utc)
|
|
run.duration_ms = int((time.perf_counter() - t_start) * 1000)
|
|
db.commit()
|
|
|
|
if event_bus:
|
|
event_bus.publish({
|
|
"type": "run.failed",
|
|
"run_id": str(run.id),
|
|
"experiment_id": str(run.experiment_id),
|
|
"error": str(exc),
|
|
})
|
|
raise
|
|
|
|
return run
|