"""Admin router — user management, impersonation, and usage analytics.""" from __future__ import annotations import logging from datetime import datetime, timedelta, timezone from typing import Annotated from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from pydantic import BaseModel from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import aliased from auth import ( create_impersonation_token, decode_access_token, get_current_user, require_role, ) from database import get_session from models import ChatUsageLog, ImpersonationLog, User, UserRole logger = logging.getLogger("chrysopedia.admin") router = APIRouter(prefix="/admin", tags=["admin"]) _require_admin = require_role(UserRole.admin) # ── Schemas ────────────────────────────────────────────────────────────────── class UserListItem(BaseModel): id: str email: str display_name: str role: str creator_id: str | None is_active: bool class Config: from_attributes = True class ImpersonateResponse(BaseModel): access_token: str token_type: str = "bearer" target_user: UserListItem class StopImpersonateResponse(BaseModel): message: str class StartImpersonationRequest(BaseModel): write_mode: bool = False class ImpersonationLogItem(BaseModel): id: str admin_name: str target_name: str action: str write_mode: bool ip_address: str | None created_at: datetime # ── Helpers ────────────────────────────────────────────────────────────────── def _client_ip(request: Request) -> str | None: """Best-effort client IP from X-Forwarded-For or direct connection.""" forwarded = request.headers.get("x-forwarded-for") if forwarded: return forwarded.split(",")[0].strip() if request.client: return request.client.host return None # ── Endpoints ──────────────────────────────────────────────────────────────── @router.get("/users", response_model=list[UserListItem]) async def list_users( _admin: Annotated[User, Depends(_require_admin)], session: Annotated[AsyncSession, Depends(get_session)], ): """List all users. Admin only.""" result = await session.execute( select(User).order_by(User.display_name) ) users = result.scalars().all() return [ UserListItem( id=str(u.id), email=u.email, display_name=u.display_name, role=u.role.value, creator_id=str(u.creator_id) if u.creator_id else None, is_active=u.is_active, ) for u in users ] @router.post("/impersonate/{user_id}", response_model=ImpersonateResponse) async def start_impersonation( user_id: UUID, request: Request, admin: Annotated[User, Depends(_require_admin)], session: Annotated[AsyncSession, Depends(get_session)], body: StartImpersonationRequest | None = None, ): """Start impersonating a user. Admin only. Returns a scoped JWT.""" if body is None: body = StartImpersonationRequest() # Cannot impersonate yourself if admin.id == user_id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot impersonate yourself", ) # Load target user result = await session.execute(select(User).where(User.id == user_id)) target = result.scalar_one_or_none() if target is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Target user not found", ) # Create impersonation token token = create_impersonation_token( admin_user_id=admin.id, target_user_id=target.id, target_role=target.role.value, write_mode=body.write_mode, ) # Audit log session.add(ImpersonationLog( admin_user_id=admin.id, target_user_id=target.id, action="start", write_mode=body.write_mode, ip_address=_client_ip(request), )) await session.commit() logger.info( "Impersonation started: admin=%s target=%s write_mode=%s", admin.id, target.id, body.write_mode, ) return ImpersonateResponse( access_token=token, target_user=UserListItem( id=str(target.id), email=target.email, display_name=target.display_name, role=target.role.value, creator_id=str(target.creator_id) if target.creator_id else None, is_active=target.is_active, ), ) @router.post("/impersonate/stop", response_model=StopImpersonateResponse) async def stop_impersonation( request: Request, current_user: Annotated[User, Depends(get_current_user)], session: Annotated[AsyncSession, Depends(get_session)], ): """Stop impersonation. Requires a valid impersonation token.""" admin_id = getattr(current_user, "_impersonating_admin_id", None) if admin_id is None: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Not currently impersonating", ) # Audit log session.add(ImpersonationLog( admin_user_id=admin_id, target_user_id=current_user.id, action="stop", ip_address=_client_ip(request), )) await session.commit() logger.info( "Impersonation stopped: admin=%s target=%s", admin_id, current_user.id, ) return StopImpersonateResponse(message="Impersonation ended") @router.get("/impersonation-log", response_model=list[ImpersonationLogItem]) async def get_impersonation_log( _admin: Annotated[User, Depends(_require_admin)], session: Annotated[AsyncSession, Depends(get_session)], page: int = Query(1, ge=1), page_size: int = Query(50, ge=1, le=200), ): """Paginated impersonation audit log. Admin only.""" AdminUser = aliased(User, name="admin_user") TargetUser = aliased(User, name="target_user") stmt = ( select(ImpersonationLog, AdminUser.display_name, TargetUser.display_name) .join(AdminUser, ImpersonationLog.admin_user_id == AdminUser.id) .join(TargetUser, ImpersonationLog.target_user_id == TargetUser.id) .order_by(ImpersonationLog.created_at.desc()) .offset((page - 1) * page_size) .limit(page_size) ) result = await session.execute(stmt) rows = result.all() return [ ImpersonationLogItem( id=str(log.id), admin_name=admin_name, target_name=target_name, action=log.action, write_mode=log.write_mode, ip_address=log.ip_address, created_at=log.created_at, ) for log, admin_name, target_name in rows ] @router.post("/creators/{slug}/extract-profile") async def extract_creator_profile( slug: str, _admin: Annotated[User, Depends(_require_admin)], session: Annotated[AsyncSession, Depends(get_session)], ): """Queue personality profile extraction for a creator. Admin only.""" from models import Creator result = await session.execute( select(Creator).where(Creator.slug == slug) ) creator = result.scalar_one_or_none() if creator is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Creator not found: {slug}", ) from pipeline.stages import extract_personality_profile extract_personality_profile.delay(str(creator.id)) logger.info("Queued personality extraction for creator=%s (%s)", slug, creator.id) return {"status": "queued", "creator_id": str(creator.id)} # ── Usage Analytics ────────────────────────────────────────────────────────── class _PeriodStats(BaseModel): request_count: int total_tokens: int prompt_tokens: int completion_tokens: int class _CreatorUsage(BaseModel): creator_slug: str request_count: int total_tokens: int class _UserUsage(BaseModel): identifier: str # display_name or IP request_count: int total_tokens: int class _DailyCount(BaseModel): date: str # ISO date YYYY-MM-DD request_count: int class UsageStatsResponse(BaseModel): today: _PeriodStats week: _PeriodStats month: _PeriodStats top_creators: list[_CreatorUsage] top_users: list[_UserUsage] daily_counts: list[_DailyCount] async def _period_stats( session: AsyncSession, since: datetime, ) -> _PeriodStats: """Aggregate token stats for chat usage since a given timestamp.""" stmt = select( func.count().label("cnt"), func.coalesce(func.sum(ChatUsageLog.total_tokens), 0).label("total"), func.coalesce(func.sum(ChatUsageLog.prompt_tokens), 0).label("prompt"), func.coalesce(func.sum(ChatUsageLog.completion_tokens), 0).label("completion"), ).where(ChatUsageLog.created_at >= since) row = (await session.execute(stmt)).one() return _PeriodStats( request_count=row.cnt, total_tokens=row.total, prompt_tokens=row.prompt, completion_tokens=row.completion, ) @router.get("/usage", response_model=UsageStatsResponse) async def get_usage_stats( _admin: Annotated[User, Depends(_require_admin)], session: Annotated[AsyncSession, Depends(get_session)], ): """Aggregated chat usage statistics. Admin only.""" now = datetime.now(timezone.utc).replace(tzinfo=None) today_start = now.replace(hour=0, minute=0, second=0, microsecond=0) week_start = today_start - timedelta(days=today_start.weekday()) # Monday month_start = today_start.replace(day=1) today = await _period_stats(session, today_start) week = await _period_stats(session, week_start) month = await _period_stats(session, month_start) # Top 10 creators by total tokens (this month) creator_stmt = ( select( ChatUsageLog.creator_slug, func.count().label("cnt"), func.coalesce(func.sum(ChatUsageLog.total_tokens), 0).label("total"), ) .where( ChatUsageLog.created_at >= month_start, ChatUsageLog.creator_slug.isnot(None), ) .group_by(ChatUsageLog.creator_slug) .order_by(func.sum(ChatUsageLog.total_tokens).desc()) .limit(10) ) creator_rows = (await session.execute(creator_stmt)).all() top_creators = [ _CreatorUsage(creator_slug=r.creator_slug, request_count=r.cnt, total_tokens=r.total) for r in creator_rows ] # Top 10 users by request count (this month) # Join with users table to get display_name; fall back to IP for anonymous user_stmt = ( select( ChatUsageLog.user_id, ChatUsageLog.client_ip, func.count().label("cnt"), func.coalesce(func.sum(ChatUsageLog.total_tokens), 0).label("total"), ) .where(ChatUsageLog.created_at >= month_start) .group_by(ChatUsageLog.user_id, ChatUsageLog.client_ip) .order_by(func.count().desc()) .limit(10) ) user_rows = (await session.execute(user_stmt)).all() # Resolve user display names user_ids = [r.user_id for r in user_rows if r.user_id is not None] name_map: dict[str, str] = {} if user_ids: name_result = await session.execute( select(User.id, User.display_name).where(User.id.in_(user_ids)) ) for uid, name in name_result.all(): name_map[str(uid)] = name top_users = [ _UserUsage( identifier=name_map.get(str(r.user_id), r.client_ip or "anonymous") if r.user_id else (r.client_ip or "anonymous"), request_count=r.cnt, total_tokens=r.total, ) for r in user_rows ] # Daily request counts for last 7 days seven_days_ago = today_start - timedelta(days=6) day_col = func.date_trunc("day", ChatUsageLog.created_at).label("day") daily_stmt = ( select(day_col, func.count().label("cnt")) .where(ChatUsageLog.created_at >= seven_days_ago) .group_by(day_col) .order_by(day_col) ) daily_rows = (await session.execute(daily_stmt)).all() daily_counts = [ _DailyCount(date=r.day.strftime("%Y-%m-%d"), request_count=r.cnt) for r in daily_rows ] return UsageStatsResponse( today=today, week=week, month=month, top_creators=top_creators, top_users=top_users, daily_counts=daily_counts, )