"""Chat evaluation harness — sends queries to the live chat endpoint, scores responses. Loads a test suite (YAML or JSON), calls the chat HTTP endpoint for each query, parses SSE events to collect response text and sources, then scores each using ChatScoreRunner. Writes results to a JSON file. Usage: python -m pipeline.quality chat_eval --suite fixtures/chat_test_suite.yaml python -m pipeline.quality chat_eval --suite fixtures/chat_test_suite.yaml --base-url http://ub01:8096 """ from __future__ import annotations import json import logging import time from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path from typing import Any import httpx from pipeline.llm_client import LLMClient from pipeline.quality.chat_scorer import CHAT_DIMENSIONS, ChatScoreResult, ChatScoreRunner logger = logging.getLogger(__name__) _DEFAULT_BASE_URL = "http://localhost:8096" _CHAT_ENDPOINT = "/api/chat" _REQUEST_TIMEOUT = 120.0 # seconds — LLM streaming can be slow @dataclass class ChatTestCase: """A single test case from the test suite.""" query: str creator: str | None = None personality_weight: float = 0.0 category: str = "general" description: str = "" @dataclass class ChatEvalResult: """Result of evaluating a single test case.""" test_case: ChatTestCase response: str = "" sources: list[dict] = field(default_factory=list) cascade_tier: str = "" score: ChatScoreResult | None = None request_error: str | None = None latency_seconds: float = 0.0 class ChatEvalRunner: """Runs a chat evaluation suite against a live endpoint.""" def __init__( self, scorer: ChatScoreRunner, base_url: str = _DEFAULT_BASE_URL, timeout: float = _REQUEST_TIMEOUT, ) -> None: self.scorer = scorer self.base_url = base_url.rstrip("/") self.timeout = timeout @staticmethod def load_suite(path: str | Path) -> list[ChatTestCase]: """Load test cases from a YAML or JSON file. Expected format (YAML): queries: - query: "How do I sidechain a bass?" creator: null personality_weight: 0.0 category: technical description: "Basic sidechain compression question" """ filepath = Path(path) text = filepath.read_text(encoding="utf-8") if filepath.suffix in (".yaml", ".yml"): try: import yaml except ImportError: raise ImportError( "PyYAML is required to load YAML test suites. " "Install with: pip install pyyaml" ) data = yaml.safe_load(text) else: data = json.loads(text) queries = data.get("queries", []) cases: list[ChatTestCase] = [] for q in queries: cases.append(ChatTestCase( query=q["query"], creator=q.get("creator"), personality_weight=float(q.get("personality_weight", 0.0)), category=q.get("category", "general"), description=q.get("description", ""), )) return cases def run_suite(self, cases: list[ChatTestCase]) -> list[ChatEvalResult]: """Execute all test cases sequentially, scoring each response.""" results: list[ChatEvalResult] = [] for i, case in enumerate(cases, 1): print(f"\n [{i}/{len(cases)}] {case.category}: {case.query[:60]}...") result = self._run_single(case) results.append(result) if result.request_error: print(f" ✗ Request error: {result.request_error}") elif result.score and result.score.error: print(f" ✗ Scoring error: {result.score.error}") elif result.score: print(f" ✓ Composite: {result.score.composite:.3f} " f"(latency: {result.latency_seconds:.1f}s)") return results def _run_single(self, case: ChatTestCase) -> ChatEvalResult: """Execute a single test case: call endpoint, parse SSE, score.""" eval_result = ChatEvalResult(test_case=case) # Call the chat endpoint t0 = time.monotonic() try: response_text, sources, cascade_tier = self._call_chat_endpoint(case) eval_result.latency_seconds = round(time.monotonic() - t0, 2) except Exception as exc: eval_result.latency_seconds = round(time.monotonic() - t0, 2) eval_result.request_error = str(exc) logger.error("chat_eval_request_error query=%r error=%s", case.query, exc) return eval_result eval_result.response = response_text eval_result.sources = sources eval_result.cascade_tier = cascade_tier if not response_text: eval_result.request_error = "Empty response from chat endpoint" return eval_result # Score the response eval_result.score = self.scorer.score_response( query=case.query, response=response_text, sources=sources, personality_weight=case.personality_weight, creator_name=case.creator, ) return eval_result def _call_chat_endpoint( self, case: ChatTestCase ) -> tuple[str, list[dict], str]: """Call the chat SSE endpoint and parse the event stream. Returns (accumulated_text, sources_list, cascade_tier). """ url = f"{self.base_url}{_CHAT_ENDPOINT}" payload: dict[str, Any] = {"query": case.query} if case.creator: payload["creator"] = case.creator if case.personality_weight > 0: payload["personality_weight"] = case.personality_weight sources: list[dict] = [] accumulated = "" cascade_tier = "" with httpx.Client(timeout=self.timeout) as client: with client.stream("POST", url, json=payload) as resp: resp.raise_for_status() buffer = "" for chunk in resp.iter_text(): buffer += chunk # Parse SSE events from buffer while "\n\n" in buffer: event_block, buffer = buffer.split("\n\n", 1) event_type, event_data = self._parse_sse_event(event_block) if event_type == "sources": sources = event_data if isinstance(event_data, list) else [] elif event_type == "token": accumulated += event_data if isinstance(event_data, str) else str(event_data) elif event_type == "done": if isinstance(event_data, dict): cascade_tier = event_data.get("cascade_tier", "") elif event_type == "error": msg = event_data.get("message", str(event_data)) if isinstance(event_data, dict) else str(event_data) raise RuntimeError(f"Chat endpoint returned error: {msg}") return accumulated, sources, cascade_tier @staticmethod def _parse_sse_event(block: str) -> tuple[str, Any]: """Parse a single SSE event block into (event_type, data).""" event_type = "" data_lines: list[str] = [] for line in block.strip().splitlines(): if line.startswith("event: "): event_type = line[7:].strip() elif line.startswith("data: "): data_lines.append(line[6:]) elif line.startswith("data:"): data_lines.append(line[5:]) raw_data = "\n".join(data_lines) try: parsed = json.loads(raw_data) except (json.JSONDecodeError, ValueError): parsed = raw_data # plain text token return event_type, parsed @staticmethod def write_results( results: list[ChatEvalResult], output_path: str | Path, ) -> str: """Write evaluation results to a JSON file. Returns the path.""" out = Path(output_path) out.parent.mkdir(parents=True, exist_ok=True) timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") if out.is_dir(): filepath = out / f"chat_eval_{timestamp}.json" else: filepath = out # Build serializable payload entries: list[dict] = [] for r in results: entry: dict[str, Any] = { "query": r.test_case.query, "creator": r.test_case.creator, "personality_weight": r.test_case.personality_weight, "category": r.test_case.category, "description": r.test_case.description, "response_length": len(r.response), "source_count": len(r.sources), "cascade_tier": r.cascade_tier, "latency_seconds": r.latency_seconds, } if r.request_error: entry["error"] = r.request_error elif r.score: entry["scores"] = r.score.scores entry["composite"] = r.score.composite entry["justifications"] = r.score.justifications entry["scoring_time"] = r.score.elapsed_seconds if r.score.error: entry["scoring_error"] = r.score.error entries.append(entry) # Summary stats scored = [e for e in entries if "composite" in e] avg_composite = ( sum(e["composite"] for e in scored) / len(scored) if scored else 0.0 ) dim_avgs: dict[str, float] = {} for dim in CHAT_DIMENSIONS: vals = [e["scores"][dim] for e in scored if dim in e.get("scores", {})] dim_avgs[dim] = round(sum(vals) / len(vals), 3) if vals else 0.0 payload = { "timestamp": timestamp, "total_queries": len(results), "scored_queries": len(scored), "errors": len(results) - len(scored), "average_composite": round(avg_composite, 3), "dimension_averages": dim_avgs, "results": entries, } filepath.write_text(json.dumps(payload, indent=2), encoding="utf-8") return str(filepath) @staticmethod def print_summary(results: list[ChatEvalResult]) -> None: """Print a summary table of evaluation results.""" print("\n" + "=" * 72) print(" CHAT EVALUATION SUMMARY") print("=" * 72) scored = [r for r in results if r.score and not r.score.error and not r.request_error] errored = [r for r in results if r.request_error or (r.score and r.score.error)] if not scored: print("\n No successfully scored responses.\n") if errored: print(f" Errors: {len(errored)}") for r in errored: err = r.request_error or (r.score.error if r.score else "unknown") print(f" - {r.test_case.query[:50]}: {err}") print("=" * 72 + "\n") return # Header print(f"\n {'Category':<12s} {'Query':<30s} {'Comp':>5s} {'Cite':>5s} {'Struct':>6s} {'Domain':>6s} {'Ground':>6s} {'Person':>6s}") print(f" {'─'*12} {'─'*30} {'─'*5} {'─'*5} {'─'*6} {'─'*6} {'─'*6} {'─'*6}") for r in scored: s = r.score assert s is not None q = r.test_case.query[:30] cat = r.test_case.category[:12] print( f" {cat:<12s} {q:<30s} " f"{s.composite:5.2f} " f"{s.citation_accuracy:5.2f} " f"{s.response_structure:6.2f} " f"{s.domain_expertise:6.2f} " f"{s.source_grounding:6.2f} " f"{s.personality_fidelity:6.2f}" ) # Averages avg_comp = sum(r.score.composite for r in scored) / len(scored) avg_dims = {} for dim in CHAT_DIMENSIONS: vals = [r.score.scores.get(dim, 0.0) for r in scored] avg_dims[dim] = sum(vals) / len(vals) print(f"\n {'AVERAGE':<12s} {'':30s} " f"{avg_comp:5.2f} " f"{avg_dims['citation_accuracy']:5.2f} " f"{avg_dims['response_structure']:6.2f} " f"{avg_dims['domain_expertise']:6.2f} " f"{avg_dims['source_grounding']:6.2f} " f"{avg_dims['personality_fidelity']:6.2f}") if errored: print(f"\n Errors: {len(errored)}") for r in errored: err = r.request_error or (r.score.error if r.score else "unknown") print(f" - {r.test_case.query[:50]}: {err}") print("=" * 72 + "\n")