test: Added GET /api/v1/search/suggestions endpoint returning popular t…
- "backend/schemas.py" - "backend/routers/search.py" - "backend/tests/test_search.py" GSD-Task: S04/T01
This commit is contained in:
parent
836fcb2304
commit
9107323a66
3 changed files with 197 additions and 1 deletions
|
|
@ -6,11 +6,18 @@ import logging
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
from sqlalchemy import func, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from config import get_settings
|
from config import get_settings
|
||||||
from database import get_session
|
from database import get_session
|
||||||
from schemas import SearchResponse, SearchResultItem
|
from models import Creator, TechniquePage
|
||||||
|
from schemas import (
|
||||||
|
SearchResponse,
|
||||||
|
SearchResultItem,
|
||||||
|
SuggestionItem,
|
||||||
|
SuggestionsResponse,
|
||||||
|
)
|
||||||
from search_service import SearchService
|
from search_service import SearchService
|
||||||
|
|
||||||
logger = logging.getLogger("chrysopedia.search.router")
|
logger = logging.getLogger("chrysopedia.search.router")
|
||||||
|
|
@ -44,3 +51,68 @@ async def search(
|
||||||
query=result["query"],
|
query=result["query"],
|
||||||
fallback_used=result["fallback_used"],
|
fallback_used=result["fallback_used"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/suggestions", response_model=SuggestionsResponse)
|
||||||
|
async def suggestions(
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> SuggestionsResponse:
|
||||||
|
"""Return popular search suggestions for autocomplete.
|
||||||
|
|
||||||
|
Combines top technique pages (by view_count), popular topic tags
|
||||||
|
(by technique count), and top creators (by view_count).
|
||||||
|
Returns 8–12 deduplicated items.
|
||||||
|
"""
|
||||||
|
seen: set[str] = set()
|
||||||
|
items: list[SuggestionItem] = []
|
||||||
|
|
||||||
|
def _add(text: str, type_: str) -> None:
|
||||||
|
key = text.lower()
|
||||||
|
if key not in seen:
|
||||||
|
seen.add(key)
|
||||||
|
items.append(SuggestionItem(text=text, type=type_))
|
||||||
|
|
||||||
|
# Top 4 technique pages by view_count
|
||||||
|
tp_stmt = (
|
||||||
|
select(TechniquePage.title)
|
||||||
|
.order_by(TechniquePage.view_count.desc(), TechniquePage.title)
|
||||||
|
.limit(4)
|
||||||
|
)
|
||||||
|
tp_result = await db.execute(tp_stmt)
|
||||||
|
for (title,) in tp_result.all():
|
||||||
|
_add(title, "technique")
|
||||||
|
|
||||||
|
# Top 4 topic tags by how many technique pages use them
|
||||||
|
# Unnest the topic_tags ARRAY and count occurrences
|
||||||
|
tag_unnest = (
|
||||||
|
select(
|
||||||
|
func.unnest(TechniquePage.topic_tags).label("tag"),
|
||||||
|
)
|
||||||
|
.where(TechniquePage.topic_tags.isnot(None))
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
tag_stmt = (
|
||||||
|
select(
|
||||||
|
tag_unnest.c.tag,
|
||||||
|
func.count().label("cnt"),
|
||||||
|
)
|
||||||
|
.group_by(tag_unnest.c.tag)
|
||||||
|
.order_by(func.count().desc(), tag_unnest.c.tag)
|
||||||
|
.limit(4)
|
||||||
|
)
|
||||||
|
tag_result = await db.execute(tag_stmt)
|
||||||
|
for tag, _cnt in tag_result.all():
|
||||||
|
_add(tag, "topic")
|
||||||
|
|
||||||
|
# Top 4 creators by view_count
|
||||||
|
cr_stmt = (
|
||||||
|
select(Creator.name)
|
||||||
|
.where(Creator.hidden.is_(False))
|
||||||
|
.order_by(Creator.view_count.desc(), Creator.name)
|
||||||
|
.limit(4)
|
||||||
|
)
|
||||||
|
cr_result = await db.execute(cr_stmt)
|
||||||
|
for (name,) in cr_result.all():
|
||||||
|
_add(name, "creator")
|
||||||
|
|
||||||
|
return SuggestionsResponse(suggestions=items)
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
@ -223,6 +224,17 @@ class SearchResponse(BaseModel):
|
||||||
fallback_used: bool = False
|
fallback_used: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class SuggestionItem(BaseModel):
|
||||||
|
"""A single autocomplete suggestion."""
|
||||||
|
text: str
|
||||||
|
type: Literal["topic", "technique", "creator"]
|
||||||
|
|
||||||
|
|
||||||
|
class SuggestionsResponse(BaseModel):
|
||||||
|
"""Popular search suggestions for autocomplete."""
|
||||||
|
suggestions: list[SuggestionItem] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
# ── Technique Page Detail ────────────────────────────────────────────────────
|
# ── Technique Page Detail ────────────────────────────────────────────────────
|
||||||
|
|
||||||
class KeyMomentSummary(BaseModel):
|
class KeyMomentSummary(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -403,3 +403,115 @@ async def test_keyword_search_key_moment_without_technique_page(db_engine):
|
||||||
km_results = [r for r in results if r["type"] == "key_moment"]
|
km_results = [r for r in results if r["type"] == "key_moment"]
|
||||||
assert len(km_results) == 1
|
assert len(km_results) == 1
|
||||||
assert km_results[0]["technique_page_slug"] == ""
|
assert km_results[0]["technique_page_slug"] == ""
|
||||||
|
|
||||||
|
|
||||||
|
# ── Suggestions endpoint tests ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
SUGGESTIONS_URL = "/api/v1/search/suggestions"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_suggestions_returns_correct_shape(client, db_engine):
|
||||||
|
"""Suggestions endpoint returns items with text and type fields."""
|
||||||
|
await _seed_search_data(db_engine)
|
||||||
|
|
||||||
|
resp = await client.get(SUGGESTIONS_URL)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
assert "suggestions" in data
|
||||||
|
assert isinstance(data["suggestions"], list)
|
||||||
|
assert len(data["suggestions"]) > 0
|
||||||
|
|
||||||
|
for item in data["suggestions"]:
|
||||||
|
assert "text" in item
|
||||||
|
assert "type" in item
|
||||||
|
assert item["type"] in ("topic", "technique", "creator")
|
||||||
|
assert len(item["text"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_suggestions_includes_all_types(client, db_engine):
|
||||||
|
"""Suggestions should include technique, topic, and creator types."""
|
||||||
|
await _seed_search_data(db_engine)
|
||||||
|
|
||||||
|
resp = await client.get(SUGGESTIONS_URL)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
types_present = {item["type"] for item in data["suggestions"]}
|
||||||
|
assert "technique" in types_present, "Expected technique suggestions"
|
||||||
|
assert "topic" in types_present, "Expected topic suggestions"
|
||||||
|
assert "creator" in types_present, "Expected creator suggestions"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_suggestions_no_duplicates(client, db_engine):
|
||||||
|
"""Suggestions should not contain duplicate texts (case-insensitive)."""
|
||||||
|
await _seed_search_data(db_engine)
|
||||||
|
|
||||||
|
resp = await client.get(SUGGESTIONS_URL)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
texts_lower = [item["text"].lower() for item in data["suggestions"]]
|
||||||
|
assert len(texts_lower) == len(set(texts_lower)), "Duplicate suggestions found"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_suggestions_empty_db(client, db_engine):
|
||||||
|
"""Suggestions endpoint returns empty list on empty database."""
|
||||||
|
resp = await client.get(SUGGESTIONS_URL)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
assert data["suggestions"] == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_suggestions_respects_view_count_ordering(client, db_engine):
|
||||||
|
"""Higher view_count technique pages should appear first among techniques."""
|
||||||
|
session_factory = async_sessionmaker(
|
||||||
|
db_engine, class_=AsyncSession, expire_on_commit=False
|
||||||
|
)
|
||||||
|
async with session_factory() as session:
|
||||||
|
creator = Creator(
|
||||||
|
name="Test Creator",
|
||||||
|
slug="test-creator",
|
||||||
|
genres=["Electronic"],
|
||||||
|
folder_name="TestCreator",
|
||||||
|
view_count=10,
|
||||||
|
)
|
||||||
|
session.add(creator)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
tp_low = TechniquePage(
|
||||||
|
creator_id=creator.id,
|
||||||
|
title="Low Views Page",
|
||||||
|
slug="low-views-page",
|
||||||
|
topic_category="Sound design",
|
||||||
|
topic_tags=["bass"],
|
||||||
|
view_count=5,
|
||||||
|
)
|
||||||
|
tp_high = TechniquePage(
|
||||||
|
creator_id=creator.id,
|
||||||
|
title="High Views Page",
|
||||||
|
slug="high-views-page",
|
||||||
|
topic_category="Synthesis",
|
||||||
|
topic_tags=["pads"],
|
||||||
|
view_count=100,
|
||||||
|
)
|
||||||
|
session.add_all([tp_low, tp_high])
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
resp = await client.get(SUGGESTIONS_URL)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
technique_items = [
|
||||||
|
item for item in data["suggestions"] if item["type"] == "technique"
|
||||||
|
]
|
||||||
|
assert len(technique_items) >= 2
|
||||||
|
# High Views Page should come before Low Views Page
|
||||||
|
titles = [item["text"] for item in technique_items]
|
||||||
|
assert titles.index("High Views Page") < titles.index("Low Views Page")
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue