chrysopedia/backend/pipeline/quality/optimizer.py
jlightner d75ec80c98 optimize: Stage 5 synthesis prompt — round 0 winner (0.95→1.0 composite)
Applied first optimization result: tighter voice preservation instructions,
improved section flow guidance, trimmed redundant metadata instructions.
13382→11123 chars (-17%).
2026-04-01 10:15:24 +00:00

522 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Automated prompt optimization loop for pipeline stages 2-5.
Orchestrates a generate→score→select cycle:
1. Score the current best prompt against reference fixtures
2. Generate N variants targeting weak dimensions
3. Score each variant
4. Keep the best scorer as the new baseline
5. Repeat for K iterations
Usage (via CLI):
python -m pipeline.quality optimize --stage 5 --iterations 10
python -m pipeline.quality optimize --stage 3 --iterations 5 --file fixtures/sample_topic_group.json
"""
from __future__ import annotations
import json
import logging
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from pipeline.llm_client import LLMClient
from pipeline.quality.scorer import STAGE_CONFIGS, ScoreResult, ScoreRunner
logger = logging.getLogger(__name__)
@dataclass
class OptimizationResult:
"""Full result of an optimization run."""
best_prompt: str = ""
best_score: ScoreResult = field(default_factory=ScoreResult)
history: list[dict] = field(default_factory=list)
elapsed_seconds: float = 0.0
class OptimizationLoop:
"""Runs iterative prompt optimization for a pipeline stage.
Each iteration generates *variants_per_iter* prompt mutations,
scores each against reference fixture data, and keeps the
highest-composite-scoring variant as the new baseline.
Parameters
----------
client:
LLMClient instance for LLM calls (synthesis + scoring + variant gen).
stage:
Pipeline stage number (2-5).
fixture_path:
Path to a JSON fixture file matching the stage's expected keys.
iterations:
Number of generate→score→select cycles.
variants_per_iter:
Number of variant prompts to generate per iteration.
"""
def __init__(
self,
client: LLMClient,
stage: int,
fixture_path: str,
iterations: int = 5,
variants_per_iter: int = 2,
output_dir: str | None = None,
) -> None:
if stage not in STAGE_CONFIGS:
raise ValueError(
f"Unsupported stage {stage}. Valid stages: {sorted(STAGE_CONFIGS)}"
)
self.client = client
self.stage = stage
self.fixture_path = fixture_path
self.iterations = iterations
self.variants_per_iter = variants_per_iter
self.config = STAGE_CONFIGS[stage]
self.output_dir = output_dir
self.scorer = ScoreRunner(client)
self.generator = PromptVariantGenerator(client)
def run(self) -> OptimizationResult:
"""Execute the full optimization loop.
Returns
-------
OptimizationResult
Contains the best prompt, its scores, full iteration history,
and wall-clock elapsed time.
"""
from pipeline.stages import _load_prompt
t0 = time.monotonic()
dimensions = self.config.dimensions
# Load base prompt using the stage's configured prompt file
prompt_file = self.config.prompt_file
try:
base_prompt = _load_prompt(prompt_file)
except FileNotFoundError:
logger.error("Prompt file not found: %s", prompt_file)
return OptimizationResult(
best_prompt="",
best_score=ScoreResult(error=f"Prompt file not found: {prompt_file}"),
elapsed_seconds=round(time.monotonic() - t0, 2),
)
# Load fixture data
try:
fixture = self._load_fixture()
except (FileNotFoundError, json.JSONDecodeError, KeyError) as exc:
logger.error("Failed to load fixture: %s", exc)
return OptimizationResult(
best_prompt=base_prompt,
best_score=ScoreResult(error=f"Fixture load error: {exc}"),
elapsed_seconds=round(time.monotonic() - t0, 2),
)
history: list[dict] = []
# Score the baseline
print(f"\n{'='*60}")
print(f" PROMPT OPTIMIZATION — Stage {self.stage}")
print(f" Iterations: {self.iterations}, Variants/iter: {self.variants_per_iter}")
print(f"{'='*60}\n")
print(" Scoring baseline prompt...")
best_score = self._score_variant(base_prompt, fixture)
best_prompt = base_prompt
history.append({
"iteration": 0,
"variant_index": 0,
"prompt_text": base_prompt[:200] + "..." if len(base_prompt) > 200 else base_prompt,
"prompt_length": len(base_prompt),
"composite": best_score.composite,
"scores": {d: best_score.scores.get(d, 0.0) for d in dimensions},
"error": best_score.error,
"label": "baseline",
})
if best_score.error:
print(f" ✗ Baseline scoring failed: {best_score.error}")
print(" Aborting optimization — fix the baseline first.\n")
return OptimizationResult(
best_prompt=best_prompt,
best_score=best_score,
history=history,
elapsed_seconds=round(time.monotonic() - t0, 2),
)
baseline_composite = best_score.composite
total_variants_scored = 0
self._write_progress(
phase="baseline_scored",
iteration=0, variant=0,
total_variants_scored=0,
best_composite=best_score.composite,
baseline_composite=baseline_composite,
elapsed_seconds=round(time.monotonic() - t0, 2),
best_label="baseline",
)
self._print_iteration_summary(0, best_score, is_baseline=True)
# Iterate
best_label = "baseline"
for iteration in range(1, self.iterations + 1):
print(f"\n ── Iteration {iteration}/{self.iterations} ──")
# Generate variants with stage-appropriate markers
variants = self.generator.generate(
base_prompt=best_prompt,
scores=best_score,
n=self.variants_per_iter,
stage=self.stage,
)
if not variants:
print(" ⚠ No valid variants generated — skipping iteration")
continue
# Score each variant
iteration_best_score = best_score
iteration_best_prompt = best_prompt
for vi, variant_prompt in enumerate(variants):
print(f" Scoring variant {vi + 1}/{len(variants)}...")
score = self._score_variant(variant_prompt, fixture)
history.append({
"iteration": iteration,
"variant_index": vi + 1,
"prompt_text": variant_prompt[:200] + "..." if len(variant_prompt) > 200 else variant_prompt,
"prompt_length": len(variant_prompt),
"composite": score.composite,
"scores": {d: score.scores.get(d, 0.0) for d in dimensions},
"error": score.error,
"label": f"iter{iteration}_v{vi+1}",
})
if score.error:
print(f" ✗ Variant {vi + 1} errored: {score.error}")
total_variants_scored += 1
self._write_progress(
phase="variant_scored",
iteration=iteration, variant=vi + 1,
total_variants_scored=total_variants_scored,
best_composite=best_score.composite,
baseline_composite=baseline_composite,
elapsed_seconds=round(time.monotonic() - t0, 2),
best_label=best_label,
)
continue
total_variants_scored += 1
if score.composite > iteration_best_score.composite:
iteration_best_score = score
iteration_best_prompt = variant_prompt
print(f" ✓ New best: {score.composite:.3f} (was {best_score.composite:.3f})")
else:
print(f" · Score {score.composite:.3f} ≤ current best {iteration_best_score.composite:.3f}")
self._write_progress(
phase="variant_scored",
iteration=iteration, variant=vi + 1,
total_variants_scored=total_variants_scored,
best_composite=max(best_score.composite, iteration_best_score.composite),
baseline_composite=baseline_composite,
elapsed_seconds=round(time.monotonic() - t0, 2),
best_label=best_label if iteration_best_score.composite <= best_score.composite
else f"iter{iteration}_v{vi+1}",
)
# Update global best if this iteration improved
if iteration_best_score.composite > best_score.composite:
best_score = iteration_best_score
best_prompt = iteration_best_prompt
best_label = f"iter{iteration}"
print(f" ★ Iteration {iteration} improved: {best_score.composite:.3f}")
else:
print(f" · No improvement in iteration {iteration}")
self._print_iteration_summary(iteration, best_score)
# Final report
elapsed = round(time.monotonic() - t0, 2)
self._print_final_report(best_score, history, elapsed)
self._write_progress(
phase="complete",
iteration=self.iterations,
variant=self.variants_per_iter,
total_variants_scored=total_variants_scored,
best_composite=best_score.composite,
baseline_composite=baseline_composite,
elapsed_seconds=elapsed,
best_label=best_label,
)
return OptimizationResult(
best_prompt=best_prompt,
best_score=best_score,
history=history,
elapsed_seconds=elapsed,
)
# ── Internal helpers ──────────────────────────────────────────────────
def _write_progress(
self,
*,
phase: str,
iteration: int,
variant: int,
total_variants_scored: int,
best_composite: float,
baseline_composite: float,
elapsed_seconds: float,
best_label: str = "",
) -> None:
"""Write a progress.json file to the output directory for external monitoring.
File is atomic-replaced so readers never see partial writes.
"""
if not self.output_dir:
return
out_dir = Path(self.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
progress_path = out_dir / f"progress_stage{self.stage}.json"
total_expected = self.iterations * self.variants_per_iter
pct = (total_variants_scored / total_expected * 100) if total_expected else 0
# ETA: average time per variant × remaining
remaining = total_expected - total_variants_scored
avg_per_variant = (elapsed_seconds / total_variants_scored) if total_variants_scored > 0 else 0
eta_seconds = round(avg_per_variant * remaining, 1)
payload = {
"stage": self.stage,
"phase": phase,
"iteration": iteration,
"total_iterations": self.iterations,
"variant": variant,
"variants_per_iter": self.variants_per_iter,
"total_variants_scored": total_variants_scored,
"total_expected": total_expected,
"percent_complete": round(pct, 1),
"baseline_composite": round(baseline_composite, 4),
"best_composite": round(best_composite, 4),
"improvement": round(best_composite - baseline_composite, 4),
"best_label": best_label,
"elapsed_seconds": round(elapsed_seconds, 1),
"eta_seconds": eta_seconds,
"updated_at": datetime.now(timezone.utc).isoformat(),
}
# Atomic write via temp file + rename
tmp_path = progress_path.with_suffix(".tmp")
tmp_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
tmp_path.rename(progress_path)
def _load_fixture(self) -> dict:
"""Load and validate the fixture JSON file against stage-specific keys."""
path = Path(self.fixture_path)
if not path.exists():
raise FileNotFoundError(f"Fixture not found: {path}")
data = json.loads(path.read_text(encoding="utf-8"))
for key in self.config.fixture_keys:
if key not in data:
raise KeyError(
f"Stage {self.stage} fixture must contain '{key}' key "
f"(required: {self.config.fixture_keys})"
)
return data
def _score_variant(
self,
variant_prompt: str,
fixture: dict,
) -> ScoreResult:
"""Score a variant prompt by running LLM completion + scoring.
Dispatches to stage-specific synthesis logic:
- Stages 2-4: call LLM with the variant prompt and fixture input,
parse with the stage's schema, then score via score_stage_output()
- Stage 5: original flow (synthesis + page scoring)
"""
from pipeline.stages import _get_stage_config
import json as _json
import openai as _openai
model_override, modality = _get_stage_config(self.stage)
schema_class = self.config.get_schema()
# Build user prompt from fixture data — stage-specific formatting
user_prompt = self._build_user_prompt(fixture)
t0 = time.monotonic()
try:
raw = self.client.complete(
system_prompt=variant_prompt,
user_prompt=user_prompt,
response_model=schema_class,
modality=modality,
model_override=model_override,
)
elapsed_synth = round(time.monotonic() - t0, 2)
except (_openai.APIConnectionError, _openai.APITimeoutError) as exc:
elapsed_synth = round(time.monotonic() - t0, 2)
return ScoreResult(
elapsed_seconds=elapsed_synth,
error=f"LLM error (stage {self.stage}): {exc}",
)
except Exception as exc:
elapsed_synth = round(time.monotonic() - t0, 2)
logger.exception("Unexpected error during variant synthesis (stage %d)", self.stage)
return ScoreResult(
elapsed_seconds=elapsed_synth,
error=f"Unexpected synthesis error: {exc}",
)
# Parse the LLM response into the stage schema
raw_text = str(raw).strip()
try:
parsed = self.client.parse_response(raw_text, schema_class)
except Exception as exc:
return ScoreResult(
elapsed_seconds=elapsed_synth,
error=f"Variant parse error (stage {self.stage}): {exc}",
)
# Convert parsed output to JSON for the scorer
output_json = self._schema_to_output_json(parsed)
if output_json is None:
return ScoreResult(
elapsed_seconds=elapsed_synth,
error=f"Stage {self.stage} produced empty output",
)
# Score using the generic stage scorer
result = self.scorer.score_stage_output(
stage=self.stage,
output_json=output_json,
input_json=self._fixture_to_input_json(fixture),
)
result.elapsed_seconds = round(result.elapsed_seconds + elapsed_synth, 2)
return result
def _build_user_prompt(self, fixture: dict) -> str:
"""Build a stage-appropriate user prompt from fixture data."""
if self.stage == 2:
segments_json = json.dumps(fixture["transcript_segments"], indent=2)
return f"<transcript_segments>\n{segments_json}\n</transcript_segments>"
elif self.stage == 3:
segments_json = json.dumps(fixture["topic_segments"], indent=2)
return f"<topic_segments>\n{segments_json}\n</topic_segments>"
elif self.stage == 4:
moments_json = json.dumps(fixture["extracted_moments"], indent=2)
taxonomy = fixture.get("taxonomy", "")
prompt = f"<moments>\n{moments_json}\n</moments>"
if taxonomy:
prompt += f"\n<taxonomy>{taxonomy}</taxonomy>"
return prompt
elif self.stage == 5:
moments_json = json.dumps(fixture["moments"], indent=2)
creator = fixture.get("creator_name", "Unknown")
return f"<creator>{creator}</creator>\n<moments>\n{moments_json}\n</moments>"
else:
return json.dumps(fixture, indent=2)
def _schema_to_output_json(self, parsed: object) -> dict | list | None:
"""Convert a parsed Pydantic schema instance to JSON-serializable dict."""
if hasattr(parsed, "model_dump"):
return parsed.model_dump()
elif hasattr(parsed, "dict"):
return parsed.dict()
return None
def _fixture_to_input_json(self, fixture: dict) -> dict | list:
"""Extract the primary input data from the fixture for scorer context."""
if self.stage == 2:
return fixture["transcript_segments"]
elif self.stage == 3:
return fixture["topic_segments"]
elif self.stage == 4:
return fixture["extracted_moments"]
elif self.stage == 5:
return fixture["moments"]
return fixture
def _print_iteration_summary(
self,
iteration: int,
score: ScoreResult,
is_baseline: bool = False,
) -> None:
"""Print a compact one-line summary of the current best scores."""
label = "BASELINE" if is_baseline else f"ITER {iteration}"
dimensions = self.config.dimensions
dims = " ".join(
f"{d[:4]}={score.scores.get(d, 0.0):.2f}" for d in dimensions
)
print(f" [{label}] composite={score.composite:.3f} {dims}")
def _print_final_report(
self,
best_score: ScoreResult,
history: list[dict],
elapsed: float,
) -> None:
"""Print the final optimization summary."""
dimensions = self.config.dimensions
print(f"\n{'='*60}")
print(" OPTIMIZATION COMPLETE")
print(f"{'='*60}")
print(f" Total time: {elapsed}s")
print(f" Iterations: {self.iterations}")
print(f" Variants scored: {len(history) - 1}") # minus baseline
baseline_composite = history[0]["composite"] if history else 0.0
improvement = best_score.composite - baseline_composite
print(f"\n Baseline composite: {baseline_composite:.3f}")
print(f" Best composite: {best_score.composite:.3f}")
if improvement > 0:
print(f" Improvement: +{improvement:.3f}")
else:
print(f" Improvement: {improvement:.3f} (no gain)")
print(f"\n Per-dimension best scores:")
for d in dimensions:
val = best_score.scores.get(d, 0.0)
bar = "" * int(val * 20) + "" * (20 - int(val * 20))
print(f" {d.replace('_', ' ').title():25s} {val:.2f} {bar}")
errored = sum(1 for h in history if h.get("error"))
if errored:
print(f"\n{errored} variant(s) errored during scoring")
print(f"{'='*60}\n")
# Late import to avoid circular dependency (scorer imports at module level,
# variant_generator imports scorer)
from pipeline.quality.variant_generator import PromptVariantGenerator # noqa: E402