MAESTRO: Implement LLMJudgeScorer with configurable judge prompt, rating parsing, and response caching

This commit is contained in:
John Lightner 2026-04-07 03:05:00 -05:00
parent 0d5a6169c5
commit fb78eac1b0
4 changed files with 621 additions and 2 deletions

View file

@ -26,7 +26,8 @@ Implement the core experiment execution engine: LLM adapters, response caching,
- [x] Implement backend/engine/scorers/keyword.py — checks for presence/absence of required keywords in output. Configurable with required_present and required_absent lists. Score = (found / required) ratio. - [x] Implement backend/engine/scorers/keyword.py — checks for presence/absence of required keywords in output. Configurable with required_present and required_absent lists. Score = (found / required) ratio.
<!-- Completed: KeywordScorer with required_present/required_absent lists, case-sensitive option, combined ratio scoring. 37 tests in test_scorer_keyword.py, all passing. --> <!-- Completed: KeywordScorer with required_present/required_absent lists, case-sensitive option, combined ratio scoring. 37 tests in test_scorer_keyword.py, all passing. -->
- [ ] Implement backend/engine/scorers/llm_judge.py — sends the output to a separate LLM with a configurable judge prompt and asks for a 1-10 rating. Parses the numeric score from the response. This scorer requires an LLM call so it should be clearly marked as "costs tokens" in the UI. Cache the judge's response too. - [x] Implement backend/engine/scorers/llm_judge.py — sends the output to a separate LLM with a configurable judge prompt and asks for a 1-10 rating. Parses the numeric score from the response. This scorer requires an LLM call so it should be clearly marked as "costs tokens" in the UI. Cache the judge's response too.
<!-- Completed: LLMJudgeScorer with configurable judge prompt, 1-10 rating parsing via regex, normalized to 0.0-1.0. COSTS_TOKENS class marker for UI. Optional ResponseCacheLayer integration for caching judge responses. Retries with exponential backoff. 36 tests in test_scorer_llm_judge.py, all passing. -->
- [ ] Wire up the Celery worker in backend/worker.py. Define tasks: execute_run(run_id), execute_sweep(experiment_id, sweep_config). Configure Celery to use Redis as broker. In single-container mode (no Redis), implement a simple synchronous fallback that runs tasks in-process. - [ ] Wire up the Celery worker in backend/worker.py. Define tasks: execute_run(run_id), execute_sweep(experiment_id, sweep_config). Configure Celery to use Redis as broker. In single-container mode (no Redis), implement a simple synchronous fallback that runs tasks in-process.

View file

@ -4,5 +4,12 @@ from engine.scorers.base import BaseScorer
from engine.scorers.embedding import EmbeddingScorer from engine.scorers.embedding import EmbeddingScorer
from engine.scorers.format import FormatScorer from engine.scorers.format import FormatScorer
from engine.scorers.keyword import KeywordScorer from engine.scorers.keyword import KeywordScorer
from engine.scorers.llm_judge import LLMJudgeScorer
__all__ = ["BaseScorer", "EmbeddingScorer", "FormatScorer", "KeywordScorer"] __all__ = [
"BaseScorer",
"EmbeddingScorer",
"FormatScorer",
"KeywordScorer",
"LLMJudgeScorer",
]

View file

@ -0,0 +1,232 @@
"""LLM-as-judge scorer.
Sends the LLM output to a separate LLM with a configurable judge prompt,
asks for a 110 rating, and normalizes to the 0.01.0 range.
**This scorer costs tokens** every evaluation makes an LLM call. The
judge's response is cached via PromptLooper's response cache layer to
avoid redundant calls when re-scoring the same output.
"""
import re
from typing import Any
import httpx
from engine.scorers.base import BaseScorer
# Default judge system prompt — can be overridden at construction time.
DEFAULT_JUDGE_PROMPT = (
"You are an impartial evaluator. You will receive an input and an LLM-generated "
"output. Rate the quality of the output on a scale of 1 to 10, where 1 is terrible "
"and 10 is perfect.\n\n"
"Respond with ONLY a single integer between 1 and 10. Do not include any other text."
)
# Regex to extract the first integer 110 from the judge response.
_RATING_RE = re.compile(r"\b(10|[1-9])\b")
class LLMJudgeScorer(BaseScorer):
"""Score outputs by asking a separate LLM to rate them 110.
Args:
base_url: Chat completions API base URL.
model: Model to use for judging.
api_key: Optional API key.
judge_prompt: System prompt for the judge LLM.
timeout: HTTP request timeout in seconds.
max_retries: Retry attempts on transient failures.
cache_layer: Optional ``ResponseCacheLayer`` instance. When provided,
judge responses are cached to avoid duplicate LLM calls.
db_session_factory: Callable returning a SQLAlchemy ``Session``.
Required when *cache_layer* is supplied.
"""
# Marker for the UI so it can warn users about token cost.
COSTS_TOKENS = True
def __init__(
self,
base_url: str = "http://localhost:11434/v1",
model: str = "llama3",
api_key: str | None = None,
judge_prompt: str = DEFAULT_JUDGE_PROMPT,
timeout: float = 120.0,
max_retries: int = 3,
cache_layer: Any = None,
db_session_factory: Any = None,
) -> None:
self.base_url = base_url.rstrip("/")
self.model = model
self.api_key = api_key
self.judge_prompt = judge_prompt
self.timeout = timeout
self.max_retries = max_retries
self._cache_layer = cache_layer
self._db_session_factory = db_session_factory
@property
def name(self) -> str:
return "llm_judge"
# ------------------------------------------------------------------
# Synchronous entry point
# ------------------------------------------------------------------
def score(self, input_data: Any, output: str, context: dict) -> float:
"""Synchronous scoring — delegates to the async variant."""
import asyncio
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
raise RuntimeError(
"LLMJudgeScorer.score() cannot be called from an async context. "
"Use score_async() instead."
)
return asyncio.get_event_loop().run_until_complete(
self.score_async(input_data, output, context)
)
# ------------------------------------------------------------------
# Async entry point
# ------------------------------------------------------------------
async def score_async(
self, input_data: Any, output: str, context: dict
) -> float:
"""Ask the judge LLM to rate the output and return a normalised score."""
user_message = self._build_user_message(input_data, output, context)
# Check cache first.
config_hash: str | None = None
if self._cache_layer and self._db_session_factory:
from engine.cache import compute_config_hash
config_hash = compute_config_hash(
prompt=self.judge_prompt,
model=self.model,
params={"scorer": "llm_judge"},
input_data=user_message,
)
db = self._db_session_factory()
try:
cached = self._cache_layer.get(db, config_hash)
if cached is not None:
return _parse_rating(cached.response)
finally:
db.close()
# Call the judge LLM.
judge_response = await self._call_judge(user_message)
# Cache the judge response.
if self._cache_layer and self._db_session_factory and config_hash:
db = self._db_session_factory()
try:
self._cache_layer.put(
db,
config_hash=config_hash,
response=judge_response,
model=self.model,
)
finally:
db.close()
return _parse_rating(judge_response)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _build_user_message(
self, input_data: Any, output: str, context: dict
) -> str:
"""Build the user message sent to the judge LLM."""
parts = []
if input_data is not None:
parts.append(f"## Input\n{input_data}")
parts.append(f"## Output\n{output}")
# Include reference answer if available — helps the judge compare.
reference = context.get("reference")
if reference:
parts.append(f"## Reference Answer\n{reference}")
return "\n\n".join(parts)
async def _call_judge(self, user_message: str) -> str:
"""Send a chat completion request to the judge LLM with retries."""
url = f"{self.base_url}/chat/completions"
headers: dict[str, str] = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
body = {
"model": self.model,
"messages": [
{"role": "system", "content": self.judge_prompt},
{"role": "user", "content": user_message},
],
"temperature": 0.0,
"max_tokens": 16,
}
last_exc: Exception | None = None
retryable = {429, 500, 502, 503, 504}
for attempt in range(self.max_retries):
try:
async with httpx.AsyncClient(
timeout=httpx.Timeout(self.timeout), headers=headers
) as client:
resp = await client.post(url, json=body)
if resp.status_code == 200:
data = resp.json()
choices = data.get("choices", [])
if choices:
return choices[0].get("message", {}).get("content", "").strip()
return ""
if resp.status_code not in retryable:
resp.raise_for_status()
last_exc = httpx.HTTPStatusError(
f"HTTP {resp.status_code}",
request=resp.request,
response=resp,
)
except httpx.HTTPStatusError:
raise
except httpx.HTTPError as exc:
last_exc = exc
if attempt < self.max_retries - 1:
import asyncio
await asyncio.sleep(2**attempt)
raise RuntimeError(
f"All {self.max_retries} attempts failed for judge LLM at {url}"
) from last_exc
def _parse_rating(text: str) -> float:
"""Extract a 110 rating from the judge response and normalise to 0.01.0.
Falls back to 0.0 if no valid rating is found.
"""
match = _RATING_RE.search(text)
if match is None:
return 0.0
rating = int(match.group(1))
# Normalise: 1 → ~0.0, 10 → 1.0
return max(0.0, min(1.0, (rating - 1) / 9.0))

View file

@ -0,0 +1,379 @@
"""Tests for the LLMJudgeScorer."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from engine.scorers.base import BaseScorer
from engine.scorers.llm_judge import (
DEFAULT_JUDGE_PROMPT,
LLMJudgeScorer,
_parse_rating,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _chat_response(content: str) -> dict:
"""Build a fake OpenAI-compatible chat completion response."""
return {
"id": "chatcmpl-test",
"object": "chat.completion",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 50, "completion_tokens": 2},
}
def _mock_client(response_data: dict, status_code: int = 200):
"""Create a mocked httpx.AsyncClient that returns *response_data*."""
mock_response = MagicMock()
mock_response.status_code = status_code
mock_response.json.return_value = response_data
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
return mock_client
# ---------------------------------------------------------------------------
# _parse_rating unit tests
# ---------------------------------------------------------------------------
class TestParseRating:
def test_single_digit(self):
assert _parse_rating("7") == pytest.approx((7 - 1) / 9.0)
def test_rating_10(self):
assert _parse_rating("10") == pytest.approx(1.0)
def test_rating_1(self):
assert _parse_rating("1") == pytest.approx(0.0)
def test_rating_5(self):
assert _parse_rating("5") == pytest.approx((5 - 1) / 9.0)
def test_embedded_in_text(self):
assert _parse_rating("I would rate this a 8 out of 10") == pytest.approx(
(8 - 1) / 9.0
)
def test_no_rating(self):
assert _parse_rating("This is great!") == 0.0
def test_empty_string(self):
assert _parse_rating("") == 0.0
def test_zero_not_matched(self):
# 0 is outside the 110 range, should not match.
assert _parse_rating("0") == 0.0
def test_number_above_10_extracts_partial(self):
# "11" should match "1" (the first valid 110 integer).
result = _parse_rating("11")
assert result == pytest.approx(0.0) # matches "1"
def test_rating_with_newlines(self):
assert _parse_rating("\n\n8\n") == pytest.approx((8 - 1) / 9.0)
def test_rating_with_period(self):
assert _parse_rating("Rating: 9.") == pytest.approx((9 - 1) / 9.0)
def test_ten_in_sentence(self):
assert _parse_rating("Score: 10 points") == pytest.approx(1.0)
# ---------------------------------------------------------------------------
# LLMJudgeScorer class tests
# ---------------------------------------------------------------------------
class TestLLMJudgeScorerProperties:
def test_is_base_scorer_subclass(self):
scorer = LLMJudgeScorer()
assert isinstance(scorer, BaseScorer)
def test_name(self):
assert LLMJudgeScorer().name == "llm_judge"
def test_costs_tokens_marker(self):
assert LLMJudgeScorer.COSTS_TOKENS is True
def test_custom_judge_prompt(self):
scorer = LLMJudgeScorer(judge_prompt="Custom prompt")
assert scorer.judge_prompt == "Custom prompt"
def test_default_judge_prompt(self):
scorer = LLMJudgeScorer()
assert scorer.judge_prompt == DEFAULT_JUDGE_PROMPT
class TestBuildUserMessage:
def test_basic_message(self):
scorer = LLMJudgeScorer()
msg = scorer._build_user_message("What is 2+2?", "4", {})
assert "## Input" in msg
assert "What is 2+2?" in msg
assert "## Output" in msg
assert "4" in msg
def test_no_input(self):
scorer = LLMJudgeScorer()
msg = scorer._build_user_message(None, "Hello", {})
assert "## Input" not in msg
assert "## Output" in msg
def test_with_reference(self):
scorer = LLMJudgeScorer()
msg = scorer._build_user_message("q", "a", {"reference": "correct answer"})
assert "## Reference Answer" in msg
assert "correct answer" in msg
def test_without_reference(self):
scorer = LLMJudgeScorer()
msg = scorer._build_user_message("q", "a", {})
assert "## Reference Answer" not in msg
class TestScoreAsync:
@pytest.mark.asyncio
async def test_basic_score(self):
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
mock_client = _mock_client(_chat_response("8"))
with patch("httpx.AsyncClient", return_value=mock_client):
score = await scorer.score_async("input", "output", {})
assert score == pytest.approx((8 - 1) / 9.0)
@pytest.mark.asyncio
async def test_perfect_score(self):
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
mock_client = _mock_client(_chat_response("10"))
with patch("httpx.AsyncClient", return_value=mock_client):
score = await scorer.score_async("input", "output", {})
assert score == pytest.approx(1.0)
@pytest.mark.asyncio
async def test_lowest_score(self):
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
mock_client = _mock_client(_chat_response("1"))
with patch("httpx.AsyncClient", return_value=mock_client):
score = await scorer.score_async("input", "output", {})
assert score == pytest.approx(0.0)
@pytest.mark.asyncio
async def test_unparseable_response(self):
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
mock_client = _mock_client(_chat_response("I cannot rate this"))
with patch("httpx.AsyncClient", return_value=mock_client):
score = await scorer.score_async("input", "output", {})
assert score == 0.0
@pytest.mark.asyncio
async def test_empty_choices(self):
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
response = {"choices": []}
mock_client = _mock_client(response)
with patch("httpx.AsyncClient", return_value=mock_client):
score = await scorer.score_async("input", "output", {})
assert score == 0.0
@pytest.mark.asyncio
async def test_request_includes_correct_body(self):
scorer = LLMJudgeScorer(
base_url="http://fake:11434/v1",
model="judge-model",
judge_prompt="Rate it",
)
mock_client = _mock_client(_chat_response("7"))
with patch("httpx.AsyncClient", return_value=mock_client):
await scorer.score_async("test input", "test output", {})
call_kwargs = mock_client.post.call_args
body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
assert body["model"] == "judge-model"
assert body["messages"][0]["role"] == "system"
assert body["messages"][0]["content"] == "Rate it"
assert body["messages"][1]["role"] == "user"
assert body["temperature"] == 0.0
assert body["max_tokens"] == 16
@pytest.mark.asyncio
async def test_api_key_in_headers(self):
scorer = LLMJudgeScorer(
base_url="http://fake:11434/v1",
api_key="sk-test-key",
)
mock_client = _mock_client(_chat_response("5"))
with patch("httpx.AsyncClient", return_value=mock_client) as mock_cls:
await scorer.score_async("input", "output", {})
call_kwargs = mock_cls.call_args
headers = call_kwargs.kwargs.get("headers") or call_kwargs[1].get("headers")
assert headers["Authorization"] == "Bearer sk-test-key"
@pytest.mark.asyncio
async def test_reference_included_in_message(self):
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
mock_client = _mock_client(_chat_response("9"))
with patch("httpx.AsyncClient", return_value=mock_client):
await scorer.score_async(
"q", "a", {"reference": "expected answer"}
)
call_kwargs = mock_client.post.call_args
body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
user_msg = body["messages"][1]["content"]
assert "expected answer" in user_msg
class TestScoreAsyncWithCache:
@pytest.mark.asyncio
async def test_cache_hit_skips_llm_call(self):
mock_cache = MagicMock()
cached_response = MagicMock()
cached_response.response = "8"
mock_cache.get.return_value = cached_response
mock_session = MagicMock()
mock_session_factory = MagicMock(return_value=mock_session)
scorer = LLMJudgeScorer(
base_url="http://fake:11434/v1",
cache_layer=mock_cache,
db_session_factory=mock_session_factory,
)
with patch("httpx.AsyncClient") as mock_http:
score = await scorer.score_async("input", "output", {})
assert score == pytest.approx((8 - 1) / 9.0)
mock_http.assert_not_called()
@pytest.mark.asyncio
async def test_cache_miss_calls_llm_and_stores(self):
mock_cache = MagicMock()
mock_cache.get.return_value = None
mock_session = MagicMock()
mock_session_factory = MagicMock(return_value=mock_session)
scorer = LLMJudgeScorer(
base_url="http://fake:11434/v1",
cache_layer=mock_cache,
db_session_factory=mock_session_factory,
)
mock_client = _mock_client(_chat_response("6"))
with patch("httpx.AsyncClient", return_value=mock_client):
score = await scorer.score_async("input", "output", {})
assert score == pytest.approx((6 - 1) / 9.0)
mock_cache.put.assert_called_once()
put_kwargs = mock_cache.put.call_args
assert put_kwargs.kwargs.get("response") == "6" or (
len(put_kwargs.args) > 2 and put_kwargs.args[2] == "6"
)
@pytest.mark.asyncio
async def test_no_cache_when_layer_not_configured(self):
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
mock_client = _mock_client(_chat_response("7"))
with patch("httpx.AsyncClient", return_value=mock_client):
score = await scorer.score_async("input", "output", {})
assert score == pytest.approx((7 - 1) / 9.0)
class TestScoreSync:
def test_sync_score_works_outside_async(self):
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
mock_client = _mock_client(_chat_response("5"))
with patch("httpx.AsyncClient", return_value=mock_client):
score = scorer.score("input", "output", {})
assert score == pytest.approx((5 - 1) / 9.0)
def test_sync_score_raises_in_async_context(self):
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
async def _run():
scorer.score("input", "output", {})
with pytest.raises(RuntimeError, match="cannot be called from an async"):
asyncio.get_event_loop().run_until_complete(_run())
class TestRetries:
@pytest.mark.asyncio
async def test_retries_on_server_error(self):
scorer = LLMJudgeScorer(
base_url="http://fake:11434/v1", max_retries=3
)
# First two calls fail with 500, third succeeds.
fail_resp = MagicMock()
fail_resp.status_code = 500
fail_resp.json.return_value = {}
ok_resp = MagicMock()
ok_resp.status_code = 200
ok_resp.json.return_value = _chat_response("7")
mock_client = AsyncMock()
mock_client.post.side_effect = [fail_resp, fail_resp, ok_resp]
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("httpx.AsyncClient", return_value=mock_client):
with patch("asyncio.sleep", new_callable=AsyncMock):
score = await scorer.score_async("input", "output", {})
assert score == pytest.approx((7 - 1) / 9.0)
assert mock_client.post.call_count == 3
@pytest.mark.asyncio
async def test_all_retries_exhausted(self):
scorer = LLMJudgeScorer(
base_url="http://fake:11434/v1", max_retries=2
)
fail_resp = MagicMock()
fail_resp.status_code = 500
fail_resp.json.return_value = {}
fail_resp.request = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = fail_resp
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("httpx.AsyncClient", return_value=mock_client):
with patch("asyncio.sleep", new_callable=AsyncMock):
with pytest.raises(RuntimeError, match="All 2 attempts failed"):
await scorer.score_async("input", "output", {})