chrysopedia/backend/pipeline/quality/chat_eval.py
jlightner 846db2aad5 test: Created chat-specific LLM-as-judge scorer (5 dimensions), SSE-par…
- "backend/pipeline/quality/chat_scorer.py"
- "backend/pipeline/quality/chat_eval.py"
- "backend/pipeline/quality/fixtures/chat_test_suite.yaml"
- "backend/pipeline/quality/__main__.py"

GSD-Task: S09/T01
2026-04-04 14:43:52 +00:00

352 lines
13 KiB
Python

"""Chat evaluation harness — sends queries to the live chat endpoint, scores responses.
Loads a test suite (YAML or JSON), calls the chat HTTP endpoint for each query,
parses SSE events to collect response text and sources, then scores each using
ChatScoreRunner. Writes results to a JSON file.
Usage:
python -m pipeline.quality chat_eval --suite fixtures/chat_test_suite.yaml
python -m pipeline.quality chat_eval --suite fixtures/chat_test_suite.yaml --base-url http://ub01:8096
"""
from __future__ import annotations
import json
import logging
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import httpx
from pipeline.llm_client import LLMClient
from pipeline.quality.chat_scorer import CHAT_DIMENSIONS, ChatScoreResult, ChatScoreRunner
logger = logging.getLogger(__name__)
_DEFAULT_BASE_URL = "http://localhost:8096"
_CHAT_ENDPOINT = "/api/chat"
_REQUEST_TIMEOUT = 120.0 # seconds — LLM streaming can be slow
@dataclass
class ChatTestCase:
"""A single test case from the test suite."""
query: str
creator: str | None = None
personality_weight: float = 0.0
category: str = "general"
description: str = ""
@dataclass
class ChatEvalResult:
"""Result of evaluating a single test case."""
test_case: ChatTestCase
response: str = ""
sources: list[dict] = field(default_factory=list)
cascade_tier: str = ""
score: ChatScoreResult | None = None
request_error: str | None = None
latency_seconds: float = 0.0
class ChatEvalRunner:
"""Runs a chat evaluation suite against a live endpoint."""
def __init__(
self,
scorer: ChatScoreRunner,
base_url: str = _DEFAULT_BASE_URL,
timeout: float = _REQUEST_TIMEOUT,
) -> None:
self.scorer = scorer
self.base_url = base_url.rstrip("/")
self.timeout = timeout
@staticmethod
def load_suite(path: str | Path) -> list[ChatTestCase]:
"""Load test cases from a YAML or JSON file.
Expected format (YAML):
queries:
- query: "How do I sidechain a bass?"
creator: null
personality_weight: 0.0
category: technical
description: "Basic sidechain compression question"
"""
filepath = Path(path)
text = filepath.read_text(encoding="utf-8")
if filepath.suffix in (".yaml", ".yml"):
try:
import yaml
except ImportError:
raise ImportError(
"PyYAML is required to load YAML test suites. "
"Install with: pip install pyyaml"
)
data = yaml.safe_load(text)
else:
data = json.loads(text)
queries = data.get("queries", [])
cases: list[ChatTestCase] = []
for q in queries:
cases.append(ChatTestCase(
query=q["query"],
creator=q.get("creator"),
personality_weight=float(q.get("personality_weight", 0.0)),
category=q.get("category", "general"),
description=q.get("description", ""),
))
return cases
def run_suite(self, cases: list[ChatTestCase]) -> list[ChatEvalResult]:
"""Execute all test cases sequentially, scoring each response."""
results: list[ChatEvalResult] = []
for i, case in enumerate(cases, 1):
print(f"\n [{i}/{len(cases)}] {case.category}: {case.query[:60]}...")
result = self._run_single(case)
results.append(result)
if result.request_error:
print(f" ✗ Request error: {result.request_error}")
elif result.score and result.score.error:
print(f" ✗ Scoring error: {result.score.error}")
elif result.score:
print(f" ✓ Composite: {result.score.composite:.3f} "
f"(latency: {result.latency_seconds:.1f}s)")
return results
def _run_single(self, case: ChatTestCase) -> ChatEvalResult:
"""Execute a single test case: call endpoint, parse SSE, score."""
eval_result = ChatEvalResult(test_case=case)
# Call the chat endpoint
t0 = time.monotonic()
try:
response_text, sources, cascade_tier = self._call_chat_endpoint(case)
eval_result.latency_seconds = round(time.monotonic() - t0, 2)
except Exception as exc:
eval_result.latency_seconds = round(time.monotonic() - t0, 2)
eval_result.request_error = str(exc)
logger.error("chat_eval_request_error query=%r error=%s", case.query, exc)
return eval_result
eval_result.response = response_text
eval_result.sources = sources
eval_result.cascade_tier = cascade_tier
if not response_text:
eval_result.request_error = "Empty response from chat endpoint"
return eval_result
# Score the response
eval_result.score = self.scorer.score_response(
query=case.query,
response=response_text,
sources=sources,
personality_weight=case.personality_weight,
creator_name=case.creator,
)
return eval_result
def _call_chat_endpoint(
self, case: ChatTestCase
) -> tuple[str, list[dict], str]:
"""Call the chat SSE endpoint and parse the event stream.
Returns (accumulated_text, sources_list, cascade_tier).
"""
url = f"{self.base_url}{_CHAT_ENDPOINT}"
payload: dict[str, Any] = {"query": case.query}
if case.creator:
payload["creator"] = case.creator
if case.personality_weight > 0:
payload["personality_weight"] = case.personality_weight
sources: list[dict] = []
accumulated = ""
cascade_tier = ""
with httpx.Client(timeout=self.timeout) as client:
with client.stream("POST", url, json=payload) as resp:
resp.raise_for_status()
buffer = ""
for chunk in resp.iter_text():
buffer += chunk
# Parse SSE events from buffer
while "\n\n" in buffer:
event_block, buffer = buffer.split("\n\n", 1)
event_type, event_data = self._parse_sse_event(event_block)
if event_type == "sources":
sources = event_data if isinstance(event_data, list) else []
elif event_type == "token":
accumulated += event_data if isinstance(event_data, str) else str(event_data)
elif event_type == "done":
if isinstance(event_data, dict):
cascade_tier = event_data.get("cascade_tier", "")
elif event_type == "error":
msg = event_data.get("message", str(event_data)) if isinstance(event_data, dict) else str(event_data)
raise RuntimeError(f"Chat endpoint returned error: {msg}")
return accumulated, sources, cascade_tier
@staticmethod
def _parse_sse_event(block: str) -> tuple[str, Any]:
"""Parse a single SSE event block into (event_type, data)."""
event_type = ""
data_lines: list[str] = []
for line in block.strip().splitlines():
if line.startswith("event: "):
event_type = line[7:].strip()
elif line.startswith("data: "):
data_lines.append(line[6:])
elif line.startswith("data:"):
data_lines.append(line[5:])
raw_data = "\n".join(data_lines)
try:
parsed = json.loads(raw_data)
except (json.JSONDecodeError, ValueError):
parsed = raw_data # plain text token
return event_type, parsed
@staticmethod
def write_results(
results: list[ChatEvalResult],
output_path: str | Path,
) -> str:
"""Write evaluation results to a JSON file. Returns the path."""
out = Path(output_path)
out.parent.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
if out.is_dir():
filepath = out / f"chat_eval_{timestamp}.json"
else:
filepath = out
# Build serializable payload
entries: list[dict] = []
for r in results:
entry: dict[str, Any] = {
"query": r.test_case.query,
"creator": r.test_case.creator,
"personality_weight": r.test_case.personality_weight,
"category": r.test_case.category,
"description": r.test_case.description,
"response_length": len(r.response),
"source_count": len(r.sources),
"cascade_tier": r.cascade_tier,
"latency_seconds": r.latency_seconds,
}
if r.request_error:
entry["error"] = r.request_error
elif r.score:
entry["scores"] = r.score.scores
entry["composite"] = r.score.composite
entry["justifications"] = r.score.justifications
entry["scoring_time"] = r.score.elapsed_seconds
if r.score.error:
entry["scoring_error"] = r.score.error
entries.append(entry)
# Summary stats
scored = [e for e in entries if "composite" in e]
avg_composite = (
sum(e["composite"] for e in scored) / len(scored) if scored else 0.0
)
dim_avgs: dict[str, float] = {}
for dim in CHAT_DIMENSIONS:
vals = [e["scores"][dim] for e in scored if dim in e.get("scores", {})]
dim_avgs[dim] = round(sum(vals) / len(vals), 3) if vals else 0.0
payload = {
"timestamp": timestamp,
"total_queries": len(results),
"scored_queries": len(scored),
"errors": len(results) - len(scored),
"average_composite": round(avg_composite, 3),
"dimension_averages": dim_avgs,
"results": entries,
}
filepath.write_text(json.dumps(payload, indent=2), encoding="utf-8")
return str(filepath)
@staticmethod
def print_summary(results: list[ChatEvalResult]) -> None:
"""Print a summary table of evaluation results."""
print("\n" + "=" * 72)
print(" CHAT EVALUATION SUMMARY")
print("=" * 72)
scored = [r for r in results if r.score and not r.score.error and not r.request_error]
errored = [r for r in results if r.request_error or (r.score and r.score.error)]
if not scored:
print("\n No successfully scored responses.\n")
if errored:
print(f" Errors: {len(errored)}")
for r in errored:
err = r.request_error or (r.score.error if r.score else "unknown")
print(f" - {r.test_case.query[:50]}: {err}")
print("=" * 72 + "\n")
return
# Header
print(f"\n {'Category':<12s} {'Query':<30s} {'Comp':>5s} {'Cite':>5s} {'Struct':>6s} {'Domain':>6s} {'Ground':>6s} {'Person':>6s}")
print(f" {''*12} {''*30} {''*5} {''*5} {''*6} {''*6} {''*6} {''*6}")
for r in scored:
s = r.score
assert s is not None
q = r.test_case.query[:30]
cat = r.test_case.category[:12]
print(
f" {cat:<12s} {q:<30s} "
f"{s.composite:5.2f} "
f"{s.citation_accuracy:5.2f} "
f"{s.response_structure:6.2f} "
f"{s.domain_expertise:6.2f} "
f"{s.source_grounding:6.2f} "
f"{s.personality_fidelity:6.2f}"
)
# Averages
avg_comp = sum(r.score.composite for r in scored) / len(scored)
avg_dims = {}
for dim in CHAT_DIMENSIONS:
vals = [r.score.scores.get(dim, 0.0) for r in scored]
avg_dims[dim] = sum(vals) / len(vals)
print(f"\n {'AVERAGE':<12s} {'':30s} "
f"{avg_comp:5.2f} "
f"{avg_dims['citation_accuracy']:5.2f} "
f"{avg_dims['response_structure']:6.2f} "
f"{avg_dims['domain_expertise']:6.2f} "
f"{avg_dims['source_grounding']:6.2f} "
f"{avg_dims['personality_fidelity']:6.2f}")
if errored:
print(f"\n Errors: {len(errored)}")
for r in errored:
err = r.request_error or (r.score.error if r.score else "unknown")
print(f" - {r.test_case.query[:50]}: {err}")
print("=" * 72 + "\n")