232 lines
7.9 KiB
Python
232 lines
7.9 KiB
Python
"""LLM-as-judge scorer.
|
||
|
||
Sends the LLM output to a separate LLM with a configurable judge prompt,
|
||
asks for a 1–10 rating, and normalizes to the 0.0–1.0 range.
|
||
|
||
**This scorer costs tokens** — every evaluation makes an LLM call. The
|
||
judge's response is cached via PromptLooper's response cache layer to
|
||
avoid redundant calls when re-scoring the same output.
|
||
"""
|
||
|
||
import re
|
||
from typing import Any
|
||
|
||
import httpx
|
||
|
||
from engine.scorers.base import BaseScorer
|
||
|
||
|
||
# Default judge system prompt — can be overridden at construction time.
|
||
DEFAULT_JUDGE_PROMPT = (
|
||
"You are an impartial evaluator. You will receive an input and an LLM-generated "
|
||
"output. Rate the quality of the output on a scale of 1 to 10, where 1 is terrible "
|
||
"and 10 is perfect.\n\n"
|
||
"Respond with ONLY a single integer between 1 and 10. Do not include any other text."
|
||
)
|
||
|
||
# Regex to extract the first integer 1–10 from the judge response.
|
||
_RATING_RE = re.compile(r"\b(10|[1-9])\b")
|
||
|
||
|
||
class LLMJudgeScorer(BaseScorer):
|
||
"""Score outputs by asking a separate LLM to rate them 1–10.
|
||
|
||
Args:
|
||
base_url: Chat completions API base URL.
|
||
model: Model to use for judging.
|
||
api_key: Optional API key.
|
||
judge_prompt: System prompt for the judge LLM.
|
||
timeout: HTTP request timeout in seconds.
|
||
max_retries: Retry attempts on transient failures.
|
||
cache_layer: Optional ``ResponseCacheLayer`` instance. When provided,
|
||
judge responses are cached to avoid duplicate LLM calls.
|
||
db_session_factory: Callable returning a SQLAlchemy ``Session``.
|
||
Required when *cache_layer* is supplied.
|
||
"""
|
||
|
||
# Marker for the UI so it can warn users about token cost.
|
||
COSTS_TOKENS = True
|
||
|
||
def __init__(
|
||
self,
|
||
base_url: str = "http://localhost:11434/v1",
|
||
model: str = "llama3",
|
||
api_key: str | None = None,
|
||
judge_prompt: str = DEFAULT_JUDGE_PROMPT,
|
||
timeout: float = 120.0,
|
||
max_retries: int = 3,
|
||
cache_layer: Any = None,
|
||
db_session_factory: Any = None,
|
||
) -> None:
|
||
self.base_url = base_url.rstrip("/")
|
||
self.model = model
|
||
self.api_key = api_key
|
||
self.judge_prompt = judge_prompt
|
||
self.timeout = timeout
|
||
self.max_retries = max_retries
|
||
self._cache_layer = cache_layer
|
||
self._db_session_factory = db_session_factory
|
||
|
||
@property
|
||
def name(self) -> str:
|
||
return "llm_judge"
|
||
|
||
# ------------------------------------------------------------------
|
||
# Synchronous entry point
|
||
# ------------------------------------------------------------------
|
||
|
||
def score(self, input_data: Any, output: str, context: dict) -> float:
|
||
"""Synchronous scoring — delegates to the async variant."""
|
||
import asyncio
|
||
|
||
try:
|
||
loop = asyncio.get_running_loop()
|
||
except RuntimeError:
|
||
loop = None
|
||
|
||
if loop and loop.is_running():
|
||
raise RuntimeError(
|
||
"LLMJudgeScorer.score() cannot be called from an async context. "
|
||
"Use score_async() instead."
|
||
)
|
||
|
||
return asyncio.get_event_loop().run_until_complete(
|
||
self.score_async(input_data, output, context)
|
||
)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Async entry point
|
||
# ------------------------------------------------------------------
|
||
|
||
async def score_async(
|
||
self, input_data: Any, output: str, context: dict
|
||
) -> float:
|
||
"""Ask the judge LLM to rate the output and return a normalised score."""
|
||
user_message = self._build_user_message(input_data, output, context)
|
||
|
||
# Check cache first.
|
||
config_hash: str | None = None
|
||
if self._cache_layer and self._db_session_factory:
|
||
from engine.cache import compute_config_hash
|
||
|
||
config_hash = compute_config_hash(
|
||
prompt=self.judge_prompt,
|
||
model=self.model,
|
||
params={"scorer": "llm_judge"},
|
||
input_data=user_message,
|
||
)
|
||
db = self._db_session_factory()
|
||
try:
|
||
cached = self._cache_layer.get(db, config_hash)
|
||
if cached is not None:
|
||
return _parse_rating(cached.response)
|
||
finally:
|
||
db.close()
|
||
|
||
# Call the judge LLM.
|
||
judge_response = await self._call_judge(user_message)
|
||
|
||
# Cache the judge response.
|
||
if self._cache_layer and self._db_session_factory and config_hash:
|
||
db = self._db_session_factory()
|
||
try:
|
||
self._cache_layer.put(
|
||
db,
|
||
config_hash=config_hash,
|
||
response=judge_response,
|
||
model=self.model,
|
||
)
|
||
finally:
|
||
db.close()
|
||
|
||
return _parse_rating(judge_response)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Internal helpers
|
||
# ------------------------------------------------------------------
|
||
|
||
def _build_user_message(
|
||
self, input_data: Any, output: str, context: dict
|
||
) -> str:
|
||
"""Build the user message sent to the judge LLM."""
|
||
parts = []
|
||
if input_data is not None:
|
||
parts.append(f"## Input\n{input_data}")
|
||
parts.append(f"## Output\n{output}")
|
||
|
||
# Include reference answer if available — helps the judge compare.
|
||
reference = context.get("reference")
|
||
if reference:
|
||
parts.append(f"## Reference Answer\n{reference}")
|
||
|
||
return "\n\n".join(parts)
|
||
|
||
async def _call_judge(self, user_message: str) -> str:
|
||
"""Send a chat completion request to the judge LLM with retries."""
|
||
url = f"{self.base_url}/chat/completions"
|
||
|
||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||
if self.api_key:
|
||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||
|
||
body = {
|
||
"model": self.model,
|
||
"messages": [
|
||
{"role": "system", "content": self.judge_prompt},
|
||
{"role": "user", "content": user_message},
|
||
],
|
||
"temperature": 0.0,
|
||
"max_tokens": 16,
|
||
}
|
||
|
||
last_exc: Exception | None = None
|
||
retryable = {429, 500, 502, 503, 504}
|
||
|
||
for attempt in range(self.max_retries):
|
||
try:
|
||
async with httpx.AsyncClient(
|
||
timeout=httpx.Timeout(self.timeout), headers=headers
|
||
) as client:
|
||
resp = await client.post(url, json=body)
|
||
|
||
if resp.status_code == 200:
|
||
data = resp.json()
|
||
choices = data.get("choices", [])
|
||
if choices:
|
||
return choices[0].get("message", {}).get("content", "").strip()
|
||
return ""
|
||
|
||
if resp.status_code not in retryable:
|
||
resp.raise_for_status()
|
||
|
||
last_exc = httpx.HTTPStatusError(
|
||
f"HTTP {resp.status_code}",
|
||
request=resp.request,
|
||
response=resp,
|
||
)
|
||
except httpx.HTTPStatusError:
|
||
raise
|
||
except httpx.HTTPError as exc:
|
||
last_exc = exc
|
||
|
||
if attempt < self.max_retries - 1:
|
||
import asyncio
|
||
|
||
await asyncio.sleep(2**attempt)
|
||
|
||
raise RuntimeError(
|
||
f"All {self.max_retries} attempts failed for judge LLM at {url}"
|
||
) from last_exc
|
||
|
||
|
||
def _parse_rating(text: str) -> float:
|
||
"""Extract a 1–10 rating from the judge response and normalise to 0.0–1.0.
|
||
|
||
Falls back to 0.0 if no valid rating is found.
|
||
"""
|
||
match = _RATING_RE.search(text)
|
||
if match is None:
|
||
return 0.0
|
||
rating = int(match.group(1))
|
||
# Normalise: 1 → ~0.0, 10 → 1.0
|
||
return max(0.0, min(1.0, (rating - 1) / 9.0))
|