test: Added GET /api/v1/search/suggestions endpoint returning popular t…
- "backend/schemas.py" - "backend/routers/search.py" - "backend/tests/test_search.py" GSD-Task: S04/T01
This commit is contained in:
parent
836fcb2304
commit
9107323a66
3 changed files with 197 additions and 1 deletions
|
|
@ -6,11 +6,18 @@ import logging
|
|||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from config import get_settings
|
||||
from database import get_session
|
||||
from schemas import SearchResponse, SearchResultItem
|
||||
from models import Creator, TechniquePage
|
||||
from schemas import (
|
||||
SearchResponse,
|
||||
SearchResultItem,
|
||||
SuggestionItem,
|
||||
SuggestionsResponse,
|
||||
)
|
||||
from search_service import SearchService
|
||||
|
||||
logger = logging.getLogger("chrysopedia.search.router")
|
||||
|
|
@ -44,3 +51,68 @@ async def search(
|
|||
query=result["query"],
|
||||
fallback_used=result["fallback_used"],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/suggestions", response_model=SuggestionsResponse)
|
||||
async def suggestions(
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> SuggestionsResponse:
|
||||
"""Return popular search suggestions for autocomplete.
|
||||
|
||||
Combines top technique pages (by view_count), popular topic tags
|
||||
(by technique count), and top creators (by view_count).
|
||||
Returns 8–12 deduplicated items.
|
||||
"""
|
||||
seen: set[str] = set()
|
||||
items: list[SuggestionItem] = []
|
||||
|
||||
def _add(text: str, type_: str) -> None:
|
||||
key = text.lower()
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
items.append(SuggestionItem(text=text, type=type_))
|
||||
|
||||
# Top 4 technique pages by view_count
|
||||
tp_stmt = (
|
||||
select(TechniquePage.title)
|
||||
.order_by(TechniquePage.view_count.desc(), TechniquePage.title)
|
||||
.limit(4)
|
||||
)
|
||||
tp_result = await db.execute(tp_stmt)
|
||||
for (title,) in tp_result.all():
|
||||
_add(title, "technique")
|
||||
|
||||
# Top 4 topic tags by how many technique pages use them
|
||||
# Unnest the topic_tags ARRAY and count occurrences
|
||||
tag_unnest = (
|
||||
select(
|
||||
func.unnest(TechniquePage.topic_tags).label("tag"),
|
||||
)
|
||||
.where(TechniquePage.topic_tags.isnot(None))
|
||||
.subquery()
|
||||
)
|
||||
tag_stmt = (
|
||||
select(
|
||||
tag_unnest.c.tag,
|
||||
func.count().label("cnt"),
|
||||
)
|
||||
.group_by(tag_unnest.c.tag)
|
||||
.order_by(func.count().desc(), tag_unnest.c.tag)
|
||||
.limit(4)
|
||||
)
|
||||
tag_result = await db.execute(tag_stmt)
|
||||
for tag, _cnt in tag_result.all():
|
||||
_add(tag, "topic")
|
||||
|
||||
# Top 4 creators by view_count
|
||||
cr_stmt = (
|
||||
select(Creator.name)
|
||||
.where(Creator.hidden.is_(False))
|
||||
.order_by(Creator.view_count.desc(), Creator.name)
|
||||
.limit(4)
|
||||
)
|
||||
cr_result = await db.execute(cr_stmt)
|
||||
for (name,) in cr_result.all():
|
||||
_add(name, "creator")
|
||||
|
||||
return SuggestionsResponse(suggestions=items)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from __future__ import annotations
|
|||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
|
@ -223,6 +224,17 @@ class SearchResponse(BaseModel):
|
|||
fallback_used: bool = False
|
||||
|
||||
|
||||
class SuggestionItem(BaseModel):
|
||||
"""A single autocomplete suggestion."""
|
||||
text: str
|
||||
type: Literal["topic", "technique", "creator"]
|
||||
|
||||
|
||||
class SuggestionsResponse(BaseModel):
|
||||
"""Popular search suggestions for autocomplete."""
|
||||
suggestions: list[SuggestionItem] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ── Technique Page Detail ────────────────────────────────────────────────────
|
||||
|
||||
class KeyMomentSummary(BaseModel):
|
||||
|
|
|
|||
|
|
@ -403,3 +403,115 @@ async def test_keyword_search_key_moment_without_technique_page(db_engine):
|
|||
km_results = [r for r in results if r["type"] == "key_moment"]
|
||||
assert len(km_results) == 1
|
||||
assert km_results[0]["technique_page_slug"] == ""
|
||||
|
||||
|
||||
# ── Suggestions endpoint tests ───────────────────────────────────────────────
|
||||
|
||||
SUGGESTIONS_URL = "/api/v1/search/suggestions"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suggestions_returns_correct_shape(client, db_engine):
|
||||
"""Suggestions endpoint returns items with text and type fields."""
|
||||
await _seed_search_data(db_engine)
|
||||
|
||||
resp = await client.get(SUGGESTIONS_URL)
|
||||
assert resp.status_code == 200
|
||||
|
||||
data = resp.json()
|
||||
assert "suggestions" in data
|
||||
assert isinstance(data["suggestions"], list)
|
||||
assert len(data["suggestions"]) > 0
|
||||
|
||||
for item in data["suggestions"]:
|
||||
assert "text" in item
|
||||
assert "type" in item
|
||||
assert item["type"] in ("topic", "technique", "creator")
|
||||
assert len(item["text"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suggestions_includes_all_types(client, db_engine):
|
||||
"""Suggestions should include technique, topic, and creator types."""
|
||||
await _seed_search_data(db_engine)
|
||||
|
||||
resp = await client.get(SUGGESTIONS_URL)
|
||||
assert resp.status_code == 200
|
||||
|
||||
data = resp.json()
|
||||
types_present = {item["type"] for item in data["suggestions"]}
|
||||
assert "technique" in types_present, "Expected technique suggestions"
|
||||
assert "topic" in types_present, "Expected topic suggestions"
|
||||
assert "creator" in types_present, "Expected creator suggestions"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suggestions_no_duplicates(client, db_engine):
|
||||
"""Suggestions should not contain duplicate texts (case-insensitive)."""
|
||||
await _seed_search_data(db_engine)
|
||||
|
||||
resp = await client.get(SUGGESTIONS_URL)
|
||||
assert resp.status_code == 200
|
||||
|
||||
data = resp.json()
|
||||
texts_lower = [item["text"].lower() for item in data["suggestions"]]
|
||||
assert len(texts_lower) == len(set(texts_lower)), "Duplicate suggestions found"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suggestions_empty_db(client, db_engine):
|
||||
"""Suggestions endpoint returns empty list on empty database."""
|
||||
resp = await client.get(SUGGESTIONS_URL)
|
||||
assert resp.status_code == 200
|
||||
|
||||
data = resp.json()
|
||||
assert data["suggestions"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suggestions_respects_view_count_ordering(client, db_engine):
|
||||
"""Higher view_count technique pages should appear first among techniques."""
|
||||
session_factory = async_sessionmaker(
|
||||
db_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
async with session_factory() as session:
|
||||
creator = Creator(
|
||||
name="Test Creator",
|
||||
slug="test-creator",
|
||||
genres=["Electronic"],
|
||||
folder_name="TestCreator",
|
||||
view_count=10,
|
||||
)
|
||||
session.add(creator)
|
||||
await session.flush()
|
||||
|
||||
tp_low = TechniquePage(
|
||||
creator_id=creator.id,
|
||||
title="Low Views Page",
|
||||
slug="low-views-page",
|
||||
topic_category="Sound design",
|
||||
topic_tags=["bass"],
|
||||
view_count=5,
|
||||
)
|
||||
tp_high = TechniquePage(
|
||||
creator_id=creator.id,
|
||||
title="High Views Page",
|
||||
slug="high-views-page",
|
||||
topic_category="Synthesis",
|
||||
topic_tags=["pads"],
|
||||
view_count=100,
|
||||
)
|
||||
session.add_all([tp_low, tp_high])
|
||||
await session.commit()
|
||||
|
||||
resp = await client.get(SUGGESTIONS_URL)
|
||||
assert resp.status_code == 200
|
||||
|
||||
data = resp.json()
|
||||
technique_items = [
|
||||
item for item in data["suggestions"] if item["type"] == "technique"
|
||||
]
|
||||
assert len(technique_items) >= 2
|
||||
# High Views Page should come before Low Views Page
|
||||
titles = [item["text"] for item in technique_items]
|
||||
assert titles.index("High Views Page") < titles.index("Low Views Page")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue