- "backend/models.py" - "backend/schemas.py" - "backend/routers/auth.py" - "alembic/versions/030_add_onboarding_completed.py" GSD-Task: S03/T01
189 lines
6.8 KiB
Python
189 lines
6.8 KiB
Python
"""Auth router — registration, login, profile management."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
from typing import Annotated
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from auth import (
|
|
create_access_token,
|
|
get_current_user,
|
|
hash_password,
|
|
reject_impersonation,
|
|
verify_password,
|
|
)
|
|
from database import get_session
|
|
from models import Creator, InviteCode, User
|
|
from schemas import (
|
|
LoginRequest,
|
|
RegisterRequest,
|
|
TokenResponse,
|
|
UpdateProfileRequest,
|
|
UserResponse,
|
|
)
|
|
|
|
logger = logging.getLogger("chrysopedia.auth")
|
|
|
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
|
|
|
|
|
# ── Registration ─────────────────────────────────────────────────────────────
|
|
|
|
|
|
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
|
async def register(
|
|
body: RegisterRequest,
|
|
session: Annotated[AsyncSession, Depends(get_session)],
|
|
):
|
|
"""Register a new user with a valid invite code."""
|
|
# 1. Validate invite code
|
|
result = await session.execute(
|
|
select(InviteCode).where(InviteCode.code == body.invite_code)
|
|
)
|
|
invite = result.scalar_one_or_none()
|
|
if invite is None:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid invite code")
|
|
|
|
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
if invite.expires_at is not None and invite.expires_at < now:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invite code has expired")
|
|
|
|
if invite.uses_remaining <= 0:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invite code exhausted")
|
|
|
|
# 2. Check email uniqueness
|
|
existing = await session.execute(select(User).where(User.email == body.email))
|
|
if existing.scalar_one_or_none() is not None:
|
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Email already registered")
|
|
|
|
# 3. Optionally resolve creator_id from slug
|
|
creator_id = None
|
|
if body.creator_slug:
|
|
creator_result = await session.execute(
|
|
select(Creator).where(Creator.slug == body.creator_slug)
|
|
)
|
|
creator = creator_result.scalar_one_or_none()
|
|
if creator is not None:
|
|
creator_id = creator.id
|
|
|
|
# 4. Create user
|
|
user = User(
|
|
email=body.email,
|
|
hashed_password=hash_password(body.password),
|
|
display_name=body.display_name,
|
|
creator_id=creator_id,
|
|
)
|
|
session.add(user)
|
|
|
|
# 5. Decrement invite code uses
|
|
invite.uses_remaining -= 1
|
|
|
|
await session.commit()
|
|
await session.refresh(user)
|
|
|
|
logger.info("User registered: %s (email=%s)", user.id, user.email)
|
|
return user
|
|
|
|
|
|
# ── Login ────────────────────────────────────────────────────────────────────
|
|
|
|
|
|
@router.post("/login", response_model=TokenResponse)
|
|
async def login(
|
|
body: LoginRequest,
|
|
session: Annotated[AsyncSession, Depends(get_session)],
|
|
):
|
|
"""Authenticate with email + password, return JWT."""
|
|
result = await session.execute(select(User).where(User.email == body.email))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if user is None or not verify_password(body.password, user.hashed_password):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid email or password",
|
|
)
|
|
|
|
token = create_access_token(user.id, user.role.value)
|
|
logger.info("User logged in: %s", user.id)
|
|
return TokenResponse(access_token=token)
|
|
|
|
|
|
# ── Profile ──────────────────────────────────────────────────────────────────
|
|
|
|
|
|
@router.get("/me", response_model=UserResponse)
|
|
async def get_profile(
|
|
current_user: Annotated[User, Depends(get_current_user)],
|
|
):
|
|
"""Return the current user's profile."""
|
|
resp = UserResponse.model_validate(current_user)
|
|
admin_id = getattr(current_user, "_impersonating_admin_id", None)
|
|
if admin_id is not None:
|
|
resp.impersonating = True
|
|
return resp
|
|
|
|
|
|
@router.put("/me", response_model=UserResponse)
|
|
async def update_profile(
|
|
body: UpdateProfileRequest,
|
|
current_user: Annotated[User, Depends(reject_impersonation)],
|
|
session: Annotated[AsyncSession, Depends(get_session)],
|
|
):
|
|
"""Update the current user's display name and/or password."""
|
|
if body.display_name is not None:
|
|
current_user.display_name = body.display_name
|
|
|
|
if body.new_password is not None:
|
|
if body.current_password is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Current password required to set new password",
|
|
)
|
|
if not verify_password(body.current_password, current_user.hashed_password):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Current password is incorrect",
|
|
)
|
|
current_user.hashed_password = hash_password(body.new_password)
|
|
|
|
await session.commit()
|
|
await session.refresh(current_user)
|
|
|
|
logger.info("Profile updated: %s", current_user.id)
|
|
return current_user
|
|
|
|
|
|
# ── Seed ─────────────────────────────────────────────────────────────────────
|
|
|
|
|
|
async def seed_invite_codes(session: AsyncSession) -> None:
|
|
"""Create default invite code if none exist. Call from lifespan or CLI."""
|
|
result = await session.execute(select(InviteCode))
|
|
if result.scalar_one_or_none() is None:
|
|
session.add(InviteCode(
|
|
code="CHRYSOPEDIA-ALPHA-2026",
|
|
uses_remaining=100,
|
|
))
|
|
await session.commit()
|
|
logger.info("Seeded default invite code: CHRYSOPEDIA-ALPHA-2026")
|
|
|
|
|
|
# ── Onboarding ───────────────────────────────────────────────────────────────
|
|
|
|
|
|
@router.post("/onboarding-complete", response_model=UserResponse)
|
|
async def complete_onboarding(
|
|
current_user: Annotated[User, Depends(get_current_user)],
|
|
session: Annotated[AsyncSession, Depends(get_session)],
|
|
):
|
|
"""Mark the current user's onboarding as completed."""
|
|
current_user.onboarding_completed = True
|
|
await session.commit()
|
|
await session.refresh(current_user)
|
|
logger.info("Onboarding completed: %s", current_user.id)
|
|
return UserResponse.model_validate(current_user)
|