feat: Added 4-tier creator-scoped cascade (creator → domain → global →…
- "backend/search_service.py" - "backend/schemas.py" - "backend/routers/search.py" GSD-Task: S02/T01
This commit is contained in:
parent
b3f52cc301
commit
2568dc3812
3 changed files with 383 additions and 3 deletions
|
|
@ -58,6 +58,7 @@ async def search(
|
|||
scope: Annotated[str, Query()] = "all",
|
||||
sort: Annotated[str, Query()] = "relevance",
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 20,
|
||||
creator: Annotated[str, Query(max_length=100)] = "",
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> SearchResponse:
|
||||
"""Semantic search with keyword fallback.
|
||||
|
|
@ -65,9 +66,10 @@ async def search(
|
|||
- **q**: Search query (max 500 chars). Empty → empty results.
|
||||
- **scope**: ``all`` | ``topics`` | ``creators``. Invalid → defaults to ``all``.
|
||||
- **limit**: Max results (1–100, default 20).
|
||||
- **creator**: Creator slug or UUID for cascade search. Empty → normal search.
|
||||
"""
|
||||
svc = _get_search_service()
|
||||
result = await svc.search(query=q, scope=scope, sort=sort, limit=limit, db=db)
|
||||
result = await svc.search(query=q, scope=scope, sort=sort, limit=limit, db=db, creator=creator or None)
|
||||
|
||||
# Fire-and-forget search logging — only non-empty queries
|
||||
if q.strip():
|
||||
|
|
@ -79,6 +81,7 @@ async def search(
|
|||
total=result["total"],
|
||||
query=result["query"],
|
||||
fallback_used=result["fallback_used"],
|
||||
cascade_tier=result.get("cascade_tier", ""),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -254,6 +254,7 @@ class SearchResponse(BaseModel):
|
|||
total: int = 0
|
||||
query: str = ""
|
||||
fallback_used: bool = False
|
||||
cascade_tier: str = ""
|
||||
|
||||
|
||||
class SuggestionItem(BaseModel):
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import asyncio
|
|||
import logging
|
||||
import re
|
||||
import time
|
||||
import uuid as uuid_mod
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
|
@ -572,6 +573,259 @@ class SearchService:
|
|||
)
|
||||
return []
|
||||
|
||||
# ── Creator-scoped cascade helpers ──────────────────────────────────
|
||||
|
||||
async def _resolve_creator(
|
||||
self,
|
||||
creator_ref: str,
|
||||
db: AsyncSession,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Resolve a creator slug or UUID to (creator_id, creator_name).
|
||||
|
||||
Returns (None, None) if the creator is not found.
|
||||
"""
|
||||
try:
|
||||
creator_uuid = uuid_mod.UUID(creator_ref)
|
||||
stmt = select(Creator).where(Creator.id == creator_uuid)
|
||||
except (ValueError, AttributeError):
|
||||
stmt = select(Creator).where(Creator.slug == creator_ref)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
cr = result.scalars().first()
|
||||
if cr is None:
|
||||
return None, None
|
||||
return str(cr.id), cr.name
|
||||
|
||||
async def _get_creator_domain(
|
||||
self,
|
||||
creator_id: str,
|
||||
db: AsyncSession,
|
||||
) -> str | None:
|
||||
"""Return the dominant topic_category for a creator, or None if <2 technique pages."""
|
||||
stmt = (
|
||||
select(
|
||||
TechniquePage.topic_category,
|
||||
func.count().label("cnt"),
|
||||
)
|
||||
.where(TechniquePage.creator_id == uuid_mod.UUID(creator_id))
|
||||
.group_by(TechniquePage.topic_category)
|
||||
.order_by(func.count().desc())
|
||||
.limit(1)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
row = result.first()
|
||||
if row is None:
|
||||
return None
|
||||
# Require at least 2 technique pages to declare a domain
|
||||
if row.cnt < 2:
|
||||
return None
|
||||
return row.topic_category
|
||||
|
||||
async def _creator_scoped_search(
|
||||
self,
|
||||
query: str,
|
||||
creator_id: str,
|
||||
creator_name: str,
|
||||
limit: int,
|
||||
db: AsyncSession,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search LightRAG with creator name as keyword, post-filter by creator_id."""
|
||||
start = time.monotonic()
|
||||
try:
|
||||
resp = await self._httpx.post(
|
||||
f"{self._lightrag_url}/query/data",
|
||||
json={
|
||||
"query": query,
|
||||
"mode": "hybrid",
|
||||
"top_k": limit * 3,
|
||||
"ll_keywords": [creator_name],
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
body = resp.json()
|
||||
except Exception as exc:
|
||||
elapsed_ms = (time.monotonic() - start) * 1000
|
||||
logger.warning(
|
||||
"creator_scoped_search reason=%s query=%r creator=%s latency_ms=%.1f",
|
||||
type(exc).__name__, query, creator_id, elapsed_ms,
|
||||
)
|
||||
return []
|
||||
|
||||
try:
|
||||
data = body.get("data", {})
|
||||
chunks = data.get("chunks", []) if data else []
|
||||
|
||||
slug_set: set[str] = set()
|
||||
slug_order: list[str] = []
|
||||
for chunk in chunks:
|
||||
file_path = chunk.get("file_path", "")
|
||||
m = self._FILE_SOURCE_RE.match(file_path)
|
||||
if m and m.group("slug") not in slug_set:
|
||||
slug = m.group("slug")
|
||||
slug_set.add(slug)
|
||||
slug_order.append(slug)
|
||||
|
||||
if not slug_set:
|
||||
elapsed_ms = (time.monotonic() - start) * 1000
|
||||
logger.warning(
|
||||
"creator_scoped_search reason=no_parseable_results query=%r creator=%s latency_ms=%.1f",
|
||||
query, creator_id, elapsed_ms,
|
||||
)
|
||||
return []
|
||||
|
||||
# Batch lookup and post-filter by creator_id
|
||||
tp_stmt = (
|
||||
select(TechniquePage, Creator)
|
||||
.join(Creator, TechniquePage.creator_id == Creator.id)
|
||||
.where(TechniquePage.slug.in_(list(slug_set)))
|
||||
)
|
||||
tp_rows = await db.execute(tp_stmt)
|
||||
tp_map: dict[str, tuple] = {}
|
||||
for tp, cr in tp_rows.all():
|
||||
if str(tp.creator_id) == creator_id:
|
||||
tp_map[tp.slug] = (tp, cr)
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
for idx, slug in enumerate(slug_order):
|
||||
pair = tp_map.get(slug)
|
||||
if not pair:
|
||||
continue
|
||||
tp, cr = pair
|
||||
score = max(1.0 - (idx * 0.05), 0.5)
|
||||
results.append({
|
||||
"type": "technique_page",
|
||||
"title": tp.title,
|
||||
"slug": tp.slug,
|
||||
"technique_page_slug": tp.slug,
|
||||
"summary": tp.summary or "",
|
||||
"topic_category": tp.topic_category,
|
||||
"topic_tags": tp.topic_tags or [],
|
||||
"creator_id": str(tp.creator_id),
|
||||
"creator_name": cr.name,
|
||||
"creator_slug": cr.slug,
|
||||
"created_at": tp.created_at.isoformat() if tp.created_at else "",
|
||||
"score": score,
|
||||
"match_context": "Creator-scoped LightRAG match",
|
||||
})
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
elapsed_ms = (time.monotonic() - start) * 1000
|
||||
logger.info(
|
||||
"creator_scoped_search query=%r creator=%s latency_ms=%.1f result_count=%d",
|
||||
query, creator_id, elapsed_ms, len(results),
|
||||
)
|
||||
return results
|
||||
|
||||
except (KeyError, ValueError, TypeError) as exc:
|
||||
elapsed_ms = (time.monotonic() - start) * 1000
|
||||
logger.warning(
|
||||
"creator_scoped_search reason=parse_error query=%r creator=%s error=%s latency_ms=%.1f",
|
||||
query, creator_id, exc, elapsed_ms,
|
||||
)
|
||||
return []
|
||||
|
||||
async def _domain_scoped_search(
|
||||
self,
|
||||
query: str,
|
||||
domain: str,
|
||||
limit: int,
|
||||
db: AsyncSession,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search LightRAG with domain keyword — no post-filtering."""
|
||||
start = time.monotonic()
|
||||
try:
|
||||
resp = await self._httpx.post(
|
||||
f"{self._lightrag_url}/query/data",
|
||||
json={
|
||||
"query": query,
|
||||
"mode": "hybrid",
|
||||
"top_k": limit,
|
||||
"ll_keywords": [domain],
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
body = resp.json()
|
||||
except Exception as exc:
|
||||
elapsed_ms = (time.monotonic() - start) * 1000
|
||||
logger.warning(
|
||||
"domain_scoped_search reason=%s query=%r domain=%s latency_ms=%.1f",
|
||||
type(exc).__name__, query, domain, elapsed_ms,
|
||||
)
|
||||
return []
|
||||
|
||||
try:
|
||||
data = body.get("data", {})
|
||||
chunks = data.get("chunks", []) if data else []
|
||||
|
||||
slug_set: set[str] = set()
|
||||
slug_order: list[str] = []
|
||||
for chunk in chunks:
|
||||
file_path = chunk.get("file_path", "")
|
||||
m = self._FILE_SOURCE_RE.match(file_path)
|
||||
if m and m.group("slug") not in slug_set:
|
||||
slug = m.group("slug")
|
||||
slug_set.add(slug)
|
||||
slug_order.append(slug)
|
||||
|
||||
if not slug_set:
|
||||
elapsed_ms = (time.monotonic() - start) * 1000
|
||||
logger.warning(
|
||||
"domain_scoped_search reason=no_parseable_results query=%r domain=%s latency_ms=%.1f",
|
||||
query, domain, elapsed_ms,
|
||||
)
|
||||
return []
|
||||
|
||||
tp_stmt = (
|
||||
select(TechniquePage, Creator)
|
||||
.join(Creator, TechniquePage.creator_id == Creator.id)
|
||||
.where(TechniquePage.slug.in_(list(slug_set)))
|
||||
)
|
||||
tp_rows = await db.execute(tp_stmt)
|
||||
tp_map: dict[str, tuple] = {}
|
||||
for tp, cr in tp_rows.all():
|
||||
tp_map[tp.slug] = (tp, cr)
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
for idx, slug in enumerate(slug_order):
|
||||
pair = tp_map.get(slug)
|
||||
if not pair:
|
||||
continue
|
||||
tp, cr = pair
|
||||
score = max(1.0 - (idx * 0.05), 0.5)
|
||||
results.append({
|
||||
"type": "technique_page",
|
||||
"title": tp.title,
|
||||
"slug": tp.slug,
|
||||
"technique_page_slug": tp.slug,
|
||||
"summary": tp.summary or "",
|
||||
"topic_category": tp.topic_category,
|
||||
"topic_tags": tp.topic_tags or [],
|
||||
"creator_id": str(tp.creator_id),
|
||||
"creator_name": cr.name,
|
||||
"creator_slug": cr.slug,
|
||||
"created_at": tp.created_at.isoformat() if tp.created_at else "",
|
||||
"score": score,
|
||||
"match_context": "Domain-scoped LightRAG match",
|
||||
})
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
elapsed_ms = (time.monotonic() - start) * 1000
|
||||
logger.info(
|
||||
"domain_scoped_search query=%r domain=%s latency_ms=%.1f result_count=%d",
|
||||
query, domain, elapsed_ms, len(results),
|
||||
)
|
||||
return results
|
||||
|
||||
except (KeyError, ValueError, TypeError) as exc:
|
||||
elapsed_ms = (time.monotonic() - start) * 1000
|
||||
logger.warning(
|
||||
"domain_scoped_search reason=parse_error query=%r domain=%s error=%s latency_ms=%.1f",
|
||||
query, domain, exc, elapsed_ms,
|
||||
)
|
||||
return []
|
||||
|
||||
# ── Orchestrator ─────────────────────────────────────────────────────
|
||||
|
||||
async def search(
|
||||
|
|
@ -581,9 +835,14 @@ class SearchService:
|
|||
limit: int,
|
||||
db: AsyncSession,
|
||||
sort: str = "relevance",
|
||||
creator: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Run semantic and keyword search in parallel, merge and deduplicate.
|
||||
|
||||
When ``creator`` is provided, executes a 4-tier cascade:
|
||||
creator → domain → global → none, returning results from the first
|
||||
tier that produces hits. ``cascade_tier`` indicates which tier served.
|
||||
|
||||
Both engines run concurrently. Keyword results are always included
|
||||
(with match_context). Semantic results above the score threshold are
|
||||
merged in, deduplicated by (type, slug/title). Keyword matches rank
|
||||
|
|
@ -592,12 +851,129 @@ class SearchService:
|
|||
start = time.monotonic()
|
||||
|
||||
if not query or not query.strip():
|
||||
return {"items": [], "partial_matches": [], "total": 0, "query": query, "fallback_used": False}
|
||||
return {"items": [], "partial_matches": [], "total": 0, "query": query, "fallback_used": False, "cascade_tier": ""}
|
||||
|
||||
query = query.strip()[:500]
|
||||
if scope not in ("all", "topics", "creators"):
|
||||
scope = "all"
|
||||
|
||||
cascade_tier = ""
|
||||
|
||||
# ── Creator-scoped cascade ──────────────────────────────────────
|
||||
use_lightrag = len(query) >= self._lightrag_min_query_length
|
||||
|
||||
if creator and use_lightrag:
|
||||
creator_id, creator_name = await self._resolve_creator(creator, db)
|
||||
if creator_id and creator_name:
|
||||
# Tier 1: creator-scoped
|
||||
tier1 = await self._creator_scoped_search(query, creator_id, creator_name, limit, db)
|
||||
if tier1:
|
||||
cascade_tier = "creator"
|
||||
lightrag_results = tier1
|
||||
fallback_used = False
|
||||
else:
|
||||
# Tier 2: domain-scoped
|
||||
domain = await self._get_creator_domain(creator_id, db)
|
||||
tier2: list[dict[str, Any]] = []
|
||||
if domain:
|
||||
tier2 = await self._domain_scoped_search(query, domain, limit, db)
|
||||
if tier2:
|
||||
cascade_tier = "domain"
|
||||
lightrag_results = tier2
|
||||
fallback_used = False
|
||||
else:
|
||||
# Tier 3: global LightRAG
|
||||
tier3 = await self._lightrag_search(query, limit, db)
|
||||
if tier3:
|
||||
cascade_tier = "global"
|
||||
lightrag_results = tier3
|
||||
fallback_used = False
|
||||
else:
|
||||
# Tier 4: no LightRAG results at all
|
||||
cascade_tier = "none"
|
||||
lightrag_results = []
|
||||
fallback_used = True
|
||||
|
||||
elapsed_cascade = (time.monotonic() - start) * 1000
|
||||
logger.info(
|
||||
"cascade_search query=%r creator=%s tier=%s latency_ms=%.1f result_count=%d",
|
||||
query, creator, cascade_tier, elapsed_cascade, len(lightrag_results),
|
||||
)
|
||||
|
||||
# Skip to merge phase (keyword still runs for supplementary)
|
||||
# Jump past the non-cascade LightRAG block
|
||||
kw_result = await self.keyword_search(query, scope, limit, db, sort=sort)
|
||||
|
||||
if fallback_used:
|
||||
# Qdrant semantic fallback
|
||||
vector = await self.embed_query(query)
|
||||
semantic_results: list[dict[str, Any]] = []
|
||||
if vector:
|
||||
raw = await self.search_qdrant(vector, limit=limit)
|
||||
enriched = await self._enrich_qdrant_results(raw, db)
|
||||
semantic_results = [
|
||||
item for item in enriched
|
||||
if item.get("score", 0) >= _SEMANTIC_MIN_SCORE
|
||||
]
|
||||
for item in semantic_results:
|
||||
if not item.get("match_context"):
|
||||
item["match_context"] = "Semantic match"
|
||||
else:
|
||||
semantic_results = []
|
||||
|
||||
# Handle exceptions
|
||||
kw_items = kw_result["items"] if not isinstance(kw_result, Exception) else []
|
||||
partial_matches = kw_result.get("partial_matches", []) if not isinstance(kw_result, Exception) else []
|
||||
|
||||
# Merge: cascade results first, then keyword, then semantic
|
||||
seen_keys: set[str] = set()
|
||||
merged: list[dict[str, Any]] = []
|
||||
|
||||
def _dedup_key(item: dict) -> str:
|
||||
t = item.get("type", "")
|
||||
s = item.get("slug") or item.get("technique_page_slug") or ""
|
||||
title = item.get("title", "")
|
||||
return f"{t}:{s}:{title}"
|
||||
|
||||
for item in lightrag_results:
|
||||
key = _dedup_key(item)
|
||||
if key not in seen_keys:
|
||||
seen_keys.add(key)
|
||||
merged.append(item)
|
||||
|
||||
for item in kw_items:
|
||||
key = _dedup_key(item)
|
||||
if key not in seen_keys:
|
||||
seen_keys.add(key)
|
||||
merged.append(item)
|
||||
|
||||
for item in semantic_results:
|
||||
key = _dedup_key(item)
|
||||
if key not in seen_keys:
|
||||
seen_keys.add(key)
|
||||
merged.append(item)
|
||||
|
||||
merged = self._apply_sort(merged, sort)
|
||||
|
||||
elapsed_ms = (time.monotonic() - start) * 1000
|
||||
logger.info(
|
||||
"Search query=%r scope=%s cascade_tier=%s lightrag=%d keyword=%d semantic=%d merged=%d fallback=%s latency_ms=%.1f",
|
||||
query, scope, cascade_tier, len(lightrag_results), len(kw_items),
|
||||
len(semantic_results), len(merged), fallback_used, elapsed_ms,
|
||||
)
|
||||
|
||||
return {
|
||||
"items": merged[:limit],
|
||||
"partial_matches": partial_matches,
|
||||
"total": len(merged),
|
||||
"query": query,
|
||||
"fallback_used": fallback_used,
|
||||
"cascade_tier": cascade_tier,
|
||||
}
|
||||
else:
|
||||
logger.warning("cascade_search reason=creator_not_found creator_ref=%r", creator)
|
||||
# Fall through to normal search path
|
||||
|
||||
# ── Primary: try LightRAG for queries ≥ min length ─────────────
|
||||
lightrag_results: list[dict[str, Any]] = []
|
||||
fallback_used = True # assume fallback until LightRAG succeeds
|
||||
|
|
@ -699,6 +1075,7 @@ class SearchService:
|
|||
"total": len(merged),
|
||||
"query": query,
|
||||
"fallback_used": fallback_used,
|
||||
"cascade_tier": cascade_tier,
|
||||
}
|
||||
|
||||
# ── Sort helpers ────────────────────────────────────────────────────
|
||||
|
|
@ -744,7 +1121,6 @@ class SearchService:
|
|||
# Batch fetch creators from DB
|
||||
creator_map: dict[str, dict[str, str]] = {}
|
||||
if needs_db_lookup:
|
||||
import uuid as uuid_mod
|
||||
valid_ids = []
|
||||
for cid in needs_db_lookup:
|
||||
try:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue