Adds format.py scorer supporting four validation modes: - json: validates parseable JSON - markdown: checks for headers (0.5) and lists (0.5) - length: proportional scoring against min/max token bounds - structure: JSON schema validation via jsonschema library Includes 38 passing tests covering all format types, edge cases, and async delegation.
173 lines
5.5 KiB
Python
173 lines
5.5 KiB
Python
"""Format scorer — checks if LLM output matches expected formats.
|
|
|
|
Supports four format checks:
|
|
- json: valid JSON parse
|
|
- markdown: has headers and/or lists
|
|
- length: within min/max token count
|
|
- structure: matches a provided JSON schema
|
|
"""
|
|
|
|
import json
|
|
import re
|
|
from typing import Any
|
|
|
|
from engine.scorers.base import BaseScorer
|
|
|
|
|
|
class FormatScorer(BaseScorer):
|
|
"""Score outputs based on format compliance.
|
|
|
|
Args:
|
|
format_type: One of "json", "markdown", "length", "structure".
|
|
min_tokens: Minimum token count (for "length" mode).
|
|
max_tokens: Maximum token count (for "length" mode).
|
|
json_schema: JSON schema dict (for "structure" mode).
|
|
"""
|
|
|
|
VALID_FORMATS = {"json", "markdown", "length", "structure"}
|
|
|
|
def __init__(
|
|
self,
|
|
format_type: str = "json",
|
|
min_tokens: int | None = None,
|
|
max_tokens: int | None = None,
|
|
json_schema: dict | None = None,
|
|
) -> None:
|
|
if format_type not in self.VALID_FORMATS:
|
|
raise ValueError(
|
|
f"Invalid format_type '{format_type}'. "
|
|
f"Must be one of: {', '.join(sorted(self.VALID_FORMATS))}"
|
|
)
|
|
self.format_type = format_type
|
|
self.min_tokens = min_tokens
|
|
self.max_tokens = max_tokens
|
|
self.json_schema = json_schema
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return "format"
|
|
|
|
def score(self, input_data: Any, output: str, context: dict) -> float:
|
|
"""Score output based on format compliance.
|
|
|
|
Returns 1.0 if the output matches the expected format, 0.0 otherwise.
|
|
For length checks, returns a proportional score based on how close the
|
|
output is to the acceptable range.
|
|
"""
|
|
checkers = {
|
|
"json": self._check_json,
|
|
"markdown": self._check_markdown,
|
|
"length": self._check_length,
|
|
"structure": self._check_structure,
|
|
}
|
|
return checkers[self.format_type](output)
|
|
|
|
def _check_json(self, output: str) -> float:
|
|
"""Check if output is valid JSON."""
|
|
try:
|
|
json.loads(output.strip())
|
|
return 1.0
|
|
except (json.JSONDecodeError, ValueError):
|
|
return 0.0
|
|
|
|
def _check_markdown(self, output: str) -> float:
|
|
"""Check if output contains markdown formatting (headers and/or lists).
|
|
|
|
Scoring:
|
|
- 0.5 for having headers (lines starting with #)
|
|
- 0.5 for having lists (lines starting with - or * or numbered)
|
|
- 1.0 for having both
|
|
"""
|
|
score = 0.0
|
|
|
|
# Check for headers
|
|
if re.search(r"^#{1,6}\s+\S", output, re.MULTILINE):
|
|
score += 0.5
|
|
|
|
# Check for lists (unordered or ordered)
|
|
if re.search(r"^[\s]*[-*]\s+\S", output, re.MULTILINE) or re.search(
|
|
r"^[\s]*\d+[.)]\s+\S", output, re.MULTILINE
|
|
):
|
|
score += 0.5
|
|
|
|
return score
|
|
|
|
def _check_length(self, output: str) -> float:
|
|
"""Check if output length is within min/max token bounds.
|
|
|
|
Uses whitespace tokenization as an approximation.
|
|
Returns 1.0 if within bounds, scaled score if outside.
|
|
"""
|
|
token_count = len(output.split())
|
|
|
|
if self.min_tokens is None and self.max_tokens is None:
|
|
return 1.0
|
|
|
|
min_t = self.min_tokens or 0
|
|
max_t = self.max_tokens or float("inf")
|
|
|
|
if min_t <= token_count <= max_t:
|
|
return 1.0
|
|
|
|
# Score proportionally based on distance from acceptable range
|
|
if token_count < min_t:
|
|
return max(0.0, token_count / min_t) if min_t > 0 else 0.0
|
|
|
|
# token_count > max_t
|
|
if max_t > 0 and max_t != float("inf"):
|
|
# Linearly decay: at 2x max, score = 0
|
|
ratio = max_t / token_count
|
|
return max(0.0, 2 * ratio - 1.0)
|
|
|
|
return 0.0
|
|
|
|
def _check_structure(self, output: str) -> float:
|
|
"""Check if output matches a JSON schema.
|
|
|
|
Returns 1.0 if valid against the schema, 0.0 otherwise.
|
|
"""
|
|
if self.json_schema is None:
|
|
return 0.0
|
|
|
|
try:
|
|
parsed = json.loads(output.strip())
|
|
except (json.JSONDecodeError, ValueError):
|
|
return 0.0
|
|
|
|
try:
|
|
import jsonschema
|
|
jsonschema.validate(instance=parsed, schema=self.json_schema)
|
|
return 1.0
|
|
except ImportError:
|
|
# Fallback: basic type and required-field checking without jsonschema
|
|
return self._basic_schema_check(parsed, self.json_schema)
|
|
except jsonschema.ValidationError:
|
|
return 0.0
|
|
|
|
def _basic_schema_check(self, data: Any, schema: dict) -> float:
|
|
"""Basic JSON schema validation without jsonschema library.
|
|
|
|
Checks type and required fields only.
|
|
"""
|
|
schema_type = schema.get("type")
|
|
if schema_type:
|
|
type_map = {
|
|
"object": dict,
|
|
"array": list,
|
|
"string": str,
|
|
"number": (int, float),
|
|
"integer": int,
|
|
"boolean": bool,
|
|
"null": type(None),
|
|
}
|
|
expected = type_map.get(schema_type)
|
|
if expected and not isinstance(data, expected):
|
|
return 0.0
|
|
|
|
if schema_type == "object" and isinstance(data, dict):
|
|
required = schema.get("required", [])
|
|
if required:
|
|
present = sum(1 for k in required if k in data)
|
|
return present / len(required)
|
|
|
|
return 1.0
|