Completed slices: - S01: Desire Embedding & Clustering - S02: Fulfillment Flow & Frontend Branch: milestone/M001
137 lines
4.7 KiB
Python
137 lines
4.7 KiB
Python
"""Unit tests for the text embedding service.
|
|
|
|
Validates that TF-IDF + TruncatedSVD produces 512-dim L2-normalized vectors
|
|
with meaningful cosine similarity for shader/visual-art domain text.
|
|
"""
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from app.services.embedding import EmbeddingService, embed_text
|
|
|
|
|
|
def cosine_sim(a: list[float], b: list[float]) -> float:
|
|
"""Compute cosine similarity between two vectors.
|
|
|
|
Since our vectors are already L2-normalized, this is just the dot product.
|
|
"""
|
|
return float(np.dot(a, b))
|
|
|
|
|
|
class TestEmbedDimension:
|
|
"""Verify output vector dimensions."""
|
|
|
|
def test_embed_produces_512_dim_vector(self) -> None:
|
|
result = embed_text("particle system fluid simulation")
|
|
assert len(result) == 512, f"Expected 512 dims, got {len(result)}"
|
|
|
|
def test_embed_returns_list_of_floats(self) -> None:
|
|
result = embed_text("fractal noise pattern")
|
|
assert isinstance(result, list)
|
|
assert all(isinstance(x, float) for x in result)
|
|
|
|
|
|
class TestNormalization:
|
|
"""Verify L2 normalization of output vectors."""
|
|
|
|
def test_embed_vectors_are_normalized(self) -> None:
|
|
result = embed_text("raymarching distance field shapes")
|
|
norm = np.linalg.norm(result)
|
|
assert abs(norm - 1.0) < 1e-6, f"Expected norm ≈ 1.0, got {norm}"
|
|
|
|
def test_various_inputs_all_normalized(self) -> None:
|
|
texts = [
|
|
"short",
|
|
"a much longer description of a complex visual effect with many words",
|
|
"ragdoll physics dark moody atmosphere simulation",
|
|
]
|
|
for text in texts:
|
|
result = embed_text(text)
|
|
norm = np.linalg.norm(result)
|
|
assert abs(norm - 1.0) < 1e-6, (
|
|
f"Norm for '{text}' = {norm}, expected ≈ 1.0"
|
|
)
|
|
|
|
|
|
class TestSimilarity:
|
|
"""Verify semantic similarity properties of the embeddings."""
|
|
|
|
def test_similar_texts_have_high_cosine_similarity(self) -> None:
|
|
a = embed_text("ragdoll physics dark and slow")
|
|
b = embed_text("dark physics simulation ragdoll")
|
|
sim = cosine_sim(a, b)
|
|
assert sim > 0.8, (
|
|
f"Similar texts should have >0.8 cosine sim, got {sim:.4f}"
|
|
)
|
|
|
|
def test_dissimilar_texts_have_low_cosine_similarity(self) -> None:
|
|
a = embed_text("ragdoll physics dark")
|
|
b = embed_text("bright colorful kaleidoscope flowers")
|
|
sim = cosine_sim(a, b)
|
|
assert sim < 0.5, (
|
|
f"Dissimilar texts should have <0.5 cosine sim, got {sim:.4f}"
|
|
)
|
|
|
|
def test_identical_texts_have_perfect_similarity(self) -> None:
|
|
text = "procedural noise fractal generation"
|
|
a = embed_text(text)
|
|
b = embed_text(text)
|
|
sim = cosine_sim(a, b)
|
|
assert sim > 0.999, (
|
|
f"Identical texts should have ~1.0 cosine sim, got {sim:.4f}"
|
|
)
|
|
|
|
|
|
class TestBatch:
|
|
"""Verify batch embedding matches individual embeddings."""
|
|
|
|
def test_embed_batch_matches_individual(self) -> None:
|
|
texts = [
|
|
"particle system fluid",
|
|
"ragdoll physics dark moody",
|
|
"kaleidoscope symmetry rotation",
|
|
]
|
|
|
|
# Fresh service to ensure deterministic results
|
|
service = EmbeddingService()
|
|
individual = [service.embed_text(t) for t in texts]
|
|
|
|
# Reset and do batch
|
|
service2 = EmbeddingService()
|
|
batched = service2.embed_batch(texts)
|
|
|
|
assert len(batched) == len(individual)
|
|
for i, (ind, bat) in enumerate(zip(individual, batched)):
|
|
sim = cosine_sim(ind, bat)
|
|
assert sim > 0.999, (
|
|
f"Batch result {i} doesn't match individual: sim={sim:.6f}"
|
|
)
|
|
|
|
def test_batch_dimensions(self) -> None:
|
|
texts = ["fire smoke volumetric", "crystal refraction light"]
|
|
results = EmbeddingService().embed_batch(texts)
|
|
assert len(results) == 2
|
|
for vec in results:
|
|
assert len(vec) == 512
|
|
|
|
|
|
class TestErrorHandling:
|
|
"""Verify clear error messages on invalid input."""
|
|
|
|
def test_empty_string_raises_valueerror(self) -> None:
|
|
with pytest.raises(ValueError, match="empty or whitespace"):
|
|
embed_text("")
|
|
|
|
def test_whitespace_only_raises_valueerror(self) -> None:
|
|
with pytest.raises(ValueError, match="empty or whitespace"):
|
|
embed_text(" \n\t ")
|
|
|
|
def test_batch_with_empty_string_raises_valueerror(self) -> None:
|
|
service = EmbeddingService()
|
|
with pytest.raises(ValueError, match="empty or whitespace"):
|
|
service.embed_batch(["valid text", ""])
|
|
|
|
def test_batch_with_whitespace_raises_valueerror(self) -> None:
|
|
service = EmbeddingService()
|
|
with pytest.raises(ValueError, match="empty or whitespace"):
|
|
service.embed_batch([" ", "valid text"])
|