"""Tests for the LLMJudgeScorer.""" import asyncio from unittest.mock import AsyncMock, MagicMock, patch import pytest from engine.scorers.base import BaseScorer from engine.scorers.llm_judge import ( DEFAULT_JUDGE_PROMPT, LLMJudgeScorer, _parse_rating, ) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _chat_response(content: str) -> dict: """Build a fake OpenAI-compatible chat completion response.""" return { "id": "chatcmpl-test", "object": "chat.completion", "choices": [ { "index": 0, "message": {"role": "assistant", "content": content}, "finish_reason": "stop", } ], "usage": {"prompt_tokens": 50, "completion_tokens": 2}, } def _mock_client(response_data: dict, status_code: int = 200): """Create a mocked httpx.AsyncClient that returns *response_data*.""" mock_response = MagicMock() mock_response.status_code = status_code mock_response.json.return_value = response_data mock_response.raise_for_status = MagicMock() mock_client = AsyncMock() mock_client.post.return_value = mock_response mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=False) return mock_client # --------------------------------------------------------------------------- # _parse_rating unit tests # --------------------------------------------------------------------------- class TestParseRating: def test_single_digit(self): assert _parse_rating("7") == pytest.approx((7 - 1) / 9.0) def test_rating_10(self): assert _parse_rating("10") == pytest.approx(1.0) def test_rating_1(self): assert _parse_rating("1") == pytest.approx(0.0) def test_rating_5(self): assert _parse_rating("5") == pytest.approx((5 - 1) / 9.0) def test_embedded_in_text(self): assert _parse_rating("I would rate this a 8 out of 10") == pytest.approx( (8 - 1) / 9.0 ) def test_no_rating(self): assert _parse_rating("This is great!") == 0.0 def test_empty_string(self): assert _parse_rating("") == 0.0 def test_zero_not_matched(self): # 0 is outside the 1–10 range, should not match. assert _parse_rating("0") == 0.0 def test_number_above_10_extracts_partial(self): # "11" should match "1" (the first valid 1–10 integer). result = _parse_rating("11") assert result == pytest.approx(0.0) # matches "1" def test_rating_with_newlines(self): assert _parse_rating("\n\n8\n") == pytest.approx((8 - 1) / 9.0) def test_rating_with_period(self): assert _parse_rating("Rating: 9.") == pytest.approx((9 - 1) / 9.0) def test_ten_in_sentence(self): assert _parse_rating("Score: 10 points") == pytest.approx(1.0) # --------------------------------------------------------------------------- # LLMJudgeScorer class tests # --------------------------------------------------------------------------- class TestLLMJudgeScorerProperties: def test_is_base_scorer_subclass(self): scorer = LLMJudgeScorer() assert isinstance(scorer, BaseScorer) def test_name(self): assert LLMJudgeScorer().name == "llm_judge" def test_costs_tokens_marker(self): assert LLMJudgeScorer.COSTS_TOKENS is True def test_custom_judge_prompt(self): scorer = LLMJudgeScorer(judge_prompt="Custom prompt") assert scorer.judge_prompt == "Custom prompt" def test_default_judge_prompt(self): scorer = LLMJudgeScorer() assert scorer.judge_prompt == DEFAULT_JUDGE_PROMPT class TestBuildUserMessage: def test_basic_message(self): scorer = LLMJudgeScorer() msg = scorer._build_user_message("What is 2+2?", "4", {}) assert "## Input" in msg assert "What is 2+2?" in msg assert "## Output" in msg assert "4" in msg def test_no_input(self): scorer = LLMJudgeScorer() msg = scorer._build_user_message(None, "Hello", {}) assert "## Input" not in msg assert "## Output" in msg def test_with_reference(self): scorer = LLMJudgeScorer() msg = scorer._build_user_message("q", "a", {"reference": "correct answer"}) assert "## Reference Answer" in msg assert "correct answer" in msg def test_without_reference(self): scorer = LLMJudgeScorer() msg = scorer._build_user_message("q", "a", {}) assert "## Reference Answer" not in msg class TestScoreAsync: @pytest.mark.asyncio async def test_basic_score(self): scorer = LLMJudgeScorer(base_url="http://fake:11434/v1") mock_client = _mock_client(_chat_response("8")) with patch("httpx.AsyncClient", return_value=mock_client): score = await scorer.score_async("input", "output", {}) assert score == pytest.approx((8 - 1) / 9.0) @pytest.mark.asyncio async def test_perfect_score(self): scorer = LLMJudgeScorer(base_url="http://fake:11434/v1") mock_client = _mock_client(_chat_response("10")) with patch("httpx.AsyncClient", return_value=mock_client): score = await scorer.score_async("input", "output", {}) assert score == pytest.approx(1.0) @pytest.mark.asyncio async def test_lowest_score(self): scorer = LLMJudgeScorer(base_url="http://fake:11434/v1") mock_client = _mock_client(_chat_response("1")) with patch("httpx.AsyncClient", return_value=mock_client): score = await scorer.score_async("input", "output", {}) assert score == pytest.approx(0.0) @pytest.mark.asyncio async def test_unparseable_response(self): scorer = LLMJudgeScorer(base_url="http://fake:11434/v1") mock_client = _mock_client(_chat_response("I cannot rate this")) with patch("httpx.AsyncClient", return_value=mock_client): score = await scorer.score_async("input", "output", {}) assert score == 0.0 @pytest.mark.asyncio async def test_empty_choices(self): scorer = LLMJudgeScorer(base_url="http://fake:11434/v1") response = {"choices": []} mock_client = _mock_client(response) with patch("httpx.AsyncClient", return_value=mock_client): score = await scorer.score_async("input", "output", {}) assert score == 0.0 @pytest.mark.asyncio async def test_request_includes_correct_body(self): scorer = LLMJudgeScorer( base_url="http://fake:11434/v1", model="judge-model", judge_prompt="Rate it", ) mock_client = _mock_client(_chat_response("7")) with patch("httpx.AsyncClient", return_value=mock_client): await scorer.score_async("test input", "test output", {}) call_kwargs = mock_client.post.call_args body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json") assert body["model"] == "judge-model" assert body["messages"][0]["role"] == "system" assert body["messages"][0]["content"] == "Rate it" assert body["messages"][1]["role"] == "user" assert body["temperature"] == 0.0 assert body["max_tokens"] == 16 @pytest.mark.asyncio async def test_api_key_in_headers(self): scorer = LLMJudgeScorer( base_url="http://fake:11434/v1", api_key="sk-test-key", ) mock_client = _mock_client(_chat_response("5")) with patch("httpx.AsyncClient", return_value=mock_client) as mock_cls: await scorer.score_async("input", "output", {}) call_kwargs = mock_cls.call_args headers = call_kwargs.kwargs.get("headers") or call_kwargs[1].get("headers") assert headers["Authorization"] == "Bearer sk-test-key" @pytest.mark.asyncio async def test_reference_included_in_message(self): scorer = LLMJudgeScorer(base_url="http://fake:11434/v1") mock_client = _mock_client(_chat_response("9")) with patch("httpx.AsyncClient", return_value=mock_client): await scorer.score_async( "q", "a", {"reference": "expected answer"} ) call_kwargs = mock_client.post.call_args body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json") user_msg = body["messages"][1]["content"] assert "expected answer" in user_msg class TestScoreAsyncWithCache: @pytest.mark.asyncio async def test_cache_hit_skips_llm_call(self): mock_cache = MagicMock() cached_response = MagicMock() cached_response.response = "8" mock_cache.get.return_value = cached_response mock_session = MagicMock() mock_session_factory = MagicMock(return_value=mock_session) scorer = LLMJudgeScorer( base_url="http://fake:11434/v1", cache_layer=mock_cache, db_session_factory=mock_session_factory, ) with patch("httpx.AsyncClient") as mock_http: score = await scorer.score_async("input", "output", {}) assert score == pytest.approx((8 - 1) / 9.0) mock_http.assert_not_called() @pytest.mark.asyncio async def test_cache_miss_calls_llm_and_stores(self): mock_cache = MagicMock() mock_cache.get.return_value = None mock_session = MagicMock() mock_session_factory = MagicMock(return_value=mock_session) scorer = LLMJudgeScorer( base_url="http://fake:11434/v1", cache_layer=mock_cache, db_session_factory=mock_session_factory, ) mock_client = _mock_client(_chat_response("6")) with patch("httpx.AsyncClient", return_value=mock_client): score = await scorer.score_async("input", "output", {}) assert score == pytest.approx((6 - 1) / 9.0) mock_cache.put.assert_called_once() put_kwargs = mock_cache.put.call_args assert put_kwargs.kwargs.get("response") == "6" or ( len(put_kwargs.args) > 2 and put_kwargs.args[2] == "6" ) @pytest.mark.asyncio async def test_no_cache_when_layer_not_configured(self): scorer = LLMJudgeScorer(base_url="http://fake:11434/v1") mock_client = _mock_client(_chat_response("7")) with patch("httpx.AsyncClient", return_value=mock_client): score = await scorer.score_async("input", "output", {}) assert score == pytest.approx((7 - 1) / 9.0) class TestScoreSync: def test_sync_score_works_outside_async(self): scorer = LLMJudgeScorer(base_url="http://fake:11434/v1") mock_client = _mock_client(_chat_response("5")) with patch("httpx.AsyncClient", return_value=mock_client): score = scorer.score("input", "output", {}) assert score == pytest.approx((5 - 1) / 9.0) def test_sync_score_raises_in_async_context(self): scorer = LLMJudgeScorer(base_url="http://fake:11434/v1") async def _run(): scorer.score("input", "output", {}) with pytest.raises(RuntimeError, match="cannot be called from an async"): asyncio.get_event_loop().run_until_complete(_run()) class TestRetries: @pytest.mark.asyncio async def test_retries_on_server_error(self): scorer = LLMJudgeScorer( base_url="http://fake:11434/v1", max_retries=3 ) # First two calls fail with 500, third succeeds. fail_resp = MagicMock() fail_resp.status_code = 500 fail_resp.json.return_value = {} ok_resp = MagicMock() ok_resp.status_code = 200 ok_resp.json.return_value = _chat_response("7") mock_client = AsyncMock() mock_client.post.side_effect = [fail_resp, fail_resp, ok_resp] mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=False) with patch("httpx.AsyncClient", return_value=mock_client): with patch("asyncio.sleep", new_callable=AsyncMock): score = await scorer.score_async("input", "output", {}) assert score == pytest.approx((7 - 1) / 9.0) assert mock_client.post.call_count == 3 @pytest.mark.asyncio async def test_all_retries_exhausted(self): scorer = LLMJudgeScorer( base_url="http://fake:11434/v1", max_retries=2 ) fail_resp = MagicMock() fail_resp.status_code = 500 fail_resp.json.return_value = {} fail_resp.request = MagicMock() mock_client = AsyncMock() mock_client.post.return_value = fail_resp mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=False) with patch("httpx.AsyncClient", return_value=mock_client): with patch("asyncio.sleep", new_callable=AsyncMock): with pytest.raises(RuntimeError, match="All 2 attempts failed"): await scorer.score_async("input", "output", {})