379 lines
13 KiB
Python
379 lines
13 KiB
Python
"""Tests for the LLMJudgeScorer."""
|
||
|
||
import asyncio
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
|
||
from engine.scorers.base import BaseScorer
|
||
from engine.scorers.llm_judge import (
|
||
DEFAULT_JUDGE_PROMPT,
|
||
LLMJudgeScorer,
|
||
_parse_rating,
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _chat_response(content: str) -> dict:
|
||
"""Build a fake OpenAI-compatible chat completion response."""
|
||
return {
|
||
"id": "chatcmpl-test",
|
||
"object": "chat.completion",
|
||
"choices": [
|
||
{
|
||
"index": 0,
|
||
"message": {"role": "assistant", "content": content},
|
||
"finish_reason": "stop",
|
||
}
|
||
],
|
||
"usage": {"prompt_tokens": 50, "completion_tokens": 2},
|
||
}
|
||
|
||
|
||
def _mock_client(response_data: dict, status_code: int = 200):
|
||
"""Create a mocked httpx.AsyncClient that returns *response_data*."""
|
||
mock_response = MagicMock()
|
||
mock_response.status_code = status_code
|
||
mock_response.json.return_value = response_data
|
||
mock_response.raise_for_status = MagicMock()
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post.return_value = mock_response
|
||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||
|
||
return mock_client
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# _parse_rating unit tests
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestParseRating:
|
||
def test_single_digit(self):
|
||
assert _parse_rating("7") == pytest.approx((7 - 1) / 9.0)
|
||
|
||
def test_rating_10(self):
|
||
assert _parse_rating("10") == pytest.approx(1.0)
|
||
|
||
def test_rating_1(self):
|
||
assert _parse_rating("1") == pytest.approx(0.0)
|
||
|
||
def test_rating_5(self):
|
||
assert _parse_rating("5") == pytest.approx((5 - 1) / 9.0)
|
||
|
||
def test_embedded_in_text(self):
|
||
assert _parse_rating("I would rate this a 8 out of 10") == pytest.approx(
|
||
(8 - 1) / 9.0
|
||
)
|
||
|
||
def test_no_rating(self):
|
||
assert _parse_rating("This is great!") == 0.0
|
||
|
||
def test_empty_string(self):
|
||
assert _parse_rating("") == 0.0
|
||
|
||
def test_zero_not_matched(self):
|
||
# 0 is outside the 1–10 range, should not match.
|
||
assert _parse_rating("0") == 0.0
|
||
|
||
def test_number_above_10_extracts_partial(self):
|
||
# "11" should match "1" (the first valid 1–10 integer).
|
||
result = _parse_rating("11")
|
||
assert result == pytest.approx(0.0) # matches "1"
|
||
|
||
def test_rating_with_newlines(self):
|
||
assert _parse_rating("\n\n8\n") == pytest.approx((8 - 1) / 9.0)
|
||
|
||
def test_rating_with_period(self):
|
||
assert _parse_rating("Rating: 9.") == pytest.approx((9 - 1) / 9.0)
|
||
|
||
def test_ten_in_sentence(self):
|
||
assert _parse_rating("Score: 10 points") == pytest.approx(1.0)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# LLMJudgeScorer class tests
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestLLMJudgeScorerProperties:
|
||
def test_is_base_scorer_subclass(self):
|
||
scorer = LLMJudgeScorer()
|
||
assert isinstance(scorer, BaseScorer)
|
||
|
||
def test_name(self):
|
||
assert LLMJudgeScorer().name == "llm_judge"
|
||
|
||
def test_costs_tokens_marker(self):
|
||
assert LLMJudgeScorer.COSTS_TOKENS is True
|
||
|
||
def test_custom_judge_prompt(self):
|
||
scorer = LLMJudgeScorer(judge_prompt="Custom prompt")
|
||
assert scorer.judge_prompt == "Custom prompt"
|
||
|
||
def test_default_judge_prompt(self):
|
||
scorer = LLMJudgeScorer()
|
||
assert scorer.judge_prompt == DEFAULT_JUDGE_PROMPT
|
||
|
||
|
||
class TestBuildUserMessage:
|
||
def test_basic_message(self):
|
||
scorer = LLMJudgeScorer()
|
||
msg = scorer._build_user_message("What is 2+2?", "4", {})
|
||
assert "## Input" in msg
|
||
assert "What is 2+2?" in msg
|
||
assert "## Output" in msg
|
||
assert "4" in msg
|
||
|
||
def test_no_input(self):
|
||
scorer = LLMJudgeScorer()
|
||
msg = scorer._build_user_message(None, "Hello", {})
|
||
assert "## Input" not in msg
|
||
assert "## Output" in msg
|
||
|
||
def test_with_reference(self):
|
||
scorer = LLMJudgeScorer()
|
||
msg = scorer._build_user_message("q", "a", {"reference": "correct answer"})
|
||
assert "## Reference Answer" in msg
|
||
assert "correct answer" in msg
|
||
|
||
def test_without_reference(self):
|
||
scorer = LLMJudgeScorer()
|
||
msg = scorer._build_user_message("q", "a", {})
|
||
assert "## Reference Answer" not in msg
|
||
|
||
|
||
class TestScoreAsync:
|
||
@pytest.mark.asyncio
|
||
async def test_basic_score(self):
|
||
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||
mock_client = _mock_client(_chat_response("8"))
|
||
|
||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||
score = await scorer.score_async("input", "output", {})
|
||
|
||
assert score == pytest.approx((8 - 1) / 9.0)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_perfect_score(self):
|
||
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||
mock_client = _mock_client(_chat_response("10"))
|
||
|
||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||
score = await scorer.score_async("input", "output", {})
|
||
|
||
assert score == pytest.approx(1.0)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_lowest_score(self):
|
||
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||
mock_client = _mock_client(_chat_response("1"))
|
||
|
||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||
score = await scorer.score_async("input", "output", {})
|
||
|
||
assert score == pytest.approx(0.0)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_unparseable_response(self):
|
||
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||
mock_client = _mock_client(_chat_response("I cannot rate this"))
|
||
|
||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||
score = await scorer.score_async("input", "output", {})
|
||
|
||
assert score == 0.0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_empty_choices(self):
|
||
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||
response = {"choices": []}
|
||
mock_client = _mock_client(response)
|
||
|
||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||
score = await scorer.score_async("input", "output", {})
|
||
|
||
assert score == 0.0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_request_includes_correct_body(self):
|
||
scorer = LLMJudgeScorer(
|
||
base_url="http://fake:11434/v1",
|
||
model="judge-model",
|
||
judge_prompt="Rate it",
|
||
)
|
||
mock_client = _mock_client(_chat_response("7"))
|
||
|
||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||
await scorer.score_async("test input", "test output", {})
|
||
|
||
call_kwargs = mock_client.post.call_args
|
||
body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
|
||
assert body["model"] == "judge-model"
|
||
assert body["messages"][0]["role"] == "system"
|
||
assert body["messages"][0]["content"] == "Rate it"
|
||
assert body["messages"][1]["role"] == "user"
|
||
assert body["temperature"] == 0.0
|
||
assert body["max_tokens"] == 16
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_api_key_in_headers(self):
|
||
scorer = LLMJudgeScorer(
|
||
base_url="http://fake:11434/v1",
|
||
api_key="sk-test-key",
|
||
)
|
||
mock_client = _mock_client(_chat_response("5"))
|
||
|
||
with patch("httpx.AsyncClient", return_value=mock_client) as mock_cls:
|
||
await scorer.score_async("input", "output", {})
|
||
|
||
call_kwargs = mock_cls.call_args
|
||
headers = call_kwargs.kwargs.get("headers") or call_kwargs[1].get("headers")
|
||
assert headers["Authorization"] == "Bearer sk-test-key"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_reference_included_in_message(self):
|
||
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||
mock_client = _mock_client(_chat_response("9"))
|
||
|
||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||
await scorer.score_async(
|
||
"q", "a", {"reference": "expected answer"}
|
||
)
|
||
|
||
call_kwargs = mock_client.post.call_args
|
||
body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
|
||
user_msg = body["messages"][1]["content"]
|
||
assert "expected answer" in user_msg
|
||
|
||
|
||
class TestScoreAsyncWithCache:
|
||
@pytest.mark.asyncio
|
||
async def test_cache_hit_skips_llm_call(self):
|
||
mock_cache = MagicMock()
|
||
cached_response = MagicMock()
|
||
cached_response.response = "8"
|
||
mock_cache.get.return_value = cached_response
|
||
|
||
mock_session = MagicMock()
|
||
mock_session_factory = MagicMock(return_value=mock_session)
|
||
|
||
scorer = LLMJudgeScorer(
|
||
base_url="http://fake:11434/v1",
|
||
cache_layer=mock_cache,
|
||
db_session_factory=mock_session_factory,
|
||
)
|
||
|
||
with patch("httpx.AsyncClient") as mock_http:
|
||
score = await scorer.score_async("input", "output", {})
|
||
|
||
assert score == pytest.approx((8 - 1) / 9.0)
|
||
mock_http.assert_not_called()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_cache_miss_calls_llm_and_stores(self):
|
||
mock_cache = MagicMock()
|
||
mock_cache.get.return_value = None
|
||
|
||
mock_session = MagicMock()
|
||
mock_session_factory = MagicMock(return_value=mock_session)
|
||
|
||
scorer = LLMJudgeScorer(
|
||
base_url="http://fake:11434/v1",
|
||
cache_layer=mock_cache,
|
||
db_session_factory=mock_session_factory,
|
||
)
|
||
mock_client = _mock_client(_chat_response("6"))
|
||
|
||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||
score = await scorer.score_async("input", "output", {})
|
||
|
||
assert score == pytest.approx((6 - 1) / 9.0)
|
||
mock_cache.put.assert_called_once()
|
||
put_kwargs = mock_cache.put.call_args
|
||
assert put_kwargs.kwargs.get("response") == "6" or (
|
||
len(put_kwargs.args) > 2 and put_kwargs.args[2] == "6"
|
||
)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_no_cache_when_layer_not_configured(self):
|
||
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||
mock_client = _mock_client(_chat_response("7"))
|
||
|
||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||
score = await scorer.score_async("input", "output", {})
|
||
|
||
assert score == pytest.approx((7 - 1) / 9.0)
|
||
|
||
|
||
class TestScoreSync:
|
||
def test_sync_score_works_outside_async(self):
|
||
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||
mock_client = _mock_client(_chat_response("5"))
|
||
|
||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||
score = scorer.score("input", "output", {})
|
||
|
||
assert score == pytest.approx((5 - 1) / 9.0)
|
||
|
||
def test_sync_score_raises_in_async_context(self):
|
||
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||
|
||
async def _run():
|
||
scorer.score("input", "output", {})
|
||
|
||
with pytest.raises(RuntimeError, match="cannot be called from an async"):
|
||
asyncio.get_event_loop().run_until_complete(_run())
|
||
|
||
|
||
class TestRetries:
|
||
@pytest.mark.asyncio
|
||
async def test_retries_on_server_error(self):
|
||
scorer = LLMJudgeScorer(
|
||
base_url="http://fake:11434/v1", max_retries=3
|
||
)
|
||
|
||
# First two calls fail with 500, third succeeds.
|
||
fail_resp = MagicMock()
|
||
fail_resp.status_code = 500
|
||
fail_resp.json.return_value = {}
|
||
|
||
ok_resp = MagicMock()
|
||
ok_resp.status_code = 200
|
||
ok_resp.json.return_value = _chat_response("7")
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post.side_effect = [fail_resp, fail_resp, ok_resp]
|
||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||
|
||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||
score = await scorer.score_async("input", "output", {})
|
||
|
||
assert score == pytest.approx((7 - 1) / 9.0)
|
||
assert mock_client.post.call_count == 3
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_all_retries_exhausted(self):
|
||
scorer = LLMJudgeScorer(
|
||
base_url="http://fake:11434/v1", max_retries=2
|
||
)
|
||
|
||
fail_resp = MagicMock()
|
||
fail_resp.status_code = 500
|
||
fail_resp.json.return_value = {}
|
||
fail_resp.request = MagicMock()
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.post.return_value = fail_resp
|
||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||
|
||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||
with pytest.raises(RuntimeError, match="All 2 attempts failed"):
|
||
await scorer.score_async("input", "output", {})
|