Applied first optimization result: tighter voice preservation instructions, improved section flow guidance, trimmed redundant metadata instructions. 13382→11123 chars (-17%).
522 lines
20 KiB
Python
522 lines
20 KiB
Python
"""Automated prompt optimization loop for pipeline stages 2-5.
|
||
|
||
Orchestrates a generate→score→select cycle:
|
||
1. Score the current best prompt against reference fixtures
|
||
2. Generate N variants targeting weak dimensions
|
||
3. Score each variant
|
||
4. Keep the best scorer as the new baseline
|
||
5. Repeat for K iterations
|
||
|
||
Usage (via CLI):
|
||
python -m pipeline.quality optimize --stage 5 --iterations 10
|
||
python -m pipeline.quality optimize --stage 3 --iterations 5 --file fixtures/sample_topic_group.json
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import time
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime, timezone
|
||
from pathlib import Path
|
||
|
||
from pipeline.llm_client import LLMClient
|
||
from pipeline.quality.scorer import STAGE_CONFIGS, ScoreResult, ScoreRunner
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class OptimizationResult:
|
||
"""Full result of an optimization run."""
|
||
|
||
best_prompt: str = ""
|
||
best_score: ScoreResult = field(default_factory=ScoreResult)
|
||
history: list[dict] = field(default_factory=list)
|
||
elapsed_seconds: float = 0.0
|
||
|
||
|
||
class OptimizationLoop:
|
||
"""Runs iterative prompt optimization for a pipeline stage.
|
||
|
||
Each iteration generates *variants_per_iter* prompt mutations,
|
||
scores each against reference fixture data, and keeps the
|
||
highest-composite-scoring variant as the new baseline.
|
||
|
||
Parameters
|
||
----------
|
||
client:
|
||
LLMClient instance for LLM calls (synthesis + scoring + variant gen).
|
||
stage:
|
||
Pipeline stage number (2-5).
|
||
fixture_path:
|
||
Path to a JSON fixture file matching the stage's expected keys.
|
||
iterations:
|
||
Number of generate→score→select cycles.
|
||
variants_per_iter:
|
||
Number of variant prompts to generate per iteration.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
client: LLMClient,
|
||
stage: int,
|
||
fixture_path: str,
|
||
iterations: int = 5,
|
||
variants_per_iter: int = 2,
|
||
output_dir: str | None = None,
|
||
) -> None:
|
||
if stage not in STAGE_CONFIGS:
|
||
raise ValueError(
|
||
f"Unsupported stage {stage}. Valid stages: {sorted(STAGE_CONFIGS)}"
|
||
)
|
||
|
||
self.client = client
|
||
self.stage = stage
|
||
self.fixture_path = fixture_path
|
||
self.iterations = iterations
|
||
self.variants_per_iter = variants_per_iter
|
||
self.config = STAGE_CONFIGS[stage]
|
||
self.output_dir = output_dir
|
||
|
||
self.scorer = ScoreRunner(client)
|
||
self.generator = PromptVariantGenerator(client)
|
||
|
||
def run(self) -> OptimizationResult:
|
||
"""Execute the full optimization loop.
|
||
|
||
Returns
|
||
-------
|
||
OptimizationResult
|
||
Contains the best prompt, its scores, full iteration history,
|
||
and wall-clock elapsed time.
|
||
"""
|
||
from pipeline.stages import _load_prompt
|
||
|
||
t0 = time.monotonic()
|
||
dimensions = self.config.dimensions
|
||
|
||
# Load base prompt using the stage's configured prompt file
|
||
prompt_file = self.config.prompt_file
|
||
try:
|
||
base_prompt = _load_prompt(prompt_file)
|
||
except FileNotFoundError:
|
||
logger.error("Prompt file not found: %s", prompt_file)
|
||
return OptimizationResult(
|
||
best_prompt="",
|
||
best_score=ScoreResult(error=f"Prompt file not found: {prompt_file}"),
|
||
elapsed_seconds=round(time.monotonic() - t0, 2),
|
||
)
|
||
|
||
# Load fixture data
|
||
try:
|
||
fixture = self._load_fixture()
|
||
except (FileNotFoundError, json.JSONDecodeError, KeyError) as exc:
|
||
logger.error("Failed to load fixture: %s", exc)
|
||
return OptimizationResult(
|
||
best_prompt=base_prompt,
|
||
best_score=ScoreResult(error=f"Fixture load error: {exc}"),
|
||
elapsed_seconds=round(time.monotonic() - t0, 2),
|
||
)
|
||
|
||
history: list[dict] = []
|
||
|
||
# Score the baseline
|
||
print(f"\n{'='*60}")
|
||
print(f" PROMPT OPTIMIZATION — Stage {self.stage}")
|
||
print(f" Iterations: {self.iterations}, Variants/iter: {self.variants_per_iter}")
|
||
print(f"{'='*60}\n")
|
||
|
||
print(" Scoring baseline prompt...")
|
||
best_score = self._score_variant(base_prompt, fixture)
|
||
best_prompt = base_prompt
|
||
|
||
history.append({
|
||
"iteration": 0,
|
||
"variant_index": 0,
|
||
"prompt_text": base_prompt[:200] + "..." if len(base_prompt) > 200 else base_prompt,
|
||
"prompt_length": len(base_prompt),
|
||
"composite": best_score.composite,
|
||
"scores": {d: best_score.scores.get(d, 0.0) for d in dimensions},
|
||
"error": best_score.error,
|
||
"label": "baseline",
|
||
})
|
||
|
||
if best_score.error:
|
||
print(f" ✗ Baseline scoring failed: {best_score.error}")
|
||
print(" Aborting optimization — fix the baseline first.\n")
|
||
return OptimizationResult(
|
||
best_prompt=best_prompt,
|
||
best_score=best_score,
|
||
history=history,
|
||
elapsed_seconds=round(time.monotonic() - t0, 2),
|
||
)
|
||
|
||
baseline_composite = best_score.composite
|
||
total_variants_scored = 0
|
||
|
||
self._write_progress(
|
||
phase="baseline_scored",
|
||
iteration=0, variant=0,
|
||
total_variants_scored=0,
|
||
best_composite=best_score.composite,
|
||
baseline_composite=baseline_composite,
|
||
elapsed_seconds=round(time.monotonic() - t0, 2),
|
||
best_label="baseline",
|
||
)
|
||
|
||
self._print_iteration_summary(0, best_score, is_baseline=True)
|
||
|
||
# Iterate
|
||
best_label = "baseline"
|
||
for iteration in range(1, self.iterations + 1):
|
||
print(f"\n ── Iteration {iteration}/{self.iterations} ──")
|
||
|
||
# Generate variants with stage-appropriate markers
|
||
variants = self.generator.generate(
|
||
base_prompt=best_prompt,
|
||
scores=best_score,
|
||
n=self.variants_per_iter,
|
||
stage=self.stage,
|
||
)
|
||
|
||
if not variants:
|
||
print(" ⚠ No valid variants generated — skipping iteration")
|
||
continue
|
||
|
||
# Score each variant
|
||
iteration_best_score = best_score
|
||
iteration_best_prompt = best_prompt
|
||
|
||
for vi, variant_prompt in enumerate(variants):
|
||
print(f" Scoring variant {vi + 1}/{len(variants)}...")
|
||
|
||
score = self._score_variant(variant_prompt, fixture)
|
||
|
||
history.append({
|
||
"iteration": iteration,
|
||
"variant_index": vi + 1,
|
||
"prompt_text": variant_prompt[:200] + "..." if len(variant_prompt) > 200 else variant_prompt,
|
||
"prompt_length": len(variant_prompt),
|
||
"composite": score.composite,
|
||
"scores": {d: score.scores.get(d, 0.0) for d in dimensions},
|
||
"error": score.error,
|
||
"label": f"iter{iteration}_v{vi+1}",
|
||
})
|
||
|
||
if score.error:
|
||
print(f" ✗ Variant {vi + 1} errored: {score.error}")
|
||
total_variants_scored += 1
|
||
self._write_progress(
|
||
phase="variant_scored",
|
||
iteration=iteration, variant=vi + 1,
|
||
total_variants_scored=total_variants_scored,
|
||
best_composite=best_score.composite,
|
||
baseline_composite=baseline_composite,
|
||
elapsed_seconds=round(time.monotonic() - t0, 2),
|
||
best_label=best_label,
|
||
)
|
||
continue
|
||
|
||
total_variants_scored += 1
|
||
|
||
if score.composite > iteration_best_score.composite:
|
||
iteration_best_score = score
|
||
iteration_best_prompt = variant_prompt
|
||
print(f" ✓ New best: {score.composite:.3f} (was {best_score.composite:.3f})")
|
||
else:
|
||
print(f" · Score {score.composite:.3f} ≤ current best {iteration_best_score.composite:.3f}")
|
||
|
||
self._write_progress(
|
||
phase="variant_scored",
|
||
iteration=iteration, variant=vi + 1,
|
||
total_variants_scored=total_variants_scored,
|
||
best_composite=max(best_score.composite, iteration_best_score.composite),
|
||
baseline_composite=baseline_composite,
|
||
elapsed_seconds=round(time.monotonic() - t0, 2),
|
||
best_label=best_label if iteration_best_score.composite <= best_score.composite
|
||
else f"iter{iteration}_v{vi+1}",
|
||
)
|
||
|
||
# Update global best if this iteration improved
|
||
if iteration_best_score.composite > best_score.composite:
|
||
best_score = iteration_best_score
|
||
best_prompt = iteration_best_prompt
|
||
best_label = f"iter{iteration}"
|
||
print(f" ★ Iteration {iteration} improved: {best_score.composite:.3f}")
|
||
else:
|
||
print(f" · No improvement in iteration {iteration}")
|
||
|
||
self._print_iteration_summary(iteration, best_score)
|
||
|
||
# Final report
|
||
elapsed = round(time.monotonic() - t0, 2)
|
||
self._print_final_report(best_score, history, elapsed)
|
||
|
||
self._write_progress(
|
||
phase="complete",
|
||
iteration=self.iterations,
|
||
variant=self.variants_per_iter,
|
||
total_variants_scored=total_variants_scored,
|
||
best_composite=best_score.composite,
|
||
baseline_composite=baseline_composite,
|
||
elapsed_seconds=elapsed,
|
||
best_label=best_label,
|
||
)
|
||
|
||
return OptimizationResult(
|
||
best_prompt=best_prompt,
|
||
best_score=best_score,
|
||
history=history,
|
||
elapsed_seconds=elapsed,
|
||
)
|
||
|
||
# ── Internal helpers ──────────────────────────────────────────────────
|
||
|
||
def _write_progress(
|
||
self,
|
||
*,
|
||
phase: str,
|
||
iteration: int,
|
||
variant: int,
|
||
total_variants_scored: int,
|
||
best_composite: float,
|
||
baseline_composite: float,
|
||
elapsed_seconds: float,
|
||
best_label: str = "",
|
||
) -> None:
|
||
"""Write a progress.json file to the output directory for external monitoring.
|
||
|
||
File is atomic-replaced so readers never see partial writes.
|
||
"""
|
||
if not self.output_dir:
|
||
return
|
||
|
||
out_dir = Path(self.output_dir)
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
progress_path = out_dir / f"progress_stage{self.stage}.json"
|
||
|
||
total_expected = self.iterations * self.variants_per_iter
|
||
pct = (total_variants_scored / total_expected * 100) if total_expected else 0
|
||
|
||
# ETA: average time per variant × remaining
|
||
remaining = total_expected - total_variants_scored
|
||
avg_per_variant = (elapsed_seconds / total_variants_scored) if total_variants_scored > 0 else 0
|
||
eta_seconds = round(avg_per_variant * remaining, 1)
|
||
|
||
payload = {
|
||
"stage": self.stage,
|
||
"phase": phase,
|
||
"iteration": iteration,
|
||
"total_iterations": self.iterations,
|
||
"variant": variant,
|
||
"variants_per_iter": self.variants_per_iter,
|
||
"total_variants_scored": total_variants_scored,
|
||
"total_expected": total_expected,
|
||
"percent_complete": round(pct, 1),
|
||
"baseline_composite": round(baseline_composite, 4),
|
||
"best_composite": round(best_composite, 4),
|
||
"improvement": round(best_composite - baseline_composite, 4),
|
||
"best_label": best_label,
|
||
"elapsed_seconds": round(elapsed_seconds, 1),
|
||
"eta_seconds": eta_seconds,
|
||
"updated_at": datetime.now(timezone.utc).isoformat(),
|
||
}
|
||
|
||
# Atomic write via temp file + rename
|
||
tmp_path = progress_path.with_suffix(".tmp")
|
||
tmp_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
||
tmp_path.rename(progress_path)
|
||
|
||
def _load_fixture(self) -> dict:
|
||
"""Load and validate the fixture JSON file against stage-specific keys."""
|
||
path = Path(self.fixture_path)
|
||
if not path.exists():
|
||
raise FileNotFoundError(f"Fixture not found: {path}")
|
||
data = json.loads(path.read_text(encoding="utf-8"))
|
||
|
||
for key in self.config.fixture_keys:
|
||
if key not in data:
|
||
raise KeyError(
|
||
f"Stage {self.stage} fixture must contain '{key}' key "
|
||
f"(required: {self.config.fixture_keys})"
|
||
)
|
||
|
||
return data
|
||
|
||
def _score_variant(
|
||
self,
|
||
variant_prompt: str,
|
||
fixture: dict,
|
||
) -> ScoreResult:
|
||
"""Score a variant prompt by running LLM completion + scoring.
|
||
|
||
Dispatches to stage-specific synthesis logic:
|
||
- Stages 2-4: call LLM with the variant prompt and fixture input,
|
||
parse with the stage's schema, then score via score_stage_output()
|
||
- Stage 5: original flow (synthesis + page scoring)
|
||
"""
|
||
from pipeline.stages import _get_stage_config
|
||
|
||
import json as _json
|
||
import openai as _openai
|
||
|
||
model_override, modality = _get_stage_config(self.stage)
|
||
schema_class = self.config.get_schema()
|
||
|
||
# Build user prompt from fixture data — stage-specific formatting
|
||
user_prompt = self._build_user_prompt(fixture)
|
||
|
||
t0 = time.monotonic()
|
||
try:
|
||
raw = self.client.complete(
|
||
system_prompt=variant_prompt,
|
||
user_prompt=user_prompt,
|
||
response_model=schema_class,
|
||
modality=modality,
|
||
model_override=model_override,
|
||
)
|
||
elapsed_synth = round(time.monotonic() - t0, 2)
|
||
except (_openai.APIConnectionError, _openai.APITimeoutError) as exc:
|
||
elapsed_synth = round(time.monotonic() - t0, 2)
|
||
return ScoreResult(
|
||
elapsed_seconds=elapsed_synth,
|
||
error=f"LLM error (stage {self.stage}): {exc}",
|
||
)
|
||
except Exception as exc:
|
||
elapsed_synth = round(time.monotonic() - t0, 2)
|
||
logger.exception("Unexpected error during variant synthesis (stage %d)", self.stage)
|
||
return ScoreResult(
|
||
elapsed_seconds=elapsed_synth,
|
||
error=f"Unexpected synthesis error: {exc}",
|
||
)
|
||
|
||
# Parse the LLM response into the stage schema
|
||
raw_text = str(raw).strip()
|
||
try:
|
||
parsed = self.client.parse_response(raw_text, schema_class)
|
||
except Exception as exc:
|
||
return ScoreResult(
|
||
elapsed_seconds=elapsed_synth,
|
||
error=f"Variant parse error (stage {self.stage}): {exc}",
|
||
)
|
||
|
||
# Convert parsed output to JSON for the scorer
|
||
output_json = self._schema_to_output_json(parsed)
|
||
if output_json is None:
|
||
return ScoreResult(
|
||
elapsed_seconds=elapsed_synth,
|
||
error=f"Stage {self.stage} produced empty output",
|
||
)
|
||
|
||
# Score using the generic stage scorer
|
||
result = self.scorer.score_stage_output(
|
||
stage=self.stage,
|
||
output_json=output_json,
|
||
input_json=self._fixture_to_input_json(fixture),
|
||
)
|
||
result.elapsed_seconds = round(result.elapsed_seconds + elapsed_synth, 2)
|
||
return result
|
||
|
||
def _build_user_prompt(self, fixture: dict) -> str:
|
||
"""Build a stage-appropriate user prompt from fixture data."""
|
||
if self.stage == 2:
|
||
segments_json = json.dumps(fixture["transcript_segments"], indent=2)
|
||
return f"<transcript_segments>\n{segments_json}\n</transcript_segments>"
|
||
|
||
elif self.stage == 3:
|
||
segments_json = json.dumps(fixture["topic_segments"], indent=2)
|
||
return f"<topic_segments>\n{segments_json}\n</topic_segments>"
|
||
|
||
elif self.stage == 4:
|
||
moments_json = json.dumps(fixture["extracted_moments"], indent=2)
|
||
taxonomy = fixture.get("taxonomy", "")
|
||
prompt = f"<moments>\n{moments_json}\n</moments>"
|
||
if taxonomy:
|
||
prompt += f"\n<taxonomy>{taxonomy}</taxonomy>"
|
||
return prompt
|
||
|
||
elif self.stage == 5:
|
||
moments_json = json.dumps(fixture["moments"], indent=2)
|
||
creator = fixture.get("creator_name", "Unknown")
|
||
return f"<creator>{creator}</creator>\n<moments>\n{moments_json}\n</moments>"
|
||
|
||
else:
|
||
return json.dumps(fixture, indent=2)
|
||
|
||
def _schema_to_output_json(self, parsed: object) -> dict | list | None:
|
||
"""Convert a parsed Pydantic schema instance to JSON-serializable dict."""
|
||
if hasattr(parsed, "model_dump"):
|
||
return parsed.model_dump()
|
||
elif hasattr(parsed, "dict"):
|
||
return parsed.dict()
|
||
return None
|
||
|
||
def _fixture_to_input_json(self, fixture: dict) -> dict | list:
|
||
"""Extract the primary input data from the fixture for scorer context."""
|
||
if self.stage == 2:
|
||
return fixture["transcript_segments"]
|
||
elif self.stage == 3:
|
||
return fixture["topic_segments"]
|
||
elif self.stage == 4:
|
||
return fixture["extracted_moments"]
|
||
elif self.stage == 5:
|
||
return fixture["moments"]
|
||
return fixture
|
||
|
||
def _print_iteration_summary(
|
||
self,
|
||
iteration: int,
|
||
score: ScoreResult,
|
||
is_baseline: bool = False,
|
||
) -> None:
|
||
"""Print a compact one-line summary of the current best scores."""
|
||
label = "BASELINE" if is_baseline else f"ITER {iteration}"
|
||
dimensions = self.config.dimensions
|
||
dims = " ".join(
|
||
f"{d[:4]}={score.scores.get(d, 0.0):.2f}" for d in dimensions
|
||
)
|
||
print(f" [{label}] composite={score.composite:.3f} {dims}")
|
||
|
||
def _print_final_report(
|
||
self,
|
||
best_score: ScoreResult,
|
||
history: list[dict],
|
||
elapsed: float,
|
||
) -> None:
|
||
"""Print the final optimization summary."""
|
||
dimensions = self.config.dimensions
|
||
|
||
print(f"\n{'='*60}")
|
||
print(" OPTIMIZATION COMPLETE")
|
||
print(f"{'='*60}")
|
||
print(f" Total time: {elapsed}s")
|
||
print(f" Iterations: {self.iterations}")
|
||
print(f" Variants scored: {len(history) - 1}") # minus baseline
|
||
|
||
baseline_composite = history[0]["composite"] if history else 0.0
|
||
improvement = best_score.composite - baseline_composite
|
||
|
||
print(f"\n Baseline composite: {baseline_composite:.3f}")
|
||
print(f" Best composite: {best_score.composite:.3f}")
|
||
if improvement > 0:
|
||
print(f" Improvement: +{improvement:.3f}")
|
||
else:
|
||
print(f" Improvement: {improvement:.3f} (no gain)")
|
||
|
||
print(f"\n Per-dimension best scores:")
|
||
for d in dimensions:
|
||
val = best_score.scores.get(d, 0.0)
|
||
bar = "█" * int(val * 20) + "░" * (20 - int(val * 20))
|
||
print(f" {d.replace('_', ' ').title():25s} {val:.2f} {bar}")
|
||
|
||
errored = sum(1 for h in history if h.get("error"))
|
||
if errored:
|
||
print(f"\n ⚠ {errored} variant(s) errored during scoring")
|
||
|
||
print(f"{'='*60}\n")
|
||
|
||
|
||
# Late import to avoid circular dependency (scorer imports at module level,
|
||
# variant_generator imports scorer)
|
||
from pipeline.quality.variant_generator import PromptVariantGenerator # noqa: E402
|