Completed slices: - S01: Desire Embedding & Clustering - S02: Fulfillment Flow & Frontend Branch: milestone/M001
366 lines
12 KiB
Python
366 lines
12 KiB
Python
"""Unit tests for the clustering service.
|
|
|
|
Tests use mocked async DB sessions to isolate clustering logic from
|
|
pgvector and database concerns. Synthetic 512-dim vectors verify the
|
|
service's orchestration, heat calculation, and threshold behavior.
|
|
"""
|
|
|
|
import uuid
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from app.models.models import DesireCluster
|
|
from app.services.clustering import (
|
|
add_to_cluster,
|
|
cluster_desire,
|
|
create_cluster,
|
|
find_nearest_cluster,
|
|
recalculate_heat,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _make_embedding(dim: int = 512) -> list[float]:
|
|
"""Create a synthetic embedding vector for testing."""
|
|
import numpy as np
|
|
rng = np.random.default_rng(42)
|
|
vec = rng.standard_normal(dim)
|
|
vec = vec / np.linalg.norm(vec)
|
|
return vec.tolist()
|
|
|
|
|
|
def _mock_result_row(**kwargs):
|
|
"""Create a mock DB result row with named attributes."""
|
|
row = MagicMock()
|
|
for key, value in kwargs.items():
|
|
setattr(row, key, value)
|
|
return row
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: cluster_desire orchestration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestClusterDesireOrchestration:
|
|
"""Test the main cluster_desire orchestrator with mocked sub-functions."""
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("app.services.clustering.find_nearest_cluster", new_callable=AsyncMock)
|
|
@patch("app.services.clustering.create_cluster", new_callable=AsyncMock)
|
|
async def test_new_desire_creates_own_cluster(
|
|
self, mock_create, mock_find
|
|
) -> None:
|
|
"""When no nearby cluster exists, create a new one."""
|
|
new_cluster_id = uuid.uuid4()
|
|
desire_id = uuid.uuid4()
|
|
embedding = _make_embedding()
|
|
|
|
mock_find.return_value = (None, 0.0)
|
|
mock_create.return_value = new_cluster_id
|
|
|
|
db = AsyncMock()
|
|
result = await cluster_desire(desire_id, embedding, db)
|
|
|
|
mock_find.assert_awaited_once_with(embedding, db)
|
|
mock_create.assert_awaited_once_with(desire_id, db)
|
|
assert result["is_new"] is True
|
|
assert result["cluster_id"] == new_cluster_id
|
|
assert result["heat_score"] == 1.0
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("app.services.clustering.find_nearest_cluster", new_callable=AsyncMock)
|
|
@patch("app.services.clustering.add_to_cluster", new_callable=AsyncMock)
|
|
@patch("app.services.clustering.recalculate_heat", new_callable=AsyncMock)
|
|
async def test_similar_desire_joins_existing_cluster(
|
|
self, mock_recalc, mock_add, mock_find
|
|
) -> None:
|
|
"""When a nearby cluster is found, join it and recalculate heat."""
|
|
existing_cluster_id = uuid.uuid4()
|
|
desire_id = uuid.uuid4()
|
|
embedding = _make_embedding()
|
|
similarity = 0.92
|
|
|
|
mock_find.return_value = (existing_cluster_id, similarity)
|
|
mock_recalc.return_value = 3.0
|
|
|
|
db = AsyncMock()
|
|
result = await cluster_desire(desire_id, embedding, db)
|
|
|
|
mock_find.assert_awaited_once_with(embedding, db)
|
|
mock_add.assert_awaited_once_with(
|
|
existing_cluster_id, desire_id, similarity, db
|
|
)
|
|
mock_recalc.assert_awaited_once_with(existing_cluster_id, db)
|
|
assert result["is_new"] is False
|
|
assert result["cluster_id"] == existing_cluster_id
|
|
assert result["heat_score"] == 3.0
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("app.services.clustering.find_nearest_cluster", new_callable=AsyncMock)
|
|
@patch("app.services.clustering.create_cluster", new_callable=AsyncMock)
|
|
async def test_cluster_desire_returns_observability_dict(
|
|
self, mock_create, mock_find
|
|
) -> None:
|
|
"""Returned dict always has cluster_id, is_new, heat_score."""
|
|
cluster_id = uuid.uuid4()
|
|
mock_find.return_value = (None, 0.0)
|
|
mock_create.return_value = cluster_id
|
|
|
|
db = AsyncMock()
|
|
result = await cluster_desire(uuid.uuid4(), _make_embedding(), db)
|
|
|
|
assert "cluster_id" in result
|
|
assert "is_new" in result
|
|
assert "heat_score" in result
|
|
assert isinstance(result["cluster_id"], uuid.UUID)
|
|
assert isinstance(result["is_new"], bool)
|
|
assert isinstance(result["heat_score"], float)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: recalculate_heat
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestRecalculateHeat:
|
|
"""Test heat score recalculation with mocked DB results."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_heat_scales_with_cluster_size(self) -> None:
|
|
"""Heat score should equal cluster size (linear scaling)."""
|
|
cluster_id = uuid.uuid4()
|
|
db = AsyncMock()
|
|
|
|
# First call: COUNT(*) returns 3
|
|
count_result = MagicMock()
|
|
count_result.scalar_one.return_value = 3
|
|
|
|
# Second call: UPDATE (no return value needed)
|
|
update_result = MagicMock()
|
|
|
|
db.execute = AsyncMock(side_effect=[count_result, update_result])
|
|
|
|
heat = await recalculate_heat(cluster_id, db)
|
|
|
|
assert heat == 3.0
|
|
assert db.execute.await_count == 2
|
|
assert db.flush.await_count >= 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_heat_for_single_member_cluster(self) -> None:
|
|
"""A single-member cluster should have heat_score = 1.0."""
|
|
cluster_id = uuid.uuid4()
|
|
db = AsyncMock()
|
|
|
|
count_result = MagicMock()
|
|
count_result.scalar_one.return_value = 1
|
|
update_result = MagicMock()
|
|
|
|
db.execute = AsyncMock(side_effect=[count_result, update_result])
|
|
|
|
heat = await recalculate_heat(cluster_id, db)
|
|
|
|
assert heat == 1.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_heat_for_large_cluster(self) -> None:
|
|
"""Heat scales to large cluster sizes."""
|
|
cluster_id = uuid.uuid4()
|
|
db = AsyncMock()
|
|
|
|
count_result = MagicMock()
|
|
count_result.scalar_one.return_value = 15
|
|
update_result = MagicMock()
|
|
|
|
db.execute = AsyncMock(side_effect=[count_result, update_result])
|
|
|
|
heat = await recalculate_heat(cluster_id, db)
|
|
|
|
assert heat == 15.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: find_nearest_cluster
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestFindNearestCluster:
|
|
"""Test pgvector distance query with mocked DB results."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_db_returns_none(self) -> None:
|
|
"""No desires with embeddings → no cluster match."""
|
|
db = AsyncMock()
|
|
|
|
# Query returns no rows
|
|
empty_result = MagicMock()
|
|
empty_result.first.return_value = None
|
|
db.execute = AsyncMock(return_value=empty_result)
|
|
|
|
cluster_id, similarity = await find_nearest_cluster(
|
|
_make_embedding(), db
|
|
)
|
|
|
|
assert cluster_id is None
|
|
assert similarity == 0.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_match_found_with_cluster(self) -> None:
|
|
"""A desire within threshold that has a cluster → returns cluster."""
|
|
desire_id = uuid.uuid4()
|
|
cluster_id = uuid.uuid4()
|
|
db = AsyncMock()
|
|
|
|
# First query: find nearest desire (distance = 0.08 → similarity = 0.92)
|
|
desire_row = _mock_result_row(desire_id=desire_id, distance=0.08)
|
|
desire_result = MagicMock()
|
|
desire_result.first.return_value = desire_row
|
|
|
|
# Second query: cluster lookup
|
|
cluster_row = _mock_result_row(cluster_id=cluster_id)
|
|
cluster_result = MagicMock()
|
|
cluster_result.first.return_value = cluster_row
|
|
|
|
db.execute = AsyncMock(side_effect=[desire_result, cluster_result])
|
|
|
|
found_id, sim = await find_nearest_cluster(_make_embedding(), db)
|
|
|
|
assert found_id == cluster_id
|
|
assert abs(sim - 0.92) < 1e-6
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_match_found_without_cluster(self) -> None:
|
|
"""A nearby desire that has no cluster entry → returns None."""
|
|
desire_id = uuid.uuid4()
|
|
db = AsyncMock()
|
|
|
|
# First query: find nearest desire
|
|
desire_row = _mock_result_row(desire_id=desire_id, distance=0.10)
|
|
desire_result = MagicMock()
|
|
desire_result.first.return_value = desire_row
|
|
|
|
# Second query: cluster lookup returns nothing
|
|
cluster_result = MagicMock()
|
|
cluster_result.first.return_value = None
|
|
|
|
db.execute = AsyncMock(side_effect=[desire_result, cluster_result])
|
|
|
|
found_id, sim = await find_nearest_cluster(_make_embedding(), db)
|
|
|
|
assert found_id is None
|
|
assert sim == 0.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_threshold_boundary_at_0_82(self) -> None:
|
|
"""Threshold of 0.82 means max distance of 0.18.
|
|
|
|
A desire at exactly distance=0.18 (similarity=0.82) should be
|
|
returned by the SQL query (distance <= 0.18).
|
|
"""
|
|
desire_id = uuid.uuid4()
|
|
cluster_id = uuid.uuid4()
|
|
db = AsyncMock()
|
|
|
|
# Exactly at boundary: distance = 0.18 → similarity = 0.82
|
|
desire_row = _mock_result_row(desire_id=desire_id, distance=0.18)
|
|
desire_result = MagicMock()
|
|
desire_result.first.return_value = desire_row
|
|
|
|
cluster_row = _mock_result_row(cluster_id=cluster_id)
|
|
cluster_result = MagicMock()
|
|
cluster_result.first.return_value = cluster_row
|
|
|
|
db.execute = AsyncMock(side_effect=[desire_result, cluster_result])
|
|
|
|
found_id, sim = await find_nearest_cluster(
|
|
_make_embedding(), db, threshold=0.82
|
|
)
|
|
|
|
assert found_id == cluster_id
|
|
assert abs(sim - 0.82) < 1e-6
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_below_threshold_returns_none(self) -> None:
|
|
"""A desire beyond the distance threshold is not returned by SQL.
|
|
|
|
With threshold=0.82 (max_distance=0.18), a desire at distance=0.19
|
|
(similarity=0.81) would be filtered out by the WHERE clause.
|
|
The mock simulates this by returning no rows.
|
|
"""
|
|
db = AsyncMock()
|
|
|
|
# SQL filters it out → no rows
|
|
empty_result = MagicMock()
|
|
empty_result.first.return_value = None
|
|
db.execute = AsyncMock(return_value=empty_result)
|
|
|
|
found_id, sim = await find_nearest_cluster(
|
|
_make_embedding(), db, threshold=0.82
|
|
)
|
|
|
|
assert found_id is None
|
|
assert sim == 0.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: create_cluster
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestCreateCluster:
|
|
"""Test cluster creation."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_cluster_returns_uuid(self) -> None:
|
|
"""New cluster gets a valid UUID."""
|
|
db = AsyncMock()
|
|
db.add = MagicMock() # Session.add() is synchronous
|
|
desire_id = uuid.uuid4()
|
|
|
|
cluster_id = await create_cluster(desire_id, db)
|
|
|
|
assert isinstance(cluster_id, uuid.UUID)
|
|
db.add.assert_called_once()
|
|
db.flush.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_cluster_adds_desire_cluster_row(self) -> None:
|
|
"""The DesireCluster row has similarity=1.0 (self-reference)."""
|
|
db = AsyncMock()
|
|
db.add = MagicMock() # Session.add() is synchronous
|
|
desire_id = uuid.uuid4()
|
|
|
|
cluster_id = await create_cluster(desire_id, db)
|
|
|
|
added_obj = db.add.call_args[0][0]
|
|
assert isinstance(added_obj, DesireCluster)
|
|
assert added_obj.cluster_id == cluster_id
|
|
assert added_obj.desire_id == desire_id
|
|
assert added_obj.similarity == 1.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: add_to_cluster
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestAddToCluster:
|
|
"""Test adding a desire to an existing cluster."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_to_cluster_executes_insert(self) -> None:
|
|
"""Insert is executed and flushed."""
|
|
db = AsyncMock()
|
|
cluster_id = uuid.uuid4()
|
|
desire_id = uuid.uuid4()
|
|
|
|
await add_to_cluster(cluster_id, desire_id, 0.91, db)
|
|
|
|
db.execute.assert_awaited_once()
|
|
db.flush.assert_awaited()
|
|
|
|
# Verify the parameters passed to execute
|
|
call_kwargs = db.execute.call_args[0][1]
|
|
assert call_kwargs["cluster_id"] == cluster_id
|
|
assert call_kwargs["desire_id"] == desire_id
|
|
assert call_kwargs["similarity"] == 0.91
|