- "backend/pipeline/quality/chat_scorer.py" - "backend/pipeline/quality/chat_eval.py" - "backend/pipeline/quality/fixtures/chat_test_suite.yaml" - "backend/pipeline/quality/__main__.py" GSD-Task: S09/T01
352 lines
13 KiB
Python
352 lines
13 KiB
Python
"""Chat evaluation harness — sends queries to the live chat endpoint, scores responses.
|
|
|
|
Loads a test suite (YAML or JSON), calls the chat HTTP endpoint for each query,
|
|
parses SSE events to collect response text and sources, then scores each using
|
|
ChatScoreRunner. Writes results to a JSON file.
|
|
|
|
Usage:
|
|
python -m pipeline.quality chat_eval --suite fixtures/chat_test_suite.yaml
|
|
python -m pipeline.quality chat_eval --suite fixtures/chat_test_suite.yaml --base-url http://ub01:8096
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import httpx
|
|
|
|
from pipeline.llm_client import LLMClient
|
|
from pipeline.quality.chat_scorer import CHAT_DIMENSIONS, ChatScoreResult, ChatScoreRunner
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_DEFAULT_BASE_URL = "http://localhost:8096"
|
|
_CHAT_ENDPOINT = "/api/chat"
|
|
_REQUEST_TIMEOUT = 120.0 # seconds — LLM streaming can be slow
|
|
|
|
|
|
@dataclass
|
|
class ChatTestCase:
|
|
"""A single test case from the test suite."""
|
|
|
|
query: str
|
|
creator: str | None = None
|
|
personality_weight: float = 0.0
|
|
category: str = "general"
|
|
description: str = ""
|
|
|
|
|
|
@dataclass
|
|
class ChatEvalResult:
|
|
"""Result of evaluating a single test case."""
|
|
|
|
test_case: ChatTestCase
|
|
response: str = ""
|
|
sources: list[dict] = field(default_factory=list)
|
|
cascade_tier: str = ""
|
|
score: ChatScoreResult | None = None
|
|
request_error: str | None = None
|
|
latency_seconds: float = 0.0
|
|
|
|
|
|
class ChatEvalRunner:
|
|
"""Runs a chat evaluation suite against a live endpoint."""
|
|
|
|
def __init__(
|
|
self,
|
|
scorer: ChatScoreRunner,
|
|
base_url: str = _DEFAULT_BASE_URL,
|
|
timeout: float = _REQUEST_TIMEOUT,
|
|
) -> None:
|
|
self.scorer = scorer
|
|
self.base_url = base_url.rstrip("/")
|
|
self.timeout = timeout
|
|
|
|
@staticmethod
|
|
def load_suite(path: str | Path) -> list[ChatTestCase]:
|
|
"""Load test cases from a YAML or JSON file.
|
|
|
|
Expected format (YAML):
|
|
queries:
|
|
- query: "How do I sidechain a bass?"
|
|
creator: null
|
|
personality_weight: 0.0
|
|
category: technical
|
|
description: "Basic sidechain compression question"
|
|
"""
|
|
filepath = Path(path)
|
|
text = filepath.read_text(encoding="utf-8")
|
|
|
|
if filepath.suffix in (".yaml", ".yml"):
|
|
try:
|
|
import yaml
|
|
except ImportError:
|
|
raise ImportError(
|
|
"PyYAML is required to load YAML test suites. "
|
|
"Install with: pip install pyyaml"
|
|
)
|
|
data = yaml.safe_load(text)
|
|
else:
|
|
data = json.loads(text)
|
|
|
|
queries = data.get("queries", [])
|
|
cases: list[ChatTestCase] = []
|
|
for q in queries:
|
|
cases.append(ChatTestCase(
|
|
query=q["query"],
|
|
creator=q.get("creator"),
|
|
personality_weight=float(q.get("personality_weight", 0.0)),
|
|
category=q.get("category", "general"),
|
|
description=q.get("description", ""),
|
|
))
|
|
return cases
|
|
|
|
def run_suite(self, cases: list[ChatTestCase]) -> list[ChatEvalResult]:
|
|
"""Execute all test cases sequentially, scoring each response."""
|
|
results: list[ChatEvalResult] = []
|
|
|
|
for i, case in enumerate(cases, 1):
|
|
print(f"\n [{i}/{len(cases)}] {case.category}: {case.query[:60]}...")
|
|
result = self._run_single(case)
|
|
results.append(result)
|
|
|
|
if result.request_error:
|
|
print(f" ✗ Request error: {result.request_error}")
|
|
elif result.score and result.score.error:
|
|
print(f" ✗ Scoring error: {result.score.error}")
|
|
elif result.score:
|
|
print(f" ✓ Composite: {result.score.composite:.3f} "
|
|
f"(latency: {result.latency_seconds:.1f}s)")
|
|
|
|
return results
|
|
|
|
def _run_single(self, case: ChatTestCase) -> ChatEvalResult:
|
|
"""Execute a single test case: call endpoint, parse SSE, score."""
|
|
eval_result = ChatEvalResult(test_case=case)
|
|
|
|
# Call the chat endpoint
|
|
t0 = time.monotonic()
|
|
try:
|
|
response_text, sources, cascade_tier = self._call_chat_endpoint(case)
|
|
eval_result.latency_seconds = round(time.monotonic() - t0, 2)
|
|
except Exception as exc:
|
|
eval_result.latency_seconds = round(time.monotonic() - t0, 2)
|
|
eval_result.request_error = str(exc)
|
|
logger.error("chat_eval_request_error query=%r error=%s", case.query, exc)
|
|
return eval_result
|
|
|
|
eval_result.response = response_text
|
|
eval_result.sources = sources
|
|
eval_result.cascade_tier = cascade_tier
|
|
|
|
if not response_text:
|
|
eval_result.request_error = "Empty response from chat endpoint"
|
|
return eval_result
|
|
|
|
# Score the response
|
|
eval_result.score = self.scorer.score_response(
|
|
query=case.query,
|
|
response=response_text,
|
|
sources=sources,
|
|
personality_weight=case.personality_weight,
|
|
creator_name=case.creator,
|
|
)
|
|
|
|
return eval_result
|
|
|
|
def _call_chat_endpoint(
|
|
self, case: ChatTestCase
|
|
) -> tuple[str, list[dict], str]:
|
|
"""Call the chat SSE endpoint and parse the event stream.
|
|
|
|
Returns (accumulated_text, sources_list, cascade_tier).
|
|
"""
|
|
url = f"{self.base_url}{_CHAT_ENDPOINT}"
|
|
payload: dict[str, Any] = {"query": case.query}
|
|
if case.creator:
|
|
payload["creator"] = case.creator
|
|
if case.personality_weight > 0:
|
|
payload["personality_weight"] = case.personality_weight
|
|
|
|
sources: list[dict] = []
|
|
accumulated = ""
|
|
cascade_tier = ""
|
|
|
|
with httpx.Client(timeout=self.timeout) as client:
|
|
with client.stream("POST", url, json=payload) as resp:
|
|
resp.raise_for_status()
|
|
|
|
buffer = ""
|
|
for chunk in resp.iter_text():
|
|
buffer += chunk
|
|
# Parse SSE events from buffer
|
|
while "\n\n" in buffer:
|
|
event_block, buffer = buffer.split("\n\n", 1)
|
|
event_type, event_data = self._parse_sse_event(event_block)
|
|
|
|
if event_type == "sources":
|
|
sources = event_data if isinstance(event_data, list) else []
|
|
elif event_type == "token":
|
|
accumulated += event_data if isinstance(event_data, str) else str(event_data)
|
|
elif event_type == "done":
|
|
if isinstance(event_data, dict):
|
|
cascade_tier = event_data.get("cascade_tier", "")
|
|
elif event_type == "error":
|
|
msg = event_data.get("message", str(event_data)) if isinstance(event_data, dict) else str(event_data)
|
|
raise RuntimeError(f"Chat endpoint returned error: {msg}")
|
|
|
|
return accumulated, sources, cascade_tier
|
|
|
|
@staticmethod
|
|
def _parse_sse_event(block: str) -> tuple[str, Any]:
|
|
"""Parse a single SSE event block into (event_type, data)."""
|
|
event_type = ""
|
|
data_lines: list[str] = []
|
|
|
|
for line in block.strip().splitlines():
|
|
if line.startswith("event: "):
|
|
event_type = line[7:].strip()
|
|
elif line.startswith("data: "):
|
|
data_lines.append(line[6:])
|
|
elif line.startswith("data:"):
|
|
data_lines.append(line[5:])
|
|
|
|
raw_data = "\n".join(data_lines)
|
|
try:
|
|
parsed = json.loads(raw_data)
|
|
except (json.JSONDecodeError, ValueError):
|
|
parsed = raw_data # plain text token
|
|
|
|
return event_type, parsed
|
|
|
|
@staticmethod
|
|
def write_results(
|
|
results: list[ChatEvalResult],
|
|
output_path: str | Path,
|
|
) -> str:
|
|
"""Write evaluation results to a JSON file. Returns the path."""
|
|
out = Path(output_path)
|
|
out.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
|
if out.is_dir():
|
|
filepath = out / f"chat_eval_{timestamp}.json"
|
|
else:
|
|
filepath = out
|
|
|
|
# Build serializable payload
|
|
entries: list[dict] = []
|
|
for r in results:
|
|
entry: dict[str, Any] = {
|
|
"query": r.test_case.query,
|
|
"creator": r.test_case.creator,
|
|
"personality_weight": r.test_case.personality_weight,
|
|
"category": r.test_case.category,
|
|
"description": r.test_case.description,
|
|
"response_length": len(r.response),
|
|
"source_count": len(r.sources),
|
|
"cascade_tier": r.cascade_tier,
|
|
"latency_seconds": r.latency_seconds,
|
|
}
|
|
|
|
if r.request_error:
|
|
entry["error"] = r.request_error
|
|
elif r.score:
|
|
entry["scores"] = r.score.scores
|
|
entry["composite"] = r.score.composite
|
|
entry["justifications"] = r.score.justifications
|
|
entry["scoring_time"] = r.score.elapsed_seconds
|
|
if r.score.error:
|
|
entry["scoring_error"] = r.score.error
|
|
|
|
entries.append(entry)
|
|
|
|
# Summary stats
|
|
scored = [e for e in entries if "composite" in e]
|
|
avg_composite = (
|
|
sum(e["composite"] for e in scored) / len(scored) if scored else 0.0
|
|
)
|
|
dim_avgs: dict[str, float] = {}
|
|
for dim in CHAT_DIMENSIONS:
|
|
vals = [e["scores"][dim] for e in scored if dim in e.get("scores", {})]
|
|
dim_avgs[dim] = round(sum(vals) / len(vals), 3) if vals else 0.0
|
|
|
|
payload = {
|
|
"timestamp": timestamp,
|
|
"total_queries": len(results),
|
|
"scored_queries": len(scored),
|
|
"errors": len(results) - len(scored),
|
|
"average_composite": round(avg_composite, 3),
|
|
"dimension_averages": dim_avgs,
|
|
"results": entries,
|
|
}
|
|
|
|
filepath.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
|
return str(filepath)
|
|
|
|
@staticmethod
|
|
def print_summary(results: list[ChatEvalResult]) -> None:
|
|
"""Print a summary table of evaluation results."""
|
|
print("\n" + "=" * 72)
|
|
print(" CHAT EVALUATION SUMMARY")
|
|
print("=" * 72)
|
|
|
|
scored = [r for r in results if r.score and not r.score.error and not r.request_error]
|
|
errored = [r for r in results if r.request_error or (r.score and r.score.error)]
|
|
|
|
if not scored:
|
|
print("\n No successfully scored responses.\n")
|
|
if errored:
|
|
print(f" Errors: {len(errored)}")
|
|
for r in errored:
|
|
err = r.request_error or (r.score.error if r.score else "unknown")
|
|
print(f" - {r.test_case.query[:50]}: {err}")
|
|
print("=" * 72 + "\n")
|
|
return
|
|
|
|
# Header
|
|
print(f"\n {'Category':<12s} {'Query':<30s} {'Comp':>5s} {'Cite':>5s} {'Struct':>6s} {'Domain':>6s} {'Ground':>6s} {'Person':>6s}")
|
|
print(f" {'─'*12} {'─'*30} {'─'*5} {'─'*5} {'─'*6} {'─'*6} {'─'*6} {'─'*6}")
|
|
|
|
for r in scored:
|
|
s = r.score
|
|
assert s is not None
|
|
q = r.test_case.query[:30]
|
|
cat = r.test_case.category[:12]
|
|
print(
|
|
f" {cat:<12s} {q:<30s} "
|
|
f"{s.composite:5.2f} "
|
|
f"{s.citation_accuracy:5.2f} "
|
|
f"{s.response_structure:6.2f} "
|
|
f"{s.domain_expertise:6.2f} "
|
|
f"{s.source_grounding:6.2f} "
|
|
f"{s.personality_fidelity:6.2f}"
|
|
)
|
|
|
|
# Averages
|
|
avg_comp = sum(r.score.composite for r in scored) / len(scored)
|
|
avg_dims = {}
|
|
for dim in CHAT_DIMENSIONS:
|
|
vals = [r.score.scores.get(dim, 0.0) for r in scored]
|
|
avg_dims[dim] = sum(vals) / len(vals)
|
|
|
|
print(f"\n {'AVERAGE':<12s} {'':30s} "
|
|
f"{avg_comp:5.2f} "
|
|
f"{avg_dims['citation_accuracy']:5.2f} "
|
|
f"{avg_dims['response_structure']:6.2f} "
|
|
f"{avg_dims['domain_expertise']:6.2f} "
|
|
f"{avg_dims['source_grounding']:6.2f} "
|
|
f"{avg_dims['personality_fidelity']:6.2f}")
|
|
|
|
if errored:
|
|
print(f"\n Errors: {len(errored)}")
|
|
for r in errored:
|
|
err = r.request_error or (r.score.error if r.score else "unknown")
|
|
print(f" - {r.test_case.query[:50]}: {err}")
|
|
|
|
print("=" * 72 + "\n")
|