test: Added 6 integration tests covering creator-scoped cascade tiers (…
- "backend/tests/test_search.py" GSD-Task: S02/T02
This commit is contained in:
parent
2568dc3812
commit
9530c85b9c
1 changed files with 281 additions and 0 deletions
|
|
@ -966,3 +966,284 @@ async def test_search_lightrag_fallback_on_http_error(db_engine):
|
|||
|
||||
assert result["fallback_used"] is True
|
||||
assert result["total"] >= 1
|
||||
|
||||
|
||||
# ── Creator-scoped cascade integration tests ─────────────────────────────────
|
||||
|
||||
|
||||
async def _seed_cascade_data(db_engine) -> dict:
|
||||
"""Seed creators and technique pages for cascade tier testing.
|
||||
|
||||
Creator 'keota' has 3 Sound Design pages (≥2 → domain='Sound Design').
|
||||
Creator 'virtual-riot' has 1 Synthesis page (< 2 → no dominant domain).
|
||||
"""
|
||||
session_factory = async_sessionmaker(
|
||||
db_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
async with session_factory() as session:
|
||||
keota = Creator(
|
||||
name="Keota",
|
||||
slug="keota",
|
||||
genres=["Bass music"],
|
||||
folder_name="Keota",
|
||||
)
|
||||
vr = Creator(
|
||||
name="Virtual Riot",
|
||||
slug="virtual-riot",
|
||||
genres=["Dubstep"],
|
||||
folder_name="VirtualRiot",
|
||||
)
|
||||
session.add_all([keota, vr])
|
||||
await session.flush()
|
||||
|
||||
tp1 = TechniquePage(
|
||||
creator_id=keota.id,
|
||||
title="Reese Bass Fundamentals",
|
||||
slug="reese-bass-fundamentals",
|
||||
topic_category="Sound Design",
|
||||
topic_tags=["bass", "reese"],
|
||||
summary="Fundamentals of reese bass",
|
||||
)
|
||||
tp2 = TechniquePage(
|
||||
creator_id=keota.id,
|
||||
title="FM Sound Design",
|
||||
slug="fm-sound-design",
|
||||
topic_category="Sound Design",
|
||||
topic_tags=["fm", "design"],
|
||||
summary="FM sound design techniques",
|
||||
)
|
||||
tp3 = TechniquePage(
|
||||
creator_id=keota.id,
|
||||
title="Granular Textures",
|
||||
slug="granular-textures",
|
||||
topic_category="Sound Design",
|
||||
topic_tags=["granular"],
|
||||
summary="Granular texture design",
|
||||
)
|
||||
tp4 = TechniquePage(
|
||||
creator_id=vr.id,
|
||||
title="Serum Wavetable Tricks",
|
||||
slug="serum-wavetable-tricks",
|
||||
topic_category="Synthesis",
|
||||
topic_tags=["serum", "wavetable"],
|
||||
summary="Advanced Serum wavetable tricks",
|
||||
)
|
||||
session.add_all([tp1, tp2, tp3, tp4])
|
||||
await session.commit()
|
||||
|
||||
return {
|
||||
"keota_id": str(keota.id),
|
||||
"keota_name": keota.name,
|
||||
"keota_slug": keota.slug,
|
||||
"vr_id": str(vr.id),
|
||||
"vr_name": vr.name,
|
||||
"tp1_slug": tp1.slug,
|
||||
"tp2_slug": tp2.slug,
|
||||
"tp3_slug": tp3.slug,
|
||||
"tp4_slug": tp4.slug,
|
||||
}
|
||||
|
||||
|
||||
def _cascade_lightrag_body(chunks: list[dict]) -> dict:
|
||||
"""Build a LightRAG /query/data response with given chunks."""
|
||||
return {
|
||||
"data": {
|
||||
"chunks": chunks,
|
||||
"entities": [],
|
||||
"relationships": [],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _chunk(slug: str, creator_id: str, content: str = "chunk content") -> dict:
|
||||
return {
|
||||
"content": content,
|
||||
"file_path": f"technique:{slug}:creator:{creator_id}",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_cascade_creator_tier(db_engine):
|
||||
"""Tier 1: creator-scoped search returns results → cascade_tier='creator'."""
|
||||
seed = await _seed_cascade_data(db_engine)
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
db_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
async with session_factory() as session:
|
||||
from config import Settings
|
||||
svc = SearchService(settings=Settings())
|
||||
|
||||
# httpx returns chunks matching keota's technique pages
|
||||
body = _cascade_lightrag_body([
|
||||
_chunk(seed["tp1_slug"], seed["keota_id"], "Reese bass fundamentals"),
|
||||
])
|
||||
mock_resp = _mock_httpx_response(body)
|
||||
svc._httpx = AsyncMock()
|
||||
svc._httpx.post = AsyncMock(return_value=mock_resp)
|
||||
svc.embed_query = AsyncMock(return_value=None)
|
||||
|
||||
result = await svc.search("reese bass", "all", 10, session, creator="keota")
|
||||
|
||||
assert result["cascade_tier"] == "creator"
|
||||
assert result["fallback_used"] is False
|
||||
assert result["total"] >= 1
|
||||
# All cascade items belong to keota
|
||||
cascade_items = [i for i in result["items"] if i.get("creator_slug") == "keota"]
|
||||
assert len(cascade_items) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_cascade_domain_tier(db_engine):
|
||||
"""Tier 2: creator-scoped empty → domain-scoped returns results → cascade_tier='domain'."""
|
||||
seed = await _seed_cascade_data(db_engine)
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
db_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
async with session_factory() as session:
|
||||
from config import Settings
|
||||
svc = SearchService(settings=Settings())
|
||||
|
||||
# Call 1 (creator-scoped): returns chunks for a DIFFERENT creator → post-filter removes them
|
||||
creator_body = _cascade_lightrag_body([
|
||||
_chunk(seed["tp4_slug"], seed["vr_id"], "VR content not Keota"),
|
||||
])
|
||||
# Call 2 (domain-scoped with "Sound Design"): returns chunks matching Keota
|
||||
domain_body = _cascade_lightrag_body([
|
||||
_chunk(seed["tp1_slug"], seed["keota_id"], "Reese bass from domain"),
|
||||
])
|
||||
|
||||
call_count = 0
|
||||
async def _side_effect(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _mock_httpx_response(creator_body)
|
||||
else:
|
||||
return _mock_httpx_response(domain_body)
|
||||
|
||||
svc._httpx = AsyncMock()
|
||||
svc._httpx.post = AsyncMock(side_effect=_side_effect)
|
||||
svc.embed_query = AsyncMock(return_value=None)
|
||||
|
||||
result = await svc.search("synthesis techniques", "all", 10, session, creator="keota")
|
||||
|
||||
assert result["cascade_tier"] == "domain"
|
||||
assert result["fallback_used"] is False
|
||||
assert result["total"] >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_cascade_global_fallback(db_engine):
|
||||
"""Tier 3: creator + domain empty → global LightRAG returns → cascade_tier='global'."""
|
||||
seed = await _seed_cascade_data(db_engine)
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
db_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
async with session_factory() as session:
|
||||
from config import Settings
|
||||
svc = SearchService(settings=Settings())
|
||||
|
||||
# Calls 1-2 (creator + domain): empty chunks
|
||||
empty_body = _cascade_lightrag_body([])
|
||||
# Call 3 (global _lightrag_search): returns results
|
||||
global_body = _cascade_lightrag_body([
|
||||
_chunk(seed["tp4_slug"], seed["vr_id"], "Global result"),
|
||||
])
|
||||
|
||||
call_count = 0
|
||||
async def _side_effect(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count <= 2:
|
||||
return _mock_httpx_response(empty_body)
|
||||
else:
|
||||
return _mock_httpx_response(global_body)
|
||||
|
||||
svc._httpx = AsyncMock()
|
||||
svc._httpx.post = AsyncMock(side_effect=_side_effect)
|
||||
svc.embed_query = AsyncMock(return_value=None)
|
||||
|
||||
result = await svc.search("mixing tips", "all", 10, session, creator="keota")
|
||||
|
||||
assert result["cascade_tier"] == "global"
|
||||
assert result["fallback_used"] is False
|
||||
assert result["total"] >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_cascade_graceful_empty(db_engine):
|
||||
"""Tier 4: all tiers empty → cascade_tier='none', fallback_used=True."""
|
||||
seed = await _seed_cascade_data(db_engine)
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
db_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
async with session_factory() as session:
|
||||
from config import Settings
|
||||
svc = SearchService(settings=Settings())
|
||||
|
||||
# All calls return empty chunks
|
||||
empty_body = _cascade_lightrag_body([])
|
||||
svc._httpx = AsyncMock()
|
||||
svc._httpx.post = AsyncMock(return_value=_mock_httpx_response(empty_body))
|
||||
svc.embed_query = AsyncMock(return_value=None)
|
||||
|
||||
result = await svc.search("nonexistent topic xyz", "all", 10, session, creator="keota")
|
||||
|
||||
assert result["cascade_tier"] == "none"
|
||||
assert result["fallback_used"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_cascade_unknown_creator(db_engine):
|
||||
"""Unknown creator slug → cascade skipped, normal search, cascade_tier=''."""
|
||||
seed = await _seed_cascade_data(db_engine)
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
db_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
async with session_factory() as session:
|
||||
from config import Settings
|
||||
svc = SearchService(settings=Settings())
|
||||
|
||||
# LightRAG returns normal results (non-cascade path)
|
||||
body = _cascade_lightrag_body([
|
||||
_chunk(seed["tp4_slug"], seed["vr_id"], "Normal search result"),
|
||||
])
|
||||
svc._httpx = AsyncMock()
|
||||
svc._httpx.post = AsyncMock(return_value=_mock_httpx_response(body))
|
||||
svc.embed_query = AsyncMock(return_value=None)
|
||||
|
||||
result = await svc.search("bass design", "all", 10, session, creator="nonexistent-slug")
|
||||
|
||||
# Cascade skipped — falls through to normal search
|
||||
assert result["cascade_tier"] == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_no_creator_param_unchanged(db_engine):
|
||||
"""No creator param → normal search path, cascade_tier='' (empty)."""
|
||||
seed = await _seed_cascade_data(db_engine)
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
db_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
async with session_factory() as session:
|
||||
from config import Settings
|
||||
svc = SearchService(settings=Settings())
|
||||
|
||||
body = _cascade_lightrag_body([
|
||||
_chunk(seed["tp1_slug"], seed["keota_id"], "Normal result"),
|
||||
])
|
||||
svc._httpx = AsyncMock()
|
||||
svc._httpx.post = AsyncMock(return_value=_mock_httpx_response(body))
|
||||
svc.embed_query = AsyncMock(return_value=None)
|
||||
|
||||
# No creator param
|
||||
result = await svc.search("reese bass", "all", 10, session)
|
||||
|
||||
assert result["cascade_tier"] == ""
|
||||
assert result["total"] >= 1
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue