"""FYN-LLM fitness test runner. Tests four categories: 1. Mandelbrot reasoning — factual knowledge / reasoning depth 2. JSON compliance — simple and nested structured output 3. Instruction following — bullet count, keyword inclusion, casing 4. Diverse prompt battery — summarization, classification, extraction """ from __future__ import annotations import json import logging import time from dataclasses import dataclass, field import openai from pydantic import BaseModel from pipeline.llm_client import LLMClient logger = logging.getLogger(__name__) # ── Result types ───────────────────────────────────────────────────────────── @dataclass class TestResult: """Outcome of a single fitness test.""" name: str passed: bool elapsed_seconds: float token_count: int | None = None detail: str = "" @dataclass class CategoryReport: """Results for one test category.""" category: str results: list[TestResult] = field(default_factory=list) @property def all_passed(self) -> bool: return all(r.passed for r in self.results) # ── Pydantic models for JSON compliance tests ──────────────────────────────── class SimpleItem(BaseModel): name: str count: int class Address(BaseModel): street: str city: str zip_code: str class PersonWithAddress(BaseModel): name: str age: int address: Address # ── Runner ─────────────────────────────────────────────────────────────────── class FitnessRunner: """Runs all fitness tests against the configured LLM endpoint.""" def __init__(self, client: LLMClient) -> None: self.client = client # ── Public entry point ─────────────────────────────────────────────── def run_all(self) -> int: """Run all fitness tests, print report, return exit code (0=pass, 1=fail).""" categories: list[CategoryReport] = [] # Connectivity pre-check — fail fast with a clear message try: self._probe_connectivity() except (openai.APIConnectionError, openai.APITimeoutError) as exc: url = self.client.settings.llm_api_url fallback = self.client.settings.llm_fallback_url print( f"\n✗ Cannot reach LLM endpoint at {url} (fallback {fallback})\n" f" Error: {exc}\n" ) return 1 categories.append(self._run_mandelbrot()) categories.append(self._run_json_compliance()) categories.append(self._run_instruction_following()) categories.append(self._run_diverse_battery()) self._print_report(categories) return 0 if all(c.all_passed for c in categories) else 1 # ── Connectivity probe ─────────────────────────────────────────────── def _probe_connectivity(self) -> None: """Quick completion to verify the endpoint is reachable.""" self.client.complete( system_prompt="You are a test probe.", user_prompt="Respond with the single word: ok", ) # ── Category 1: Mandelbrot reasoning ───────────────────────────────── def _run_mandelbrot(self) -> CategoryReport: cat = CategoryReport(category="Mandelbrot Reasoning") cat.results.append(self._test_mandelbrot()) return cat def _test_mandelbrot(self) -> TestResult: name = "mandelbrot_area_knowledge" t0 = time.monotonic() try: resp = self.client.complete( system_prompt="You are a mathematics expert. Answer precisely and concisely.", user_prompt=( "What is the approximate area of the Mandelbrot set? " "Include the numerical value and mention whether the exact area is known." ), modality="thinking", ) elapsed = time.monotonic() - t0 text = resp.lower() # Check for key concepts has_area = any(kw in text for kw in ["1.506", "1.507", "1.50659"]) has_uncertainty = any( kw in text for kw in ["not exactly known", "not known exactly", "approximate", "estimated", "conjecture"] ) passed = has_area and has_uncertainty detail = "" if passed else f"Missing: area={has_area}, uncertainty={has_uncertainty}. Response: {resp[:200]}" return TestResult( name=name, passed=passed, elapsed_seconds=round(elapsed, 2), token_count=resp.completion_tokens, detail=detail, ) except Exception as exc: return TestResult( name=name, passed=False, elapsed_seconds=round(time.monotonic() - t0, 2), detail=f"Exception: {exc}", ) # ── Category 2: JSON compliance ────────────────────────────────────── def _run_json_compliance(self) -> CategoryReport: cat = CategoryReport(category="JSON Compliance") cat.results.append(self._test_json_simple()) cat.results.append(self._test_json_nested()) return cat def _test_json_simple(self) -> TestResult: name = "json_simple_object" t0 = time.monotonic() try: resp = self.client.complete( system_prompt="You are a JSON generator. Output ONLY valid JSON, nothing else.", user_prompt=( 'Generate a JSON object with exactly two keys: "name" (a string) ' 'and "count" (an integer). Example structure: {"name": "...", "count": N}' ), response_model=SimpleItem, modality="chat", ) elapsed = time.monotonic() - t0 return self._validate_json(name, resp, SimpleItem, elapsed) except Exception as exc: return TestResult( name=name, passed=False, elapsed_seconds=round(time.monotonic() - t0, 2), detail=f"Exception: {exc}", ) def _test_json_nested(self) -> TestResult: name = "json_nested_object" t0 = time.monotonic() try: resp = self.client.complete( system_prompt="You are a JSON generator. Output ONLY valid JSON, nothing else.", user_prompt=( 'Generate a JSON object with keys "name" (string), "age" (integer), ' 'and "address" (object with "street", "city", "zip_code" string fields).' ), response_model=PersonWithAddress, modality="chat", ) elapsed = time.monotonic() - t0 return self._validate_json(name, resp, PersonWithAddress, elapsed) except Exception as exc: return TestResult( name=name, passed=False, elapsed_seconds=round(time.monotonic() - t0, 2), detail=f"Exception: {exc}", ) def _validate_json( self, name: str, resp: str, model: type[BaseModel], elapsed: float, ) -> TestResult: """Parse response as JSON, validate against Pydantic model.""" text = str(resp).strip() if not text: return TestResult( name=name, passed=False, elapsed_seconds=round(elapsed, 2), token_count=getattr(resp, "completion_tokens", None), detail="Empty response from LLM", ) try: parsed = json.loads(text) except json.JSONDecodeError as exc: return TestResult( name=name, passed=False, elapsed_seconds=round(elapsed, 2), token_count=getattr(resp, "completion_tokens", None), detail=f"Invalid JSON: {exc}. Raw: {text[:200]}", ) try: model.model_validate(parsed) except Exception as exc: return TestResult( name=name, passed=False, elapsed_seconds=round(elapsed, 2), token_count=getattr(resp, "completion_tokens", None), detail=f"Schema validation failed: {exc}", ) return TestResult( name=name, passed=True, elapsed_seconds=round(elapsed, 2), token_count=getattr(resp, "completion_tokens", None), ) # ── Category 3: Instruction following ──────────────────────────────── def _run_instruction_following(self) -> CategoryReport: cat = CategoryReport(category="Instruction Following") cat.results.append(self._test_bullet_count()) cat.results.append(self._test_keyword_inclusion()) cat.results.append(self._test_lowercase_only()) return cat def _test_bullet_count(self) -> TestResult: name = "instruction_bullet_count" t0 = time.monotonic() try: resp = self.client.complete( system_prompt="Follow instructions exactly.", user_prompt="List exactly 3 benefits of exercise. Use bullet points starting with '- '.", ) elapsed = time.monotonic() - t0 lines = [l.strip() for l in str(resp).strip().splitlines() if l.strip().startswith("- ")] passed = len(lines) == 3 detail = "" if passed else f"Expected 3 bullets, got {len(lines)}: {str(resp)[:200]}" return TestResult( name=name, passed=passed, elapsed_seconds=round(elapsed, 2), token_count=resp.completion_tokens, detail=detail, ) except Exception as exc: return TestResult( name=name, passed=False, elapsed_seconds=round(time.monotonic() - t0, 2), detail=f"Exception: {exc}", ) def _test_keyword_inclusion(self) -> TestResult: name = "instruction_keyword_inclusion" t0 = time.monotonic() try: resp = self.client.complete( system_prompt="Follow instructions exactly.", user_prompt=( "Write one sentence about the weather. " 'You MUST include the word "elephant" somewhere in your sentence.' ), ) elapsed = time.monotonic() - t0 passed = "elephant" in str(resp).lower() detail = "" if passed else f"Missing keyword 'elephant'. Response: {str(resp)[:200]}" return TestResult( name=name, passed=passed, elapsed_seconds=round(elapsed, 2), token_count=resp.completion_tokens, detail=detail, ) except Exception as exc: return TestResult( name=name, passed=False, elapsed_seconds=round(time.monotonic() - t0, 2), detail=f"Exception: {exc}", ) def _test_lowercase_only(self) -> TestResult: name = "instruction_lowercase_only" t0 = time.monotonic() try: resp = self.client.complete( system_prompt="Follow instructions exactly.", user_prompt=( "Write a short sentence about the ocean. " "Use ONLY lowercase letters — no uppercase at all, not even at the start." ), ) elapsed = time.monotonic() - t0 text = str(resp).strip() # Allow non-alpha chars (punctuation, spaces, numbers) but no uppercase letters has_upper = any(c.isupper() for c in text) passed = not has_upper and len(text) > 5 detail = "" if passed else f"Contains uppercase or too short. Response: {text[:200]}" return TestResult( name=name, passed=passed, elapsed_seconds=round(elapsed, 2), token_count=resp.completion_tokens, detail=detail, ) except Exception as exc: return TestResult( name=name, passed=False, elapsed_seconds=round(time.monotonic() - t0, 2), detail=f"Exception: {exc}", ) # ── Category 4: Diverse prompt battery ─────────────────────────────── def _run_diverse_battery(self) -> CategoryReport: cat = CategoryReport(category="Diverse Prompt Battery") cat.results.append(self._test_summarization()) cat.results.append(self._test_classification()) cat.results.append(self._test_extraction()) return cat def _test_summarization(self) -> TestResult: name = "battery_summarization" paragraph = ( "The James Webb Space Telescope (JWST) is the largest optical telescope in space. " "Launched in December 2021, it is designed to conduct infrared astronomy. Its high " "resolution and sensitivity allow it to view objects too old and distant for the Hubble " "Space Telescope. Among its goals are observing the first stars and the formation of " "the first galaxies, and detailed atmospheric characterization of exoplanets." ) t0 = time.monotonic() try: resp = self.client.complete( system_prompt="You are a concise summarizer.", user_prompt=f"Summarize the following in exactly 2 sentences:\n\n{paragraph}", ) elapsed = time.monotonic() - t0 text = str(resp).strip() # Rough sentence count: split on period followed by space or end sentences = [s.strip() for s in text.replace("! ", ". ").split(". ") if s.strip()] # Be generous: 1-3 sentences is acceptable passed = 1 <= len(sentences) <= 3 and len(text) > 20 detail = "" if passed else f"Expected ~2 sentences, got {len(sentences)}. Response: {text[:200]}" return TestResult( name=name, passed=passed, elapsed_seconds=round(elapsed, 2), token_count=resp.completion_tokens, detail=detail, ) except Exception as exc: return TestResult( name=name, passed=False, elapsed_seconds=round(time.monotonic() - t0, 2), detail=f"Exception: {exc}", ) def _test_classification(self) -> TestResult: name = "battery_classification" categories = ["technology", "sports", "politics", "science", "entertainment"] t0 = time.monotonic() try: resp = self.client.complete( system_prompt=( "You are a text classifier. Respond with ONLY one word from the given categories." ), user_prompt=( f"Classify the following text into one of these categories: {', '.join(categories)}\n\n" "Text: \"NASA's Perseverance rover has discovered organic molecules on Mars, " "suggesting the planet may have once harbored microbial life.\"\n\n" "Category:" ), ) elapsed = time.monotonic() - t0 answer = str(resp).strip().lower().rstrip(".") passed = answer in categories detail = "" if passed else f"Response '{answer}' not in {categories}" return TestResult( name=name, passed=passed, elapsed_seconds=round(elapsed, 2), token_count=resp.completion_tokens, detail=detail, ) except Exception as exc: return TestResult( name=name, passed=False, elapsed_seconds=round(time.monotonic() - t0, 2), detail=f"Exception: {exc}", ) def _test_extraction(self) -> TestResult: name = "battery_extraction" t0 = time.monotonic() try: resp = self.client.complete( system_prompt="You are a data extractor. Output ONLY valid JSON, nothing else.", user_prompt=( "Extract the following fields as a JSON object: " '"event_name", "date", "location"\n\n' "Text: \"The annual Tech Summit 2026 will be held on March 15, 2026 " 'in San Francisco, California."\n\n' "JSON:" ), response_model=BaseModel, # triggers json mode modality="chat", ) elapsed = time.monotonic() - t0 text = str(resp).strip() if not text: return TestResult( name=name, passed=False, elapsed_seconds=round(elapsed, 2), token_count=getattr(resp, "completion_tokens", None), detail="Empty response from LLM", ) try: parsed = json.loads(text) except json.JSONDecodeError as exc: return TestResult( name=name, passed=False, elapsed_seconds=round(elapsed, 2), token_count=getattr(resp, "completion_tokens", None), detail=f"Invalid JSON: {exc}. Raw: {text[:200]}", ) required_keys = {"event_name", "date", "location"} present = set(parsed.keys()) & required_keys passed = present == required_keys detail = "" if passed else f"Missing keys: {required_keys - present}" return TestResult( name=name, passed=passed, elapsed_seconds=round(elapsed, 2), token_count=getattr(resp, "completion_tokens", None), detail=detail, ) except Exception as exc: return TestResult( name=name, passed=False, elapsed_seconds=round(time.monotonic() - t0, 2), detail=f"Exception: {exc}", ) # ── Report formatting ──────────────────────────────────────────────── def _print_report(self, categories: list[CategoryReport]) -> None: """Print a formatted pass/fail report to stdout.""" total = 0 passed_count = 0 print("\n" + "=" * 60) print(" FYN-LLM FITNESS REPORT") print("=" * 60) for cat in categories: status = "✓ PASS" if cat.all_passed else "✗ FAIL" print(f"\n [{status}] {cat.category}") for r in cat.results: total += 1 icon = "✓" if r.passed else "✗" tokens = f" ({r.token_count} tok)" if r.token_count else "" print(f" {icon} {r.name} [{r.elapsed_seconds}s{tokens}]") if r.detail: # Indent detail lines for line in r.detail.splitlines(): print(f" {line}") if r.passed: passed_count += 1 print("\n" + "-" * 60) print(f" Total: {passed_count}/{total} passed") if passed_count == total: print(" Result: ✓ ALL PASS") else: print(f" Result: ✗ {total - passed_count} FAILED") print("=" * 60 + "\n")