feat: Created async search service with embedding+Qdrant+keyword fallba…
- "backend/search_service.py" - "backend/schemas.py" - "backend/routers/search.py" - "backend/routers/techniques.py" - "backend/routers/topics.py" - "backend/routers/creators.py" - "backend/main.py" GSD-Task: S05/T01
This commit is contained in:
parent
34733b199d
commit
740fb59d9d
7 changed files with 810 additions and 11 deletions
|
|
@ -12,7 +12,7 @@ from fastapi import FastAPI
|
|||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from config import get_settings
|
||||
from routers import creators, health, ingest, pipeline, review, videos
|
||||
from routers import creators, health, ingest, pipeline, review, search, techniques, topics, videos
|
||||
|
||||
|
||||
def _setup_logging() -> None:
|
||||
|
|
@ -82,6 +82,9 @@ app.include_router(creators.router, prefix="/api/v1")
|
|||
app.include_router(ingest.router, prefix="/api/v1")
|
||||
app.include_router(pipeline.router, prefix="/api/v1")
|
||||
app.include_router(review.router, prefix="/api/v1")
|
||||
app.include_router(search.router, prefix="/api/v1")
|
||||
app.include_router(techniques.router, prefix="/api/v1")
|
||||
app.include_router(topics.router, prefix="/api/v1")
|
||||
app.include_router(videos.router, prefix="/api/v1")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
"""Creator endpoints for Chrysopedia API."""
|
||||
"""Creator endpoints for Chrysopedia API.
|
||||
|
||||
Enhanced with sort (random default per R014), genre filter, and
|
||||
technique/video counts for browse pages.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
|
@ -8,26 +12,79 @@ from sqlalchemy import func, select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database import get_session
|
||||
from models import Creator, SourceVideo
|
||||
from schemas import CreatorDetail, CreatorRead
|
||||
from models import Creator, SourceVideo, TechniquePage
|
||||
from schemas import CreatorBrowseItem, CreatorDetail, CreatorRead
|
||||
|
||||
logger = logging.getLogger("chrysopedia.creators")
|
||||
|
||||
router = APIRouter(prefix="/creators", tags=["creators"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[CreatorRead])
|
||||
@router.get("", response_model=list[CreatorBrowseItem])
|
||||
async def list_creators(
|
||||
sort: Annotated[str, Query()] = "random",
|
||||
genre: Annotated[str | None, Query()] = None,
|
||||
offset: Annotated[int, Query(ge=0)] = 0,
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> list[CreatorRead]:
|
||||
"""List all creators with pagination."""
|
||||
stmt = select(Creator).order_by(Creator.name).offset(offset).limit(limit)
|
||||
) -> list[CreatorBrowseItem]:
|
||||
"""List creators with sort, genre filter, and technique/video counts.
|
||||
|
||||
- **sort**: ``random`` (default, R014 creator equity), ``alpha``, ``views``
|
||||
- **genre**: filter by genre (matches against ARRAY column)
|
||||
"""
|
||||
# Subqueries for counts
|
||||
technique_count_sq = (
|
||||
select(func.count())
|
||||
.where(TechniquePage.creator_id == Creator.id)
|
||||
.correlate(Creator)
|
||||
.scalar_subquery()
|
||||
)
|
||||
video_count_sq = (
|
||||
select(func.count())
|
||||
.where(SourceVideo.creator_id == Creator.id)
|
||||
.correlate(Creator)
|
||||
.scalar_subquery()
|
||||
)
|
||||
|
||||
stmt = select(
|
||||
Creator,
|
||||
technique_count_sq.label("technique_count"),
|
||||
video_count_sq.label("video_count"),
|
||||
)
|
||||
|
||||
# Genre filter
|
||||
if genre:
|
||||
stmt = stmt.where(Creator.genres.any(genre))
|
||||
|
||||
# Sorting
|
||||
if sort == "alpha":
|
||||
stmt = stmt.order_by(Creator.name)
|
||||
elif sort == "views":
|
||||
stmt = stmt.order_by(Creator.view_count.desc())
|
||||
else:
|
||||
# Default: random (small dataset <100, func.random() is fine)
|
||||
stmt = stmt.order_by(func.random())
|
||||
|
||||
stmt = stmt.offset(offset).limit(limit)
|
||||
result = await db.execute(stmt)
|
||||
creators = result.scalars().all()
|
||||
logger.debug("Listed %d creators (offset=%d, limit=%d)", len(creators), offset, limit)
|
||||
return [CreatorRead.model_validate(c) for c in creators]
|
||||
rows = result.all()
|
||||
|
||||
items: list[CreatorBrowseItem] = []
|
||||
for row in rows:
|
||||
creator = row[0]
|
||||
tc = row[1] or 0
|
||||
vc = row[2] or 0
|
||||
base = CreatorRead.model_validate(creator)
|
||||
items.append(
|
||||
CreatorBrowseItem(**base.model_dump(), technique_count=tc, video_count=vc)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Listed %d creators (sort=%s, genre=%s, offset=%d, limit=%d)",
|
||||
len(items), sort, genre, offset, limit,
|
||||
)
|
||||
return items
|
||||
|
||||
|
||||
@router.get("/{slug}", response_model=CreatorDetail)
|
||||
|
|
|
|||
46
backend/routers/search.py
Normal file
46
backend/routers/search.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
"""Search endpoint for semantic + keyword search with graceful fallback."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from config import get_settings
|
||||
from database import get_session
|
||||
from schemas import SearchResponse, SearchResultItem
|
||||
from search_service import SearchService
|
||||
|
||||
logger = logging.getLogger("chrysopedia.search.router")
|
||||
|
||||
router = APIRouter(prefix="/search", tags=["search"])
|
||||
|
||||
|
||||
def _get_search_service() -> SearchService:
|
||||
"""Build a SearchService from current settings."""
|
||||
return SearchService(get_settings())
|
||||
|
||||
|
||||
@router.get("", response_model=SearchResponse)
|
||||
async def search(
|
||||
q: Annotated[str, Query(max_length=500)] = "",
|
||||
scope: Annotated[str, Query()] = "all",
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 20,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> SearchResponse:
|
||||
"""Semantic search with keyword fallback.
|
||||
|
||||
- **q**: Search query (max 500 chars). Empty → empty results.
|
||||
- **scope**: ``all`` | ``topics`` | ``creators``. Invalid → defaults to ``all``.
|
||||
- **limit**: Max results (1–100, default 20).
|
||||
"""
|
||||
svc = _get_search_service()
|
||||
result = await svc.search(query=q, scope=scope, limit=limit, db=db)
|
||||
return SearchResponse(
|
||||
items=[SearchResultItem(**item) for item in result["items"]],
|
||||
total=result["total"],
|
||||
query=result["query"],
|
||||
fallback_used=result["fallback_used"],
|
||||
)
|
||||
134
backend/routers/techniques.py
Normal file
134
backend/routers/techniques.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
"""Technique page endpoints — list and detail with eager-loaded relations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from database import get_session
|
||||
from models import Creator, KeyMoment, RelatedTechniqueLink, TechniquePage
|
||||
from schemas import (
|
||||
CreatorInfo,
|
||||
KeyMomentSummary,
|
||||
PaginatedResponse,
|
||||
RelatedLinkItem,
|
||||
TechniquePageDetail,
|
||||
TechniquePageRead,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("chrysopedia.techniques")
|
||||
|
||||
router = APIRouter(prefix="/techniques", tags=["techniques"])
|
||||
|
||||
|
||||
@router.get("", response_model=PaginatedResponse)
|
||||
async def list_techniques(
|
||||
category: Annotated[str | None, Query()] = None,
|
||||
creator_slug: Annotated[str | None, Query()] = None,
|
||||
offset: Annotated[int, Query(ge=0)] = 0,
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> PaginatedResponse:
|
||||
"""List technique pages with optional category/creator filtering."""
|
||||
stmt = select(TechniquePage)
|
||||
|
||||
if category:
|
||||
stmt = stmt.where(TechniquePage.topic_category == category)
|
||||
|
||||
if creator_slug:
|
||||
# Join to Creator to filter by slug
|
||||
stmt = stmt.join(Creator, TechniquePage.creator_id == Creator.id).where(
|
||||
Creator.slug == creator_slug
|
||||
)
|
||||
|
||||
# Count total before pagination
|
||||
from sqlalchemy import func
|
||||
|
||||
count_stmt = select(func.count()).select_from(stmt.subquery())
|
||||
count_result = await db.execute(count_stmt)
|
||||
total = count_result.scalar() or 0
|
||||
|
||||
stmt = stmt.order_by(TechniquePage.created_at.desc()).offset(offset).limit(limit)
|
||||
result = await db.execute(stmt)
|
||||
pages = result.scalars().all()
|
||||
|
||||
return PaginatedResponse(
|
||||
items=[TechniquePageRead.model_validate(p) for p in pages],
|
||||
total=total,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{slug}", response_model=TechniquePageDetail)
|
||||
async def get_technique(
|
||||
slug: str,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> TechniquePageDetail:
|
||||
"""Get full technique page detail with key moments, creator, and related links."""
|
||||
stmt = (
|
||||
select(TechniquePage)
|
||||
.where(TechniquePage.slug == slug)
|
||||
.options(
|
||||
selectinload(TechniquePage.key_moments),
|
||||
selectinload(TechniquePage.creator),
|
||||
selectinload(TechniquePage.outgoing_links).selectinload(
|
||||
RelatedTechniqueLink.target_page
|
||||
),
|
||||
selectinload(TechniquePage.incoming_links).selectinload(
|
||||
RelatedTechniqueLink.source_page
|
||||
),
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
page = result.scalar_one_or_none()
|
||||
|
||||
if page is None:
|
||||
raise HTTPException(status_code=404, detail=f"Technique '{slug}' not found")
|
||||
|
||||
# Build key moments (ordered by start_time)
|
||||
key_moments = sorted(page.key_moments, key=lambda km: km.start_time)
|
||||
key_moment_items = [KeyMomentSummary.model_validate(km) for km in key_moments]
|
||||
|
||||
# Build creator info
|
||||
creator_info = None
|
||||
if page.creator:
|
||||
creator_info = CreatorInfo(
|
||||
name=page.creator.name,
|
||||
slug=page.creator.slug,
|
||||
genres=page.creator.genres,
|
||||
)
|
||||
|
||||
# Build related links (outgoing + incoming)
|
||||
related_links: list[RelatedLinkItem] = []
|
||||
for link in page.outgoing_links:
|
||||
if link.target_page:
|
||||
related_links.append(
|
||||
RelatedLinkItem(
|
||||
target_title=link.target_page.title,
|
||||
target_slug=link.target_page.slug,
|
||||
relationship=link.relationship.value if hasattr(link.relationship, 'value') else str(link.relationship),
|
||||
)
|
||||
)
|
||||
for link in page.incoming_links:
|
||||
if link.source_page:
|
||||
related_links.append(
|
||||
RelatedLinkItem(
|
||||
target_title=link.source_page.title,
|
||||
target_slug=link.source_page.slug,
|
||||
relationship=link.relationship.value if hasattr(link.relationship, 'value') else str(link.relationship),
|
||||
)
|
||||
)
|
||||
|
||||
base = TechniquePageRead.model_validate(page)
|
||||
return TechniquePageDetail(
|
||||
**base.model_dump(),
|
||||
key_moments=key_moment_items,
|
||||
creator_info=creator_info,
|
||||
related_links=related_links,
|
||||
)
|
||||
135
backend/routers/topics.py
Normal file
135
backend/routers/topics.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
"""Topics endpoint — two-level category hierarchy with aggregated counts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Annotated, Any
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database import get_session
|
||||
from models import Creator, TechniquePage
|
||||
from schemas import (
|
||||
PaginatedResponse,
|
||||
TechniquePageRead,
|
||||
TopicCategory,
|
||||
TopicSubTopic,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("chrysopedia.topics")
|
||||
|
||||
router = APIRouter(prefix="/topics", tags=["topics"])
|
||||
|
||||
# Path to canonical_tags.yaml relative to the backend directory
|
||||
_TAGS_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "config", "canonical_tags.yaml")
|
||||
|
||||
|
||||
def _load_canonical_tags() -> list[dict[str, Any]]:
|
||||
"""Load the canonical tag categories from YAML."""
|
||||
path = os.path.normpath(_TAGS_PATH)
|
||||
try:
|
||||
with open(path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
return data.get("categories", [])
|
||||
except FileNotFoundError:
|
||||
logger.warning("canonical_tags.yaml not found at %s", path)
|
||||
return []
|
||||
|
||||
|
||||
@router.get("", response_model=list[TopicCategory])
|
||||
async def list_topics(
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> list[TopicCategory]:
|
||||
"""Return the two-level topic hierarchy with technique/creator counts per sub-topic.
|
||||
|
||||
Categories come from ``canonical_tags.yaml``. Counts are computed
|
||||
from live DB data by matching ``topic_tags`` array contents.
|
||||
"""
|
||||
categories = _load_canonical_tags()
|
||||
|
||||
# Pre-fetch all technique pages with their tags and creator_ids for counting
|
||||
tp_stmt = select(
|
||||
TechniquePage.topic_category,
|
||||
TechniquePage.topic_tags,
|
||||
TechniquePage.creator_id,
|
||||
)
|
||||
tp_result = await db.execute(tp_stmt)
|
||||
tp_rows = tp_result.all()
|
||||
|
||||
# Build per-sub-topic counts
|
||||
result: list[TopicCategory] = []
|
||||
for cat in categories:
|
||||
cat_name = cat.get("name", "")
|
||||
cat_desc = cat.get("description", "")
|
||||
sub_topic_names: list[str] = cat.get("sub_topics", [])
|
||||
|
||||
sub_topics: list[TopicSubTopic] = []
|
||||
for st_name in sub_topic_names:
|
||||
technique_count = 0
|
||||
creator_ids: set[str] = set()
|
||||
|
||||
for tp_cat, tp_tags, tp_creator_id in tp_rows:
|
||||
tags = tp_tags or []
|
||||
# Match if the sub-topic name appears in the technique's tags
|
||||
# or if the category matches and tag is in sub-topics
|
||||
if st_name.lower() in [t.lower() for t in tags]:
|
||||
technique_count += 1
|
||||
creator_ids.add(str(tp_creator_id))
|
||||
|
||||
sub_topics.append(
|
||||
TopicSubTopic(
|
||||
name=st_name,
|
||||
technique_count=technique_count,
|
||||
creator_count=len(creator_ids),
|
||||
)
|
||||
)
|
||||
|
||||
result.append(
|
||||
TopicCategory(
|
||||
name=cat_name,
|
||||
description=cat_desc,
|
||||
sub_topics=sub_topics,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/{category_slug}", response_model=PaginatedResponse)
|
||||
async def get_topic_techniques(
|
||||
category_slug: str,
|
||||
offset: Annotated[int, Query(ge=0)] = 0,
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> PaginatedResponse:
|
||||
"""Return technique pages filtered by topic_category.
|
||||
|
||||
The ``category_slug`` is matched case-insensitively against
|
||||
``technique_pages.topic_category`` (e.g. 'sound-design' matches 'Sound design').
|
||||
"""
|
||||
# Normalize slug to category name: replace hyphens with spaces, title-case
|
||||
category_name = category_slug.replace("-", " ").title()
|
||||
|
||||
# Also try exact match on the slug form
|
||||
stmt = select(TechniquePage).where(
|
||||
TechniquePage.topic_category.ilike(category_name)
|
||||
)
|
||||
|
||||
count_stmt = select(func.count()).select_from(stmt.subquery())
|
||||
count_result = await db.execute(count_stmt)
|
||||
total = count_result.scalar() or 0
|
||||
|
||||
stmt = stmt.order_by(TechniquePage.title).offset(offset).limit(limit)
|
||||
result = await db.execute(stmt)
|
||||
pages = result.scalars().all()
|
||||
|
||||
return PaginatedResponse(
|
||||
items=[TechniquePageRead.model_validate(p) for p in pages],
|
||||
total=total,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
|
|
@ -248,3 +248,90 @@ class ReviewModeResponse(BaseModel):
|
|||
class ReviewModeUpdate(BaseModel):
|
||||
"""Request to update the review mode."""
|
||||
review_mode: bool
|
||||
|
||||
|
||||
# ── Search ───────────────────────────────────────────────────────────────────
|
||||
|
||||
class SearchResultItem(BaseModel):
|
||||
"""A single search result."""
|
||||
title: str
|
||||
slug: str = ""
|
||||
type: str = ""
|
||||
score: float = 0.0
|
||||
summary: str = ""
|
||||
creator_name: str = ""
|
||||
creator_slug: str = ""
|
||||
topic_category: str = ""
|
||||
topic_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
"""Top-level search response with metadata."""
|
||||
items: list[SearchResultItem] = Field(default_factory=list)
|
||||
total: int = 0
|
||||
query: str = ""
|
||||
fallback_used: bool = False
|
||||
|
||||
|
||||
# ── Technique Page Detail ────────────────────────────────────────────────────
|
||||
|
||||
class KeyMomentSummary(BaseModel):
|
||||
"""Lightweight key moment for technique page detail."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
title: str
|
||||
summary: str
|
||||
start_time: float
|
||||
end_time: float
|
||||
content_type: str
|
||||
plugins: list[str] | None = None
|
||||
|
||||
|
||||
class RelatedLinkItem(BaseModel):
|
||||
"""A related technique link with target info."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
target_title: str = ""
|
||||
target_slug: str = ""
|
||||
relationship: str = ""
|
||||
|
||||
|
||||
class CreatorInfo(BaseModel):
|
||||
"""Minimal creator info embedded in technique detail."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
name: str
|
||||
slug: str
|
||||
genres: list[str] | None = None
|
||||
|
||||
|
||||
class TechniquePageDetail(TechniquePageRead):
|
||||
"""Technique page with nested key moments, creator, and related links."""
|
||||
key_moments: list[KeyMomentSummary] = Field(default_factory=list)
|
||||
creator_info: CreatorInfo | None = None
|
||||
related_links: list[RelatedLinkItem] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ── Topics ───────────────────────────────────────────────────────────────────
|
||||
|
||||
class TopicSubTopic(BaseModel):
|
||||
"""A sub-topic with aggregated counts."""
|
||||
name: str
|
||||
technique_count: int = 0
|
||||
creator_count: int = 0
|
||||
|
||||
|
||||
class TopicCategory(BaseModel):
|
||||
"""A top-level topic category with sub-topics."""
|
||||
name: str
|
||||
description: str = ""
|
||||
sub_topics: list[TopicSubTopic] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ── Creator Browse ───────────────────────────────────────────────────────────
|
||||
|
||||
class CreatorBrowseItem(CreatorRead):
|
||||
"""Creator with technique and video counts for browse pages."""
|
||||
technique_count: int = 0
|
||||
video_count: int = 0
|
||||
|
|
|
|||
337
backend/search_service.py
Normal file
337
backend/search_service.py
Normal file
|
|
@ -0,0 +1,337 @@
|
|||
"""Async search service for the public search endpoint.
|
||||
|
||||
Orchestrates semantic search (embedding + Qdrant) with keyword fallback.
|
||||
All external calls have timeouts and graceful degradation — if embedding
|
||||
or Qdrant fail, the service falls back to keyword-only (ILIKE) search.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
from qdrant_client.http import exceptions as qdrant_exceptions
|
||||
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from config import Settings
|
||||
from models import Creator, KeyMoment, TechniquePage
|
||||
|
||||
logger = logging.getLogger("chrysopedia.search")
|
||||
|
||||
# Timeout for external calls (embedding API, Qdrant) in seconds
|
||||
_EXTERNAL_TIMEOUT = 0.3 # 300ms per plan
|
||||
|
||||
|
||||
class SearchService:
|
||||
"""Async search service with semantic + keyword fallback.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
settings:
|
||||
Application settings containing embedding and Qdrant config.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
self.settings = settings
|
||||
self._openai = openai.AsyncOpenAI(
|
||||
base_url=settings.embedding_api_url,
|
||||
api_key=settings.llm_api_key,
|
||||
)
|
||||
self._qdrant = AsyncQdrantClient(url=settings.qdrant_url)
|
||||
self._collection = settings.qdrant_collection
|
||||
|
||||
# ── Embedding ────────────────────────────────────────────────────────
|
||||
|
||||
async def embed_query(self, text: str) -> list[float] | None:
|
||||
"""Embed a query string into a vector.
|
||||
|
||||
Returns None on any failure (timeout, connection, malformed response)
|
||||
so the caller can fall back to keyword search.
|
||||
"""
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
self._openai.embeddings.create(
|
||||
model=self.settings.embedding_model,
|
||||
input=text,
|
||||
),
|
||||
timeout=_EXTERNAL_TIMEOUT,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Embedding API timeout (%.0fms limit) for query: %.50s…", _EXTERNAL_TIMEOUT * 1000, text)
|
||||
return None
|
||||
except (openai.APIConnectionError, openai.APITimeoutError) as exc:
|
||||
logger.warning("Embedding API connection error (%s: %s)", type(exc).__name__, exc)
|
||||
return None
|
||||
except openai.APIError as exc:
|
||||
logger.warning("Embedding API error (%s: %s)", type(exc).__name__, exc)
|
||||
return None
|
||||
|
||||
if not response.data:
|
||||
logger.warning("Embedding API returned empty data for query: %.50s…", text)
|
||||
return None
|
||||
|
||||
vector = response.data[0].embedding
|
||||
if len(vector) != self.settings.embedding_dimensions:
|
||||
logger.warning(
|
||||
"Embedding dimension mismatch: expected %d, got %d",
|
||||
self.settings.embedding_dimensions,
|
||||
len(vector),
|
||||
)
|
||||
return None
|
||||
|
||||
return vector
|
||||
|
||||
# ── Qdrant vector search ─────────────────────────────────────────────
|
||||
|
||||
async def search_qdrant(
|
||||
self,
|
||||
vector: list[float],
|
||||
limit: int = 20,
|
||||
type_filter: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search Qdrant for nearest neighbours.
|
||||
|
||||
Returns a list of dicts with 'score' and 'payload' keys.
|
||||
Returns empty list on failure.
|
||||
"""
|
||||
query_filter = None
|
||||
if type_filter:
|
||||
query_filter = Filter(
|
||||
must=[FieldCondition(key="type", match=MatchValue(value=type_filter))]
|
||||
)
|
||||
|
||||
try:
|
||||
results = await asyncio.wait_for(
|
||||
self._qdrant.query_points(
|
||||
collection_name=self._collection,
|
||||
query=vector,
|
||||
query_filter=query_filter,
|
||||
limit=limit,
|
||||
with_payload=True,
|
||||
),
|
||||
timeout=_EXTERNAL_TIMEOUT,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Qdrant search timeout (%.0fms limit)", _EXTERNAL_TIMEOUT * 1000)
|
||||
return []
|
||||
except qdrant_exceptions.UnexpectedResponse as exc:
|
||||
logger.warning("Qdrant search error: %s", exc)
|
||||
return []
|
||||
except Exception as exc:
|
||||
logger.warning("Qdrant connection error (%s: %s)", type(exc).__name__, exc)
|
||||
return []
|
||||
|
||||
return [
|
||||
{"score": point.score, "payload": point.payload}
|
||||
for point in results.points
|
||||
]
|
||||
|
||||
# ── Keyword fallback ─────────────────────────────────────────────────
|
||||
|
||||
async def keyword_search(
|
||||
self,
|
||||
query: str,
|
||||
scope: str,
|
||||
limit: int,
|
||||
db: AsyncSession,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""ILIKE keyword search across technique pages, key moments, and creators.
|
||||
|
||||
Searches title/name columns. Returns a unified list of result dicts.
|
||||
"""
|
||||
results: list[dict[str, Any]] = []
|
||||
pattern = f"%{query}%"
|
||||
|
||||
if scope in ("all", "topics"):
|
||||
stmt = (
|
||||
select(TechniquePage)
|
||||
.where(
|
||||
or_(
|
||||
TechniquePage.title.ilike(pattern),
|
||||
TechniquePage.summary.ilike(pattern),
|
||||
)
|
||||
)
|
||||
.limit(limit)
|
||||
)
|
||||
rows = await db.execute(stmt)
|
||||
for tp in rows.scalars().all():
|
||||
results.append({
|
||||
"type": "technique_page",
|
||||
"title": tp.title,
|
||||
"slug": tp.slug,
|
||||
"summary": tp.summary or "",
|
||||
"topic_category": tp.topic_category,
|
||||
"topic_tags": tp.topic_tags or [],
|
||||
"creator_id": str(tp.creator_id),
|
||||
"score": 0.0,
|
||||
})
|
||||
|
||||
if scope in ("all",):
|
||||
km_stmt = (
|
||||
select(KeyMoment)
|
||||
.where(KeyMoment.title.ilike(pattern))
|
||||
.limit(limit)
|
||||
)
|
||||
km_rows = await db.execute(km_stmt)
|
||||
for km in km_rows.scalars().all():
|
||||
results.append({
|
||||
"type": "key_moment",
|
||||
"title": km.title,
|
||||
"slug": "",
|
||||
"summary": km.summary or "",
|
||||
"topic_category": "",
|
||||
"topic_tags": [],
|
||||
"creator_id": "",
|
||||
"score": 0.0,
|
||||
})
|
||||
|
||||
if scope in ("all", "creators"):
|
||||
cr_stmt = (
|
||||
select(Creator)
|
||||
.where(Creator.name.ilike(pattern))
|
||||
.limit(limit)
|
||||
)
|
||||
cr_rows = await db.execute(cr_stmt)
|
||||
for cr in cr_rows.scalars().all():
|
||||
results.append({
|
||||
"type": "creator",
|
||||
"title": cr.name,
|
||||
"slug": cr.slug,
|
||||
"summary": "",
|
||||
"topic_category": "",
|
||||
"topic_tags": cr.genres or [],
|
||||
"creator_id": str(cr.id),
|
||||
"score": 0.0,
|
||||
})
|
||||
|
||||
return results[:limit]
|
||||
|
||||
# ── Orchestrator ─────────────────────────────────────────────────────
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
scope: str,
|
||||
limit: int,
|
||||
db: AsyncSession,
|
||||
) -> dict[str, Any]:
|
||||
"""Run semantic search with keyword fallback.
|
||||
|
||||
Returns a dict matching the SearchResponse schema shape.
|
||||
"""
|
||||
start = time.monotonic()
|
||||
|
||||
# Validate / sanitize inputs
|
||||
if not query or not query.strip():
|
||||
return {"items": [], "total": 0, "query": query, "fallback_used": False}
|
||||
|
||||
# Truncate long queries
|
||||
query = query.strip()[:500]
|
||||
|
||||
# Normalize scope
|
||||
if scope not in ("all", "topics", "creators"):
|
||||
scope = "all"
|
||||
|
||||
# Map scope to Qdrant type filter
|
||||
type_filter_map = {
|
||||
"all": None,
|
||||
"topics": "technique_page",
|
||||
"creators": None, # creators aren't in Qdrant
|
||||
}
|
||||
qdrant_type_filter = type_filter_map.get(scope)
|
||||
|
||||
fallback_used = False
|
||||
items: list[dict[str, Any]] = []
|
||||
|
||||
# Try semantic search
|
||||
vector = await self.embed_query(query)
|
||||
if vector is not None:
|
||||
qdrant_results = await self.search_qdrant(vector, limit=limit, type_filter=qdrant_type_filter)
|
||||
if qdrant_results:
|
||||
# Enrich Qdrant results with DB metadata
|
||||
items = await self._enrich_results(qdrant_results, db)
|
||||
|
||||
# Fallback to keyword search if semantic failed or returned nothing
|
||||
if not items:
|
||||
items = await self.keyword_search(query, scope, limit, db)
|
||||
fallback_used = True
|
||||
|
||||
elapsed_ms = (time.monotonic() - start) * 1000
|
||||
|
||||
logger.info(
|
||||
"Search query=%r scope=%s results=%d fallback=%s latency_ms=%.1f",
|
||||
query,
|
||||
scope,
|
||||
len(items),
|
||||
fallback_used,
|
||||
elapsed_ms,
|
||||
)
|
||||
|
||||
return {
|
||||
"items": items,
|
||||
"total": len(items),
|
||||
"query": query,
|
||||
"fallback_used": fallback_used,
|
||||
}
|
||||
|
||||
# ── Result enrichment ────────────────────────────────────────────────
|
||||
|
||||
async def _enrich_results(
|
||||
self,
|
||||
qdrant_results: list[dict[str, Any]],
|
||||
db: AsyncSession,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Enrich Qdrant results with creator names and slugs from DB."""
|
||||
enriched: list[dict[str, Any]] = []
|
||||
|
||||
# Collect creator_ids to batch-fetch
|
||||
creator_ids = set()
|
||||
for r in qdrant_results:
|
||||
payload = r.get("payload", {})
|
||||
cid = payload.get("creator_id")
|
||||
if cid:
|
||||
creator_ids.add(cid)
|
||||
|
||||
# Batch fetch creators
|
||||
creator_map: dict[str, dict[str, str]] = {}
|
||||
if creator_ids:
|
||||
from sqlalchemy.dialects.postgresql import UUID as PgUUID
|
||||
import uuid as uuid_mod
|
||||
valid_ids = []
|
||||
for cid in creator_ids:
|
||||
try:
|
||||
valid_ids.append(uuid_mod.UUID(cid))
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
if valid_ids:
|
||||
stmt = select(Creator).where(Creator.id.in_(valid_ids))
|
||||
result = await db.execute(stmt)
|
||||
for c in result.scalars().all():
|
||||
creator_map[str(c.id)] = {"name": c.name, "slug": c.slug}
|
||||
|
||||
for r in qdrant_results:
|
||||
payload = r.get("payload", {})
|
||||
cid = payload.get("creator_id", "")
|
||||
creator_info = creator_map.get(cid, {"name": "", "slug": ""})
|
||||
|
||||
enriched.append({
|
||||
"type": payload.get("type", ""),
|
||||
"title": payload.get("title", ""),
|
||||
"slug": payload.get("slug", payload.get("title", "").lower().replace(" ", "-")),
|
||||
"summary": payload.get("summary", ""),
|
||||
"topic_category": payload.get("topic_category", ""),
|
||||
"topic_tags": payload.get("topic_tags", []),
|
||||
"creator_id": cid,
|
||||
"creator_name": creator_info["name"],
|
||||
"creator_slug": creator_info["slug"],
|
||||
"score": r.get("score", 0.0),
|
||||
})
|
||||
|
||||
return enriched
|
||||
Loading…
Add table
Reference in a new issue