"""Format scorer — checks if LLM output matches expected formats. Supports four format checks: - json: valid JSON parse - markdown: has headers and/or lists - length: within min/max token count - structure: matches a provided JSON schema """ import json import re from typing import Any from engine.scorers.base import BaseScorer class FormatScorer(BaseScorer): """Score outputs based on format compliance. Args: format_type: One of "json", "markdown", "length", "structure". min_tokens: Minimum token count (for "length" mode). max_tokens: Maximum token count (for "length" mode). json_schema: JSON schema dict (for "structure" mode). """ VALID_FORMATS = {"json", "markdown", "length", "structure"} def __init__( self, format_type: str = "json", min_tokens: int | None = None, max_tokens: int | None = None, json_schema: dict | None = None, ) -> None: if format_type not in self.VALID_FORMATS: raise ValueError( f"Invalid format_type '{format_type}'. " f"Must be one of: {', '.join(sorted(self.VALID_FORMATS))}" ) self.format_type = format_type self.min_tokens = min_tokens self.max_tokens = max_tokens self.json_schema = json_schema @property def name(self) -> str: return "format" def score(self, input_data: Any, output: str, context: dict) -> float: """Score output based on format compliance. Returns 1.0 if the output matches the expected format, 0.0 otherwise. For length checks, returns a proportional score based on how close the output is to the acceptable range. """ checkers = { "json": self._check_json, "markdown": self._check_markdown, "length": self._check_length, "structure": self._check_structure, } return checkers[self.format_type](output) def _check_json(self, output: str) -> float: """Check if output is valid JSON.""" try: json.loads(output.strip()) return 1.0 except (json.JSONDecodeError, ValueError): return 0.0 def _check_markdown(self, output: str) -> float: """Check if output contains markdown formatting (headers and/or lists). Scoring: - 0.5 for having headers (lines starting with #) - 0.5 for having lists (lines starting with - or * or numbered) - 1.0 for having both """ score = 0.0 # Check for headers if re.search(r"^#{1,6}\s+\S", output, re.MULTILINE): score += 0.5 # Check for lists (unordered or ordered) if re.search(r"^[\s]*[-*]\s+\S", output, re.MULTILINE) or re.search( r"^[\s]*\d+[.)]\s+\S", output, re.MULTILINE ): score += 0.5 return score def _check_length(self, output: str) -> float: """Check if output length is within min/max token bounds. Uses whitespace tokenization as an approximation. Returns 1.0 if within bounds, scaled score if outside. """ token_count = len(output.split()) if self.min_tokens is None and self.max_tokens is None: return 1.0 min_t = self.min_tokens or 0 max_t = self.max_tokens or float("inf") if min_t <= token_count <= max_t: return 1.0 # Score proportionally based on distance from acceptable range if token_count < min_t: return max(0.0, token_count / min_t) if min_t > 0 else 0.0 # token_count > max_t if max_t > 0 and max_t != float("inf"): # Linearly decay: at 2x max, score = 0 ratio = max_t / token_count return max(0.0, 2 * ratio - 1.0) return 0.0 def _check_structure(self, output: str) -> float: """Check if output matches a JSON schema. Returns 1.0 if valid against the schema, 0.0 otherwise. """ if self.json_schema is None: return 0.0 try: parsed = json.loads(output.strip()) except (json.JSONDecodeError, ValueError): return 0.0 try: import jsonschema jsonschema.validate(instance=parsed, schema=self.json_schema) return 1.0 except ImportError: # Fallback: basic type and required-field checking without jsonschema return self._basic_schema_check(parsed, self.json_schema) except jsonschema.ValidationError: return 0.0 def _basic_schema_check(self, data: Any, schema: dict) -> float: """Basic JSON schema validation without jsonschema library. Checks type and required fields only. """ schema_type = schema.get("type") if schema_type: type_map = { "object": dict, "array": list, "string": str, "number": (int, float), "integer": int, "boolean": bool, "null": type(None), } expected = type_map.get(schema_type) if expected and not isinstance(data, expected): return 0.0 if schema_type == "object" and isinstance(data, dict): required = schema.get("required", []) if required: present = sum(1 for k in required if k in data) return present / len(required) return 1.0