test: Implemented auth API router with register/login/me/update-profile…

- "backend/routers/auth.py"
- "backend/main.py"
- "backend/auth.py"
- "backend/requirements.txt"
- "backend/tests/conftest.py"
- "backend/tests/test_auth.py"

GSD-Task: S02/T02
This commit is contained in:
jlightner 2026-04-03 21:54:11 +00:00
parent ae62c09881
commit f4020251b9
6 changed files with 535 additions and 7 deletions

View file

@ -6,10 +6,10 @@ import uuid
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Annotated from typing import Annotated
import bcrypt
import jwt import jwt
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from passlib.context import CryptContext
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@ -19,17 +19,15 @@ from models import User, UserRole
# ── Password hashing ───────────────────────────────────────────────────────── # ── Password hashing ─────────────────────────────────────────────────────────
_pwd_ctx = CryptContext(schemes=["bcrypt"], deprecated="auto")
def hash_password(plain: str) -> str: def hash_password(plain: str) -> str:
"""Hash a plaintext password with bcrypt.""" """Hash a plaintext password with bcrypt."""
return _pwd_ctx.hash(plain) return bcrypt.hashpw(plain.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
def verify_password(plain: str, hashed: str) -> bool: def verify_password(plain: str, hashed: str) -> bool:
"""Verify a plaintext password against a bcrypt hash.""" """Verify a plaintext password against a bcrypt hash."""
return _pwd_ctx.verify(plain, hashed) return bcrypt.checkpw(plain.encode("utf-8"), hashed.encode("utf-8"))
# ── JWT ────────────────────────────────────────────────────────────────────── # ── JWT ──────────────────────────────────────────────────────────────────────

View file

@ -12,7 +12,7 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from config import get_settings from config import get_settings
from routers import creators, health, ingest, pipeline, reports, search, stats, techniques, topics, videos from routers import auth, creators, health, ingest, pipeline, reports, search, stats, techniques, topics, videos
def _setup_logging() -> None: def _setup_logging() -> None:
@ -78,6 +78,7 @@ app.add_middleware(
app.include_router(health.router) app.include_router(health.router)
# Versioned API # Versioned API
app.include_router(auth.router, prefix="/api/v1")
app.include_router(creators.router, prefix="/api/v1") app.include_router(creators.router, prefix="/api/v1")
app.include_router(ingest.router, prefix="/api/v1") app.include_router(ingest.router, prefix="/api/v1")
app.include_router(pipeline.router, prefix="/api/v1") app.include_router(pipeline.router, prefix="/api/v1")

View file

@ -16,7 +16,7 @@ pyyaml>=6.0,<7.0
psycopg2-binary>=2.9,<3.0 psycopg2-binary>=2.9,<3.0
watchdog>=4.0,<5.0 watchdog>=4.0,<5.0
PyJWT>=2.8,<3.0 PyJWT>=2.8,<3.0
passlib[bcrypt]>=1.7,<2.0 bcrypt>=4.0,<6.0
# Test dependencies # Test dependencies
pytest>=8.0,<10.0 pytest>=8.0,<10.0
pytest-asyncio>=0.24,<1.0 pytest-asyncio>=0.24,<1.0

168
backend/routers/auth.py Normal file
View file

@ -0,0 +1,168 @@
"""Auth router — registration, login, profile management."""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from auth import (
create_access_token,
get_current_user,
hash_password,
verify_password,
)
from database import get_session
from models import Creator, InviteCode, User
from schemas import (
LoginRequest,
RegisterRequest,
TokenResponse,
UpdateProfileRequest,
UserResponse,
)
logger = logging.getLogger("chrysopedia.auth")
router = APIRouter(prefix="/auth", tags=["auth"])
# ── Registration ─────────────────────────────────────────────────────────────
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def register(
body: RegisterRequest,
session: Annotated[AsyncSession, Depends(get_session)],
):
"""Register a new user with a valid invite code."""
# 1. Validate invite code
result = await session.execute(
select(InviteCode).where(InviteCode.code == body.invite_code)
)
invite = result.scalar_one_or_none()
if invite is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid invite code")
now = datetime.now(timezone.utc).replace(tzinfo=None)
if invite.expires_at is not None and invite.expires_at < now:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invite code has expired")
if invite.uses_remaining <= 0:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invite code exhausted")
# 2. Check email uniqueness
existing = await session.execute(select(User).where(User.email == body.email))
if existing.scalar_one_or_none() is not None:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Email already registered")
# 3. Optionally resolve creator_id from slug
creator_id = None
if body.creator_slug:
creator_result = await session.execute(
select(Creator).where(Creator.slug == body.creator_slug)
)
creator = creator_result.scalar_one_or_none()
if creator is not None:
creator_id = creator.id
# 4. Create user
user = User(
email=body.email,
hashed_password=hash_password(body.password),
display_name=body.display_name,
creator_id=creator_id,
)
session.add(user)
# 5. Decrement invite code uses
invite.uses_remaining -= 1
await session.commit()
await session.refresh(user)
logger.info("User registered: %s (email=%s)", user.id, user.email)
return user
# ── Login ────────────────────────────────────────────────────────────────────
@router.post("/login", response_model=TokenResponse)
async def login(
body: LoginRequest,
session: Annotated[AsyncSession, Depends(get_session)],
):
"""Authenticate with email + password, return JWT."""
result = await session.execute(select(User).where(User.email == body.email))
user = result.scalar_one_or_none()
if user is None or not verify_password(body.password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password",
)
token = create_access_token(user.id, user.role.value)
logger.info("User logged in: %s", user.id)
return TokenResponse(access_token=token)
# ── Profile ──────────────────────────────────────────────────────────────────
@router.get("/me", response_model=UserResponse)
async def get_profile(
current_user: Annotated[User, Depends(get_current_user)],
):
"""Return the current user's profile."""
return current_user
@router.put("/me", response_model=UserResponse)
async def update_profile(
body: UpdateProfileRequest,
current_user: Annotated[User, Depends(get_current_user)],
session: Annotated[AsyncSession, Depends(get_session)],
):
"""Update the current user's display name and/or password."""
if body.display_name is not None:
current_user.display_name = body.display_name
if body.new_password is not None:
if body.current_password is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Current password required to set new password",
)
if not verify_password(body.current_password, current_user.hashed_password):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Current password is incorrect",
)
current_user.hashed_password = hash_password(body.new_password)
await session.commit()
await session.refresh(current_user)
logger.info("Profile updated: %s", current_user.id)
return current_user
# ── Seed ─────────────────────────────────────────────────────────────────────
async def seed_invite_codes(session: AsyncSession) -> None:
"""Create default invite code if none exist. Call from lifespan or CLI."""
result = await session.execute(select(InviteCode))
if result.scalar_one_or_none() is None:
session.add(InviteCode(
code="CHRYSOPEDIA-ALPHA-2026",
uses_remaining=100,
))
await session.commit()
logger.info("Seeded default invite code: CHRYSOPEDIA-ALPHA-2026")

View file

@ -34,9 +34,12 @@ from main import app # noqa: E402
from models import ( # noqa: E402 from models import ( # noqa: E402
ContentType, ContentType,
Creator, Creator,
InviteCode,
ProcessingStatus, ProcessingStatus,
SourceVideo, SourceVideo,
TranscriptSegment, TranscriptSegment,
User,
UserRole,
) )
TEST_DATABASE_URL = os.getenv( TEST_DATABASE_URL = os.getenv(
@ -190,3 +193,47 @@ def pre_ingested_video(sync_engine):
session.close() session.close()
return result return result
# ── Auth fixtures ────────────────────────────────────────────────────────────
_TEST_INVITE_CODE = "TEST-INVITE-2026"
_TEST_EMAIL = "testuser@chrysopedia.com"
_TEST_PASSWORD = "securepass123"
@pytest_asyncio.fixture()
async def invite_code(db_engine):
"""Create a test invite code in the DB and return the code string."""
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
async with factory() as session:
code = InviteCode(code=_TEST_INVITE_CODE, uses_remaining=10)
session.add(code)
await session.commit()
return _TEST_INVITE_CODE
@pytest_asyncio.fixture()
async def registered_user(client, invite_code):
"""Register a user via the API and return the response dict."""
resp = await client.post("/api/v1/auth/register", json={
"email": _TEST_EMAIL,
"password": _TEST_PASSWORD,
"display_name": "Test User",
"invite_code": invite_code,
})
assert resp.status_code == 201
return resp.json()
@pytest_asyncio.fixture()
async def auth_headers(client, registered_user):
"""Log in and return Authorization headers dict."""
resp = await client.post("/api/v1/auth/login", json={
"email": _TEST_EMAIL,
"password": _TEST_PASSWORD,
})
assert resp.status_code == 200
token = resp.json()["access_token"]
return {"Authorization": f"Bearer {token}"}

314
backend/tests/test_auth.py Normal file
View file

@ -0,0 +1,314 @@
"""Integration tests for the auth router — registration, login, profile."""
import uuid
from datetime import datetime, timedelta, timezone
import pytest
import pytest_asyncio
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from models import InviteCode, User
# ── Registration ─────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_register_valid(client, invite_code):
"""Register with a valid invite code → 201 + user created."""
resp = await client.post("/api/v1/auth/register", json={
"email": "newuser@example.com",
"password": "strongpass1",
"display_name": "New User",
"invite_code": invite_code,
})
assert resp.status_code == 201
data = resp.json()
assert data["email"] == "newuser@example.com"
assert data["display_name"] == "New User"
assert data["role"] == "creator"
assert "id" in data
# Password not leaked
assert "hashed_password" not in data
@pytest.mark.asyncio
async def test_register_invalid_invite_code(client, invite_code):
"""Register with a wrong invite code → 403."""
resp = await client.post("/api/v1/auth/register", json={
"email": "bad@example.com",
"password": "strongpass1",
"display_name": "Bad",
"invite_code": "WRONG-CODE",
})
assert resp.status_code == 403
assert "Invalid invite code" in resp.json()["detail"]
@pytest.mark.asyncio
async def test_register_expired_invite_code(client, db_engine):
"""Register with an expired invite code → 403."""
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
async with factory() as session:
past = (datetime.now(timezone.utc) - timedelta(days=1)).replace(tzinfo=None)
session.add(InviteCode(code="EXPIRED-CODE", uses_remaining=10, expires_at=past))
await session.commit()
resp = await client.post("/api/v1/auth/register", json={
"email": "exp@example.com",
"password": "strongpass1",
"display_name": "Expired",
"invite_code": "EXPIRED-CODE",
})
assert resp.status_code == 403
assert "expired" in resp.json()["detail"].lower()
@pytest.mark.asyncio
async def test_register_exhausted_invite_code(client, db_engine):
"""Register with an invite code that has uses_remaining=0 → 403."""
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
async with factory() as session:
session.add(InviteCode(code="EXHAUSTED", uses_remaining=0))
await session.commit()
resp = await client.post("/api/v1/auth/register", json={
"email": "nope@example.com",
"password": "strongpass1",
"display_name": "Nope",
"invite_code": "EXHAUSTED",
})
assert resp.status_code == 403
assert "exhausted" in resp.json()["detail"].lower()
@pytest.mark.asyncio
async def test_register_invite_code_decrements(client, db_engine):
"""Invite code uses_remaining decrements after registration."""
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
async with factory() as session:
session.add(InviteCode(code="SINGLE-USE", uses_remaining=1))
await session.commit()
# First registration succeeds
resp = await client.post("/api/v1/auth/register", json={
"email": "first@example.com",
"password": "strongpass1",
"display_name": "First",
"invite_code": "SINGLE-USE",
})
assert resp.status_code == 201
# Second registration with same code fails (exhausted)
resp = await client.post("/api/v1/auth/register", json={
"email": "second@example.com",
"password": "strongpass1",
"display_name": "Second",
"invite_code": "SINGLE-USE",
})
assert resp.status_code == 403
@pytest.mark.asyncio
async def test_register_duplicate_email(client, invite_code, registered_user):
"""Register with an already-used email → 409."""
resp = await client.post("/api/v1/auth/register", json={
"email": "testuser@chrysopedia.com",
"password": "anotherpass1",
"display_name": "Dup",
"invite_code": invite_code,
})
assert resp.status_code == 409
assert "already registered" in resp.json()["detail"].lower()
# ── Login ────────────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_login_success(client, registered_user):
"""Login with correct credentials → 200 + JWT."""
resp = await client.post("/api/v1/auth/login", json={
"email": "testuser@chrysopedia.com",
"password": "securepass123",
})
assert resp.status_code == 200
data = resp.json()
assert "access_token" in data
assert data["token_type"] == "bearer"
@pytest.mark.asyncio
async def test_login_wrong_password(client, registered_user):
"""Login with wrong password → 401."""
resp = await client.post("/api/v1/auth/login", json={
"email": "testuser@chrysopedia.com",
"password": "wrongpassword",
})
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_login_nonexistent_email(client):
"""Login with an email that doesn't exist → 401."""
resp = await client.post("/api/v1/auth/login", json={
"email": "nobody@example.com",
"password": "somepass123",
})
assert resp.status_code == 401
# ── Profile (GET /me) ───────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_get_me_authenticated(client, auth_headers):
"""GET /me with valid token → 200 + profile."""
resp = await client.get("/api/v1/auth/me", headers=auth_headers)
assert resp.status_code == 200
data = resp.json()
assert data["email"] == "testuser@chrysopedia.com"
assert data["display_name"] == "Test User"
@pytest.mark.asyncio
async def test_get_me_no_token(client, db_engine):
"""GET /me without token → 401."""
resp = await client.get("/api/v1/auth/me")
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_get_me_invalid_token(client, db_engine):
"""GET /me with garbage token → 401."""
resp = await client.get("/api/v1/auth/me", headers={
"Authorization": "Bearer invalid.garbage.token",
})
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_get_me_expired_token(client, db_engine, invite_code):
"""GET /me with an expired JWT → 401."""
from auth import create_access_token
# Register a user first
resp = await client.post("/api/v1/auth/register", json={
"email": "expired@example.com",
"password": "strongpass1",
"display_name": "Expired Token User",
"invite_code": invite_code,
})
assert resp.status_code == 201
user_id = resp.json()["id"]
# Create a token that expires immediately
token = create_access_token(user_id, "creator", expires_minutes=-1)
resp = await client.get("/api/v1/auth/me", headers={
"Authorization": f"Bearer {token}",
})
assert resp.status_code == 401
# ── Profile (PUT /me) ───────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_update_display_name(client, auth_headers):
"""PUT /me updates display_name → 200 + new name."""
resp = await client.put("/api/v1/auth/me", json={
"display_name": "Updated Name",
}, headers=auth_headers)
assert resp.status_code == 200
assert resp.json()["display_name"] == "Updated Name"
# Verify persistence
resp2 = await client.get("/api/v1/auth/me", headers=auth_headers)
assert resp2.json()["display_name"] == "Updated Name"
@pytest.mark.asyncio
async def test_update_password(client, invite_code):
"""PUT /me changes password → can login with new password."""
# Register
await client.post("/api/v1/auth/register", json={
"email": "pwchange@example.com",
"password": "oldpassword1",
"display_name": "PW User",
"invite_code": invite_code,
})
# Login
login_resp = await client.post("/api/v1/auth/login", json={
"email": "pwchange@example.com",
"password": "oldpassword1",
})
token = login_resp.json()["access_token"]
headers = {"Authorization": f"Bearer {token}"}
# Change password
resp = await client.put("/api/v1/auth/me", json={
"current_password": "oldpassword1",
"new_password": "newpassword1",
}, headers=headers)
assert resp.status_code == 200
# Old password fails
resp_old = await client.post("/api/v1/auth/login", json={
"email": "pwchange@example.com",
"password": "oldpassword1",
})
assert resp_old.status_code == 401
# New password works
resp_new = await client.post("/api/v1/auth/login", json={
"email": "pwchange@example.com",
"password": "newpassword1",
})
assert resp_new.status_code == 200
# ── Malformed inputs (422 validation) ───────────────────────────────────────
@pytest.mark.asyncio
async def test_register_missing_fields(client):
"""Register with missing fields → 422."""
resp = await client.post("/api/v1/auth/register", json={})
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_register_empty_password(client):
"""Register with empty password → 422 (min_length=8)."""
resp = await client.post("/api/v1/auth/register", json={
"email": "a@b.com",
"password": "",
"display_name": "X",
"invite_code": "CODE",
})
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_login_empty_body(client):
"""Login with empty body → 422."""
resp = await client.post("/api/v1/auth/login", json={})
assert resp.status_code == 422
# ── Public endpoints unaffected ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_public_techniques_no_auth(client, db_engine):
"""GET /api/v1/techniques works without auth."""
resp = await client.get("/api/v1/techniques")
# 200 even if empty — no 401/403
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_public_creators_no_auth(client, db_engine):
"""GET /api/v1/creators works without auth."""
resp = await client.get("/api/v1/creators")
assert resp.status_code == 200