From fb78eac1b0a90cb86deacbc1f28943e3998f0c95 Mon Sep 17 00:00:00 2001 From: John Lightner Date: Tue, 7 Apr 2026 03:05:00 -0500 Subject: [PATCH] MAESTRO: Implement LLMJudgeScorer with configurable judge prompt, rating parsing, and response caching --- Auto Run Docs/02a-backend-engine.md | 3 +- backend/engine/scorers/__init__.py | 9 +- backend/engine/scorers/llm_judge.py | 232 +++++++++++++++ backend/tests/test_scorer_llm_judge.py | 379 +++++++++++++++++++++++++ 4 files changed, 621 insertions(+), 2 deletions(-) create mode 100644 backend/engine/scorers/llm_judge.py create mode 100644 backend/tests/test_scorer_llm_judge.py diff --git a/Auto Run Docs/02a-backend-engine.md b/Auto Run Docs/02a-backend-engine.md index c307503..e371e41 100644 --- a/Auto Run Docs/02a-backend-engine.md +++ b/Auto Run Docs/02a-backend-engine.md @@ -26,7 +26,8 @@ Implement the core experiment execution engine: LLM adapters, response caching, - [x] Implement backend/engine/scorers/keyword.py — checks for presence/absence of required keywords in output. Configurable with required_present and required_absent lists. Score = (found / required) ratio. -- [ ] Implement backend/engine/scorers/llm_judge.py — sends the output to a separate LLM with a configurable judge prompt and asks for a 1-10 rating. Parses the numeric score from the response. This scorer requires an LLM call so it should be clearly marked as "costs tokens" in the UI. Cache the judge's response too. +- [x] Implement backend/engine/scorers/llm_judge.py — sends the output to a separate LLM with a configurable judge prompt and asks for a 1-10 rating. Parses the numeric score from the response. This scorer requires an LLM call so it should be clearly marked as "costs tokens" in the UI. Cache the judge's response too. + - [ ] Wire up the Celery worker in backend/worker.py. Define tasks: execute_run(run_id), execute_sweep(experiment_id, sweep_config). Configure Celery to use Redis as broker. In single-container mode (no Redis), implement a simple synchronous fallback that runs tasks in-process. diff --git a/backend/engine/scorers/__init__.py b/backend/engine/scorers/__init__.py index 8582c87..bd6dbc7 100644 --- a/backend/engine/scorers/__init__.py +++ b/backend/engine/scorers/__init__.py @@ -4,5 +4,12 @@ from engine.scorers.base import BaseScorer from engine.scorers.embedding import EmbeddingScorer from engine.scorers.format import FormatScorer from engine.scorers.keyword import KeywordScorer +from engine.scorers.llm_judge import LLMJudgeScorer -__all__ = ["BaseScorer", "EmbeddingScorer", "FormatScorer", "KeywordScorer"] +__all__ = [ + "BaseScorer", + "EmbeddingScorer", + "FormatScorer", + "KeywordScorer", + "LLMJudgeScorer", +] diff --git a/backend/engine/scorers/llm_judge.py b/backend/engine/scorers/llm_judge.py new file mode 100644 index 0000000..53a86b3 --- /dev/null +++ b/backend/engine/scorers/llm_judge.py @@ -0,0 +1,232 @@ +"""LLM-as-judge scorer. + +Sends the LLM output to a separate LLM with a configurable judge prompt, +asks for a 1–10 rating, and normalizes to the 0.0–1.0 range. + +**This scorer costs tokens** — every evaluation makes an LLM call. The +judge's response is cached via PromptLooper's response cache layer to +avoid redundant calls when re-scoring the same output. +""" + +import re +from typing import Any + +import httpx + +from engine.scorers.base import BaseScorer + + +# Default judge system prompt — can be overridden at construction time. +DEFAULT_JUDGE_PROMPT = ( + "You are an impartial evaluator. You will receive an input and an LLM-generated " + "output. Rate the quality of the output on a scale of 1 to 10, where 1 is terrible " + "and 10 is perfect.\n\n" + "Respond with ONLY a single integer between 1 and 10. Do not include any other text." +) + +# Regex to extract the first integer 1–10 from the judge response. +_RATING_RE = re.compile(r"\b(10|[1-9])\b") + + +class LLMJudgeScorer(BaseScorer): + """Score outputs by asking a separate LLM to rate them 1–10. + + Args: + base_url: Chat completions API base URL. + model: Model to use for judging. + api_key: Optional API key. + judge_prompt: System prompt for the judge LLM. + timeout: HTTP request timeout in seconds. + max_retries: Retry attempts on transient failures. + cache_layer: Optional ``ResponseCacheLayer`` instance. When provided, + judge responses are cached to avoid duplicate LLM calls. + db_session_factory: Callable returning a SQLAlchemy ``Session``. + Required when *cache_layer* is supplied. + """ + + # Marker for the UI so it can warn users about token cost. + COSTS_TOKENS = True + + def __init__( + self, + base_url: str = "http://localhost:11434/v1", + model: str = "llama3", + api_key: str | None = None, + judge_prompt: str = DEFAULT_JUDGE_PROMPT, + timeout: float = 120.0, + max_retries: int = 3, + cache_layer: Any = None, + db_session_factory: Any = None, + ) -> None: + self.base_url = base_url.rstrip("/") + self.model = model + self.api_key = api_key + self.judge_prompt = judge_prompt + self.timeout = timeout + self.max_retries = max_retries + self._cache_layer = cache_layer + self._db_session_factory = db_session_factory + + @property + def name(self) -> str: + return "llm_judge" + + # ------------------------------------------------------------------ + # Synchronous entry point + # ------------------------------------------------------------------ + + def score(self, input_data: Any, output: str, context: dict) -> float: + """Synchronous scoring — delegates to the async variant.""" + import asyncio + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + raise RuntimeError( + "LLMJudgeScorer.score() cannot be called from an async context. " + "Use score_async() instead." + ) + + return asyncio.get_event_loop().run_until_complete( + self.score_async(input_data, output, context) + ) + + # ------------------------------------------------------------------ + # Async entry point + # ------------------------------------------------------------------ + + async def score_async( + self, input_data: Any, output: str, context: dict + ) -> float: + """Ask the judge LLM to rate the output and return a normalised score.""" + user_message = self._build_user_message(input_data, output, context) + + # Check cache first. + config_hash: str | None = None + if self._cache_layer and self._db_session_factory: + from engine.cache import compute_config_hash + + config_hash = compute_config_hash( + prompt=self.judge_prompt, + model=self.model, + params={"scorer": "llm_judge"}, + input_data=user_message, + ) + db = self._db_session_factory() + try: + cached = self._cache_layer.get(db, config_hash) + if cached is not None: + return _parse_rating(cached.response) + finally: + db.close() + + # Call the judge LLM. + judge_response = await self._call_judge(user_message) + + # Cache the judge response. + if self._cache_layer and self._db_session_factory and config_hash: + db = self._db_session_factory() + try: + self._cache_layer.put( + db, + config_hash=config_hash, + response=judge_response, + model=self.model, + ) + finally: + db.close() + + return _parse_rating(judge_response) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _build_user_message( + self, input_data: Any, output: str, context: dict + ) -> str: + """Build the user message sent to the judge LLM.""" + parts = [] + if input_data is not None: + parts.append(f"## Input\n{input_data}") + parts.append(f"## Output\n{output}") + + # Include reference answer if available — helps the judge compare. + reference = context.get("reference") + if reference: + parts.append(f"## Reference Answer\n{reference}") + + return "\n\n".join(parts) + + async def _call_judge(self, user_message: str) -> str: + """Send a chat completion request to the judge LLM with retries.""" + url = f"{self.base_url}/chat/completions" + + headers: dict[str, str] = {"Content-Type": "application/json"} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + + body = { + "model": self.model, + "messages": [ + {"role": "system", "content": self.judge_prompt}, + {"role": "user", "content": user_message}, + ], + "temperature": 0.0, + "max_tokens": 16, + } + + last_exc: Exception | None = None + retryable = {429, 500, 502, 503, 504} + + for attempt in range(self.max_retries): + try: + async with httpx.AsyncClient( + timeout=httpx.Timeout(self.timeout), headers=headers + ) as client: + resp = await client.post(url, json=body) + + if resp.status_code == 200: + data = resp.json() + choices = data.get("choices", []) + if choices: + return choices[0].get("message", {}).get("content", "").strip() + return "" + + if resp.status_code not in retryable: + resp.raise_for_status() + + last_exc = httpx.HTTPStatusError( + f"HTTP {resp.status_code}", + request=resp.request, + response=resp, + ) + except httpx.HTTPStatusError: + raise + except httpx.HTTPError as exc: + last_exc = exc + + if attempt < self.max_retries - 1: + import asyncio + + await asyncio.sleep(2**attempt) + + raise RuntimeError( + f"All {self.max_retries} attempts failed for judge LLM at {url}" + ) from last_exc + + +def _parse_rating(text: str) -> float: + """Extract a 1–10 rating from the judge response and normalise to 0.0–1.0. + + Falls back to 0.0 if no valid rating is found. + """ + match = _RATING_RE.search(text) + if match is None: + return 0.0 + rating = int(match.group(1)) + # Normalise: 1 → ~0.0, 10 → 1.0 + return max(0.0, min(1.0, (rating - 1) / 9.0)) diff --git a/backend/tests/test_scorer_llm_judge.py b/backend/tests/test_scorer_llm_judge.py new file mode 100644 index 0000000..dcdf24e --- /dev/null +++ b/backend/tests/test_scorer_llm_judge.py @@ -0,0 +1,379 @@ +"""Tests for the LLMJudgeScorer.""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from engine.scorers.base import BaseScorer +from engine.scorers.llm_judge import ( + DEFAULT_JUDGE_PROMPT, + LLMJudgeScorer, + _parse_rating, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _chat_response(content: str) -> dict: + """Build a fake OpenAI-compatible chat completion response.""" + return { + "id": "chatcmpl-test", + "object": "chat.completion", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 50, "completion_tokens": 2}, + } + + +def _mock_client(response_data: dict, status_code: int = 200): + """Create a mocked httpx.AsyncClient that returns *response_data*.""" + mock_response = MagicMock() + mock_response.status_code = status_code + mock_response.json.return_value = response_data + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + return mock_client + + +# --------------------------------------------------------------------------- +# _parse_rating unit tests +# --------------------------------------------------------------------------- + +class TestParseRating: + def test_single_digit(self): + assert _parse_rating("7") == pytest.approx((7 - 1) / 9.0) + + def test_rating_10(self): + assert _parse_rating("10") == pytest.approx(1.0) + + def test_rating_1(self): + assert _parse_rating("1") == pytest.approx(0.0) + + def test_rating_5(self): + assert _parse_rating("5") == pytest.approx((5 - 1) / 9.0) + + def test_embedded_in_text(self): + assert _parse_rating("I would rate this a 8 out of 10") == pytest.approx( + (8 - 1) / 9.0 + ) + + def test_no_rating(self): + assert _parse_rating("This is great!") == 0.0 + + def test_empty_string(self): + assert _parse_rating("") == 0.0 + + def test_zero_not_matched(self): + # 0 is outside the 1–10 range, should not match. + assert _parse_rating("0") == 0.0 + + def test_number_above_10_extracts_partial(self): + # "11" should match "1" (the first valid 1–10 integer). + result = _parse_rating("11") + assert result == pytest.approx(0.0) # matches "1" + + def test_rating_with_newlines(self): + assert _parse_rating("\n\n8\n") == pytest.approx((8 - 1) / 9.0) + + def test_rating_with_period(self): + assert _parse_rating("Rating: 9.") == pytest.approx((9 - 1) / 9.0) + + def test_ten_in_sentence(self): + assert _parse_rating("Score: 10 points") == pytest.approx(1.0) + + +# --------------------------------------------------------------------------- +# LLMJudgeScorer class tests +# --------------------------------------------------------------------------- + +class TestLLMJudgeScorerProperties: + def test_is_base_scorer_subclass(self): + scorer = LLMJudgeScorer() + assert isinstance(scorer, BaseScorer) + + def test_name(self): + assert LLMJudgeScorer().name == "llm_judge" + + def test_costs_tokens_marker(self): + assert LLMJudgeScorer.COSTS_TOKENS is True + + def test_custom_judge_prompt(self): + scorer = LLMJudgeScorer(judge_prompt="Custom prompt") + assert scorer.judge_prompt == "Custom prompt" + + def test_default_judge_prompt(self): + scorer = LLMJudgeScorer() + assert scorer.judge_prompt == DEFAULT_JUDGE_PROMPT + + +class TestBuildUserMessage: + def test_basic_message(self): + scorer = LLMJudgeScorer() + msg = scorer._build_user_message("What is 2+2?", "4", {}) + assert "## Input" in msg + assert "What is 2+2?" in msg + assert "## Output" in msg + assert "4" in msg + + def test_no_input(self): + scorer = LLMJudgeScorer() + msg = scorer._build_user_message(None, "Hello", {}) + assert "## Input" not in msg + assert "## Output" in msg + + def test_with_reference(self): + scorer = LLMJudgeScorer() + msg = scorer._build_user_message("q", "a", {"reference": "correct answer"}) + assert "## Reference Answer" in msg + assert "correct answer" in msg + + def test_without_reference(self): + scorer = LLMJudgeScorer() + msg = scorer._build_user_message("q", "a", {}) + assert "## Reference Answer" not in msg + + +class TestScoreAsync: + @pytest.mark.asyncio + async def test_basic_score(self): + scorer = LLMJudgeScorer(base_url="http://fake:11434/v1") + mock_client = _mock_client(_chat_response("8")) + + with patch("httpx.AsyncClient", return_value=mock_client): + score = await scorer.score_async("input", "output", {}) + + assert score == pytest.approx((8 - 1) / 9.0) + + @pytest.mark.asyncio + async def test_perfect_score(self): + scorer = LLMJudgeScorer(base_url="http://fake:11434/v1") + mock_client = _mock_client(_chat_response("10")) + + with patch("httpx.AsyncClient", return_value=mock_client): + score = await scorer.score_async("input", "output", {}) + + assert score == pytest.approx(1.0) + + @pytest.mark.asyncio + async def test_lowest_score(self): + scorer = LLMJudgeScorer(base_url="http://fake:11434/v1") + mock_client = _mock_client(_chat_response("1")) + + with patch("httpx.AsyncClient", return_value=mock_client): + score = await scorer.score_async("input", "output", {}) + + assert score == pytest.approx(0.0) + + @pytest.mark.asyncio + async def test_unparseable_response(self): + scorer = LLMJudgeScorer(base_url="http://fake:11434/v1") + mock_client = _mock_client(_chat_response("I cannot rate this")) + + with patch("httpx.AsyncClient", return_value=mock_client): + score = await scorer.score_async("input", "output", {}) + + assert score == 0.0 + + @pytest.mark.asyncio + async def test_empty_choices(self): + scorer = LLMJudgeScorer(base_url="http://fake:11434/v1") + response = {"choices": []} + mock_client = _mock_client(response) + + with patch("httpx.AsyncClient", return_value=mock_client): + score = await scorer.score_async("input", "output", {}) + + assert score == 0.0 + + @pytest.mark.asyncio + async def test_request_includes_correct_body(self): + scorer = LLMJudgeScorer( + base_url="http://fake:11434/v1", + model="judge-model", + judge_prompt="Rate it", + ) + mock_client = _mock_client(_chat_response("7")) + + with patch("httpx.AsyncClient", return_value=mock_client): + await scorer.score_async("test input", "test output", {}) + + call_kwargs = mock_client.post.call_args + body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json") + assert body["model"] == "judge-model" + assert body["messages"][0]["role"] == "system" + assert body["messages"][0]["content"] == "Rate it" + assert body["messages"][1]["role"] == "user" + assert body["temperature"] == 0.0 + assert body["max_tokens"] == 16 + + @pytest.mark.asyncio + async def test_api_key_in_headers(self): + scorer = LLMJudgeScorer( + base_url="http://fake:11434/v1", + api_key="sk-test-key", + ) + mock_client = _mock_client(_chat_response("5")) + + with patch("httpx.AsyncClient", return_value=mock_client) as mock_cls: + await scorer.score_async("input", "output", {}) + + call_kwargs = mock_cls.call_args + headers = call_kwargs.kwargs.get("headers") or call_kwargs[1].get("headers") + assert headers["Authorization"] == "Bearer sk-test-key" + + @pytest.mark.asyncio + async def test_reference_included_in_message(self): + scorer = LLMJudgeScorer(base_url="http://fake:11434/v1") + mock_client = _mock_client(_chat_response("9")) + + with patch("httpx.AsyncClient", return_value=mock_client): + await scorer.score_async( + "q", "a", {"reference": "expected answer"} + ) + + call_kwargs = mock_client.post.call_args + body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json") + user_msg = body["messages"][1]["content"] + assert "expected answer" in user_msg + + +class TestScoreAsyncWithCache: + @pytest.mark.asyncio + async def test_cache_hit_skips_llm_call(self): + mock_cache = MagicMock() + cached_response = MagicMock() + cached_response.response = "8" + mock_cache.get.return_value = cached_response + + mock_session = MagicMock() + mock_session_factory = MagicMock(return_value=mock_session) + + scorer = LLMJudgeScorer( + base_url="http://fake:11434/v1", + cache_layer=mock_cache, + db_session_factory=mock_session_factory, + ) + + with patch("httpx.AsyncClient") as mock_http: + score = await scorer.score_async("input", "output", {}) + + assert score == pytest.approx((8 - 1) / 9.0) + mock_http.assert_not_called() + + @pytest.mark.asyncio + async def test_cache_miss_calls_llm_and_stores(self): + mock_cache = MagicMock() + mock_cache.get.return_value = None + + mock_session = MagicMock() + mock_session_factory = MagicMock(return_value=mock_session) + + scorer = LLMJudgeScorer( + base_url="http://fake:11434/v1", + cache_layer=mock_cache, + db_session_factory=mock_session_factory, + ) + mock_client = _mock_client(_chat_response("6")) + + with patch("httpx.AsyncClient", return_value=mock_client): + score = await scorer.score_async("input", "output", {}) + + assert score == pytest.approx((6 - 1) / 9.0) + mock_cache.put.assert_called_once() + put_kwargs = mock_cache.put.call_args + assert put_kwargs.kwargs.get("response") == "6" or ( + len(put_kwargs.args) > 2 and put_kwargs.args[2] == "6" + ) + + @pytest.mark.asyncio + async def test_no_cache_when_layer_not_configured(self): + scorer = LLMJudgeScorer(base_url="http://fake:11434/v1") + mock_client = _mock_client(_chat_response("7")) + + with patch("httpx.AsyncClient", return_value=mock_client): + score = await scorer.score_async("input", "output", {}) + + assert score == pytest.approx((7 - 1) / 9.0) + + +class TestScoreSync: + def test_sync_score_works_outside_async(self): + scorer = LLMJudgeScorer(base_url="http://fake:11434/v1") + mock_client = _mock_client(_chat_response("5")) + + with patch("httpx.AsyncClient", return_value=mock_client): + score = scorer.score("input", "output", {}) + + assert score == pytest.approx((5 - 1) / 9.0) + + def test_sync_score_raises_in_async_context(self): + scorer = LLMJudgeScorer(base_url="http://fake:11434/v1") + + async def _run(): + scorer.score("input", "output", {}) + + with pytest.raises(RuntimeError, match="cannot be called from an async"): + asyncio.get_event_loop().run_until_complete(_run()) + + +class TestRetries: + @pytest.mark.asyncio + async def test_retries_on_server_error(self): + scorer = LLMJudgeScorer( + base_url="http://fake:11434/v1", max_retries=3 + ) + + # First two calls fail with 500, third succeeds. + fail_resp = MagicMock() + fail_resp.status_code = 500 + fail_resp.json.return_value = {} + + ok_resp = MagicMock() + ok_resp.status_code = 200 + ok_resp.json.return_value = _chat_response("7") + + mock_client = AsyncMock() + mock_client.post.side_effect = [fail_resp, fail_resp, ok_resp] + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + with patch("httpx.AsyncClient", return_value=mock_client): + with patch("asyncio.sleep", new_callable=AsyncMock): + score = await scorer.score_async("input", "output", {}) + + assert score == pytest.approx((7 - 1) / 9.0) + assert mock_client.post.call_count == 3 + + @pytest.mark.asyncio + async def test_all_retries_exhausted(self): + scorer = LLMJudgeScorer( + base_url="http://fake:11434/v1", max_retries=2 + ) + + fail_resp = MagicMock() + fail_resp.status_code = 500 + fail_resp.json.return_value = {} + fail_resp.request = MagicMock() + + mock_client = AsyncMock() + mock_client.post.return_value = fail_resp + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + with patch("httpx.AsyncClient", return_value=mock_client): + with patch("asyncio.sleep", new_callable=AsyncMock): + with pytest.raises(RuntimeError, match="All 2 attempts failed"): + await scorer.score_async("input", "output", {})