promptlooper/backend/engine/scorers/llm_judge.py

232 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""LLM-as-judge scorer.
Sends the LLM output to a separate LLM with a configurable judge prompt,
asks for a 110 rating, and normalizes to the 0.01.0 range.
**This scorer costs tokens** — every evaluation makes an LLM call. The
judge's response is cached via PromptLooper's response cache layer to
avoid redundant calls when re-scoring the same output.
"""
import re
from typing import Any
import httpx
from engine.scorers.base import BaseScorer
# Default judge system prompt — can be overridden at construction time.
DEFAULT_JUDGE_PROMPT = (
"You are an impartial evaluator. You will receive an input and an LLM-generated "
"output. Rate the quality of the output on a scale of 1 to 10, where 1 is terrible "
"and 10 is perfect.\n\n"
"Respond with ONLY a single integer between 1 and 10. Do not include any other text."
)
# Regex to extract the first integer 110 from the judge response.
_RATING_RE = re.compile(r"\b(10|[1-9])\b")
class LLMJudgeScorer(BaseScorer):
"""Score outputs by asking a separate LLM to rate them 110.
Args:
base_url: Chat completions API base URL.
model: Model to use for judging.
api_key: Optional API key.
judge_prompt: System prompt for the judge LLM.
timeout: HTTP request timeout in seconds.
max_retries: Retry attempts on transient failures.
cache_layer: Optional ``ResponseCacheLayer`` instance. When provided,
judge responses are cached to avoid duplicate LLM calls.
db_session_factory: Callable returning a SQLAlchemy ``Session``.
Required when *cache_layer* is supplied.
"""
# Marker for the UI so it can warn users about token cost.
COSTS_TOKENS = True
def __init__(
self,
base_url: str = "http://localhost:11434/v1",
model: str = "llama3",
api_key: str | None = None,
judge_prompt: str = DEFAULT_JUDGE_PROMPT,
timeout: float = 120.0,
max_retries: int = 3,
cache_layer: Any = None,
db_session_factory: Any = None,
) -> None:
self.base_url = base_url.rstrip("/")
self.model = model
self.api_key = api_key
self.judge_prompt = judge_prompt
self.timeout = timeout
self.max_retries = max_retries
self._cache_layer = cache_layer
self._db_session_factory = db_session_factory
@property
def name(self) -> str:
return "llm_judge"
# ------------------------------------------------------------------
# Synchronous entry point
# ------------------------------------------------------------------
def score(self, input_data: Any, output: str, context: dict) -> float:
"""Synchronous scoring — delegates to the async variant."""
import asyncio
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
raise RuntimeError(
"LLMJudgeScorer.score() cannot be called from an async context. "
"Use score_async() instead."
)
return asyncio.get_event_loop().run_until_complete(
self.score_async(input_data, output, context)
)
# ------------------------------------------------------------------
# Async entry point
# ------------------------------------------------------------------
async def score_async(
self, input_data: Any, output: str, context: dict
) -> float:
"""Ask the judge LLM to rate the output and return a normalised score."""
user_message = self._build_user_message(input_data, output, context)
# Check cache first.
config_hash: str | None = None
if self._cache_layer and self._db_session_factory:
from engine.cache import compute_config_hash
config_hash = compute_config_hash(
prompt=self.judge_prompt,
model=self.model,
params={"scorer": "llm_judge"},
input_data=user_message,
)
db = self._db_session_factory()
try:
cached = self._cache_layer.get(db, config_hash)
if cached is not None:
return _parse_rating(cached.response)
finally:
db.close()
# Call the judge LLM.
judge_response = await self._call_judge(user_message)
# Cache the judge response.
if self._cache_layer and self._db_session_factory and config_hash:
db = self._db_session_factory()
try:
self._cache_layer.put(
db,
config_hash=config_hash,
response=judge_response,
model=self.model,
)
finally:
db.close()
return _parse_rating(judge_response)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _build_user_message(
self, input_data: Any, output: str, context: dict
) -> str:
"""Build the user message sent to the judge LLM."""
parts = []
if input_data is not None:
parts.append(f"## Input\n{input_data}")
parts.append(f"## Output\n{output}")
# Include reference answer if available — helps the judge compare.
reference = context.get("reference")
if reference:
parts.append(f"## Reference Answer\n{reference}")
return "\n\n".join(parts)
async def _call_judge(self, user_message: str) -> str:
"""Send a chat completion request to the judge LLM with retries."""
url = f"{self.base_url}/chat/completions"
headers: dict[str, str] = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
body = {
"model": self.model,
"messages": [
{"role": "system", "content": self.judge_prompt},
{"role": "user", "content": user_message},
],
"temperature": 0.0,
"max_tokens": 16,
}
last_exc: Exception | None = None
retryable = {429, 500, 502, 503, 504}
for attempt in range(self.max_retries):
try:
async with httpx.AsyncClient(
timeout=httpx.Timeout(self.timeout), headers=headers
) as client:
resp = await client.post(url, json=body)
if resp.status_code == 200:
data = resp.json()
choices = data.get("choices", [])
if choices:
return choices[0].get("message", {}).get("content", "").strip()
return ""
if resp.status_code not in retryable:
resp.raise_for_status()
last_exc = httpx.HTTPStatusError(
f"HTTP {resp.status_code}",
request=resp.request,
response=resp,
)
except httpx.HTTPStatusError:
raise
except httpx.HTTPError as exc:
last_exc = exc
if attempt < self.max_retries - 1:
import asyncio
await asyncio.sleep(2**attempt)
raise RuntimeError(
f"All {self.max_retries} attempts failed for judge LLM at {url}"
) from last_exc
def _parse_rating(text: str) -> float:
"""Extract a 110 rating from the judge response and normalise to 0.01.0.
Falls back to 0.0 if no valid rating is found.
"""
match = _RATING_RE.search(text)
if match is None:
return 0.0
rating = int(match.group(1))
# Normalise: 1 → ~0.0, 10 → 1.0
return max(0.0, min(1.0, (rating - 1) / 9.0))