test: Added GET /api/v1/search/suggestions endpoint returning popular t…

- "backend/schemas.py"
- "backend/routers/search.py"
- "backend/tests/test_search.py"

GSD-Task: S04/T01
This commit is contained in:
jlightner 2026-03-31 06:35:37 +00:00
parent 836fcb2304
commit 9107323a66
3 changed files with 197 additions and 1 deletions

View file

@ -6,11 +6,18 @@ import logging
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from config import get_settings from config import get_settings
from database import get_session from database import get_session
from schemas import SearchResponse, SearchResultItem from models import Creator, TechniquePage
from schemas import (
SearchResponse,
SearchResultItem,
SuggestionItem,
SuggestionsResponse,
)
from search_service import SearchService from search_service import SearchService
logger = logging.getLogger("chrysopedia.search.router") logger = logging.getLogger("chrysopedia.search.router")
@ -44,3 +51,68 @@ async def search(
query=result["query"], query=result["query"],
fallback_used=result["fallback_used"], fallback_used=result["fallback_used"],
) )
@router.get("/suggestions", response_model=SuggestionsResponse)
async def suggestions(
db: AsyncSession = Depends(get_session),
) -> SuggestionsResponse:
"""Return popular search suggestions for autocomplete.
Combines top technique pages (by view_count), popular topic tags
(by technique count), and top creators (by view_count).
Returns 812 deduplicated items.
"""
seen: set[str] = set()
items: list[SuggestionItem] = []
def _add(text: str, type_: str) -> None:
key = text.lower()
if key not in seen:
seen.add(key)
items.append(SuggestionItem(text=text, type=type_))
# Top 4 technique pages by view_count
tp_stmt = (
select(TechniquePage.title)
.order_by(TechniquePage.view_count.desc(), TechniquePage.title)
.limit(4)
)
tp_result = await db.execute(tp_stmt)
for (title,) in tp_result.all():
_add(title, "technique")
# Top 4 topic tags by how many technique pages use them
# Unnest the topic_tags ARRAY and count occurrences
tag_unnest = (
select(
func.unnest(TechniquePage.topic_tags).label("tag"),
)
.where(TechniquePage.topic_tags.isnot(None))
.subquery()
)
tag_stmt = (
select(
tag_unnest.c.tag,
func.count().label("cnt"),
)
.group_by(tag_unnest.c.tag)
.order_by(func.count().desc(), tag_unnest.c.tag)
.limit(4)
)
tag_result = await db.execute(tag_stmt)
for tag, _cnt in tag_result.all():
_add(tag, "topic")
# Top 4 creators by view_count
cr_stmt = (
select(Creator.name)
.where(Creator.hidden.is_(False))
.order_by(Creator.view_count.desc(), Creator.name)
.limit(4)
)
cr_result = await db.execute(cr_stmt)
for (name,) in cr_result.all():
_add(name, "creator")
return SuggestionsResponse(suggestions=items)

View file

@ -8,6 +8,7 @@ from __future__ import annotations
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
@ -223,6 +224,17 @@ class SearchResponse(BaseModel):
fallback_used: bool = False fallback_used: bool = False
class SuggestionItem(BaseModel):
"""A single autocomplete suggestion."""
text: str
type: Literal["topic", "technique", "creator"]
class SuggestionsResponse(BaseModel):
"""Popular search suggestions for autocomplete."""
suggestions: list[SuggestionItem] = Field(default_factory=list)
# ── Technique Page Detail ──────────────────────────────────────────────────── # ── Technique Page Detail ────────────────────────────────────────────────────
class KeyMomentSummary(BaseModel): class KeyMomentSummary(BaseModel):

View file

@ -403,3 +403,115 @@ async def test_keyword_search_key_moment_without_technique_page(db_engine):
km_results = [r for r in results if r["type"] == "key_moment"] km_results = [r for r in results if r["type"] == "key_moment"]
assert len(km_results) == 1 assert len(km_results) == 1
assert km_results[0]["technique_page_slug"] == "" assert km_results[0]["technique_page_slug"] == ""
# ── Suggestions endpoint tests ───────────────────────────────────────────────
SUGGESTIONS_URL = "/api/v1/search/suggestions"
@pytest.mark.asyncio
async def test_suggestions_returns_correct_shape(client, db_engine):
"""Suggestions endpoint returns items with text and type fields."""
await _seed_search_data(db_engine)
resp = await client.get(SUGGESTIONS_URL)
assert resp.status_code == 200
data = resp.json()
assert "suggestions" in data
assert isinstance(data["suggestions"], list)
assert len(data["suggestions"]) > 0
for item in data["suggestions"]:
assert "text" in item
assert "type" in item
assert item["type"] in ("topic", "technique", "creator")
assert len(item["text"]) > 0
@pytest.mark.asyncio
async def test_suggestions_includes_all_types(client, db_engine):
"""Suggestions should include technique, topic, and creator types."""
await _seed_search_data(db_engine)
resp = await client.get(SUGGESTIONS_URL)
assert resp.status_code == 200
data = resp.json()
types_present = {item["type"] for item in data["suggestions"]}
assert "technique" in types_present, "Expected technique suggestions"
assert "topic" in types_present, "Expected topic suggestions"
assert "creator" in types_present, "Expected creator suggestions"
@pytest.mark.asyncio
async def test_suggestions_no_duplicates(client, db_engine):
"""Suggestions should not contain duplicate texts (case-insensitive)."""
await _seed_search_data(db_engine)
resp = await client.get(SUGGESTIONS_URL)
assert resp.status_code == 200
data = resp.json()
texts_lower = [item["text"].lower() for item in data["suggestions"]]
assert len(texts_lower) == len(set(texts_lower)), "Duplicate suggestions found"
@pytest.mark.asyncio
async def test_suggestions_empty_db(client, db_engine):
"""Suggestions endpoint returns empty list on empty database."""
resp = await client.get(SUGGESTIONS_URL)
assert resp.status_code == 200
data = resp.json()
assert data["suggestions"] == []
@pytest.mark.asyncio
async def test_suggestions_respects_view_count_ordering(client, db_engine):
"""Higher view_count technique pages should appear first among techniques."""
session_factory = async_sessionmaker(
db_engine, class_=AsyncSession, expire_on_commit=False
)
async with session_factory() as session:
creator = Creator(
name="Test Creator",
slug="test-creator",
genres=["Electronic"],
folder_name="TestCreator",
view_count=10,
)
session.add(creator)
await session.flush()
tp_low = TechniquePage(
creator_id=creator.id,
title="Low Views Page",
slug="low-views-page",
topic_category="Sound design",
topic_tags=["bass"],
view_count=5,
)
tp_high = TechniquePage(
creator_id=creator.id,
title="High Views Page",
slug="high-views-page",
topic_category="Synthesis",
topic_tags=["pads"],
view_count=100,
)
session.add_all([tp_low, tp_high])
await session.commit()
resp = await client.get(SUGGESTIONS_URL)
assert resp.status_code == 200
data = resp.json()
technique_items = [
item for item in data["suggestions"] if item["type"] == "technique"
]
assert len(technique_items) >= 2
# High Views Page should come before Low Views Page
titles = [item["text"] for item in technique_items]
assert titles.index("High Views Page") < titles.index("Low Views Page")