229 lines
8.6 KiB
Python
229 lines
8.6 KiB
Python
"""Tests for the EmbeddingScorer."""
|
|
|
|
import asyncio
|
|
import math
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from engine.scorers.base import BaseScorer
|
|
from engine.scorers.embedding import EmbeddingScorer, _cosine_similarity
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _make_embedding_response(embeddings: list[list[float]]) -> dict:
|
|
"""Build a fake OpenAI-compatible /embeddings response."""
|
|
return {
|
|
"object": "list",
|
|
"data": [
|
|
{"object": "embedding", "index": i, "embedding": emb}
|
|
for i, emb in enumerate(embeddings)
|
|
],
|
|
"model": "nomic-embed-text",
|
|
"usage": {"prompt_tokens": 10, "total_tokens": 10},
|
|
}
|
|
|
|
|
|
def _mock_client(response_data: dict) -> tuple[MagicMock, AsyncMock]:
|
|
"""Create a mocked httpx.AsyncClient that returns *response_data*.
|
|
|
|
Returns (mock_client_cls, mock_client) so tests can inspect calls.
|
|
"""
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = response_data
|
|
mock_response.raise_for_status = MagicMock()
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.post.return_value = mock_response
|
|
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
|
mock_client.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
return mock_client, mock_response
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Unit tests for cosine similarity helper
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestCosineSimilarity:
|
|
def test_identical_vectors(self):
|
|
assert _cosine_similarity([1, 2, 3], [1, 2, 3]) == pytest.approx(1.0)
|
|
|
|
def test_opposite_vectors(self):
|
|
assert _cosine_similarity([1, 0], [-1, 0]) == pytest.approx(-1.0)
|
|
|
|
def test_orthogonal_vectors(self):
|
|
assert _cosine_similarity([1, 0], [0, 1]) == pytest.approx(0.0)
|
|
|
|
def test_zero_vector(self):
|
|
assert _cosine_similarity([0, 0], [1, 2]) == 0.0
|
|
|
|
def test_both_zero(self):
|
|
assert _cosine_similarity([0, 0], [0, 0]) == 0.0
|
|
|
|
def test_known_angle(self):
|
|
# 45 degrees → cos(45°) ≈ 0.7071
|
|
a = [1.0, 0.0]
|
|
b = [1.0, 1.0]
|
|
expected = 1.0 / math.sqrt(2)
|
|
assert _cosine_similarity(a, b) == pytest.approx(expected, abs=1e-6)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# EmbeddingScorer tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestEmbeddingScorerInterface:
|
|
def test_is_base_scorer(self):
|
|
scorer = EmbeddingScorer()
|
|
assert isinstance(scorer, BaseScorer)
|
|
|
|
def test_name(self):
|
|
scorer = EmbeddingScorer()
|
|
assert scorer.name == "embedding"
|
|
|
|
def test_custom_params(self):
|
|
scorer = EmbeddingScorer(
|
|
base_url="https://api.example.com/v1",
|
|
model="text-embedding-3-small",
|
|
api_key="sk-test",
|
|
timeout=60.0,
|
|
)
|
|
assert scorer.base_url == "https://api.example.com/v1"
|
|
assert scorer.model == "text-embedding-3-small"
|
|
assert scorer.api_key == "sk-test"
|
|
assert scorer.timeout == 60.0
|
|
|
|
def test_base_url_strips_trailing_slash(self):
|
|
scorer = EmbeddingScorer(base_url="http://localhost:11434/v1/")
|
|
assert scorer.base_url == "http://localhost:11434/v1"
|
|
|
|
|
|
class TestEmbeddingScorerAsync:
|
|
def test_no_reference_returns_zero(self):
|
|
scorer = EmbeddingScorer()
|
|
result = asyncio.get_event_loop().run_until_complete(
|
|
scorer.score_async("input", "output", {})
|
|
)
|
|
assert result == 0.0
|
|
|
|
def test_empty_reference_returns_zero(self):
|
|
scorer = EmbeddingScorer()
|
|
result = asyncio.get_event_loop().run_until_complete(
|
|
scorer.score_async("input", "output", {"reference": ""})
|
|
)
|
|
assert result == 0.0
|
|
|
|
@patch("engine.scorers.embedding.httpx.AsyncClient")
|
|
def test_identical_texts_score_high(self, mock_client_cls):
|
|
"""When output equals reference, embeddings should be identical → score ~1.0."""
|
|
emb = [0.1, 0.2, 0.3, 0.4]
|
|
client, _ = _mock_client(_make_embedding_response([emb, emb]))
|
|
mock_client_cls.return_value = client
|
|
|
|
scorer = EmbeddingScorer()
|
|
result = asyncio.get_event_loop().run_until_complete(
|
|
scorer.score_async("input", "hello world", {"reference": "hello world"})
|
|
)
|
|
assert result == pytest.approx(1.0)
|
|
|
|
@patch("engine.scorers.embedding.httpx.AsyncClient")
|
|
def test_opposite_embeddings_score_zero(self, mock_client_cls):
|
|
"""Opposite embeddings → cosine sim = -1 → normalized to 0.0."""
|
|
client, _ = _mock_client(_make_embedding_response([[1.0, 0.0], [-1.0, 0.0]]))
|
|
mock_client_cls.return_value = client
|
|
|
|
scorer = EmbeddingScorer()
|
|
result = asyncio.get_event_loop().run_until_complete(
|
|
scorer.score_async("input", "good", {"reference": "bad"})
|
|
)
|
|
assert result == pytest.approx(0.0)
|
|
|
|
@patch("engine.scorers.embedding.httpx.AsyncClient")
|
|
def test_orthogonal_embeddings_score_half(self, mock_client_cls):
|
|
"""Orthogonal embeddings → cosine sim = 0 → normalized to 0.5."""
|
|
client, _ = _mock_client(_make_embedding_response([[1.0, 0.0], [0.0, 1.0]]))
|
|
mock_client_cls.return_value = client
|
|
|
|
scorer = EmbeddingScorer()
|
|
result = asyncio.get_event_loop().run_until_complete(
|
|
scorer.score_async("input", "alpha", {"reference": "beta"})
|
|
)
|
|
assert result == pytest.approx(0.5)
|
|
|
|
@patch("engine.scorers.embedding.httpx.AsyncClient")
|
|
def test_api_key_sent_in_headers(self, mock_client_cls):
|
|
"""When api_key is set, Authorization header must be sent."""
|
|
emb = [0.5, 0.5]
|
|
client, _ = _mock_client(_make_embedding_response([emb, emb]))
|
|
mock_client_cls.return_value = client
|
|
|
|
scorer = EmbeddingScorer(api_key="sk-test-key")
|
|
asyncio.get_event_loop().run_until_complete(
|
|
scorer.score_async("input", "text", {"reference": "ref"})
|
|
)
|
|
|
|
call_kwargs = mock_client_cls.call_args
|
|
headers = call_kwargs.kwargs.get("headers", {})
|
|
assert headers.get("Authorization") == "Bearer sk-test-key"
|
|
|
|
@patch("engine.scorers.embedding.httpx.AsyncClient")
|
|
def test_correct_url_and_body(self, mock_client_cls):
|
|
"""Verify the embedding API is called with correct URL and body."""
|
|
emb = [0.1]
|
|
client, _ = _mock_client(_make_embedding_response([emb, emb]))
|
|
mock_client_cls.return_value = client
|
|
|
|
scorer = EmbeddingScorer(
|
|
base_url="http://myhost:1234/v1",
|
|
model="my-model",
|
|
)
|
|
asyncio.get_event_loop().run_until_complete(
|
|
scorer.score_async("input", "output text", {"reference": "ref text"})
|
|
)
|
|
|
|
client.post.assert_called_once()
|
|
call_args = client.post.call_args
|
|
assert call_args.args[0] == "http://myhost:1234/v1/embeddings"
|
|
body = call_args.kwargs.get("json", {})
|
|
assert body["model"] == "my-model"
|
|
assert body["input"] == ["output text", "ref text"]
|
|
|
|
@patch("engine.scorers.embedding.httpx.AsyncClient")
|
|
def test_out_of_order_index_handling(self, mock_client_cls):
|
|
"""API may return embeddings out of order; scorer sorts by index."""
|
|
response_data = {
|
|
"object": "list",
|
|
"data": [
|
|
{"object": "embedding", "index": 1, "embedding": [0.0, 1.0]},
|
|
{"object": "embedding", "index": 0, "embedding": [1.0, 0.0]},
|
|
],
|
|
"model": "nomic-embed-text",
|
|
}
|
|
client, _ = _mock_client(response_data)
|
|
mock_client_cls.return_value = client
|
|
|
|
scorer = EmbeddingScorer()
|
|
result = asyncio.get_event_loop().run_until_complete(
|
|
scorer.score_async("input", "out", {"reference": "ref"})
|
|
)
|
|
# Orthogonal after sorting → 0.5
|
|
assert result == pytest.approx(0.5)
|
|
|
|
@patch("engine.scorers.embedding.httpx.AsyncClient")
|
|
def test_score_clamped_to_unit_range(self, mock_client_cls):
|
|
"""Result is always clamped to [0.0, 1.0]."""
|
|
emb = [1.0, 1.0, 1.0]
|
|
client, _ = _mock_client(_make_embedding_response([emb, emb]))
|
|
mock_client_cls.return_value = client
|
|
|
|
scorer = EmbeddingScorer()
|
|
result = asyncio.get_event_loop().run_until_complete(
|
|
scorer.score_async("input", "text", {"reference": "text"})
|
|
)
|
|
assert 0.0 <= result <= 1.0
|