test: Added GET /api/v1/search/suggestions endpoint returning popular t…

- "backend/schemas.py"
- "backend/routers/search.py"
- "backend/tests/test_search.py"

GSD-Task: S04/T01
This commit is contained in:
jlightner 2026-03-31 06:35:37 +00:00
parent 836fcb2304
commit 9107323a66
3 changed files with 197 additions and 1 deletions

View file

@ -6,11 +6,18 @@ import logging
from typing import Annotated
from fastapi import APIRouter, Depends, Query
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from config import get_settings
from database import get_session
from schemas import SearchResponse, SearchResultItem
from models import Creator, TechniquePage
from schemas import (
SearchResponse,
SearchResultItem,
SuggestionItem,
SuggestionsResponse,
)
from search_service import SearchService
logger = logging.getLogger("chrysopedia.search.router")
@ -44,3 +51,68 @@ async def search(
query=result["query"],
fallback_used=result["fallback_used"],
)
@router.get("/suggestions", response_model=SuggestionsResponse)
async def suggestions(
db: AsyncSession = Depends(get_session),
) -> SuggestionsResponse:
"""Return popular search suggestions for autocomplete.
Combines top technique pages (by view_count), popular topic tags
(by technique count), and top creators (by view_count).
Returns 812 deduplicated items.
"""
seen: set[str] = set()
items: list[SuggestionItem] = []
def _add(text: str, type_: str) -> None:
key = text.lower()
if key not in seen:
seen.add(key)
items.append(SuggestionItem(text=text, type=type_))
# Top 4 technique pages by view_count
tp_stmt = (
select(TechniquePage.title)
.order_by(TechniquePage.view_count.desc(), TechniquePage.title)
.limit(4)
)
tp_result = await db.execute(tp_stmt)
for (title,) in tp_result.all():
_add(title, "technique")
# Top 4 topic tags by how many technique pages use them
# Unnest the topic_tags ARRAY and count occurrences
tag_unnest = (
select(
func.unnest(TechniquePage.topic_tags).label("tag"),
)
.where(TechniquePage.topic_tags.isnot(None))
.subquery()
)
tag_stmt = (
select(
tag_unnest.c.tag,
func.count().label("cnt"),
)
.group_by(tag_unnest.c.tag)
.order_by(func.count().desc(), tag_unnest.c.tag)
.limit(4)
)
tag_result = await db.execute(tag_stmt)
for tag, _cnt in tag_result.all():
_add(tag, "topic")
# Top 4 creators by view_count
cr_stmt = (
select(Creator.name)
.where(Creator.hidden.is_(False))
.order_by(Creator.view_count.desc(), Creator.name)
.limit(4)
)
cr_result = await db.execute(cr_stmt)
for (name,) in cr_result.all():
_add(name, "creator")
return SuggestionsResponse(suggestions=items)

View file

@ -8,6 +8,7 @@ from __future__ import annotations
import uuid
from datetime import datetime
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field
@ -223,6 +224,17 @@ class SearchResponse(BaseModel):
fallback_used: bool = False
class SuggestionItem(BaseModel):
"""A single autocomplete suggestion."""
text: str
type: Literal["topic", "technique", "creator"]
class SuggestionsResponse(BaseModel):
"""Popular search suggestions for autocomplete."""
suggestions: list[SuggestionItem] = Field(default_factory=list)
# ── Technique Page Detail ────────────────────────────────────────────────────
class KeyMomentSummary(BaseModel):

View file

@ -403,3 +403,115 @@ async def test_keyword_search_key_moment_without_technique_page(db_engine):
km_results = [r for r in results if r["type"] == "key_moment"]
assert len(km_results) == 1
assert km_results[0]["technique_page_slug"] == ""
# ── Suggestions endpoint tests ───────────────────────────────────────────────
SUGGESTIONS_URL = "/api/v1/search/suggestions"
@pytest.mark.asyncio
async def test_suggestions_returns_correct_shape(client, db_engine):
"""Suggestions endpoint returns items with text and type fields."""
await _seed_search_data(db_engine)
resp = await client.get(SUGGESTIONS_URL)
assert resp.status_code == 200
data = resp.json()
assert "suggestions" in data
assert isinstance(data["suggestions"], list)
assert len(data["suggestions"]) > 0
for item in data["suggestions"]:
assert "text" in item
assert "type" in item
assert item["type"] in ("topic", "technique", "creator")
assert len(item["text"]) > 0
@pytest.mark.asyncio
async def test_suggestions_includes_all_types(client, db_engine):
"""Suggestions should include technique, topic, and creator types."""
await _seed_search_data(db_engine)
resp = await client.get(SUGGESTIONS_URL)
assert resp.status_code == 200
data = resp.json()
types_present = {item["type"] for item in data["suggestions"]}
assert "technique" in types_present, "Expected technique suggestions"
assert "topic" in types_present, "Expected topic suggestions"
assert "creator" in types_present, "Expected creator suggestions"
@pytest.mark.asyncio
async def test_suggestions_no_duplicates(client, db_engine):
"""Suggestions should not contain duplicate texts (case-insensitive)."""
await _seed_search_data(db_engine)
resp = await client.get(SUGGESTIONS_URL)
assert resp.status_code == 200
data = resp.json()
texts_lower = [item["text"].lower() for item in data["suggestions"]]
assert len(texts_lower) == len(set(texts_lower)), "Duplicate suggestions found"
@pytest.mark.asyncio
async def test_suggestions_empty_db(client, db_engine):
"""Suggestions endpoint returns empty list on empty database."""
resp = await client.get(SUGGESTIONS_URL)
assert resp.status_code == 200
data = resp.json()
assert data["suggestions"] == []
@pytest.mark.asyncio
async def test_suggestions_respects_view_count_ordering(client, db_engine):
"""Higher view_count technique pages should appear first among techniques."""
session_factory = async_sessionmaker(
db_engine, class_=AsyncSession, expire_on_commit=False
)
async with session_factory() as session:
creator = Creator(
name="Test Creator",
slug="test-creator",
genres=["Electronic"],
folder_name="TestCreator",
view_count=10,
)
session.add(creator)
await session.flush()
tp_low = TechniquePage(
creator_id=creator.id,
title="Low Views Page",
slug="low-views-page",
topic_category="Sound design",
topic_tags=["bass"],
view_count=5,
)
tp_high = TechniquePage(
creator_id=creator.id,
title="High Views Page",
slug="high-views-page",
topic_category="Synthesis",
topic_tags=["pads"],
view_count=100,
)
session.add_all([tp_low, tp_high])
await session.commit()
resp = await client.get(SUGGESTIONS_URL)
assert resp.status_code == 200
data = resp.json()
technique_items = [
item for item in data["suggestions"] if item["type"] == "technique"
]
assert len(technique_items) >= 2
# High Views Page should come before Low Views Page
titles = [item["text"] for item in technique_items]
assert titles.index("High Views Page") < titles.index("Low Views Page")