MAESTRO: Implement LLMJudgeScorer with configurable judge prompt, rating parsing, and response caching
This commit is contained in:
parent
0d5a6169c5
commit
fb78eac1b0
4 changed files with 621 additions and 2 deletions
|
|
@ -26,7 +26,8 @@ Implement the core experiment execution engine: LLM adapters, response caching,
|
|||
- [x] Implement backend/engine/scorers/keyword.py — checks for presence/absence of required keywords in output. Configurable with required_present and required_absent lists. Score = (found / required) ratio.
|
||||
<!-- Completed: KeywordScorer with required_present/required_absent lists, case-sensitive option, combined ratio scoring. 37 tests in test_scorer_keyword.py, all passing. -->
|
||||
|
||||
- [ ] Implement backend/engine/scorers/llm_judge.py — sends the output to a separate LLM with a configurable judge prompt and asks for a 1-10 rating. Parses the numeric score from the response. This scorer requires an LLM call so it should be clearly marked as "costs tokens" in the UI. Cache the judge's response too.
|
||||
- [x] Implement backend/engine/scorers/llm_judge.py — sends the output to a separate LLM with a configurable judge prompt and asks for a 1-10 rating. Parses the numeric score from the response. This scorer requires an LLM call so it should be clearly marked as "costs tokens" in the UI. Cache the judge's response too.
|
||||
<!-- Completed: LLMJudgeScorer with configurable judge prompt, 1-10 rating parsing via regex, normalized to 0.0-1.0. COSTS_TOKENS class marker for UI. Optional ResponseCacheLayer integration for caching judge responses. Retries with exponential backoff. 36 tests in test_scorer_llm_judge.py, all passing. -->
|
||||
|
||||
- [ ] Wire up the Celery worker in backend/worker.py. Define tasks: execute_run(run_id), execute_sweep(experiment_id, sweep_config). Configure Celery to use Redis as broker. In single-container mode (no Redis), implement a simple synchronous fallback that runs tasks in-process.
|
||||
|
||||
|
|
|
|||
|
|
@ -4,5 +4,12 @@ from engine.scorers.base import BaseScorer
|
|||
from engine.scorers.embedding import EmbeddingScorer
|
||||
from engine.scorers.format import FormatScorer
|
||||
from engine.scorers.keyword import KeywordScorer
|
||||
from engine.scorers.llm_judge import LLMJudgeScorer
|
||||
|
||||
__all__ = ["BaseScorer", "EmbeddingScorer", "FormatScorer", "KeywordScorer"]
|
||||
__all__ = [
|
||||
"BaseScorer",
|
||||
"EmbeddingScorer",
|
||||
"FormatScorer",
|
||||
"KeywordScorer",
|
||||
"LLMJudgeScorer",
|
||||
]
|
||||
|
|
|
|||
232
backend/engine/scorers/llm_judge.py
Normal file
232
backend/engine/scorers/llm_judge.py
Normal file
|
|
@ -0,0 +1,232 @@
|
|||
"""LLM-as-judge scorer.
|
||||
|
||||
Sends the LLM output to a separate LLM with a configurable judge prompt,
|
||||
asks for a 1–10 rating, and normalizes to the 0.0–1.0 range.
|
||||
|
||||
**This scorer costs tokens** — every evaluation makes an LLM call. The
|
||||
judge's response is cached via PromptLooper's response cache layer to
|
||||
avoid redundant calls when re-scoring the same output.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from engine.scorers.base import BaseScorer
|
||||
|
||||
|
||||
# Default judge system prompt — can be overridden at construction time.
|
||||
DEFAULT_JUDGE_PROMPT = (
|
||||
"You are an impartial evaluator. You will receive an input and an LLM-generated "
|
||||
"output. Rate the quality of the output on a scale of 1 to 10, where 1 is terrible "
|
||||
"and 10 is perfect.\n\n"
|
||||
"Respond with ONLY a single integer between 1 and 10. Do not include any other text."
|
||||
)
|
||||
|
||||
# Regex to extract the first integer 1–10 from the judge response.
|
||||
_RATING_RE = re.compile(r"\b(10|[1-9])\b")
|
||||
|
||||
|
||||
class LLMJudgeScorer(BaseScorer):
|
||||
"""Score outputs by asking a separate LLM to rate them 1–10.
|
||||
|
||||
Args:
|
||||
base_url: Chat completions API base URL.
|
||||
model: Model to use for judging.
|
||||
api_key: Optional API key.
|
||||
judge_prompt: System prompt for the judge LLM.
|
||||
timeout: HTTP request timeout in seconds.
|
||||
max_retries: Retry attempts on transient failures.
|
||||
cache_layer: Optional ``ResponseCacheLayer`` instance. When provided,
|
||||
judge responses are cached to avoid duplicate LLM calls.
|
||||
db_session_factory: Callable returning a SQLAlchemy ``Session``.
|
||||
Required when *cache_layer* is supplied.
|
||||
"""
|
||||
|
||||
# Marker for the UI so it can warn users about token cost.
|
||||
COSTS_TOKENS = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = "http://localhost:11434/v1",
|
||||
model: str = "llama3",
|
||||
api_key: str | None = None,
|
||||
judge_prompt: str = DEFAULT_JUDGE_PROMPT,
|
||||
timeout: float = 120.0,
|
||||
max_retries: int = 3,
|
||||
cache_layer: Any = None,
|
||||
db_session_factory: Any = None,
|
||||
) -> None:
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.judge_prompt = judge_prompt
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
self._cache_layer = cache_layer
|
||||
self._db_session_factory = db_session_factory
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "llm_judge"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Synchronous entry point
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def score(self, input_data: Any, output: str, context: dict) -> float:
|
||||
"""Synchronous scoring — delegates to the async variant."""
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop and loop.is_running():
|
||||
raise RuntimeError(
|
||||
"LLMJudgeScorer.score() cannot be called from an async context. "
|
||||
"Use score_async() instead."
|
||||
)
|
||||
|
||||
return asyncio.get_event_loop().run_until_complete(
|
||||
self.score_async(input_data, output, context)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Async entry point
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def score_async(
|
||||
self, input_data: Any, output: str, context: dict
|
||||
) -> float:
|
||||
"""Ask the judge LLM to rate the output and return a normalised score."""
|
||||
user_message = self._build_user_message(input_data, output, context)
|
||||
|
||||
# Check cache first.
|
||||
config_hash: str | None = None
|
||||
if self._cache_layer and self._db_session_factory:
|
||||
from engine.cache import compute_config_hash
|
||||
|
||||
config_hash = compute_config_hash(
|
||||
prompt=self.judge_prompt,
|
||||
model=self.model,
|
||||
params={"scorer": "llm_judge"},
|
||||
input_data=user_message,
|
||||
)
|
||||
db = self._db_session_factory()
|
||||
try:
|
||||
cached = self._cache_layer.get(db, config_hash)
|
||||
if cached is not None:
|
||||
return _parse_rating(cached.response)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Call the judge LLM.
|
||||
judge_response = await self._call_judge(user_message)
|
||||
|
||||
# Cache the judge response.
|
||||
if self._cache_layer and self._db_session_factory and config_hash:
|
||||
db = self._db_session_factory()
|
||||
try:
|
||||
self._cache_layer.put(
|
||||
db,
|
||||
config_hash=config_hash,
|
||||
response=judge_response,
|
||||
model=self.model,
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return _parse_rating(judge_response)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _build_user_message(
|
||||
self, input_data: Any, output: str, context: dict
|
||||
) -> str:
|
||||
"""Build the user message sent to the judge LLM."""
|
||||
parts = []
|
||||
if input_data is not None:
|
||||
parts.append(f"## Input\n{input_data}")
|
||||
parts.append(f"## Output\n{output}")
|
||||
|
||||
# Include reference answer if available — helps the judge compare.
|
||||
reference = context.get("reference")
|
||||
if reference:
|
||||
parts.append(f"## Reference Answer\n{reference}")
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
async def _call_judge(self, user_message: str) -> str:
|
||||
"""Send a chat completion request to the judge LLM with retries."""
|
||||
url = f"{self.base_url}/chat/completions"
|
||||
|
||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
body = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{"role": "system", "content": self.judge_prompt},
|
||||
{"role": "user", "content": user_message},
|
||||
],
|
||||
"temperature": 0.0,
|
||||
"max_tokens": 16,
|
||||
}
|
||||
|
||||
last_exc: Exception | None = None
|
||||
retryable = {429, 500, 502, 503, 504}
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(self.timeout), headers=headers
|
||||
) as client:
|
||||
resp = await client.post(url, json=body)
|
||||
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
choices = data.get("choices", [])
|
||||
if choices:
|
||||
return choices[0].get("message", {}).get("content", "").strip()
|
||||
return ""
|
||||
|
||||
if resp.status_code not in retryable:
|
||||
resp.raise_for_status()
|
||||
|
||||
last_exc = httpx.HTTPStatusError(
|
||||
f"HTTP {resp.status_code}",
|
||||
request=resp.request,
|
||||
response=resp,
|
||||
)
|
||||
except httpx.HTTPStatusError:
|
||||
raise
|
||||
except httpx.HTTPError as exc:
|
||||
last_exc = exc
|
||||
|
||||
if attempt < self.max_retries - 1:
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(2**attempt)
|
||||
|
||||
raise RuntimeError(
|
||||
f"All {self.max_retries} attempts failed for judge LLM at {url}"
|
||||
) from last_exc
|
||||
|
||||
|
||||
def _parse_rating(text: str) -> float:
|
||||
"""Extract a 1–10 rating from the judge response and normalise to 0.0–1.0.
|
||||
|
||||
Falls back to 0.0 if no valid rating is found.
|
||||
"""
|
||||
match = _RATING_RE.search(text)
|
||||
if match is None:
|
||||
return 0.0
|
||||
rating = int(match.group(1))
|
||||
# Normalise: 1 → ~0.0, 10 → 1.0
|
||||
return max(0.0, min(1.0, (rating - 1) / 9.0))
|
||||
379
backend/tests/test_scorer_llm_judge.py
Normal file
379
backend/tests/test_scorer_llm_judge.py
Normal file
|
|
@ -0,0 +1,379 @@
|
|||
"""Tests for the LLMJudgeScorer."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from engine.scorers.base import BaseScorer
|
||||
from engine.scorers.llm_judge import (
|
||||
DEFAULT_JUDGE_PROMPT,
|
||||
LLMJudgeScorer,
|
||||
_parse_rating,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _chat_response(content: str) -> dict:
|
||||
"""Build a fake OpenAI-compatible chat completion response."""
|
||||
return {
|
||||
"id": "chatcmpl-test",
|
||||
"object": "chat.completion",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": content},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 50, "completion_tokens": 2},
|
||||
}
|
||||
|
||||
|
||||
def _mock_client(response_data: dict, status_code: int = 200):
|
||||
"""Create a mocked httpx.AsyncClient that returns *response_data*."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.json.return_value = response_data
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
return mock_client
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_rating unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParseRating:
|
||||
def test_single_digit(self):
|
||||
assert _parse_rating("7") == pytest.approx((7 - 1) / 9.0)
|
||||
|
||||
def test_rating_10(self):
|
||||
assert _parse_rating("10") == pytest.approx(1.0)
|
||||
|
||||
def test_rating_1(self):
|
||||
assert _parse_rating("1") == pytest.approx(0.0)
|
||||
|
||||
def test_rating_5(self):
|
||||
assert _parse_rating("5") == pytest.approx((5 - 1) / 9.0)
|
||||
|
||||
def test_embedded_in_text(self):
|
||||
assert _parse_rating("I would rate this a 8 out of 10") == pytest.approx(
|
||||
(8 - 1) / 9.0
|
||||
)
|
||||
|
||||
def test_no_rating(self):
|
||||
assert _parse_rating("This is great!") == 0.0
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _parse_rating("") == 0.0
|
||||
|
||||
def test_zero_not_matched(self):
|
||||
# 0 is outside the 1–10 range, should not match.
|
||||
assert _parse_rating("0") == 0.0
|
||||
|
||||
def test_number_above_10_extracts_partial(self):
|
||||
# "11" should match "1" (the first valid 1–10 integer).
|
||||
result = _parse_rating("11")
|
||||
assert result == pytest.approx(0.0) # matches "1"
|
||||
|
||||
def test_rating_with_newlines(self):
|
||||
assert _parse_rating("\n\n8\n") == pytest.approx((8 - 1) / 9.0)
|
||||
|
||||
def test_rating_with_period(self):
|
||||
assert _parse_rating("Rating: 9.") == pytest.approx((9 - 1) / 9.0)
|
||||
|
||||
def test_ten_in_sentence(self):
|
||||
assert _parse_rating("Score: 10 points") == pytest.approx(1.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLMJudgeScorer class tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLLMJudgeScorerProperties:
|
||||
def test_is_base_scorer_subclass(self):
|
||||
scorer = LLMJudgeScorer()
|
||||
assert isinstance(scorer, BaseScorer)
|
||||
|
||||
def test_name(self):
|
||||
assert LLMJudgeScorer().name == "llm_judge"
|
||||
|
||||
def test_costs_tokens_marker(self):
|
||||
assert LLMJudgeScorer.COSTS_TOKENS is True
|
||||
|
||||
def test_custom_judge_prompt(self):
|
||||
scorer = LLMJudgeScorer(judge_prompt="Custom prompt")
|
||||
assert scorer.judge_prompt == "Custom prompt"
|
||||
|
||||
def test_default_judge_prompt(self):
|
||||
scorer = LLMJudgeScorer()
|
||||
assert scorer.judge_prompt == DEFAULT_JUDGE_PROMPT
|
||||
|
||||
|
||||
class TestBuildUserMessage:
|
||||
def test_basic_message(self):
|
||||
scorer = LLMJudgeScorer()
|
||||
msg = scorer._build_user_message("What is 2+2?", "4", {})
|
||||
assert "## Input" in msg
|
||||
assert "What is 2+2?" in msg
|
||||
assert "## Output" in msg
|
||||
assert "4" in msg
|
||||
|
||||
def test_no_input(self):
|
||||
scorer = LLMJudgeScorer()
|
||||
msg = scorer._build_user_message(None, "Hello", {})
|
||||
assert "## Input" not in msg
|
||||
assert "## Output" in msg
|
||||
|
||||
def test_with_reference(self):
|
||||
scorer = LLMJudgeScorer()
|
||||
msg = scorer._build_user_message("q", "a", {"reference": "correct answer"})
|
||||
assert "## Reference Answer" in msg
|
||||
assert "correct answer" in msg
|
||||
|
||||
def test_without_reference(self):
|
||||
scorer = LLMJudgeScorer()
|
||||
msg = scorer._build_user_message("q", "a", {})
|
||||
assert "## Reference Answer" not in msg
|
||||
|
||||
|
||||
class TestScoreAsync:
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_score(self):
|
||||
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||||
mock_client = _mock_client(_chat_response("8"))
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
score = await scorer.score_async("input", "output", {})
|
||||
|
||||
assert score == pytest.approx((8 - 1) / 9.0)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perfect_score(self):
|
||||
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||||
mock_client = _mock_client(_chat_response("10"))
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
score = await scorer.score_async("input", "output", {})
|
||||
|
||||
assert score == pytest.approx(1.0)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lowest_score(self):
|
||||
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||||
mock_client = _mock_client(_chat_response("1"))
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
score = await scorer.score_async("input", "output", {})
|
||||
|
||||
assert score == pytest.approx(0.0)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unparseable_response(self):
|
||||
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||||
mock_client = _mock_client(_chat_response("I cannot rate this"))
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
score = await scorer.score_async("input", "output", {})
|
||||
|
||||
assert score == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_choices(self):
|
||||
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||||
response = {"choices": []}
|
||||
mock_client = _mock_client(response)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
score = await scorer.score_async("input", "output", {})
|
||||
|
||||
assert score == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_includes_correct_body(self):
|
||||
scorer = LLMJudgeScorer(
|
||||
base_url="http://fake:11434/v1",
|
||||
model="judge-model",
|
||||
judge_prompt="Rate it",
|
||||
)
|
||||
mock_client = _mock_client(_chat_response("7"))
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
await scorer.score_async("test input", "test output", {})
|
||||
|
||||
call_kwargs = mock_client.post.call_args
|
||||
body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
|
||||
assert body["model"] == "judge-model"
|
||||
assert body["messages"][0]["role"] == "system"
|
||||
assert body["messages"][0]["content"] == "Rate it"
|
||||
assert body["messages"][1]["role"] == "user"
|
||||
assert body["temperature"] == 0.0
|
||||
assert body["max_tokens"] == 16
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_in_headers(self):
|
||||
scorer = LLMJudgeScorer(
|
||||
base_url="http://fake:11434/v1",
|
||||
api_key="sk-test-key",
|
||||
)
|
||||
mock_client = _mock_client(_chat_response("5"))
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client) as mock_cls:
|
||||
await scorer.score_async("input", "output", {})
|
||||
|
||||
call_kwargs = mock_cls.call_args
|
||||
headers = call_kwargs.kwargs.get("headers") or call_kwargs[1].get("headers")
|
||||
assert headers["Authorization"] == "Bearer sk-test-key"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reference_included_in_message(self):
|
||||
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||||
mock_client = _mock_client(_chat_response("9"))
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
await scorer.score_async(
|
||||
"q", "a", {"reference": "expected answer"}
|
||||
)
|
||||
|
||||
call_kwargs = mock_client.post.call_args
|
||||
body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
|
||||
user_msg = body["messages"][1]["content"]
|
||||
assert "expected answer" in user_msg
|
||||
|
||||
|
||||
class TestScoreAsyncWithCache:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_hit_skips_llm_call(self):
|
||||
mock_cache = MagicMock()
|
||||
cached_response = MagicMock()
|
||||
cached_response.response = "8"
|
||||
mock_cache.get.return_value = cached_response
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_factory = MagicMock(return_value=mock_session)
|
||||
|
||||
scorer = LLMJudgeScorer(
|
||||
base_url="http://fake:11434/v1",
|
||||
cache_layer=mock_cache,
|
||||
db_session_factory=mock_session_factory,
|
||||
)
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_http:
|
||||
score = await scorer.score_async("input", "output", {})
|
||||
|
||||
assert score == pytest.approx((8 - 1) / 9.0)
|
||||
mock_http.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_miss_calls_llm_and_stores(self):
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.get.return_value = None
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_factory = MagicMock(return_value=mock_session)
|
||||
|
||||
scorer = LLMJudgeScorer(
|
||||
base_url="http://fake:11434/v1",
|
||||
cache_layer=mock_cache,
|
||||
db_session_factory=mock_session_factory,
|
||||
)
|
||||
mock_client = _mock_client(_chat_response("6"))
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
score = await scorer.score_async("input", "output", {})
|
||||
|
||||
assert score == pytest.approx((6 - 1) / 9.0)
|
||||
mock_cache.put.assert_called_once()
|
||||
put_kwargs = mock_cache.put.call_args
|
||||
assert put_kwargs.kwargs.get("response") == "6" or (
|
||||
len(put_kwargs.args) > 2 and put_kwargs.args[2] == "6"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cache_when_layer_not_configured(self):
|
||||
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||||
mock_client = _mock_client(_chat_response("7"))
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
score = await scorer.score_async("input", "output", {})
|
||||
|
||||
assert score == pytest.approx((7 - 1) / 9.0)
|
||||
|
||||
|
||||
class TestScoreSync:
|
||||
def test_sync_score_works_outside_async(self):
|
||||
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||||
mock_client = _mock_client(_chat_response("5"))
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
score = scorer.score("input", "output", {})
|
||||
|
||||
assert score == pytest.approx((5 - 1) / 9.0)
|
||||
|
||||
def test_sync_score_raises_in_async_context(self):
|
||||
scorer = LLMJudgeScorer(base_url="http://fake:11434/v1")
|
||||
|
||||
async def _run():
|
||||
scorer.score("input", "output", {})
|
||||
|
||||
with pytest.raises(RuntimeError, match="cannot be called from an async"):
|
||||
asyncio.get_event_loop().run_until_complete(_run())
|
||||
|
||||
|
||||
class TestRetries:
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_on_server_error(self):
|
||||
scorer = LLMJudgeScorer(
|
||||
base_url="http://fake:11434/v1", max_retries=3
|
||||
)
|
||||
|
||||
# First two calls fail with 500, third succeeds.
|
||||
fail_resp = MagicMock()
|
||||
fail_resp.status_code = 500
|
||||
fail_resp.json.return_value = {}
|
||||
|
||||
ok_resp = MagicMock()
|
||||
ok_resp.status_code = 200
|
||||
ok_resp.json.return_value = _chat_response("7")
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.side_effect = [fail_resp, fail_resp, ok_resp]
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
score = await scorer.score_async("input", "output", {})
|
||||
|
||||
assert score == pytest.approx((7 - 1) / 9.0)
|
||||
assert mock_client.post.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_retries_exhausted(self):
|
||||
scorer = LLMJudgeScorer(
|
||||
base_url="http://fake:11434/v1", max_retries=2
|
||||
)
|
||||
|
||||
fail_resp = MagicMock()
|
||||
fail_resp.status_code = 500
|
||||
fail_resp.json.return_value = {}
|
||||
fail_resp.request = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = fail_resp
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
with pytest.raises(RuntimeError, match="All 2 attempts failed"):
|
||||
await scorer.score_async("input", "output", {})
|
||||
Loading…
Add table
Reference in a new issue