promptlooper/backend/tests/test_scorer_llm_judge.py

379 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Tests for the LLMJudgeScorer."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from engine.scorers.base import BaseScorer
from engine.scorers.llm_judge import (
DEFAULT_JUDGE_PROMPT,
LLMJudgeScorer,
_parse_rating,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _chat_response(content: str) -> dict:
"""Build a fake OpenAI-compatible chat completion response."""
return {
"id": "chatcmpl-test",
"object": "chat.completion",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 50, "completion_tokens": 2},
}
def _mock_client(response_data: dict, status_code: int = 200):
"""Create a mocked httpx.AsyncClient that returns *response_data*."""
mock_response = MagicMock()
mock_response.status_code = status_code
mock_response.json.return_value = response_data
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
return mock_client
# ---------------------------------------------------------------------------
# _parse_rating unit tests
# ---------------------------------------------------------------------------
class TestParseRating:
def test_single_digit(self):
assert _parse_rating("7") == pytest.approx((7 - 1) / 9.0)
def test_rating_10(self):
assert _parse_rating("10") == pytest.approx(1.0)
def test_rating_1(self):
assert _parse_rating("1") == pytest.approx(0.0)
def test_rating_5(self):
assert _parse_rating("5") == pytest.approx((5 - 1) / 9.0)
def test_embedded_in_text(self):
assert _parse_rating("I would rate this a 8 out of 10") == pytest.approx(
(8 - 1) / 9.0
)
def test_no_rating(self):
assert _parse_rating("This is great!") == 0.0
def test_empty_string(self):
assert _parse_rating("") == 0.0
def test_zero_not_matched(self):
# 0 is outside the 110 range, should not match.
assert _parse_rating("0") == 0.0
def test_number_above_10_extracts_partial(self):
# "11" should match "1" (the first valid 110 integer).
result = _parse_rating("11")
assert result == pytest.approx(0.0) # matches "1"
def test_rating_with_newlines(self):
assert _parse_rating("\n\n8\n") == pytest.approx((8 - 1) / 9.0)
def test_rating_with_period(self):
assert _parse_rating("Rating: 9.") == pytest.approx((9 - 1) / 9.0)
def test_ten_in_sentence(self):
assert _parse_rating("Score: 10 points") == pytest.approx(1.0)
# ---------------------------------------------------------------------------
# LLMJudgeScorer class tests
# ---------------------------------------------------------------------------
class TestLLMJudgeScorerProperties:
def test_is_base_scorer_subclass(self):
scorer = LLMJudgeScorer()
assert isinstance(scorer, BaseScorer)
def test_name(self):
assert LLMJudgeScorer().name == "llm_judge"
def test_costs_tokens_marker(self):
assert LLMJudgeScorer.COSTS_TOKENS is True
def test_custom_judge_prompt(self):
scorer = LLMJudgeScorer(judge_prompt="Custom prompt")
assert scorer.judge_prompt == "Custom prompt"
def test_default_judge_prompt(self):
scorer = LLMJudgeScorer()
assert scorer.judge_prompt == DEFAULT_JUDGE_PROMPT
class TestBuildUserMessage:
def test_basic_message(self):
scorer = LLMJudgeScorer()
msg = scorer._build_user_message("What is 2+2?", "4", {})
assert "## Input" in msg
assert "What is 2+2?" in msg
assert "## Output" in msg
assert "4" in msg
def test_no_input(self):
scorer = LLMJudgeScorer()
msg = scorer._build_user_message(None, "Hello", {})
assert "## Input" not in msg
assert "## Output" in msg
def test_with_reference(self):
scorer = LLMJudgeScorer()
msg = scorer._build_user_message("q", "a", {"reference": "correct answer"})
assert "## Reference Answer" in msg
assert "correct answer" in msg
def test_without_reference(self):
scorer = LLMJudgeScorer()
msg = scorer._build_user_message("q", "a", {})
assert "## Reference Answer" not in msg
class TestScoreAsync:
@pytest.mark.asyncio
async def test_basic_score(self):
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
mock_client = _mock_client(_chat_response("8"))
with patch("httpx.AsyncClient", return_value=mock_client):
score = await scorer.score_async("input", "output", {})
assert score == pytest.approx((8 - 1) / 9.0)
@pytest.mark.asyncio
async def test_perfect_score(self):
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
mock_client = _mock_client(_chat_response("10"))
with patch("httpx.AsyncClient", return_value=mock_client):
score = await scorer.score_async("input", "output", {})
assert score == pytest.approx(1.0)
@pytest.mark.asyncio
async def test_lowest_score(self):
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
mock_client = _mock_client(_chat_response("1"))
with patch("httpx.AsyncClient", return_value=mock_client):
score = await scorer.score_async("input", "output", {})
assert score == pytest.approx(0.0)
@pytest.mark.asyncio
async def test_unparseable_response(self):
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
mock_client = _mock_client(_chat_response("I cannot rate this"))
with patch("httpx.AsyncClient", return_value=mock_client):
score = await scorer.score_async("input", "output", {})
assert score == 0.0
@pytest.mark.asyncio
async def test_empty_choices(self):
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
response = {"choices": []}
mock_client = _mock_client(response)
with patch("httpx.AsyncClient", return_value=mock_client):
score = await scorer.score_async("input", "output", {})
assert score == 0.0
@pytest.mark.asyncio
async def test_request_includes_correct_body(self):
scorer = LLMJudgeScorer(
base_url="http://fake:11434/v1",
model="judge-model",
judge_prompt="Rate it",
)
mock_client = _mock_client(_chat_response("7"))
with patch("httpx.AsyncClient", return_value=mock_client):
await scorer.score_async("test input", "test output", {})
call_kwargs = mock_client.post.call_args
body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
assert body["model"] == "judge-model"
assert body["messages"][0]["role"] == "system"
assert body["messages"][0]["content"] == "Rate it"
assert body["messages"][1]["role"] == "user"
assert body["temperature"] == 0.0
assert body["max_tokens"] == 16
@pytest.mark.asyncio
async def test_api_key_in_headers(self):
scorer = LLMJudgeScorer(
base_url="http://fake:11434/v1",
api_key="sk-test-key",
)
mock_client = _mock_client(_chat_response("5"))
with patch("httpx.AsyncClient", return_value=mock_client) as mock_cls:
await scorer.score_async("input", "output", {})
call_kwargs = mock_cls.call_args
headers = call_kwargs.kwargs.get("headers") or call_kwargs[1].get("headers")
assert headers["Authorization"] == "Bearer sk-test-key"
@pytest.mark.asyncio
async def test_reference_included_in_message(self):
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
mock_client = _mock_client(_chat_response("9"))
with patch("httpx.AsyncClient", return_value=mock_client):
await scorer.score_async(
"q", "a", {"reference": "expected answer"}
)
call_kwargs = mock_client.post.call_args
body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
user_msg = body["messages"][1]["content"]
assert "expected answer" in user_msg
class TestScoreAsyncWithCache:
@pytest.mark.asyncio
async def test_cache_hit_skips_llm_call(self):
mock_cache = MagicMock()
cached_response = MagicMock()
cached_response.response = "8"
mock_cache.get.return_value = cached_response
mock_session = MagicMock()
mock_session_factory = MagicMock(return_value=mock_session)
scorer = LLMJudgeScorer(
base_url="http://fake:11434/v1",
cache_layer=mock_cache,
db_session_factory=mock_session_factory,
)
with patch("httpx.AsyncClient") as mock_http:
score = await scorer.score_async("input", "output", {})
assert score == pytest.approx((8 - 1) / 9.0)
mock_http.assert_not_called()
@pytest.mark.asyncio
async def test_cache_miss_calls_llm_and_stores(self):
mock_cache = MagicMock()
mock_cache.get.return_value = None
mock_session = MagicMock()
mock_session_factory = MagicMock(return_value=mock_session)
scorer = LLMJudgeScorer(
base_url="http://fake:11434/v1",
cache_layer=mock_cache,
db_session_factory=mock_session_factory,
)
mock_client = _mock_client(_chat_response("6"))
with patch("httpx.AsyncClient", return_value=mock_client):
score = await scorer.score_async("input", "output", {})
assert score == pytest.approx((6 - 1) / 9.0)
mock_cache.put.assert_called_once()
put_kwargs = mock_cache.put.call_args
assert put_kwargs.kwargs.get("response") == "6" or (
len(put_kwargs.args) > 2 and put_kwargs.args[2] == "6"
)
@pytest.mark.asyncio
async def test_no_cache_when_layer_not_configured(self):
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
mock_client = _mock_client(_chat_response("7"))
with patch("httpx.AsyncClient", return_value=mock_client):
score = await scorer.score_async("input", "output", {})
assert score == pytest.approx((7 - 1) / 9.0)
class TestScoreSync:
def test_sync_score_works_outside_async(self):
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
mock_client = _mock_client(_chat_response("5"))
with patch("httpx.AsyncClient", return_value=mock_client):
score = scorer.score("input", "output", {})
assert score == pytest.approx((5 - 1) / 9.0)
def test_sync_score_raises_in_async_context(self):
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
async def _run():
scorer.score("input", "output", {})
with pytest.raises(RuntimeError, match="cannot be called from an async"):
asyncio.get_event_loop().run_until_complete(_run())
class TestRetries:
@pytest.mark.asyncio
async def test_retries_on_server_error(self):
scorer = LLMJudgeScorer(
base_url="http://fake:11434/v1", max_retries=3
)
# First two calls fail with 500, third succeeds.
fail_resp = MagicMock()
fail_resp.status_code = 500
fail_resp.json.return_value = {}
ok_resp = MagicMock()
ok_resp.status_code = 200
ok_resp.json.return_value = _chat_response("7")
mock_client = AsyncMock()
mock_client.post.side_effect = [fail_resp, fail_resp, ok_resp]
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("httpx.AsyncClient", return_value=mock_client):
with patch("asyncio.sleep", new_callable=AsyncMock):
score = await scorer.score_async("input", "output", {})
assert score == pytest.approx((7 - 1) / 9.0)
assert mock_client.post.call_count == 3
@pytest.mark.asyncio
async def test_all_retries_exhausted(self):
scorer = LLMJudgeScorer(
base_url="http://fake:11434/v1", max_retries=2
)
fail_resp = MagicMock()
fail_resp.status_code = 500
fail_resp.json.return_value = {}
fail_resp.request = MagicMock()
mock_client = AsyncMock()
mock_client.post.return_value = fail_resp
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("httpx.AsyncClient", return_value=mock_client):
with patch("asyncio.sleep", new_callable=AsyncMock):
with pytest.raises(RuntimeError, match="All 2 attempts failed"):
await scorer.score_async("input", "output", {})