diff --git a/Auto Run Docs/02a-backend-engine.md b/Auto Run Docs/02a-backend-engine.md index f4edafe..2e778e4 100644 --- a/Auto Run Docs/02a-backend-engine.md +++ b/Auto Run Docs/02a-backend-engine.md @@ -20,7 +20,8 @@ Implement the core experiment execution engine: LLM adapters, response caching, - [x] Implement backend/engine/scorers/embedding.py — uses a configurable embedding endpoint (Ollama nomic-embed-text or any OpenAI-compatible embedding API) to compute cosine similarity between output and reference answer. Normalize to 0.0–1.0 range. -- [ ] Implement backend/engine/scorers/format.py — checks if output matches expected format. Supports: json (valid JSON parse), markdown (has headers, lists), length (within min/max token count), structure (matches a provided JSON schema). +- [x] Implement backend/engine/scorers/format.py — checks if output matches expected format. Supports: json (valid JSON parse), markdown (has headers, lists), length (within min/max token count), structure (matches a provided JSON schema). + - [ ] Implement backend/engine/scorers/keyword.py — checks for presence/absence of required keywords in output. Configurable with required_present and required_absent lists. Score = (found / required) ratio. diff --git a/backend/engine/scorers/__init__.py b/backend/engine/scorers/__init__.py index afaddb9..1c70065 100644 --- a/backend/engine/scorers/__init__.py +++ b/backend/engine/scorers/__init__.py @@ -2,5 +2,6 @@ from engine.scorers.base import BaseScorer from engine.scorers.embedding import EmbeddingScorer +from engine.scorers.format import FormatScorer -__all__ = ["BaseScorer", "EmbeddingScorer"] +__all__ = ["BaseScorer", "EmbeddingScorer", "FormatScorer"] diff --git a/backend/engine/scorers/format.py b/backend/engine/scorers/format.py new file mode 100644 index 0000000..e0d7210 --- /dev/null +++ b/backend/engine/scorers/format.py @@ -0,0 +1,173 @@ +"""Format scorer — checks if LLM output matches expected formats. + +Supports four format checks: +- json: valid JSON parse +- markdown: has headers and/or lists +- length: within min/max token count +- structure: matches a provided JSON schema +""" + +import json +import re +from typing import Any + +from engine.scorers.base import BaseScorer + + +class FormatScorer(BaseScorer): + """Score outputs based on format compliance. + + Args: + format_type: One of "json", "markdown", "length", "structure". + min_tokens: Minimum token count (for "length" mode). + max_tokens: Maximum token count (for "length" mode). + json_schema: JSON schema dict (for "structure" mode). + """ + + VALID_FORMATS = {"json", "markdown", "length", "structure"} + + def __init__( + self, + format_type: str = "json", + min_tokens: int | None = None, + max_tokens: int | None = None, + json_schema: dict | None = None, + ) -> None: + if format_type not in self.VALID_FORMATS: + raise ValueError( + f"Invalid format_type '{format_type}'. " + f"Must be one of: {', '.join(sorted(self.VALID_FORMATS))}" + ) + self.format_type = format_type + self.min_tokens = min_tokens + self.max_tokens = max_tokens + self.json_schema = json_schema + + @property + def name(self) -> str: + return "format" + + def score(self, input_data: Any, output: str, context: dict) -> float: + """Score output based on format compliance. + + Returns 1.0 if the output matches the expected format, 0.0 otherwise. + For length checks, returns a proportional score based on how close the + output is to the acceptable range. + """ + checkers = { + "json": self._check_json, + "markdown": self._check_markdown, + "length": self._check_length, + "structure": self._check_structure, + } + return checkers[self.format_type](output) + + def _check_json(self, output: str) -> float: + """Check if output is valid JSON.""" + try: + json.loads(output.strip()) + return 1.0 + except (json.JSONDecodeError, ValueError): + return 0.0 + + def _check_markdown(self, output: str) -> float: + """Check if output contains markdown formatting (headers and/or lists). + + Scoring: + - 0.5 for having headers (lines starting with #) + - 0.5 for having lists (lines starting with - or * or numbered) + - 1.0 for having both + """ + score = 0.0 + + # Check for headers + if re.search(r"^#{1,6}\s+\S", output, re.MULTILINE): + score += 0.5 + + # Check for lists (unordered or ordered) + if re.search(r"^[\s]*[-*]\s+\S", output, re.MULTILINE) or re.search( + r"^[\s]*\d+[.)]\s+\S", output, re.MULTILINE + ): + score += 0.5 + + return score + + def _check_length(self, output: str) -> float: + """Check if output length is within min/max token bounds. + + Uses whitespace tokenization as an approximation. + Returns 1.0 if within bounds, scaled score if outside. + """ + token_count = len(output.split()) + + if self.min_tokens is None and self.max_tokens is None: + return 1.0 + + min_t = self.min_tokens or 0 + max_t = self.max_tokens or float("inf") + + if min_t <= token_count <= max_t: + return 1.0 + + # Score proportionally based on distance from acceptable range + if token_count < min_t: + return max(0.0, token_count / min_t) if min_t > 0 else 0.0 + + # token_count > max_t + if max_t > 0 and max_t != float("inf"): + # Linearly decay: at 2x max, score = 0 + ratio = max_t / token_count + return max(0.0, 2 * ratio - 1.0) + + return 0.0 + + def _check_structure(self, output: str) -> float: + """Check if output matches a JSON schema. + + Returns 1.0 if valid against the schema, 0.0 otherwise. + """ + if self.json_schema is None: + return 0.0 + + try: + parsed = json.loads(output.strip()) + except (json.JSONDecodeError, ValueError): + return 0.0 + + try: + import jsonschema + jsonschema.validate(instance=parsed, schema=self.json_schema) + return 1.0 + except ImportError: + # Fallback: basic type and required-field checking without jsonschema + return self._basic_schema_check(parsed, self.json_schema) + except jsonschema.ValidationError: + return 0.0 + + def _basic_schema_check(self, data: Any, schema: dict) -> float: + """Basic JSON schema validation without jsonschema library. + + Checks type and required fields only. + """ + schema_type = schema.get("type") + if schema_type: + type_map = { + "object": dict, + "array": list, + "string": str, + "number": (int, float), + "integer": int, + "boolean": bool, + "null": type(None), + } + expected = type_map.get(schema_type) + if expected and not isinstance(data, expected): + return 0.0 + + if schema_type == "object" and isinstance(data, dict): + required = schema.get("required", []) + if required: + present = sum(1 for k in required if k in data) + return present / len(required) + + return 1.0 diff --git a/backend/requirements.txt b/backend/requirements.txt index fd8b4eb..74dcde4 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -15,3 +15,4 @@ psycopg2-binary>=2.9,<3.0 aiosqlite>=0.20,<1.0 python-multipart>=0.0.9 jinja2>=3.1,<4.0 +jsonschema>=4.20,<5.0 diff --git a/backend/tests/test_scorer_format.py b/backend/tests/test_scorer_format.py new file mode 100644 index 0000000..ee1602e --- /dev/null +++ b/backend/tests/test_scorer_format.py @@ -0,0 +1,228 @@ +"""Tests for the FormatScorer.""" + +import asyncio +import json +from typing import Any + +import pytest + +from engine.scorers.format import FormatScorer + + +class TestFormatScorerInit: + def test_valid_format_types(self): + for fmt in ("json", "markdown", "length", "structure"): + scorer = FormatScorer(format_type=fmt) + assert scorer.format_type == fmt + + def test_invalid_format_type_raises(self): + with pytest.raises(ValueError, match="Invalid format_type"): + FormatScorer(format_type="xml") + + def test_name_property(self): + scorer = FormatScorer() + assert scorer.name == "format" + + def test_is_base_scorer(self): + from engine.scorers.base import BaseScorer + scorer = FormatScorer() + assert isinstance(scorer, BaseScorer) + + +class TestJsonFormat: + def test_valid_json_object(self): + scorer = FormatScorer(format_type="json") + assert scorer.score(None, '{"key": "value"}', {}) == 1.0 + + def test_valid_json_array(self): + scorer = FormatScorer(format_type="json") + assert scorer.score(None, '[1, 2, 3]', {}) == 1.0 + + def test_valid_json_string(self): + scorer = FormatScorer(format_type="json") + assert scorer.score(None, '"hello"', {}) == 1.0 + + def test_valid_json_number(self): + scorer = FormatScorer(format_type="json") + assert scorer.score(None, '42', {}) == 1.0 + + def test_valid_json_with_whitespace(self): + scorer = FormatScorer(format_type="json") + assert scorer.score(None, ' {"key": "value"} ', {}) == 1.0 + + def test_invalid_json(self): + scorer = FormatScorer(format_type="json") + assert scorer.score(None, "not json at all", {}) == 0.0 + + def test_empty_string(self): + scorer = FormatScorer(format_type="json") + assert scorer.score(None, "", {}) == 0.0 + + def test_partial_json(self): + scorer = FormatScorer(format_type="json") + assert scorer.score(None, '{"key":', {}) == 0.0 + + +class TestMarkdownFormat: + def test_headers_only(self): + scorer = FormatScorer(format_type="markdown") + output = "# Title\n\nSome text here." + assert scorer.score(None, output, {}) == 0.5 + + def test_lists_only_unordered(self): + scorer = FormatScorer(format_type="markdown") + output = "Some text\n- item one\n- item two" + assert scorer.score(None, output, {}) == 0.5 + + def test_lists_only_ordered(self): + scorer = FormatScorer(format_type="markdown") + output = "Some text\n1. first\n2. second" + assert scorer.score(None, output, {}) == 0.5 + + def test_both_headers_and_lists(self): + scorer = FormatScorer(format_type="markdown") + output = "# Title\n\n- item one\n- item two" + assert scorer.score(None, output, {}) == 1.0 + + def test_no_markdown(self): + scorer = FormatScorer(format_type="markdown") + output = "Just plain text without any formatting." + assert scorer.score(None, output, {}) == 0.0 + + def test_nested_header_levels(self): + scorer = FormatScorer(format_type="markdown") + output = "## Subtitle\n\nContent here" + assert scorer.score(None, output, {}) == 0.5 + + def test_asterisk_list(self): + scorer = FormatScorer(format_type="markdown") + output = "Some text\n* item one\n* item two" + assert scorer.score(None, output, {}) == 0.5 + + def test_ordered_list_with_parenthesis(self): + scorer = FormatScorer(format_type="markdown") + output = "Text\n1) first\n2) second" + assert scorer.score(None, output, {}) == 0.5 + + +class TestLengthFormat: + def test_within_range(self): + scorer = FormatScorer(format_type="length", min_tokens=5, max_tokens=20) + output = "this is a ten word sentence for the test case" + assert scorer.score(None, output, {}) == 1.0 + + def test_exact_min(self): + scorer = FormatScorer(format_type="length", min_tokens=3, max_tokens=10) + assert scorer.score(None, "one two three", {}) == 1.0 + + def test_exact_max(self): + scorer = FormatScorer(format_type="length", min_tokens=1, max_tokens=3) + assert scorer.score(None, "one two three", {}) == 1.0 + + def test_below_min(self): + scorer = FormatScorer(format_type="length", min_tokens=10, max_tokens=20) + output = "only five words here now" + result = scorer.score(None, output, {}) + assert 0.0 < result < 1.0 + assert result == 5 / 10 # 0.5 + + def test_above_max(self): + scorer = FormatScorer(format_type="length", min_tokens=1, max_tokens=5) + output = "one two three four five six seven eight nine ten" + result = scorer.score(None, output, {}) + assert 0.0 <= result < 1.0 + + def test_no_bounds(self): + scorer = FormatScorer(format_type="length") + assert scorer.score(None, "any text", {}) == 1.0 + + def test_only_min(self): + scorer = FormatScorer(format_type="length", min_tokens=3) + assert scorer.score(None, "one two three four", {}) == 1.0 + + def test_only_max(self): + scorer = FormatScorer(format_type="length", max_tokens=5) + assert scorer.score(None, "one two", {}) == 1.0 + + def test_empty_output(self): + scorer = FormatScorer(format_type="length", min_tokens=5) + # empty string splits to [''], which has length 1 + result = scorer.score(None, "", {}) + assert result < 1.0 + + def test_zero_min(self): + scorer = FormatScorer(format_type="length", min_tokens=0, max_tokens=10) + assert scorer.score(None, "hello", {}) == 1.0 + + +class TestStructureFormat: + def test_valid_structure(self): + schema = { + "type": "object", + "required": ["name", "age"], + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + } + scorer = FormatScorer(format_type="structure", json_schema=schema) + output = json.dumps({"name": "Alice", "age": 30}) + assert scorer.score(None, output, {}) == 1.0 + + def test_missing_required_field(self): + schema = { + "type": "object", + "required": ["name", "age"], + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + } + scorer = FormatScorer(format_type="structure", json_schema=schema) + output = json.dumps({"name": "Alice"}) + assert scorer.score(None, output, {}) == 0.0 + + def test_wrong_type(self): + schema = {"type": "array"} + scorer = FormatScorer(format_type="structure", json_schema=schema) + output = json.dumps({"key": "value"}) + assert scorer.score(None, output, {}) == 0.0 + + def test_valid_array_structure(self): + schema = {"type": "array"} + scorer = FormatScorer(format_type="structure", json_schema=schema) + output = json.dumps([1, 2, 3]) + assert scorer.score(None, output, {}) == 1.0 + + def test_no_schema_returns_zero(self): + scorer = FormatScorer(format_type="structure") + assert scorer.score(None, '{"key": "value"}', {}) == 0.0 + + def test_invalid_json_for_structure(self): + schema = {"type": "object"} + scorer = FormatScorer(format_type="structure", json_schema=schema) + assert scorer.score(None, "not json", {}) == 0.0 + + def test_complex_schema(self): + schema = { + "type": "object", + "required": ["results"], + "properties": { + "results": { + "type": "array", + "items": {"type": "object"}, + }, + }, + } + scorer = FormatScorer(format_type="structure", json_schema=schema) + output = json.dumps({"results": [{"id": 1}, {"id": 2}]}) + assert scorer.score(None, output, {}) == 1.0 + + +class TestAsyncScoring: + def test_async_delegates_to_sync(self): + scorer = FormatScorer(format_type="json") + result = asyncio.get_event_loop().run_until_complete( + scorer.score_async(None, '{"valid": true}', {}) + ) + assert result == 1.0