chrysopedia/backend/pipeline/quality/optimizer.py
jlightner c6cbb09dd3 feat: Created PromptVariantGenerator (LLM-powered prompt mutation) and…
- "backend/pipeline/quality/variant_generator.py"
- "backend/pipeline/quality/optimizer.py"

GSD-Task: S03/T01
2026-04-01 09:08:01 +00:00

364 lines
13 KiB
Python

"""Automated prompt optimization loop for Stage 5 synthesis.
Orchestrates a generate→score→select cycle:
1. Score the current best prompt against reference fixtures
2. Generate N variants targeting weak dimensions
3. Score each variant
4. Keep the best scorer as the new baseline
5. Repeat for K iterations
Usage (via CLI):
python -m pipeline.quality optimize --stage 5 --iterations 10
"""
from __future__ import annotations
import json
import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
from pipeline.llm_client import LLMClient
from pipeline.quality.scorer import DIMENSIONS, ScoreResult, ScoreRunner
from pipeline.quality.variant_generator import PromptVariantGenerator
logger = logging.getLogger(__name__)
@dataclass
class OptimizationResult:
"""Full result of an optimization run."""
best_prompt: str = ""
best_score: ScoreResult = field(default_factory=ScoreResult)
history: list[dict] = field(default_factory=list)
elapsed_seconds: float = 0.0
class OptimizationLoop:
"""Runs iterative prompt optimization for a pipeline stage.
Each iteration generates *variants_per_iter* prompt mutations,
scores each against reference fixture data, and keeps the
highest-composite-scoring variant as the new baseline.
Parameters
----------
client:
LLMClient instance for LLM calls (synthesis + scoring + variant gen).
stage:
Pipeline stage number (currently only 5 is supported).
fixture_path:
Path to a JSON fixture file containing ``creator_name`` and ``moments``.
iterations:
Number of generate→score→select cycles.
variants_per_iter:
Number of variant prompts to generate per iteration.
"""
def __init__(
self,
client: LLMClient,
stage: int,
fixture_path: str,
iterations: int = 5,
variants_per_iter: int = 2,
) -> None:
self.client = client
self.stage = stage
self.fixture_path = fixture_path
self.iterations = iterations
self.variants_per_iter = variants_per_iter
self.scorer = ScoreRunner(client)
self.generator = PromptVariantGenerator(client)
def run(self) -> OptimizationResult:
"""Execute the full optimization loop.
Returns
-------
OptimizationResult
Contains the best prompt, its scores, full iteration history,
and wall-clock elapsed time.
"""
from pipeline.stages import _load_prompt
t0 = time.monotonic()
# Load base prompt
prompt_file = f"stage{self.stage}_synthesis.txt"
try:
base_prompt = _load_prompt(prompt_file)
except FileNotFoundError:
logger.error("Prompt file not found: %s", prompt_file)
return OptimizationResult(
best_prompt="",
best_score=ScoreResult(error=f"Prompt file not found: {prompt_file}"),
elapsed_seconds=round(time.monotonic() - t0, 2),
)
# Load fixture data
try:
fixture = self._load_fixture()
except (FileNotFoundError, json.JSONDecodeError, KeyError) as exc:
logger.error("Failed to load fixture: %s", exc)
return OptimizationResult(
best_prompt=base_prompt,
best_score=ScoreResult(error=f"Fixture load error: {exc}"),
elapsed_seconds=round(time.monotonic() - t0, 2),
)
moments = fixture["moments"]
creator_name = fixture["creator_name"]
history: list[dict] = []
# Score the baseline
print(f"\n{'='*60}")
print(f" PROMPT OPTIMIZATION — Stage {self.stage}")
print(f" Iterations: {self.iterations}, Variants/iter: {self.variants_per_iter}")
print(f"{'='*60}\n")
print(" Scoring baseline prompt...")
best_score = self.scorer.synthesize_and_score(
moments=moments,
creator_name=creator_name,
voice_level=0.5,
)
best_prompt = base_prompt
history.append({
"iteration": 0,
"variant_index": 0,
"prompt_text": base_prompt[:200] + "..." if len(base_prompt) > 200 else base_prompt,
"prompt_length": len(base_prompt),
"composite": best_score.composite,
"scores": {d: getattr(best_score, d) for d in DIMENSIONS},
"error": best_score.error,
"label": "baseline",
})
if best_score.error:
print(f" ✗ Baseline scoring failed: {best_score.error}")
print(" Aborting optimization — fix the baseline first.\n")
return OptimizationResult(
best_prompt=best_prompt,
best_score=best_score,
history=history,
elapsed_seconds=round(time.monotonic() - t0, 2),
)
self._print_iteration_summary(0, best_score, is_baseline=True)
# Iterate
for iteration in range(1, self.iterations + 1):
print(f"\n ── Iteration {iteration}/{self.iterations} ──")
# Generate variants
variants = self.generator.generate(
base_prompt=best_prompt,
scores=best_score,
n=self.variants_per_iter,
)
if not variants:
print(" ⚠ No valid variants generated — skipping iteration")
continue
# Score each variant
iteration_best_score = best_score
iteration_best_prompt = best_prompt
for vi, variant_prompt in enumerate(variants):
print(f" Scoring variant {vi + 1}/{len(variants)}...")
# Temporarily replace the base prompt with the variant for synthesis
score = self._score_variant(
variant_prompt, moments, creator_name,
)
history.append({
"iteration": iteration,
"variant_index": vi + 1,
"prompt_text": variant_prompt[:200] + "..." if len(variant_prompt) > 200 else variant_prompt,
"prompt_length": len(variant_prompt),
"composite": score.composite,
"scores": {d: getattr(score, d) for d in DIMENSIONS},
"error": score.error,
"label": f"iter{iteration}_v{vi+1}",
})
if score.error:
print(f" ✗ Variant {vi + 1} errored: {score.error}")
continue
if score.composite > iteration_best_score.composite:
iteration_best_score = score
iteration_best_prompt = variant_prompt
print(f" ✓ New best: {score.composite:.3f} (was {best_score.composite:.3f})")
else:
print(f" · Score {score.composite:.3f} ≤ current best {iteration_best_score.composite:.3f}")
# Update global best if this iteration improved
if iteration_best_score.composite > best_score.composite:
best_score = iteration_best_score
best_prompt = iteration_best_prompt
print(f" ★ Iteration {iteration} improved: {best_score.composite:.3f}")
else:
print(f" · No improvement in iteration {iteration}")
self._print_iteration_summary(iteration, best_score)
# Final report
elapsed = round(time.monotonic() - t0, 2)
self._print_final_report(best_score, history, elapsed)
return OptimizationResult(
best_prompt=best_prompt,
best_score=best_score,
history=history,
elapsed_seconds=elapsed,
)
# ── Internal helpers ──────────────────────────────────────────────────
def _load_fixture(self) -> dict:
"""Load and validate the fixture JSON file."""
path = Path(self.fixture_path)
if not path.exists():
raise FileNotFoundError(f"Fixture not found: {path}")
data = json.loads(path.read_text(encoding="utf-8"))
if "moments" not in data:
raise KeyError("Fixture must contain 'moments' key")
if "creator_name" not in data:
raise KeyError("Fixture must contain 'creator_name' key")
return data
def _score_variant(
self,
variant_prompt: str,
moments: list[dict],
creator_name: str,
) -> ScoreResult:
"""Score a variant prompt by running synthesis + scoring.
Uses the variant as a direct system prompt for synthesis, bypassing
VoiceDial (the optimization loop owns the full prompt text).
"""
from pipeline.schemas import SynthesisResult
from pipeline.stages import _get_stage_config
import json as _json
import openai as _openai
model_override, modality = _get_stage_config(self.stage)
moments_json = _json.dumps(moments, indent=2)
user_prompt = f"<creator>{creator_name}</creator>\n<moments>\n{moments_json}\n</moments>"
t0 = time.monotonic()
try:
raw = self.client.complete(
system_prompt=variant_prompt,
user_prompt=user_prompt,
response_model=SynthesisResult,
modality=modality,
model_override=model_override,
)
elapsed_synth = round(time.monotonic() - t0, 2)
except (_openai.APIConnectionError, _openai.APITimeoutError) as exc:
elapsed_synth = round(time.monotonic() - t0, 2)
return ScoreResult(
elapsed_seconds=elapsed_synth,
error=f"Synthesis LLM error: {exc}",
)
except Exception as exc:
elapsed_synth = round(time.monotonic() - t0, 2)
logger.exception("Unexpected error during variant synthesis")
return ScoreResult(
elapsed_seconds=elapsed_synth,
error=f"Unexpected synthesis error: {exc}",
)
# Parse synthesis
raw_text = str(raw).strip()
try:
synthesis = self.client.parse_response(raw_text, SynthesisResult)
except Exception as exc:
return ScoreResult(
elapsed_seconds=elapsed_synth,
error=f"Variant synthesis parse error: {exc}",
)
if not synthesis.pages:
return ScoreResult(
elapsed_seconds=elapsed_synth,
error="Variant synthesis returned no pages",
)
# Score the first page
page = synthesis.pages[0]
page_json = {
"title": page.title,
"creator_name": creator_name,
"summary": page.summary,
"body_sections": [
{"heading": heading, "content": content}
for heading, content in page.body_sections.items()
],
}
result = self.scorer.score_page(page_json, moments)
result.elapsed_seconds = round(result.elapsed_seconds + elapsed_synth, 2)
return result
def _print_iteration_summary(
self,
iteration: int,
score: ScoreResult,
is_baseline: bool = False,
) -> None:
"""Print a compact one-line summary of the current best scores."""
label = "BASELINE" if is_baseline else f"ITER {iteration}"
dims = " ".join(
f"{d[:4]}={getattr(score, d):.2f}" for d in DIMENSIONS
)
print(f" [{label}] composite={score.composite:.3f} {dims}")
def _print_final_report(
self,
best_score: ScoreResult,
history: list[dict],
elapsed: float,
) -> None:
"""Print the final optimization summary."""
print(f"\n{'='*60}")
print(" OPTIMIZATION COMPLETE")
print(f"{'='*60}")
print(f" Total time: {elapsed}s")
print(f" Iterations: {self.iterations}")
print(f" Variants scored: {len(history) - 1}") # minus baseline
baseline_composite = history[0]["composite"] if history else 0.0
improvement = best_score.composite - baseline_composite
print(f"\n Baseline composite: {baseline_composite:.3f}")
print(f" Best composite: {best_score.composite:.3f}")
if improvement > 0:
print(f" Improvement: +{improvement:.3f}")
else:
print(f" Improvement: {improvement:.3f} (no gain)")
print(f"\n Per-dimension best scores:")
for d in DIMENSIONS:
val = getattr(best_score, d)
bar = "" * int(val * 20) + "" * (20 - int(val * 20))
print(f" {d.replace('_', ' ').title():25s} {val:.2f} {bar}")
errored = sum(1 for h in history if h.get("error"))
if errored:
print(f"\n{errored} variant(s) errored during scoring")
print(f"{'='*60}\n")