"""Async search service for the public search endpoint. Orchestrates semantic search (embedding + Qdrant) with keyword fallback. All external calls have timeouts and graceful degradation — if embedding or Qdrant fail, the service falls back to keyword-only (ILIKE) search. """ from __future__ import annotations import asyncio import logging import time from typing import Any import openai from qdrant_client import AsyncQdrantClient from qdrant_client.http import exceptions as qdrant_exceptions from qdrant_client.models import FieldCondition, Filter, MatchValue from sqlalchemy import or_, select from sqlalchemy.ext.asyncio import AsyncSession from config import Settings from models import Creator, KeyMoment, SourceVideo, TechniquePage logger = logging.getLogger("chrysopedia.search") # Timeout for external calls (embedding API, Qdrant) in seconds _EXTERNAL_TIMEOUT = 0.3 # 300ms per plan class SearchService: """Async search service with semantic + keyword fallback. Parameters ---------- settings: Application settings containing embedding and Qdrant config. """ def __init__(self, settings: Settings) -> None: self.settings = settings self._openai = openai.AsyncOpenAI( base_url=settings.embedding_api_url, api_key=settings.llm_api_key, ) self._qdrant = AsyncQdrantClient(url=settings.qdrant_url) self._collection = settings.qdrant_collection # ── Embedding ──────────────────────────────────────────────────────── async def embed_query(self, text: str) -> list[float] | None: """Embed a query string into a vector. Returns None on any failure (timeout, connection, malformed response) so the caller can fall back to keyword search. """ try: response = await asyncio.wait_for( self._openai.embeddings.create( model=self.settings.embedding_model, input=text, ), timeout=_EXTERNAL_TIMEOUT, ) except asyncio.TimeoutError: logger.warning("Embedding API timeout (%.0fms limit) for query: %.50s…", _EXTERNAL_TIMEOUT * 1000, text) return None except (openai.APIConnectionError, openai.APITimeoutError) as exc: logger.warning("Embedding API connection error (%s: %s)", type(exc).__name__, exc) return None except openai.APIError as exc: logger.warning("Embedding API error (%s: %s)", type(exc).__name__, exc) return None if not response.data: logger.warning("Embedding API returned empty data for query: %.50s…", text) return None vector = response.data[0].embedding if len(vector) != self.settings.embedding_dimensions: logger.warning( "Embedding dimension mismatch: expected %d, got %d", self.settings.embedding_dimensions, len(vector), ) return None return vector # ── Qdrant vector search ───────────────────────────────────────────── async def search_qdrant( self, vector: list[float], limit: int = 20, type_filter: str | None = None, ) -> list[dict[str, Any]]: """Search Qdrant for nearest neighbours. Returns a list of dicts with 'score' and 'payload' keys. Returns empty list on failure. """ query_filter = None if type_filter: query_filter = Filter( must=[FieldCondition(key="type", match=MatchValue(value=type_filter))] ) try: results = await asyncio.wait_for( self._qdrant.query_points( collection_name=self._collection, query=vector, query_filter=query_filter, limit=limit, with_payload=True, ), timeout=_EXTERNAL_TIMEOUT, ) except asyncio.TimeoutError: logger.warning("Qdrant search timeout (%.0fms limit)", _EXTERNAL_TIMEOUT * 1000) return [] except qdrant_exceptions.UnexpectedResponse as exc: logger.warning("Qdrant search error: %s", exc) return [] except Exception as exc: logger.warning("Qdrant connection error (%s: %s)", type(exc).__name__, exc) return [] return [ {"score": point.score, "payload": point.payload} for point in results.points ] # ── Keyword fallback ───────────────────────────────────────────────── async def keyword_search( self, query: str, scope: str, limit: int, db: AsyncSession, ) -> list[dict[str, Any]]: """ILIKE keyword search across technique pages, key moments, and creators. Searches title/name columns. Returns a unified list of result dicts. """ results: list[dict[str, Any]] = [] pattern = f"%{query}%" if scope in ("all", "topics"): stmt = ( select(TechniquePage) .where( or_( TechniquePage.title.ilike(pattern), TechniquePage.summary.ilike(pattern), ) ) .limit(limit) ) rows = await db.execute(stmt) for tp in rows.scalars().all(): results.append({ "type": "technique_page", "title": tp.title, "slug": tp.slug, "summary": tp.summary or "", "topic_category": tp.topic_category, "topic_tags": tp.topic_tags or [], "creator_id": str(tp.creator_id), "score": 0.0, }) if scope in ("all",): km_stmt = ( select(KeyMoment, SourceVideo, Creator) .join(SourceVideo, KeyMoment.source_video_id == SourceVideo.id) .join(Creator, SourceVideo.creator_id == Creator.id) .where(KeyMoment.title.ilike(pattern)) .limit(limit) ) km_rows = await db.execute(km_stmt) for km, sv, cr in km_rows.all(): results.append({ "type": "key_moment", "title": km.title, "slug": "", "summary": km.summary or "", "topic_category": "", "topic_tags": [], "creator_id": str(cr.id), "creator_name": cr.name, "creator_slug": cr.slug, "score": 0.0, }) if scope in ("all", "creators"): cr_stmt = ( select(Creator) .where(Creator.name.ilike(pattern)) .limit(limit) ) cr_rows = await db.execute(cr_stmt) for cr in cr_rows.scalars().all(): results.append({ "type": "creator", "title": cr.name, "slug": cr.slug, "summary": "", "topic_category": "", "topic_tags": cr.genres or [], "creator_id": str(cr.id), "score": 0.0, }) # Enrich keyword results with creator names kw_creator_ids = {r["creator_id"] for r in results if r.get("creator_id")} kw_creator_map: dict[str, dict[str, str]] = {} if kw_creator_ids: import uuid as _uuid_mod valid = [] for cid in kw_creator_ids: try: valid.append(_uuid_mod.UUID(cid)) except (ValueError, AttributeError): pass if valid: cr_stmt = select(Creator).where(Creator.id.in_(valid)) cr_result = await db.execute(cr_stmt) for c in cr_result.scalars().all(): kw_creator_map[str(c.id)] = {"name": c.name, "slug": c.slug} for r in results: info = kw_creator_map.get(r.get("creator_id", ""), {"name": "", "slug": ""}) r["creator_name"] = info["name"] r["creator_slug"] = info["slug"] return results[:limit] # ── Orchestrator ───────────────────────────────────────────────────── async def search( self, query: str, scope: str, limit: int, db: AsyncSession, ) -> dict[str, Any]: """Run semantic search with keyword fallback. Returns a dict matching the SearchResponse schema shape. """ start = time.monotonic() # Validate / sanitize inputs if not query or not query.strip(): return {"items": [], "total": 0, "query": query, "fallback_used": False} # Truncate long queries query = query.strip()[:500] # Normalize scope if scope not in ("all", "topics", "creators"): scope = "all" # Map scope to Qdrant type filter type_filter_map = { "all": None, "topics": "technique_page", "creators": None, # creators aren't in Qdrant } qdrant_type_filter = type_filter_map.get(scope) fallback_used = False items: list[dict[str, Any]] = [] # Try semantic search vector = await self.embed_query(query) if vector is not None: qdrant_results = await self.search_qdrant(vector, limit=limit, type_filter=qdrant_type_filter) if qdrant_results: # Enrich Qdrant results with DB metadata items = await self._enrich_results(qdrant_results, db) # Fallback to keyword search if semantic failed or returned nothing if not items: items = await self.keyword_search(query, scope, limit, db) fallback_used = True elapsed_ms = (time.monotonic() - start) * 1000 logger.info( "Search query=%r scope=%s results=%d fallback=%s latency_ms=%.1f", query, scope, len(items), fallback_used, elapsed_ms, ) return { "items": items, "total": len(items), "query": query, "fallback_used": fallback_used, } # ── Result enrichment ──────────────────────────────────────────────── async def _enrich_results( self, qdrant_results: list[dict[str, Any]], db: AsyncSession, ) -> list[dict[str, Any]]: """Enrich Qdrant results with creator names and slugs from DB.""" enriched: list[dict[str, Any]] = [] # Collect creator_ids to batch-fetch creator_ids = set() for r in qdrant_results: payload = r.get("payload", {}) cid = payload.get("creator_id") if cid: creator_ids.add(cid) # Batch fetch creators creator_map: dict[str, dict[str, str]] = {} if creator_ids: from sqlalchemy.dialects.postgresql import UUID as PgUUID import uuid as uuid_mod valid_ids = [] for cid in creator_ids: try: valid_ids.append(uuid_mod.UUID(cid)) except (ValueError, AttributeError): pass if valid_ids: stmt = select(Creator).where(Creator.id.in_(valid_ids)) result = await db.execute(stmt) for c in result.scalars().all(): creator_map[str(c.id)] = {"name": c.name, "slug": c.slug} for r in qdrant_results: payload = r.get("payload", {}) cid = payload.get("creator_id", "") creator_info = creator_map.get(cid, {"name": "", "slug": ""}) enriched.append({ "type": payload.get("type", ""), "title": payload.get("title", ""), "slug": payload.get("slug", payload.get("title", "").lower().replace(" ", "-")), "summary": payload.get("summary", ""), "topic_category": payload.get("topic_category", ""), "topic_tags": payload.get("topic_tags", []), "creator_id": cid, "creator_name": creator_info["name"], "creator_slug": creator_info["slug"], "score": r.get("score", 0.0), }) return enriched