MAESTRO: Implement LLMJudgeScorer with configurable judge prompt, rating parsing, and response caching
This commit is contained in:
parent
0d5a6169c5
commit
fb78eac1b0
4 changed files with 621 additions and 2 deletions
|
|
@ -26,7 +26,8 @@ Implement the core experiment execution engine: LLM adapters, response caching,
|
||||||
- [x] Implement backend/engine/scorers/keyword.py — checks for presence/absence of required keywords in output. Configurable with required_present and required_absent lists. Score = (found / required) ratio.
|
- [x] Implement backend/engine/scorers/keyword.py — checks for presence/absence of required keywords in output. Configurable with required_present and required_absent lists. Score = (found / required) ratio.
|
||||||
<!-- Completed: KeywordScorer with required_present/required_absent lists, case-sensitive option, combined ratio scoring. 37 tests in test_scorer_keyword.py, all passing. -->
|
<!-- Completed: KeywordScorer with required_present/required_absent lists, case-sensitive option, combined ratio scoring. 37 tests in test_scorer_keyword.py, all passing. -->
|
||||||
|
|
||||||
- [ ] Implement backend/engine/scorers/llm_judge.py — sends the output to a separate LLM with a configurable judge prompt and asks for a 1-10 rating. Parses the numeric score from the response. This scorer requires an LLM call so it should be clearly marked as "costs tokens" in the UI. Cache the judge's response too.
|
- [x] Implement backend/engine/scorers/llm_judge.py — sends the output to a separate LLM with a configurable judge prompt and asks for a 1-10 rating. Parses the numeric score from the response. This scorer requires an LLM call so it should be clearly marked as "costs tokens" in the UI. Cache the judge's response too.
|
||||||
|
<!-- Completed: LLMJudgeScorer with configurable judge prompt, 1-10 rating parsing via regex, normalized to 0.0-1.0. COSTS_TOKENS class marker for UI. Optional ResponseCacheLayer integration for caching judge responses. Retries with exponential backoff. 36 tests in test_scorer_llm_judge.py, all passing. -->
|
||||||
|
|
||||||
- [ ] Wire up the Celery worker in backend/worker.py. Define tasks: execute_run(run_id), execute_sweep(experiment_id, sweep_config). Configure Celery to use Redis as broker. In single-container mode (no Redis), implement a simple synchronous fallback that runs tasks in-process.
|
- [ ] Wire up the Celery worker in backend/worker.py. Define tasks: execute_run(run_id), execute_sweep(experiment_id, sweep_config). Configure Celery to use Redis as broker. In single-container mode (no Redis), implement a simple synchronous fallback that runs tasks in-process.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,5 +4,12 @@ from engine.scorers.base import BaseScorer
|
||||||
from engine.scorers.embedding import EmbeddingScorer
|
from engine.scorers.embedding import EmbeddingScorer
|
||||||
from engine.scorers.format import FormatScorer
|
from engine.scorers.format import FormatScorer
|
||||||
from engine.scorers.keyword import KeywordScorer
|
from engine.scorers.keyword import KeywordScorer
|
||||||
|
from engine.scorers.llm_judge import LLMJudgeScorer
|
||||||
|
|
||||||
__all__ = ["BaseScorer", "EmbeddingScorer", "FormatScorer", "KeywordScorer"]
|
__all__ = [
|
||||||
|
"BaseScorer",
|
||||||
|
"EmbeddingScorer",
|
||||||
|
"FormatScorer",
|
||||||
|
"KeywordScorer",
|
||||||
|
"LLMJudgeScorer",
|
||||||
|
]
|
||||||
|
|
|
||||||
232
backend/engine/scorers/llm_judge.py
Normal file
232
backend/engine/scorers/llm_judge.py
Normal file
|
|
@ -0,0 +1,232 @@
|
||||||
|
"""LLM-as-judge scorer.
|
||||||
|
|
||||||
|
Sends the LLM output to a separate LLM with a configurable judge prompt,
|
||||||
|
asks for a 1–10 rating, and normalizes to the 0.0–1.0 range.
|
||||||
|
|
||||||
|
**This scorer costs tokens** — every evaluation makes an LLM call. The
|
||||||
|
judge's response is cached via PromptLooper's response cache layer to
|
||||||
|
avoid redundant calls when re-scoring the same output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from engine.scorers.base import BaseScorer
|
||||||
|
|
||||||
|
|
||||||
|
# Default judge system prompt — can be overridden at construction time.
|
||||||
|
DEFAULT_JUDGE_PROMPT = (
|
||||||
|
"You are an impartial evaluator. You will receive an input and an LLM-generated "
|
||||||
|
"output. Rate the quality of the output on a scale of 1 to 10, where 1 is terrible "
|
||||||
|
"and 10 is perfect.\n\n"
|
||||||
|
"Respond with ONLY a single integer between 1 and 10. Do not include any other text."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Regex to extract the first integer 1–10 from the judge response.
|
||||||
|
_RATING_RE = re.compile(r"\b(10|[1-9])\b")
|
||||||
|
|
||||||
|
|
||||||
|
class LLMJudgeScorer(BaseScorer):
|
||||||
|
"""Score outputs by asking a separate LLM to rate them 1–10.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_url: Chat completions API base URL.
|
||||||
|
model: Model to use for judging.
|
||||||
|
api_key: Optional API key.
|
||||||
|
judge_prompt: System prompt for the judge LLM.
|
||||||
|
timeout: HTTP request timeout in seconds.
|
||||||
|
max_retries: Retry attempts on transient failures.
|
||||||
|
cache_layer: Optional ``ResponseCacheLayer`` instance. When provided,
|
||||||
|
judge responses are cached to avoid duplicate LLM calls.
|
||||||
|
db_session_factory: Callable returning a SQLAlchemy ``Session``.
|
||||||
|
Required when *cache_layer* is supplied.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Marker for the UI so it can warn users about token cost.
|
||||||
|
COSTS_TOKENS = True
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str = "http://localhost:11434/v1",
|
||||||
|
model: str = "llama3",
|
||||||
|
api_key: str | None = None,
|
||||||
|
judge_prompt: str = DEFAULT_JUDGE_PROMPT,
|
||||||
|
timeout: float = 120.0,
|
||||||
|
max_retries: int = 3,
|
||||||
|
cache_layer: Any = None,
|
||||||
|
db_session_factory: Any = None,
|
||||||
|
) -> None:
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.model = model
|
||||||
|
self.api_key = api_key
|
||||||
|
self.judge_prompt = judge_prompt
|
||||||
|
self.timeout = timeout
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self._cache_layer = cache_layer
|
||||||
|
self._db_session_factory = db_session_factory
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "llm_judge"
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Synchronous entry point
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def score(self, input_data: Any, output: str, context: dict) -> float:
|
||||||
|
"""Synchronous scoring — delegates to the async variant."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = None
|
||||||
|
|
||||||
|
if loop and loop.is_running():
|
||||||
|
raise RuntimeError(
|
||||||
|
"LLMJudgeScorer.score() cannot be called from an async context. "
|
||||||
|
"Use score_async() instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
return asyncio.get_event_loop().run_until_complete(
|
||||||
|
self.score_async(input_data, output, context)
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Async entry point
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def score_async(
|
||||||
|
self, input_data: Any, output: str, context: dict
|
||||||
|
) -> float:
|
||||||
|
"""Ask the judge LLM to rate the output and return a normalised score."""
|
||||||
|
user_message = self._build_user_message(input_data, output, context)
|
||||||
|
|
||||||
|
# Check cache first.
|
||||||
|
config_hash: str | None = None
|
||||||
|
if self._cache_layer and self._db_session_factory:
|
||||||
|
from engine.cache import compute_config_hash
|
||||||
|
|
||||||
|
config_hash = compute_config_hash(
|
||||||
|
prompt=self.judge_prompt,
|
||||||
|
model=self.model,
|
||||||
|
params={"scorer": "llm_judge"},
|
||||||
|
input_data=user_message,
|
||||||
|
)
|
||||||
|
db = self._db_session_factory()
|
||||||
|
try:
|
||||||
|
cached = self._cache_layer.get(db, config_hash)
|
||||||
|
if cached is not None:
|
||||||
|
return _parse_rating(cached.response)
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
# Call the judge LLM.
|
||||||
|
judge_response = await self._call_judge(user_message)
|
||||||
|
|
||||||
|
# Cache the judge response.
|
||||||
|
if self._cache_layer and self._db_session_factory and config_hash:
|
||||||
|
db = self._db_session_factory()
|
||||||
|
try:
|
||||||
|
self._cache_layer.put(
|
||||||
|
db,
|
||||||
|
config_hash=config_hash,
|
||||||
|
response=judge_response,
|
||||||
|
model=self.model,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
return _parse_rating(judge_response)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Internal helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _build_user_message(
|
||||||
|
self, input_data: Any, output: str, context: dict
|
||||||
|
) -> str:
|
||||||
|
"""Build the user message sent to the judge LLM."""
|
||||||
|
parts = []
|
||||||
|
if input_data is not None:
|
||||||
|
parts.append(f"## Input\n{input_data}")
|
||||||
|
parts.append(f"## Output\n{output}")
|
||||||
|
|
||||||
|
# Include reference answer if available — helps the judge compare.
|
||||||
|
reference = context.get("reference")
|
||||||
|
if reference:
|
||||||
|
parts.append(f"## Reference Answer\n{reference}")
|
||||||
|
|
||||||
|
return "\n\n".join(parts)
|
||||||
|
|
||||||
|
async def _call_judge(self, user_message: str) -> str:
|
||||||
|
"""Send a chat completion request to the judge LLM with retries."""
|
||||||
|
url = f"{self.base_url}/chat/completions"
|
||||||
|
|
||||||
|
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||||
|
if self.api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||||
|
|
||||||
|
body = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": self.judge_prompt},
|
||||||
|
{"role": "user", "content": user_message},
|
||||||
|
],
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_tokens": 16,
|
||||||
|
}
|
||||||
|
|
||||||
|
last_exc: Exception | None = None
|
||||||
|
retryable = {429, 500, 502, 503, 504}
|
||||||
|
|
||||||
|
for attempt in range(self.max_retries):
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(
|
||||||
|
timeout=httpx.Timeout(self.timeout), headers=headers
|
||||||
|
) as client:
|
||||||
|
resp = await client.post(url, json=body)
|
||||||
|
|
||||||
|
if resp.status_code == 200:
|
||||||
|
data = resp.json()
|
||||||
|
choices = data.get("choices", [])
|
||||||
|
if choices:
|
||||||
|
return choices[0].get("message", {}).get("content", "").strip()
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if resp.status_code not in retryable:
|
||||||
|
resp.raise_for_status()
|
||||||
|
|
||||||
|
last_exc = httpx.HTTPStatusError(
|
||||||
|
f"HTTP {resp.status_code}",
|
||||||
|
request=resp.request,
|
||||||
|
response=resp,
|
||||||
|
)
|
||||||
|
except httpx.HTTPStatusError:
|
||||||
|
raise
|
||||||
|
except httpx.HTTPError as exc:
|
||||||
|
last_exc = exc
|
||||||
|
|
||||||
|
if attempt < self.max_retries - 1:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
await asyncio.sleep(2**attempt)
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
f"All {self.max_retries} attempts failed for judge LLM at {url}"
|
||||||
|
) from last_exc
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_rating(text: str) -> float:
|
||||||
|
"""Extract a 1–10 rating from the judge response and normalise to 0.0–1.0.
|
||||||
|
|
||||||
|
Falls back to 0.0 if no valid rating is found.
|
||||||
|
"""
|
||||||
|
match = _RATING_RE.search(text)
|
||||||
|
if match is None:
|
||||||
|
return 0.0
|
||||||
|
rating = int(match.group(1))
|
||||||
|
# Normalise: 1 → ~0.0, 10 → 1.0
|
||||||
|
return max(0.0, min(1.0, (rating - 1) / 9.0))
|
||||||
379
backend/tests/test_scorer_llm_judge.py
Normal file
379
backend/tests/test_scorer_llm_judge.py
Normal file
|
|
@ -0,0 +1,379 @@
|
||||||
|
"""Tests for the LLMJudgeScorer."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from engine.scorers.base import BaseScorer
|
||||||
|
from engine.scorers.llm_judge import (
|
||||||
|
DEFAULT_JUDGE_PROMPT,
|
||||||
|
LLMJudgeScorer,
|
||||||
|
_parse_rating,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _chat_response(content: str) -> dict:
|
||||||
|
"""Build a fake OpenAI-compatible chat completion response."""
|
||||||
|
return {
|
||||||
|
"id": "chatcmpl-test",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {"role": "assistant", "content": content},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 50, "completion_tokens": 2},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_client(response_data: dict, status_code: int = 200):
|
||||||
|
"""Create a mocked httpx.AsyncClient that returns *response_data*."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = status_code
|
||||||
|
mock_response.json.return_value = response_data
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = mock_response
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
return mock_client
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _parse_rating unit tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestParseRating:
|
||||||
|
def test_single_digit(self):
|
||||||
|
assert _parse_rating("7") == pytest.approx((7 - 1) / 9.0)
|
||||||
|
|
||||||
|
def test_rating_10(self):
|
||||||
|
assert _parse_rating("10") == pytest.approx(1.0)
|
||||||
|
|
||||||
|
def test_rating_1(self):
|
||||||
|
assert _parse_rating("1") == pytest.approx(0.0)
|
||||||
|
|
||||||
|
def test_rating_5(self):
|
||||||
|
assert _parse_rating("5") == pytest.approx((5 - 1) / 9.0)
|
||||||
|
|
||||||
|
def test_embedded_in_text(self):
|
||||||
|
assert _parse_rating("I would rate this a 8 out of 10") == pytest.approx(
|
||||||
|
(8 - 1) / 9.0
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_no_rating(self):
|
||||||
|
assert _parse_rating("This is great!") == 0.0
|
||||||
|
|
||||||
|
def test_empty_string(self):
|
||||||
|
assert _parse_rating("") == 0.0
|
||||||
|
|
||||||
|
def test_zero_not_matched(self):
|
||||||
|
# 0 is outside the 1–10 range, should not match.
|
||||||
|
assert _parse_rating("0") == 0.0
|
||||||
|
|
||||||
|
def test_number_above_10_extracts_partial(self):
|
||||||
|
# "11" should match "1" (the first valid 1–10 integer).
|
||||||
|
result = _parse_rating("11")
|
||||||
|
assert result == pytest.approx(0.0) # matches "1"
|
||||||
|
|
||||||
|
def test_rating_with_newlines(self):
|
||||||
|
assert _parse_rating("\n\n8\n") == pytest.approx((8 - 1) / 9.0)
|
||||||
|
|
||||||
|
def test_rating_with_period(self):
|
||||||
|
assert _parse_rating("Rating: 9.") == pytest.approx((9 - 1) / 9.0)
|
||||||
|
|
||||||
|
def test_ten_in_sentence(self):
|
||||||
|
assert _parse_rating("Score: 10 points") == pytest.approx(1.0)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# LLMJudgeScorer class tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestLLMJudgeScorerProperties:
|
||||||
|
def test_is_base_scorer_subclass(self):
|
||||||
|
scorer = LLMJudgeScorer()
|
||||||
|
assert isinstance(scorer, BaseScorer)
|
||||||
|
|
||||||
|
def test_name(self):
|
||||||
|
assert LLMJudgeScorer().name == "llm_judge"
|
||||||
|
|
||||||
|
def test_costs_tokens_marker(self):
|
||||||
|
assert LLMJudgeScorer.COSTS_TOKENS is True
|
||||||
|
|
||||||
|
def test_custom_judge_prompt(self):
|
||||||
|
scorer = LLMJudgeScorer(judge_prompt="Custom prompt")
|
||||||
|
assert scorer.judge_prompt == "Custom prompt"
|
||||||
|
|
||||||
|
def test_default_judge_prompt(self):
|
||||||
|
scorer = LLMJudgeScorer()
|
||||||
|
assert scorer.judge_prompt == DEFAULT_JUDGE_PROMPT
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildUserMessage:
|
||||||
|
def test_basic_message(self):
|
||||||
|
scorer = LLMJudgeScorer()
|
||||||
|
msg = scorer._build_user_message("What is 2+2?", "4", {})
|
||||||
|
assert "## Input" in msg
|
||||||
|
assert "What is 2+2?" in msg
|
||||||
|
assert "## Output" in msg
|
||||||
|
assert "4" in msg
|
||||||
|
|
||||||
|
def test_no_input(self):
|
||||||
|
scorer = LLMJudgeScorer()
|
||||||
|
msg = scorer._build_user_message(None, "Hello", {})
|
||||||
|
assert "## Input" not in msg
|
||||||
|
assert "## Output" in msg
|
||||||
|
|
||||||
|
def test_with_reference(self):
|
||||||
|
scorer = LLMJudgeScorer()
|
||||||
|
msg = scorer._build_user_message("q", "a", {"reference": "correct answer"})
|
||||||
|
assert "## Reference Answer" in msg
|
||||||
|
assert "correct answer" in msg
|
||||||
|
|
||||||
|
def test_without_reference(self):
|
||||||
|
scorer = LLMJudgeScorer()
|
||||||
|
msg = scorer._build_user_message("q", "a", {})
|
||||||
|
assert "## Reference Answer" not in msg
|
||||||
|
|
||||||
|
|
||||||
|
class TestScoreAsync:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_basic_score(self):
|
||||||
|
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||||||
|
mock_client = _mock_client(_chat_response("8"))
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||||
|
score = await scorer.score_async("input", "output", {})
|
||||||
|
|
||||||
|
assert score == pytest.approx((8 - 1) / 9.0)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_perfect_score(self):
|
||||||
|
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||||||
|
mock_client = _mock_client(_chat_response("10"))
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||||
|
score = await scorer.score_async("input", "output", {})
|
||||||
|
|
||||||
|
assert score == pytest.approx(1.0)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lowest_score(self):
|
||||||
|
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||||||
|
mock_client = _mock_client(_chat_response("1"))
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||||
|
score = await scorer.score_async("input", "output", {})
|
||||||
|
|
||||||
|
assert score == pytest.approx(0.0)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unparseable_response(self):
|
||||||
|
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||||||
|
mock_client = _mock_client(_chat_response("I cannot rate this"))
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||||
|
score = await scorer.score_async("input", "output", {})
|
||||||
|
|
||||||
|
assert score == 0.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_choices(self):
|
||||||
|
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||||||
|
response = {"choices": []}
|
||||||
|
mock_client = _mock_client(response)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||||
|
score = await scorer.score_async("input", "output", {})
|
||||||
|
|
||||||
|
assert score == 0.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_request_includes_correct_body(self):
|
||||||
|
scorer = LLMJudgeScorer(
|
||||||
|
base_url="http://fake:11434/v1",
|
||||||
|
model="judge-model",
|
||||||
|
judge_prompt="Rate it",
|
||||||
|
)
|
||||||
|
mock_client = _mock_client(_chat_response("7"))
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||||
|
await scorer.score_async("test input", "test output", {})
|
||||||
|
|
||||||
|
call_kwargs = mock_client.post.call_args
|
||||||
|
body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
|
||||||
|
assert body["model"] == "judge-model"
|
||||||
|
assert body["messages"][0]["role"] == "system"
|
||||||
|
assert body["messages"][0]["content"] == "Rate it"
|
||||||
|
assert body["messages"][1]["role"] == "user"
|
||||||
|
assert body["temperature"] == 0.0
|
||||||
|
assert body["max_tokens"] == 16
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_key_in_headers(self):
|
||||||
|
scorer = LLMJudgeScorer(
|
||||||
|
base_url="http://fake:11434/v1",
|
||||||
|
api_key="sk-test-key",
|
||||||
|
)
|
||||||
|
mock_client = _mock_client(_chat_response("5"))
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_client) as mock_cls:
|
||||||
|
await scorer.score_async("input", "output", {})
|
||||||
|
|
||||||
|
call_kwargs = mock_cls.call_args
|
||||||
|
headers = call_kwargs.kwargs.get("headers") or call_kwargs[1].get("headers")
|
||||||
|
assert headers["Authorization"] == "Bearer sk-test-key"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reference_included_in_message(self):
|
||||||
|
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||||||
|
mock_client = _mock_client(_chat_response("9"))
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||||
|
await scorer.score_async(
|
||||||
|
"q", "a", {"reference": "expected answer"}
|
||||||
|
)
|
||||||
|
|
||||||
|
call_kwargs = mock_client.post.call_args
|
||||||
|
body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
|
||||||
|
user_msg = body["messages"][1]["content"]
|
||||||
|
assert "expected answer" in user_msg
|
||||||
|
|
||||||
|
|
||||||
|
class TestScoreAsyncWithCache:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_hit_skips_llm_call(self):
|
||||||
|
mock_cache = MagicMock()
|
||||||
|
cached_response = MagicMock()
|
||||||
|
cached_response.response = "8"
|
||||||
|
mock_cache.get.return_value = cached_response
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session_factory = MagicMock(return_value=mock_session)
|
||||||
|
|
||||||
|
scorer = LLMJudgeScorer(
|
||||||
|
base_url="http://fake:11434/v1",
|
||||||
|
cache_layer=mock_cache,
|
||||||
|
db_session_factory=mock_session_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as mock_http:
|
||||||
|
score = await scorer.score_async("input", "output", {})
|
||||||
|
|
||||||
|
assert score == pytest.approx((8 - 1) / 9.0)
|
||||||
|
mock_http.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_miss_calls_llm_and_stores(self):
|
||||||
|
mock_cache = MagicMock()
|
||||||
|
mock_cache.get.return_value = None
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session_factory = MagicMock(return_value=mock_session)
|
||||||
|
|
||||||
|
scorer = LLMJudgeScorer(
|
||||||
|
base_url="http://fake:11434/v1",
|
||||||
|
cache_layer=mock_cache,
|
||||||
|
db_session_factory=mock_session_factory,
|
||||||
|
)
|
||||||
|
mock_client = _mock_client(_chat_response("6"))
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||||
|
score = await scorer.score_async("input", "output", {})
|
||||||
|
|
||||||
|
assert score == pytest.approx((6 - 1) / 9.0)
|
||||||
|
mock_cache.put.assert_called_once()
|
||||||
|
put_kwargs = mock_cache.put.call_args
|
||||||
|
assert put_kwargs.kwargs.get("response") == "6" or (
|
||||||
|
len(put_kwargs.args) > 2 and put_kwargs.args[2] == "6"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_cache_when_layer_not_configured(self):
|
||||||
|
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||||||
|
mock_client = _mock_client(_chat_response("7"))
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||||
|
score = await scorer.score_async("input", "output", {})
|
||||||
|
|
||||||
|
assert score == pytest.approx((7 - 1) / 9.0)
|
||||||
|
|
||||||
|
|
||||||
|
class TestScoreSync:
|
||||||
|
def test_sync_score_works_outside_async(self):
|
||||||
|
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||||||
|
mock_client = _mock_client(_chat_response("5"))
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||||
|
score = scorer.score("input", "output", {})
|
||||||
|
|
||||||
|
assert score == pytest.approx((5 - 1) / 9.0)
|
||||||
|
|
||||||
|
def test_sync_score_raises_in_async_context(self):
|
||||||
|
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||||||
|
|
||||||
|
async def _run():
|
||||||
|
scorer.score("input", "output", {})
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="cannot be called from an async"):
|
||||||
|
asyncio.get_event_loop().run_until_complete(_run())
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetries:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retries_on_server_error(self):
|
||||||
|
scorer = LLMJudgeScorer(
|
||||||
|
base_url="http://fake:11434/v1", max_retries=3
|
||||||
|
)
|
||||||
|
|
||||||
|
# First two calls fail with 500, third succeeds.
|
||||||
|
fail_resp = MagicMock()
|
||||||
|
fail_resp.status_code = 500
|
||||||
|
fail_resp.json.return_value = {}
|
||||||
|
|
||||||
|
ok_resp = MagicMock()
|
||||||
|
ok_resp.status_code = 200
|
||||||
|
ok_resp.json.return_value = _chat_response("7")
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.side_effect = [fail_resp, fail_resp, ok_resp]
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||||
|
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||||
|
score = await scorer.score_async("input", "output", {})
|
||||||
|
|
||||||
|
assert score == pytest.approx((7 - 1) / 9.0)
|
||||||
|
assert mock_client.post.call_count == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_all_retries_exhausted(self):
|
||||||
|
scorer = LLMJudgeScorer(
|
||||||
|
base_url="http://fake:11434/v1", max_retries=2
|
||||||
|
)
|
||||||
|
|
||||||
|
fail_resp = MagicMock()
|
||||||
|
fail_resp.status_code = 500
|
||||||
|
fail_resp.json.return_value = {}
|
||||||
|
fail_resp.request = MagicMock()
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post.return_value = fail_resp
|
||||||
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||||
|
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||||
|
with pytest.raises(RuntimeError, match="All 2 attempts failed"):
|
||||||
|
await scorer.score_async("input", "output", {})
|
||||||
Loading…
Add table
Reference in a new issue