feat: Added 4-tier creator-scoped cascade (creator → domain → global →…

- "backend/search_service.py"
- "backend/schemas.py"
- "backend/routers/search.py"

GSD-Task: S02/T01
This commit is contained in:
jlightner 2026-04-04 05:02:30 +00:00
parent b3f52cc301
commit 2568dc3812
3 changed files with 383 additions and 3 deletions

View file

@ -58,6 +58,7 @@ async def search(
scope: Annotated[str, Query()] = "all", scope: Annotated[str, Query()] = "all",
sort: Annotated[str, Query()] = "relevance", sort: Annotated[str, Query()] = "relevance",
limit: Annotated[int, Query(ge=1, le=100)] = 20, limit: Annotated[int, Query(ge=1, le=100)] = 20,
creator: Annotated[str, Query(max_length=100)] = "",
db: AsyncSession = Depends(get_session), db: AsyncSession = Depends(get_session),
) -> SearchResponse: ) -> SearchResponse:
"""Semantic search with keyword fallback. """Semantic search with keyword fallback.
@ -65,9 +66,10 @@ async def search(
- **q**: Search query (max 500 chars). Empty empty results. - **q**: Search query (max 500 chars). Empty empty results.
- **scope**: ``all`` | ``topics`` | ``creators``. Invalid defaults to ``all``. - **scope**: ``all`` | ``topics`` | ``creators``. Invalid defaults to ``all``.
- **limit**: Max results (1100, default 20). - **limit**: Max results (1100, default 20).
- **creator**: Creator slug or UUID for cascade search. Empty normal search.
""" """
svc = _get_search_service() svc = _get_search_service()
result = await svc.search(query=q, scope=scope, sort=sort, limit=limit, db=db) result = await svc.search(query=q, scope=scope, sort=sort, limit=limit, db=db, creator=creator or None)
# Fire-and-forget search logging — only non-empty queries # Fire-and-forget search logging — only non-empty queries
if q.strip(): if q.strip():
@ -79,6 +81,7 @@ async def search(
total=result["total"], total=result["total"],
query=result["query"], query=result["query"],
fallback_used=result["fallback_used"], fallback_used=result["fallback_used"],
cascade_tier=result.get("cascade_tier", ""),
) )

View file

@ -254,6 +254,7 @@ class SearchResponse(BaseModel):
total: int = 0 total: int = 0
query: str = "" query: str = ""
fallback_used: bool = False fallback_used: bool = False
cascade_tier: str = ""
class SuggestionItem(BaseModel): class SuggestionItem(BaseModel):

View file

@ -12,6 +12,7 @@ import asyncio
import logging import logging
import re import re
import time import time
import uuid as uuid_mod
from typing import Any from typing import Any
import httpx import httpx
@ -572,6 +573,259 @@ class SearchService:
) )
return [] return []
# ── Creator-scoped cascade helpers ──────────────────────────────────
async def _resolve_creator(
self,
creator_ref: str,
db: AsyncSession,
) -> tuple[str | None, str | None]:
"""Resolve a creator slug or UUID to (creator_id, creator_name).
Returns (None, None) if the creator is not found.
"""
try:
creator_uuid = uuid_mod.UUID(creator_ref)
stmt = select(Creator).where(Creator.id == creator_uuid)
except (ValueError, AttributeError):
stmt = select(Creator).where(Creator.slug == creator_ref)
result = await db.execute(stmt)
cr = result.scalars().first()
if cr is None:
return None, None
return str(cr.id), cr.name
async def _get_creator_domain(
self,
creator_id: str,
db: AsyncSession,
) -> str | None:
"""Return the dominant topic_category for a creator, or None if <2 technique pages."""
stmt = (
select(
TechniquePage.topic_category,
func.count().label("cnt"),
)
.where(TechniquePage.creator_id == uuid_mod.UUID(creator_id))
.group_by(TechniquePage.topic_category)
.order_by(func.count().desc())
.limit(1)
)
result = await db.execute(stmt)
row = result.first()
if row is None:
return None
# Require at least 2 technique pages to declare a domain
if row.cnt < 2:
return None
return row.topic_category
async def _creator_scoped_search(
self,
query: str,
creator_id: str,
creator_name: str,
limit: int,
db: AsyncSession,
) -> list[dict[str, Any]]:
"""Search LightRAG with creator name as keyword, post-filter by creator_id."""
start = time.monotonic()
try:
resp = await self._httpx.post(
f"{self._lightrag_url}/query/data",
json={
"query": query,
"mode": "hybrid",
"top_k": limit * 3,
"ll_keywords": [creator_name],
},
)
resp.raise_for_status()
body = resp.json()
except Exception as exc:
elapsed_ms = (time.monotonic() - start) * 1000
logger.warning(
"creator_scoped_search reason=%s query=%r creator=%s latency_ms=%.1f",
type(exc).__name__, query, creator_id, elapsed_ms,
)
return []
try:
data = body.get("data", {})
chunks = data.get("chunks", []) if data else []
slug_set: set[str] = set()
slug_order: list[str] = []
for chunk in chunks:
file_path = chunk.get("file_path", "")
m = self._FILE_SOURCE_RE.match(file_path)
if m and m.group("slug") not in slug_set:
slug = m.group("slug")
slug_set.add(slug)
slug_order.append(slug)
if not slug_set:
elapsed_ms = (time.monotonic() - start) * 1000
logger.warning(
"creator_scoped_search reason=no_parseable_results query=%r creator=%s latency_ms=%.1f",
query, creator_id, elapsed_ms,
)
return []
# Batch lookup and post-filter by creator_id
tp_stmt = (
select(TechniquePage, Creator)
.join(Creator, TechniquePage.creator_id == Creator.id)
.where(TechniquePage.slug.in_(list(slug_set)))
)
tp_rows = await db.execute(tp_stmt)
tp_map: dict[str, tuple] = {}
for tp, cr in tp_rows.all():
if str(tp.creator_id) == creator_id:
tp_map[tp.slug] = (tp, cr)
results: list[dict[str, Any]] = []
for idx, slug in enumerate(slug_order):
pair = tp_map.get(slug)
if not pair:
continue
tp, cr = pair
score = max(1.0 - (idx * 0.05), 0.5)
results.append({
"type": "technique_page",
"title": tp.title,
"slug": tp.slug,
"technique_page_slug": tp.slug,
"summary": tp.summary or "",
"topic_category": tp.topic_category,
"topic_tags": tp.topic_tags or [],
"creator_id": str(tp.creator_id),
"creator_name": cr.name,
"creator_slug": cr.slug,
"created_at": tp.created_at.isoformat() if tp.created_at else "",
"score": score,
"match_context": "Creator-scoped LightRAG match",
})
if len(results) >= limit:
break
elapsed_ms = (time.monotonic() - start) * 1000
logger.info(
"creator_scoped_search query=%r creator=%s latency_ms=%.1f result_count=%d",
query, creator_id, elapsed_ms, len(results),
)
return results
except (KeyError, ValueError, TypeError) as exc:
elapsed_ms = (time.monotonic() - start) * 1000
logger.warning(
"creator_scoped_search reason=parse_error query=%r creator=%s error=%s latency_ms=%.1f",
query, creator_id, exc, elapsed_ms,
)
return []
async def _domain_scoped_search(
self,
query: str,
domain: str,
limit: int,
db: AsyncSession,
) -> list[dict[str, Any]]:
"""Search LightRAG with domain keyword — no post-filtering."""
start = time.monotonic()
try:
resp = await self._httpx.post(
f"{self._lightrag_url}/query/data",
json={
"query": query,
"mode": "hybrid",
"top_k": limit,
"ll_keywords": [domain],
},
)
resp.raise_for_status()
body = resp.json()
except Exception as exc:
elapsed_ms = (time.monotonic() - start) * 1000
logger.warning(
"domain_scoped_search reason=%s query=%r domain=%s latency_ms=%.1f",
type(exc).__name__, query, domain, elapsed_ms,
)
return []
try:
data = body.get("data", {})
chunks = data.get("chunks", []) if data else []
slug_set: set[str] = set()
slug_order: list[str] = []
for chunk in chunks:
file_path = chunk.get("file_path", "")
m = self._FILE_SOURCE_RE.match(file_path)
if m and m.group("slug") not in slug_set:
slug = m.group("slug")
slug_set.add(slug)
slug_order.append(slug)
if not slug_set:
elapsed_ms = (time.monotonic() - start) * 1000
logger.warning(
"domain_scoped_search reason=no_parseable_results query=%r domain=%s latency_ms=%.1f",
query, domain, elapsed_ms,
)
return []
tp_stmt = (
select(TechniquePage, Creator)
.join(Creator, TechniquePage.creator_id == Creator.id)
.where(TechniquePage.slug.in_(list(slug_set)))
)
tp_rows = await db.execute(tp_stmt)
tp_map: dict[str, tuple] = {}
for tp, cr in tp_rows.all():
tp_map[tp.slug] = (tp, cr)
results: list[dict[str, Any]] = []
for idx, slug in enumerate(slug_order):
pair = tp_map.get(slug)
if not pair:
continue
tp, cr = pair
score = max(1.0 - (idx * 0.05), 0.5)
results.append({
"type": "technique_page",
"title": tp.title,
"slug": tp.slug,
"technique_page_slug": tp.slug,
"summary": tp.summary or "",
"topic_category": tp.topic_category,
"topic_tags": tp.topic_tags or [],
"creator_id": str(tp.creator_id),
"creator_name": cr.name,
"creator_slug": cr.slug,
"created_at": tp.created_at.isoformat() if tp.created_at else "",
"score": score,
"match_context": "Domain-scoped LightRAG match",
})
if len(results) >= limit:
break
elapsed_ms = (time.monotonic() - start) * 1000
logger.info(
"domain_scoped_search query=%r domain=%s latency_ms=%.1f result_count=%d",
query, domain, elapsed_ms, len(results),
)
return results
except (KeyError, ValueError, TypeError) as exc:
elapsed_ms = (time.monotonic() - start) * 1000
logger.warning(
"domain_scoped_search reason=parse_error query=%r domain=%s error=%s latency_ms=%.1f",
query, domain, exc, elapsed_ms,
)
return []
# ── Orchestrator ───────────────────────────────────────────────────── # ── Orchestrator ─────────────────────────────────────────────────────
async def search( async def search(
@ -581,9 +835,14 @@ class SearchService:
limit: int, limit: int,
db: AsyncSession, db: AsyncSession,
sort: str = "relevance", sort: str = "relevance",
creator: str | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Run semantic and keyword search in parallel, merge and deduplicate. """Run semantic and keyword search in parallel, merge and deduplicate.
When ``creator`` is provided, executes a 4-tier cascade:
creator domain global none, returning results from the first
tier that produces hits. ``cascade_tier`` indicates which tier served.
Both engines run concurrently. Keyword results are always included Both engines run concurrently. Keyword results are always included
(with match_context). Semantic results above the score threshold are (with match_context). Semantic results above the score threshold are
merged in, deduplicated by (type, slug/title). Keyword matches rank merged in, deduplicated by (type, slug/title). Keyword matches rank
@ -592,12 +851,129 @@ class SearchService:
start = time.monotonic() start = time.monotonic()
if not query or not query.strip(): if not query or not query.strip():
return {"items": [], "partial_matches": [], "total": 0, "query": query, "fallback_used": False} return {"items": [], "partial_matches": [], "total": 0, "query": query, "fallback_used": False, "cascade_tier": ""}
query = query.strip()[:500] query = query.strip()[:500]
if scope not in ("all", "topics", "creators"): if scope not in ("all", "topics", "creators"):
scope = "all" scope = "all"
cascade_tier = ""
# ── Creator-scoped cascade ──────────────────────────────────────
use_lightrag = len(query) >= self._lightrag_min_query_length
if creator and use_lightrag:
creator_id, creator_name = await self._resolve_creator(creator, db)
if creator_id and creator_name:
# Tier 1: creator-scoped
tier1 = await self._creator_scoped_search(query, creator_id, creator_name, limit, db)
if tier1:
cascade_tier = "creator"
lightrag_results = tier1
fallback_used = False
else:
# Tier 2: domain-scoped
domain = await self._get_creator_domain(creator_id, db)
tier2: list[dict[str, Any]] = []
if domain:
tier2 = await self._domain_scoped_search(query, domain, limit, db)
if tier2:
cascade_tier = "domain"
lightrag_results = tier2
fallback_used = False
else:
# Tier 3: global LightRAG
tier3 = await self._lightrag_search(query, limit, db)
if tier3:
cascade_tier = "global"
lightrag_results = tier3
fallback_used = False
else:
# Tier 4: no LightRAG results at all
cascade_tier = "none"
lightrag_results = []
fallback_used = True
elapsed_cascade = (time.monotonic() - start) * 1000
logger.info(
"cascade_search query=%r creator=%s tier=%s latency_ms=%.1f result_count=%d",
query, creator, cascade_tier, elapsed_cascade, len(lightrag_results),
)
# Skip to merge phase (keyword still runs for supplementary)
# Jump past the non-cascade LightRAG block
kw_result = await self.keyword_search(query, scope, limit, db, sort=sort)
if fallback_used:
# Qdrant semantic fallback
vector = await self.embed_query(query)
semantic_results: list[dict[str, Any]] = []
if vector:
raw = await self.search_qdrant(vector, limit=limit)
enriched = await self._enrich_qdrant_results(raw, db)
semantic_results = [
item for item in enriched
if item.get("score", 0) >= _SEMANTIC_MIN_SCORE
]
for item in semantic_results:
if not item.get("match_context"):
item["match_context"] = "Semantic match"
else:
semantic_results = []
# Handle exceptions
kw_items = kw_result["items"] if not isinstance(kw_result, Exception) else []
partial_matches = kw_result.get("partial_matches", []) if not isinstance(kw_result, Exception) else []
# Merge: cascade results first, then keyword, then semantic
seen_keys: set[str] = set()
merged: list[dict[str, Any]] = []
def _dedup_key(item: dict) -> str:
t = item.get("type", "")
s = item.get("slug") or item.get("technique_page_slug") or ""
title = item.get("title", "")
return f"{t}:{s}:{title}"
for item in lightrag_results:
key = _dedup_key(item)
if key not in seen_keys:
seen_keys.add(key)
merged.append(item)
for item in kw_items:
key = _dedup_key(item)
if key not in seen_keys:
seen_keys.add(key)
merged.append(item)
for item in semantic_results:
key = _dedup_key(item)
if key not in seen_keys:
seen_keys.add(key)
merged.append(item)
merged = self._apply_sort(merged, sort)
elapsed_ms = (time.monotonic() - start) * 1000
logger.info(
"Search query=%r scope=%s cascade_tier=%s lightrag=%d keyword=%d semantic=%d merged=%d fallback=%s latency_ms=%.1f",
query, scope, cascade_tier, len(lightrag_results), len(kw_items),
len(semantic_results), len(merged), fallback_used, elapsed_ms,
)
return {
"items": merged[:limit],
"partial_matches": partial_matches,
"total": len(merged),
"query": query,
"fallback_used": fallback_used,
"cascade_tier": cascade_tier,
}
else:
logger.warning("cascade_search reason=creator_not_found creator_ref=%r", creator)
# Fall through to normal search path
# ── Primary: try LightRAG for queries ≥ min length ───────────── # ── Primary: try LightRAG for queries ≥ min length ─────────────
lightrag_results: list[dict[str, Any]] = [] lightrag_results: list[dict[str, Any]] = []
fallback_used = True # assume fallback until LightRAG succeeds fallback_used = True # assume fallback until LightRAG succeeds
@ -699,6 +1075,7 @@ class SearchService:
"total": len(merged), "total": len(merged),
"query": query, "query": query,
"fallback_used": fallback_used, "fallback_used": fallback_used,
"cascade_tier": cascade_tier,
} }
# ── Sort helpers ──────────────────────────────────────────────────── # ── Sort helpers ────────────────────────────────────────────────────
@ -744,7 +1121,6 @@ class SearchService:
# Batch fetch creators from DB # Batch fetch creators from DB
creator_map: dict[str, dict[str, str]] = {} creator_map: dict[str, dict[str, str]] = {}
if needs_db_lookup: if needs_db_lookup:
import uuid as uuid_mod
valid_ids = [] valid_ids = []
for cid in needs_db_lookup: for cid in needs_db_lookup:
try: try: