"""Topics endpoint — two-level category hierarchy with aggregated counts.""" from __future__ import annotations import logging import os from typing import Annotated, Any import yaml from fastapi import APIRouter, Depends, Query from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from database import get_session from models import Creator, TechniquePage from schemas import ( PaginatedResponse, TechniquePageRead, TopicCategory, TopicSubTopic, ) logger = logging.getLogger("chrysopedia.topics") router = APIRouter(prefix="/topics", tags=["topics"]) # Path to canonical_tags.yaml relative to the backend directory _TAGS_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "config", "canonical_tags.yaml") def _load_canonical_tags() -> list[dict[str, Any]]: """Load the canonical tag categories from YAML.""" path = os.path.normpath(_TAGS_PATH) try: with open(path) as f: data = yaml.safe_load(f) return data.get("categories", []) except FileNotFoundError: logger.warning("canonical_tags.yaml not found at %s", path) return [] @router.get("", response_model=list[TopicCategory]) async def list_topics( db: AsyncSession = Depends(get_session), ) -> list[TopicCategory]: """Return the two-level topic hierarchy with technique/creator counts per sub-topic. Categories come from ``canonical_tags.yaml``. Counts are computed from live DB data by matching ``topic_tags`` array contents. """ categories = _load_canonical_tags() # Pre-fetch all technique pages with their tags and creator_ids for counting tp_stmt = select( TechniquePage.topic_category, TechniquePage.topic_tags, TechniquePage.creator_id, ) tp_result = await db.execute(tp_stmt) tp_rows = tp_result.all() # Build per-sub-topic counts result: list[TopicCategory] = [] for cat in categories: cat_name = cat.get("name", "") cat_desc = cat.get("description", "") sub_topic_names: list[str] = cat.get("sub_topics", []) sub_topics: list[TopicSubTopic] = [] for st_name in sub_topic_names: technique_count = 0 creator_ids: set[str] = set() for tp_cat, tp_tags, tp_creator_id in tp_rows: tags = tp_tags or [] # Match if the sub-topic name appears in the technique's tags # or if the category matches and tag is in sub-topics if st_name.lower() in [t.lower() for t in tags]: technique_count += 1 creator_ids.add(str(tp_creator_id)) sub_topics.append( TopicSubTopic( name=st_name, technique_count=technique_count, creator_count=len(creator_ids), ) ) result.append( TopicCategory( name=cat_name, description=cat_desc, sub_topics=sub_topics, ) ) return result @router.get("/{category_slug}", response_model=PaginatedResponse) async def get_topic_techniques( category_slug: str, offset: Annotated[int, Query(ge=0)] = 0, limit: Annotated[int, Query(ge=1, le=100)] = 50, db: AsyncSession = Depends(get_session), ) -> PaginatedResponse: """Return technique pages filtered by topic_category. The ``category_slug`` is matched case-insensitively against ``technique_pages.topic_category`` (e.g. 'sound-design' matches 'Sound design'). """ # Normalize slug to category name: replace hyphens with spaces, title-case category_name = category_slug.replace("-", " ").title() # Also try exact match on the slug form stmt = select(TechniquePage).where( TechniquePage.topic_category.ilike(category_name) ) count_stmt = select(func.count()).select_from(stmt.subquery()) count_result = await db.execute(count_stmt) total = count_result.scalar() or 0 stmt = stmt.options(selectinload(TechniquePage.creator)).order_by(TechniquePage.title).offset(offset).limit(limit) result = await db.execute(stmt) pages = result.scalars().all() items = [] for p in pages: item = TechniquePageRead.model_validate(p) if p.creator: item.creator_name = p.creator.name item.creator_slug = p.creator.slug items.append(item) return PaginatedResponse( items=items, total=total, offset=offset, limit=limit, )