- "backend/pipeline/quality/__init__.py" - "backend/pipeline/quality/__main__.py" - "backend/pipeline/quality/fitness.py" GSD-Task: S01/T01
489 lines
20 KiB
Python
489 lines
20 KiB
Python
"""FYN-LLM fitness test runner.
|
|
|
|
Tests four categories:
|
|
1. Mandelbrot reasoning — factual knowledge / reasoning depth
|
|
2. JSON compliance — simple and nested structured output
|
|
3. Instruction following — bullet count, keyword inclusion, casing
|
|
4. Diverse prompt battery — summarization, classification, extraction
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
|
|
import openai
|
|
from pydantic import BaseModel
|
|
|
|
from pipeline.llm_client import LLMClient
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ── Result types ─────────────────────────────────────────────────────────────
|
|
|
|
@dataclass
|
|
class TestResult:
|
|
"""Outcome of a single fitness test."""
|
|
|
|
name: str
|
|
passed: bool
|
|
elapsed_seconds: float
|
|
token_count: int | None = None
|
|
detail: str = ""
|
|
|
|
|
|
@dataclass
|
|
class CategoryReport:
|
|
"""Results for one test category."""
|
|
|
|
category: str
|
|
results: list[TestResult] = field(default_factory=list)
|
|
|
|
@property
|
|
def all_passed(self) -> bool:
|
|
return all(r.passed for r in self.results)
|
|
|
|
|
|
# ── Pydantic models for JSON compliance tests ────────────────────────────────
|
|
|
|
class SimpleItem(BaseModel):
|
|
name: str
|
|
count: int
|
|
|
|
|
|
class Address(BaseModel):
|
|
street: str
|
|
city: str
|
|
zip_code: str
|
|
|
|
|
|
class PersonWithAddress(BaseModel):
|
|
name: str
|
|
age: int
|
|
address: Address
|
|
|
|
|
|
# ── Runner ───────────────────────────────────────────────────────────────────
|
|
|
|
class FitnessRunner:
|
|
"""Runs all fitness tests against the configured LLM endpoint."""
|
|
|
|
def __init__(self, client: LLMClient) -> None:
|
|
self.client = client
|
|
|
|
# ── Public entry point ───────────────────────────────────────────────
|
|
|
|
def run_all(self) -> int:
|
|
"""Run all fitness tests, print report, return exit code (0=pass, 1=fail)."""
|
|
categories: list[CategoryReport] = []
|
|
|
|
# Connectivity pre-check — fail fast with a clear message
|
|
try:
|
|
self._probe_connectivity()
|
|
except (openai.APIConnectionError, openai.APITimeoutError) as exc:
|
|
url = self.client.settings.llm_api_url
|
|
fallback = self.client.settings.llm_fallback_url
|
|
print(
|
|
f"\n✗ Cannot reach LLM endpoint at {url} (fallback {fallback})\n"
|
|
f" Error: {exc}\n"
|
|
)
|
|
return 1
|
|
|
|
categories.append(self._run_mandelbrot())
|
|
categories.append(self._run_json_compliance())
|
|
categories.append(self._run_instruction_following())
|
|
categories.append(self._run_diverse_battery())
|
|
|
|
self._print_report(categories)
|
|
|
|
return 0 if all(c.all_passed for c in categories) else 1
|
|
|
|
# ── Connectivity probe ───────────────────────────────────────────────
|
|
|
|
def _probe_connectivity(self) -> None:
|
|
"""Quick completion to verify the endpoint is reachable."""
|
|
self.client.complete(
|
|
system_prompt="You are a test probe.",
|
|
user_prompt="Respond with the single word: ok",
|
|
)
|
|
|
|
# ── Category 1: Mandelbrot reasoning ─────────────────────────────────
|
|
|
|
def _run_mandelbrot(self) -> CategoryReport:
|
|
cat = CategoryReport(category="Mandelbrot Reasoning")
|
|
cat.results.append(self._test_mandelbrot())
|
|
return cat
|
|
|
|
def _test_mandelbrot(self) -> TestResult:
|
|
name = "mandelbrot_area_knowledge"
|
|
t0 = time.monotonic()
|
|
try:
|
|
resp = self.client.complete(
|
|
system_prompt="You are a mathematics expert. Answer precisely and concisely.",
|
|
user_prompt=(
|
|
"What is the approximate area of the Mandelbrot set? "
|
|
"Include the numerical value and mention whether the exact area is known."
|
|
),
|
|
modality="thinking",
|
|
)
|
|
elapsed = time.monotonic() - t0
|
|
text = resp.lower()
|
|
# Check for key concepts
|
|
has_area = any(kw in text for kw in ["1.506", "1.507", "1.50659"])
|
|
has_uncertainty = any(
|
|
kw in text
|
|
for kw in ["not exactly known", "not known exactly", "approximate", "estimated", "conjecture"]
|
|
)
|
|
passed = has_area and has_uncertainty
|
|
detail = "" if passed else f"Missing: area={has_area}, uncertainty={has_uncertainty}. Response: {resp[:200]}"
|
|
return TestResult(
|
|
name=name,
|
|
passed=passed,
|
|
elapsed_seconds=round(elapsed, 2),
|
|
token_count=resp.completion_tokens,
|
|
detail=detail,
|
|
)
|
|
except Exception as exc:
|
|
return TestResult(
|
|
name=name,
|
|
passed=False,
|
|
elapsed_seconds=round(time.monotonic() - t0, 2),
|
|
detail=f"Exception: {exc}",
|
|
)
|
|
|
|
# ── Category 2: JSON compliance ──────────────────────────────────────
|
|
|
|
def _run_json_compliance(self) -> CategoryReport:
|
|
cat = CategoryReport(category="JSON Compliance")
|
|
cat.results.append(self._test_json_simple())
|
|
cat.results.append(self._test_json_nested())
|
|
return cat
|
|
|
|
def _test_json_simple(self) -> TestResult:
|
|
name = "json_simple_object"
|
|
t0 = time.monotonic()
|
|
try:
|
|
resp = self.client.complete(
|
|
system_prompt="You are a JSON generator. Output ONLY valid JSON, nothing else.",
|
|
user_prompt=(
|
|
'Generate a JSON object with exactly two keys: "name" (a string) '
|
|
'and "count" (an integer). Example structure: {"name": "...", "count": N}'
|
|
),
|
|
response_model=SimpleItem,
|
|
modality="chat",
|
|
)
|
|
elapsed = time.monotonic() - t0
|
|
return self._validate_json(name, resp, SimpleItem, elapsed)
|
|
except Exception as exc:
|
|
return TestResult(
|
|
name=name,
|
|
passed=False,
|
|
elapsed_seconds=round(time.monotonic() - t0, 2),
|
|
detail=f"Exception: {exc}",
|
|
)
|
|
|
|
def _test_json_nested(self) -> TestResult:
|
|
name = "json_nested_object"
|
|
t0 = time.monotonic()
|
|
try:
|
|
resp = self.client.complete(
|
|
system_prompt="You are a JSON generator. Output ONLY valid JSON, nothing else.",
|
|
user_prompt=(
|
|
'Generate a JSON object with keys "name" (string), "age" (integer), '
|
|
'and "address" (object with "street", "city", "zip_code" string fields).'
|
|
),
|
|
response_model=PersonWithAddress,
|
|
modality="chat",
|
|
)
|
|
elapsed = time.monotonic() - t0
|
|
return self._validate_json(name, resp, PersonWithAddress, elapsed)
|
|
except Exception as exc:
|
|
return TestResult(
|
|
name=name,
|
|
passed=False,
|
|
elapsed_seconds=round(time.monotonic() - t0, 2),
|
|
detail=f"Exception: {exc}",
|
|
)
|
|
|
|
def _validate_json(
|
|
self,
|
|
name: str,
|
|
resp: str,
|
|
model: type[BaseModel],
|
|
elapsed: float,
|
|
) -> TestResult:
|
|
"""Parse response as JSON, validate against Pydantic model."""
|
|
text = str(resp).strip()
|
|
if not text:
|
|
return TestResult(
|
|
name=name, passed=False, elapsed_seconds=round(elapsed, 2),
|
|
token_count=getattr(resp, "completion_tokens", None),
|
|
detail="Empty response from LLM",
|
|
)
|
|
try:
|
|
parsed = json.loads(text)
|
|
except json.JSONDecodeError as exc:
|
|
return TestResult(
|
|
name=name, passed=False, elapsed_seconds=round(elapsed, 2),
|
|
token_count=getattr(resp, "completion_tokens", None),
|
|
detail=f"Invalid JSON: {exc}. Raw: {text[:200]}",
|
|
)
|
|
try:
|
|
model.model_validate(parsed)
|
|
except Exception as exc:
|
|
return TestResult(
|
|
name=name, passed=False, elapsed_seconds=round(elapsed, 2),
|
|
token_count=getattr(resp, "completion_tokens", None),
|
|
detail=f"Schema validation failed: {exc}",
|
|
)
|
|
return TestResult(
|
|
name=name, passed=True, elapsed_seconds=round(elapsed, 2),
|
|
token_count=getattr(resp, "completion_tokens", None),
|
|
)
|
|
|
|
# ── Category 3: Instruction following ────────────────────────────────
|
|
|
|
def _run_instruction_following(self) -> CategoryReport:
|
|
cat = CategoryReport(category="Instruction Following")
|
|
cat.results.append(self._test_bullet_count())
|
|
cat.results.append(self._test_keyword_inclusion())
|
|
cat.results.append(self._test_lowercase_only())
|
|
return cat
|
|
|
|
def _test_bullet_count(self) -> TestResult:
|
|
name = "instruction_bullet_count"
|
|
t0 = time.monotonic()
|
|
try:
|
|
resp = self.client.complete(
|
|
system_prompt="Follow instructions exactly.",
|
|
user_prompt="List exactly 3 benefits of exercise. Use bullet points starting with '- '.",
|
|
)
|
|
elapsed = time.monotonic() - t0
|
|
lines = [l.strip() for l in str(resp).strip().splitlines() if l.strip().startswith("- ")]
|
|
passed = len(lines) == 3
|
|
detail = "" if passed else f"Expected 3 bullets, got {len(lines)}: {str(resp)[:200]}"
|
|
return TestResult(
|
|
name=name, passed=passed, elapsed_seconds=round(elapsed, 2),
|
|
token_count=resp.completion_tokens,
|
|
detail=detail,
|
|
)
|
|
except Exception as exc:
|
|
return TestResult(
|
|
name=name, passed=False,
|
|
elapsed_seconds=round(time.monotonic() - t0, 2),
|
|
detail=f"Exception: {exc}",
|
|
)
|
|
|
|
def _test_keyword_inclusion(self) -> TestResult:
|
|
name = "instruction_keyword_inclusion"
|
|
t0 = time.monotonic()
|
|
try:
|
|
resp = self.client.complete(
|
|
system_prompt="Follow instructions exactly.",
|
|
user_prompt=(
|
|
"Write one sentence about the weather. "
|
|
'You MUST include the word "elephant" somewhere in your sentence.'
|
|
),
|
|
)
|
|
elapsed = time.monotonic() - t0
|
|
passed = "elephant" in str(resp).lower()
|
|
detail = "" if passed else f"Missing keyword 'elephant'. Response: {str(resp)[:200]}"
|
|
return TestResult(
|
|
name=name, passed=passed, elapsed_seconds=round(elapsed, 2),
|
|
token_count=resp.completion_tokens,
|
|
detail=detail,
|
|
)
|
|
except Exception as exc:
|
|
return TestResult(
|
|
name=name, passed=False,
|
|
elapsed_seconds=round(time.monotonic() - t0, 2),
|
|
detail=f"Exception: {exc}",
|
|
)
|
|
|
|
def _test_lowercase_only(self) -> TestResult:
|
|
name = "instruction_lowercase_only"
|
|
t0 = time.monotonic()
|
|
try:
|
|
resp = self.client.complete(
|
|
system_prompt="Follow instructions exactly.",
|
|
user_prompt=(
|
|
"Write a short sentence about the ocean. "
|
|
"Use ONLY lowercase letters — no uppercase at all, not even at the start."
|
|
),
|
|
)
|
|
elapsed = time.monotonic() - t0
|
|
text = str(resp).strip()
|
|
# Allow non-alpha chars (punctuation, spaces, numbers) but no uppercase letters
|
|
has_upper = any(c.isupper() for c in text)
|
|
passed = not has_upper and len(text) > 5
|
|
detail = "" if passed else f"Contains uppercase or too short. Response: {text[:200]}"
|
|
return TestResult(
|
|
name=name, passed=passed, elapsed_seconds=round(elapsed, 2),
|
|
token_count=resp.completion_tokens,
|
|
detail=detail,
|
|
)
|
|
except Exception as exc:
|
|
return TestResult(
|
|
name=name, passed=False,
|
|
elapsed_seconds=round(time.monotonic() - t0, 2),
|
|
detail=f"Exception: {exc}",
|
|
)
|
|
|
|
# ── Category 4: Diverse prompt battery ───────────────────────────────
|
|
|
|
def _run_diverse_battery(self) -> CategoryReport:
|
|
cat = CategoryReport(category="Diverse Prompt Battery")
|
|
cat.results.append(self._test_summarization())
|
|
cat.results.append(self._test_classification())
|
|
cat.results.append(self._test_extraction())
|
|
return cat
|
|
|
|
def _test_summarization(self) -> TestResult:
|
|
name = "battery_summarization"
|
|
paragraph = (
|
|
"The James Webb Space Telescope (JWST) is the largest optical telescope in space. "
|
|
"Launched in December 2021, it is designed to conduct infrared astronomy. Its high "
|
|
"resolution and sensitivity allow it to view objects too old and distant for the Hubble "
|
|
"Space Telescope. Among its goals are observing the first stars and the formation of "
|
|
"the first galaxies, and detailed atmospheric characterization of exoplanets."
|
|
)
|
|
t0 = time.monotonic()
|
|
try:
|
|
resp = self.client.complete(
|
|
system_prompt="You are a concise summarizer.",
|
|
user_prompt=f"Summarize the following in exactly 2 sentences:\n\n{paragraph}",
|
|
)
|
|
elapsed = time.monotonic() - t0
|
|
text = str(resp).strip()
|
|
# Rough sentence count: split on period followed by space or end
|
|
sentences = [s.strip() for s in text.replace("! ", ". ").split(". ") if s.strip()]
|
|
# Be generous: 1-3 sentences is acceptable
|
|
passed = 1 <= len(sentences) <= 3 and len(text) > 20
|
|
detail = "" if passed else f"Expected ~2 sentences, got {len(sentences)}. Response: {text[:200]}"
|
|
return TestResult(
|
|
name=name, passed=passed, elapsed_seconds=round(elapsed, 2),
|
|
token_count=resp.completion_tokens,
|
|
detail=detail,
|
|
)
|
|
except Exception as exc:
|
|
return TestResult(
|
|
name=name, passed=False,
|
|
elapsed_seconds=round(time.monotonic() - t0, 2),
|
|
detail=f"Exception: {exc}",
|
|
)
|
|
|
|
def _test_classification(self) -> TestResult:
|
|
name = "battery_classification"
|
|
categories = ["technology", "sports", "politics", "science", "entertainment"]
|
|
t0 = time.monotonic()
|
|
try:
|
|
resp = self.client.complete(
|
|
system_prompt=(
|
|
"You are a text classifier. Respond with ONLY one word from the given categories."
|
|
),
|
|
user_prompt=(
|
|
f"Classify the following text into one of these categories: {', '.join(categories)}\n\n"
|
|
"Text: \"NASA's Perseverance rover has discovered organic molecules on Mars, "
|
|
"suggesting the planet may have once harbored microbial life.\"\n\n"
|
|
"Category:"
|
|
),
|
|
)
|
|
elapsed = time.monotonic() - t0
|
|
answer = str(resp).strip().lower().rstrip(".")
|
|
passed = answer in categories
|
|
detail = "" if passed else f"Response '{answer}' not in {categories}"
|
|
return TestResult(
|
|
name=name, passed=passed, elapsed_seconds=round(elapsed, 2),
|
|
token_count=resp.completion_tokens,
|
|
detail=detail,
|
|
)
|
|
except Exception as exc:
|
|
return TestResult(
|
|
name=name, passed=False,
|
|
elapsed_seconds=round(time.monotonic() - t0, 2),
|
|
detail=f"Exception: {exc}",
|
|
)
|
|
|
|
def _test_extraction(self) -> TestResult:
|
|
name = "battery_extraction"
|
|
t0 = time.monotonic()
|
|
try:
|
|
resp = self.client.complete(
|
|
system_prompt="You are a data extractor. Output ONLY valid JSON, nothing else.",
|
|
user_prompt=(
|
|
"Extract the following fields as a JSON object: "
|
|
'"event_name", "date", "location"\n\n'
|
|
"Text: \"The annual Tech Summit 2026 will be held on March 15, 2026 "
|
|
'in San Francisco, California."\n\n'
|
|
"JSON:"
|
|
),
|
|
response_model=BaseModel, # triggers json mode
|
|
modality="chat",
|
|
)
|
|
elapsed = time.monotonic() - t0
|
|
text = str(resp).strip()
|
|
if not text:
|
|
return TestResult(
|
|
name=name, passed=False, elapsed_seconds=round(elapsed, 2),
|
|
token_count=getattr(resp, "completion_tokens", None),
|
|
detail="Empty response from LLM",
|
|
)
|
|
try:
|
|
parsed = json.loads(text)
|
|
except json.JSONDecodeError as exc:
|
|
return TestResult(
|
|
name=name, passed=False, elapsed_seconds=round(elapsed, 2),
|
|
token_count=getattr(resp, "completion_tokens", None),
|
|
detail=f"Invalid JSON: {exc}. Raw: {text[:200]}",
|
|
)
|
|
required_keys = {"event_name", "date", "location"}
|
|
present = set(parsed.keys()) & required_keys
|
|
passed = present == required_keys
|
|
detail = "" if passed else f"Missing keys: {required_keys - present}"
|
|
return TestResult(
|
|
name=name, passed=passed, elapsed_seconds=round(elapsed, 2),
|
|
token_count=getattr(resp, "completion_tokens", None),
|
|
detail=detail,
|
|
)
|
|
except Exception as exc:
|
|
return TestResult(
|
|
name=name, passed=False,
|
|
elapsed_seconds=round(time.monotonic() - t0, 2),
|
|
detail=f"Exception: {exc}",
|
|
)
|
|
|
|
# ── Report formatting ────────────────────────────────────────────────
|
|
|
|
def _print_report(self, categories: list[CategoryReport]) -> None:
|
|
"""Print a formatted pass/fail report to stdout."""
|
|
total = 0
|
|
passed_count = 0
|
|
|
|
print("\n" + "=" * 60)
|
|
print(" FYN-LLM FITNESS REPORT")
|
|
print("=" * 60)
|
|
|
|
for cat in categories:
|
|
status = "✓ PASS" if cat.all_passed else "✗ FAIL"
|
|
print(f"\n [{status}] {cat.category}")
|
|
for r in cat.results:
|
|
total += 1
|
|
icon = "✓" if r.passed else "✗"
|
|
tokens = f" ({r.token_count} tok)" if r.token_count else ""
|
|
print(f" {icon} {r.name} [{r.elapsed_seconds}s{tokens}]")
|
|
if r.detail:
|
|
# Indent detail lines
|
|
for line in r.detail.splitlines():
|
|
print(f" {line}")
|
|
if r.passed:
|
|
passed_count += 1
|
|
|
|
print("\n" + "-" * 60)
|
|
print(f" Total: {passed_count}/{total} passed")
|
|
if passed_count == total:
|
|
print(" Result: ✓ ALL PASS")
|
|
else:
|
|
print(f" Result: ✗ {total - passed_count} FAILED")
|
|
print("=" * 60 + "\n")
|