test: Added 6 integration tests covering creator-scoped cascade tiers (…

- "backend/tests/test_search.py"

GSD-Task: S02/T02
This commit is contained in:
jlightner 2026-04-04 05:07:24 +00:00
parent 2568dc3812
commit 9530c85b9c

View file

@ -966,3 +966,284 @@ async def test_search_lightrag_fallback_on_http_error(db_engine):
assert result["fallback_used"] is True assert result["fallback_used"] is True
assert result["total"] >= 1 assert result["total"] >= 1
# ── Creator-scoped cascade integration tests ─────────────────────────────────
async def _seed_cascade_data(db_engine) -> dict:
"""Seed creators and technique pages for cascade tier testing.
Creator 'keota' has 3 Sound Design pages (2 domain='Sound Design').
Creator 'virtual-riot' has 1 Synthesis page (< 2 no dominant domain).
"""
session_factory = async_sessionmaker(
db_engine, class_=AsyncSession, expire_on_commit=False
)
async with session_factory() as session:
keota = Creator(
name="Keota",
slug="keota",
genres=["Bass music"],
folder_name="Keota",
)
vr = Creator(
name="Virtual Riot",
slug="virtual-riot",
genres=["Dubstep"],
folder_name="VirtualRiot",
)
session.add_all([keota, vr])
await session.flush()
tp1 = TechniquePage(
creator_id=keota.id,
title="Reese Bass Fundamentals",
slug="reese-bass-fundamentals",
topic_category="Sound Design",
topic_tags=["bass", "reese"],
summary="Fundamentals of reese bass",
)
tp2 = TechniquePage(
creator_id=keota.id,
title="FM Sound Design",
slug="fm-sound-design",
topic_category="Sound Design",
topic_tags=["fm", "design"],
summary="FM sound design techniques",
)
tp3 = TechniquePage(
creator_id=keota.id,
title="Granular Textures",
slug="granular-textures",
topic_category="Sound Design",
topic_tags=["granular"],
summary="Granular texture design",
)
tp4 = TechniquePage(
creator_id=vr.id,
title="Serum Wavetable Tricks",
slug="serum-wavetable-tricks",
topic_category="Synthesis",
topic_tags=["serum", "wavetable"],
summary="Advanced Serum wavetable tricks",
)
session.add_all([tp1, tp2, tp3, tp4])
await session.commit()
return {
"keota_id": str(keota.id),
"keota_name": keota.name,
"keota_slug": keota.slug,
"vr_id": str(vr.id),
"vr_name": vr.name,
"tp1_slug": tp1.slug,
"tp2_slug": tp2.slug,
"tp3_slug": tp3.slug,
"tp4_slug": tp4.slug,
}
def _cascade_lightrag_body(chunks: list[dict]) -> dict:
"""Build a LightRAG /query/data response with given chunks."""
return {
"data": {
"chunks": chunks,
"entities": [],
"relationships": [],
}
}
def _chunk(slug: str, creator_id: str, content: str = "chunk content") -> dict:
return {
"content": content,
"file_path": f"technique:{slug}:creator:{creator_id}",
}
@pytest.mark.asyncio
async def test_search_cascade_creator_tier(db_engine):
"""Tier 1: creator-scoped search returns results → cascade_tier='creator'."""
seed = await _seed_cascade_data(db_engine)
session_factory = async_sessionmaker(
db_engine, class_=AsyncSession, expire_on_commit=False
)
async with session_factory() as session:
from config import Settings
svc = SearchService(settings=Settings())
# httpx returns chunks matching keota's technique pages
body = _cascade_lightrag_body([
_chunk(seed["tp1_slug"], seed["keota_id"], "Reese bass fundamentals"),
])
mock_resp = _mock_httpx_response(body)
svc._httpx = AsyncMock()
svc._httpx.post = AsyncMock(return_value=mock_resp)
svc.embed_query = AsyncMock(return_value=None)
result = await svc.search("reese bass", "all", 10, session, creator="keota")
assert result["cascade_tier"] == "creator"
assert result["fallback_used"] is False
assert result["total"] >= 1
# All cascade items belong to keota
cascade_items = [i for i in result["items"] if i.get("creator_slug") == "keota"]
assert len(cascade_items) >= 1
@pytest.mark.asyncio
async def test_search_cascade_domain_tier(db_engine):
"""Tier 2: creator-scoped empty → domain-scoped returns results → cascade_tier='domain'."""
seed = await _seed_cascade_data(db_engine)
session_factory = async_sessionmaker(
db_engine, class_=AsyncSession, expire_on_commit=False
)
async with session_factory() as session:
from config import Settings
svc = SearchService(settings=Settings())
# Call 1 (creator-scoped): returns chunks for a DIFFERENT creator → post-filter removes them
creator_body = _cascade_lightrag_body([
_chunk(seed["tp4_slug"], seed["vr_id"], "VR content not Keota"),
])
# Call 2 (domain-scoped with "Sound Design"): returns chunks matching Keota
domain_body = _cascade_lightrag_body([
_chunk(seed["tp1_slug"], seed["keota_id"], "Reese bass from domain"),
])
call_count = 0
async def _side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return _mock_httpx_response(creator_body)
else:
return _mock_httpx_response(domain_body)
svc._httpx = AsyncMock()
svc._httpx.post = AsyncMock(side_effect=_side_effect)
svc.embed_query = AsyncMock(return_value=None)
result = await svc.search("synthesis techniques", "all", 10, session, creator="keota")
assert result["cascade_tier"] == "domain"
assert result["fallback_used"] is False
assert result["total"] >= 1
@pytest.mark.asyncio
async def test_search_cascade_global_fallback(db_engine):
"""Tier 3: creator + domain empty → global LightRAG returns → cascade_tier='global'."""
seed = await _seed_cascade_data(db_engine)
session_factory = async_sessionmaker(
db_engine, class_=AsyncSession, expire_on_commit=False
)
async with session_factory() as session:
from config import Settings
svc = SearchService(settings=Settings())
# Calls 1-2 (creator + domain): empty chunks
empty_body = _cascade_lightrag_body([])
# Call 3 (global _lightrag_search): returns results
global_body = _cascade_lightrag_body([
_chunk(seed["tp4_slug"], seed["vr_id"], "Global result"),
])
call_count = 0
async def _side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count <= 2:
return _mock_httpx_response(empty_body)
else:
return _mock_httpx_response(global_body)
svc._httpx = AsyncMock()
svc._httpx.post = AsyncMock(side_effect=_side_effect)
svc.embed_query = AsyncMock(return_value=None)
result = await svc.search("mixing tips", "all", 10, session, creator="keota")
assert result["cascade_tier"] == "global"
assert result["fallback_used"] is False
assert result["total"] >= 1
@pytest.mark.asyncio
async def test_search_cascade_graceful_empty(db_engine):
"""Tier 4: all tiers empty → cascade_tier='none', fallback_used=True."""
seed = await _seed_cascade_data(db_engine)
session_factory = async_sessionmaker(
db_engine, class_=AsyncSession, expire_on_commit=False
)
async with session_factory() as session:
from config import Settings
svc = SearchService(settings=Settings())
# All calls return empty chunks
empty_body = _cascade_lightrag_body([])
svc._httpx = AsyncMock()
svc._httpx.post = AsyncMock(return_value=_mock_httpx_response(empty_body))
svc.embed_query = AsyncMock(return_value=None)
result = await svc.search("nonexistent topic xyz", "all", 10, session, creator="keota")
assert result["cascade_tier"] == "none"
assert result["fallback_used"] is True
@pytest.mark.asyncio
async def test_search_cascade_unknown_creator(db_engine):
"""Unknown creator slug → cascade skipped, normal search, cascade_tier=''."""
seed = await _seed_cascade_data(db_engine)
session_factory = async_sessionmaker(
db_engine, class_=AsyncSession, expire_on_commit=False
)
async with session_factory() as session:
from config import Settings
svc = SearchService(settings=Settings())
# LightRAG returns normal results (non-cascade path)
body = _cascade_lightrag_body([
_chunk(seed["tp4_slug"], seed["vr_id"], "Normal search result"),
])
svc._httpx = AsyncMock()
svc._httpx.post = AsyncMock(return_value=_mock_httpx_response(body))
svc.embed_query = AsyncMock(return_value=None)
result = await svc.search("bass design", "all", 10, session, creator="nonexistent-slug")
# Cascade skipped — falls through to normal search
assert result["cascade_tier"] == ""
@pytest.mark.asyncio
async def test_search_no_creator_param_unchanged(db_engine):
"""No creator param → normal search path, cascade_tier='' (empty)."""
seed = await _seed_cascade_data(db_engine)
session_factory = async_sessionmaker(
db_engine, class_=AsyncSession, expire_on_commit=False
)
async with session_factory() as session:
from config import Settings
svc = SearchService(settings=Settings())
body = _cascade_lightrag_body([
_chunk(seed["tp1_slug"], seed["keota_id"], "Normal result"),
])
svc._httpx = AsyncMock()
svc._httpx.post = AsyncMock(return_value=_mock_httpx_response(body))
svc.embed_query = AsyncMock(return_value=None)
# No creator param
result = await svc.search("reese bass", "all", 10, session)
assert result["cascade_tier"] == ""
assert result["total"] >= 1