diff --git a/backend/auth.py b/backend/auth.py index 9e464c3..ed05a6d 100644 --- a/backend/auth.py +++ b/backend/auth.py @@ -6,10 +6,10 @@ import uuid from datetime import datetime, timedelta, timezone from typing import Annotated +import bcrypt import jwt from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer -from passlib.context import CryptContext from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -19,17 +19,15 @@ from models import User, UserRole # ── Password hashing ───────────────────────────────────────────────────────── -_pwd_ctx = CryptContext(schemes=["bcrypt"], deprecated="auto") - def hash_password(plain: str) -> str: """Hash a plaintext password with bcrypt.""" - return _pwd_ctx.hash(plain) + return bcrypt.hashpw(plain.encode("utf-8"), bcrypt.gensalt()).decode("utf-8") def verify_password(plain: str, hashed: str) -> bool: """Verify a plaintext password against a bcrypt hash.""" - return _pwd_ctx.verify(plain, hashed) + return bcrypt.checkpw(plain.encode("utf-8"), hashed.encode("utf-8")) # ── JWT ────────────────────────────────────────────────────────────────────── diff --git a/backend/main.py b/backend/main.py index c1202d8..68bde81 100644 --- a/backend/main.py +++ b/backend/main.py @@ -12,7 +12,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from config import get_settings -from routers import creators, health, ingest, pipeline, reports, search, stats, techniques, topics, videos +from routers import auth, creators, health, ingest, pipeline, reports, search, stats, techniques, topics, videos def _setup_logging() -> None: @@ -78,6 +78,7 @@ app.add_middleware( app.include_router(health.router) # Versioned API +app.include_router(auth.router, prefix="/api/v1") app.include_router(creators.router, prefix="/api/v1") app.include_router(ingest.router, prefix="/api/v1") app.include_router(pipeline.router, prefix="/api/v1") diff --git a/backend/requirements.txt b/backend/requirements.txt index 58085b0..a2547d0 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -16,7 +16,7 @@ pyyaml>=6.0,<7.0 psycopg2-binary>=2.9,<3.0 watchdog>=4.0,<5.0 PyJWT>=2.8,<3.0 -passlib[bcrypt]>=1.7,<2.0 +bcrypt>=4.0,<6.0 # Test dependencies pytest>=8.0,<10.0 pytest-asyncio>=0.24,<1.0 diff --git a/backend/routers/auth.py b/backend/routers/auth.py new file mode 100644 index 0000000..79b1219 --- /dev/null +++ b/backend/routers/auth.py @@ -0,0 +1,168 @@ +"""Auth router — registration, login, profile management.""" + +from __future__ import annotations + +import logging +from datetime import datetime, timezone +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from auth import ( + create_access_token, + get_current_user, + hash_password, + verify_password, +) +from database import get_session +from models import Creator, InviteCode, User +from schemas import ( + LoginRequest, + RegisterRequest, + TokenResponse, + UpdateProfileRequest, + UserResponse, +) + +logger = logging.getLogger("chrysopedia.auth") + +router = APIRouter(prefix="/auth", tags=["auth"]) + + +# ── Registration ───────────────────────────────────────────────────────────── + + +@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED) +async def register( + body: RegisterRequest, + session: Annotated[AsyncSession, Depends(get_session)], +): + """Register a new user with a valid invite code.""" + # 1. Validate invite code + result = await session.execute( + select(InviteCode).where(InviteCode.code == body.invite_code) + ) + invite = result.scalar_one_or_none() + if invite is None: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid invite code") + + now = datetime.now(timezone.utc).replace(tzinfo=None) + if invite.expires_at is not None and invite.expires_at < now: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invite code has expired") + + if invite.uses_remaining <= 0: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invite code exhausted") + + # 2. Check email uniqueness + existing = await session.execute(select(User).where(User.email == body.email)) + if existing.scalar_one_or_none() is not None: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Email already registered") + + # 3. Optionally resolve creator_id from slug + creator_id = None + if body.creator_slug: + creator_result = await session.execute( + select(Creator).where(Creator.slug == body.creator_slug) + ) + creator = creator_result.scalar_one_or_none() + if creator is not None: + creator_id = creator.id + + # 4. Create user + user = User( + email=body.email, + hashed_password=hash_password(body.password), + display_name=body.display_name, + creator_id=creator_id, + ) + session.add(user) + + # 5. Decrement invite code uses + invite.uses_remaining -= 1 + + await session.commit() + await session.refresh(user) + + logger.info("User registered: %s (email=%s)", user.id, user.email) + return user + + +# ── Login ──────────────────────────────────────────────────────────────────── + + +@router.post("/login", response_model=TokenResponse) +async def login( + body: LoginRequest, + session: Annotated[AsyncSession, Depends(get_session)], +): + """Authenticate with email + password, return JWT.""" + result = await session.execute(select(User).where(User.email == body.email)) + user = result.scalar_one_or_none() + + if user is None or not verify_password(body.password, user.hashed_password): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid email or password", + ) + + token = create_access_token(user.id, user.role.value) + logger.info("User logged in: %s", user.id) + return TokenResponse(access_token=token) + + +# ── Profile ────────────────────────────────────────────────────────────────── + + +@router.get("/me", response_model=UserResponse) +async def get_profile( + current_user: Annotated[User, Depends(get_current_user)], +): + """Return the current user's profile.""" + return current_user + + +@router.put("/me", response_model=UserResponse) +async def update_profile( + body: UpdateProfileRequest, + current_user: Annotated[User, Depends(get_current_user)], + session: Annotated[AsyncSession, Depends(get_session)], +): + """Update the current user's display name and/or password.""" + if body.display_name is not None: + current_user.display_name = body.display_name + + if body.new_password is not None: + if body.current_password is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Current password required to set new password", + ) + if not verify_password(body.current_password, current_user.hashed_password): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Current password is incorrect", + ) + current_user.hashed_password = hash_password(body.new_password) + + await session.commit() + await session.refresh(current_user) + + logger.info("Profile updated: %s", current_user.id) + return current_user + + +# ── Seed ───────────────────────────────────────────────────────────────────── + + +async def seed_invite_codes(session: AsyncSession) -> None: + """Create default invite code if none exist. Call from lifespan or CLI.""" + result = await session.execute(select(InviteCode)) + if result.scalar_one_or_none() is None: + session.add(InviteCode( + code="CHRYSOPEDIA-ALPHA-2026", + uses_remaining=100, + )) + await session.commit() + logger.info("Seeded default invite code: CHRYSOPEDIA-ALPHA-2026") diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 9de778f..86aa3c1 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -34,9 +34,12 @@ from main import app # noqa: E402 from models import ( # noqa: E402 ContentType, Creator, + InviteCode, ProcessingStatus, SourceVideo, TranscriptSegment, + User, + UserRole, ) TEST_DATABASE_URL = os.getenv( @@ -190,3 +193,47 @@ def pre_ingested_video(sync_engine): session.close() return result + + +# ── Auth fixtures ──────────────────────────────────────────────────────────── + +_TEST_INVITE_CODE = "TEST-INVITE-2026" +_TEST_EMAIL = "testuser@chrysopedia.com" +_TEST_PASSWORD = "securepass123" + + +@pytest_asyncio.fixture() +async def invite_code(db_engine): + """Create a test invite code in the DB and return the code string.""" + factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + async with factory() as session: + code = InviteCode(code=_TEST_INVITE_CODE, uses_remaining=10) + session.add(code) + await session.commit() + return _TEST_INVITE_CODE + + +@pytest_asyncio.fixture() +async def registered_user(client, invite_code): + """Register a user via the API and return the response dict.""" + resp = await client.post("/api/v1/auth/register", json={ + "email": _TEST_EMAIL, + "password": _TEST_PASSWORD, + "display_name": "Test User", + "invite_code": invite_code, + }) + assert resp.status_code == 201 + return resp.json() + + +@pytest_asyncio.fixture() +async def auth_headers(client, registered_user): + """Log in and return Authorization headers dict.""" + resp = await client.post("/api/v1/auth/login", json={ + "email": _TEST_EMAIL, + "password": _TEST_PASSWORD, + }) + assert resp.status_code == 200 + token = resp.json()["access_token"] + return {"Authorization": f"Bearer {token}"} + diff --git a/backend/tests/test_auth.py b/backend/tests/test_auth.py new file mode 100644 index 0000000..c48a8c4 --- /dev/null +++ b/backend/tests/test_auth.py @@ -0,0 +1,314 @@ +"""Integration tests for the auth router — registration, login, profile.""" + +import uuid +from datetime import datetime, timedelta, timezone + +import pytest +import pytest_asyncio +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from models import InviteCode, User + + +# ── Registration ───────────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_register_valid(client, invite_code): + """Register with a valid invite code → 201 + user created.""" + resp = await client.post("/api/v1/auth/register", json={ + "email": "newuser@example.com", + "password": "strongpass1", + "display_name": "New User", + "invite_code": invite_code, + }) + assert resp.status_code == 201 + data = resp.json() + assert data["email"] == "newuser@example.com" + assert data["display_name"] == "New User" + assert data["role"] == "creator" + assert "id" in data + # Password not leaked + assert "hashed_password" not in data + + +@pytest.mark.asyncio +async def test_register_invalid_invite_code(client, invite_code): + """Register with a wrong invite code → 403.""" + resp = await client.post("/api/v1/auth/register", json={ + "email": "bad@example.com", + "password": "strongpass1", + "display_name": "Bad", + "invite_code": "WRONG-CODE", + }) + assert resp.status_code == 403 + assert "Invalid invite code" in resp.json()["detail"] + + +@pytest.mark.asyncio +async def test_register_expired_invite_code(client, db_engine): + """Register with an expired invite code → 403.""" + factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + async with factory() as session: + past = (datetime.now(timezone.utc) - timedelta(days=1)).replace(tzinfo=None) + session.add(InviteCode(code="EXPIRED-CODE", uses_remaining=10, expires_at=past)) + await session.commit() + + resp = await client.post("/api/v1/auth/register", json={ + "email": "exp@example.com", + "password": "strongpass1", + "display_name": "Expired", + "invite_code": "EXPIRED-CODE", + }) + assert resp.status_code == 403 + assert "expired" in resp.json()["detail"].lower() + + +@pytest.mark.asyncio +async def test_register_exhausted_invite_code(client, db_engine): + """Register with an invite code that has uses_remaining=0 → 403.""" + factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + async with factory() as session: + session.add(InviteCode(code="EXHAUSTED", uses_remaining=0)) + await session.commit() + + resp = await client.post("/api/v1/auth/register", json={ + "email": "nope@example.com", + "password": "strongpass1", + "display_name": "Nope", + "invite_code": "EXHAUSTED", + }) + assert resp.status_code == 403 + assert "exhausted" in resp.json()["detail"].lower() + + +@pytest.mark.asyncio +async def test_register_invite_code_decrements(client, db_engine): + """Invite code uses_remaining decrements after registration.""" + factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) + async with factory() as session: + session.add(InviteCode(code="SINGLE-USE", uses_remaining=1)) + await session.commit() + + # First registration succeeds + resp = await client.post("/api/v1/auth/register", json={ + "email": "first@example.com", + "password": "strongpass1", + "display_name": "First", + "invite_code": "SINGLE-USE", + }) + assert resp.status_code == 201 + + # Second registration with same code fails (exhausted) + resp = await client.post("/api/v1/auth/register", json={ + "email": "second@example.com", + "password": "strongpass1", + "display_name": "Second", + "invite_code": "SINGLE-USE", + }) + assert resp.status_code == 403 + + +@pytest.mark.asyncio +async def test_register_duplicate_email(client, invite_code, registered_user): + """Register with an already-used email → 409.""" + resp = await client.post("/api/v1/auth/register", json={ + "email": "testuser@chrysopedia.com", + "password": "anotherpass1", + "display_name": "Dup", + "invite_code": invite_code, + }) + assert resp.status_code == 409 + assert "already registered" in resp.json()["detail"].lower() + + +# ── Login ──────────────────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_login_success(client, registered_user): + """Login with correct credentials → 200 + JWT.""" + resp = await client.post("/api/v1/auth/login", json={ + "email": "testuser@chrysopedia.com", + "password": "securepass123", + }) + assert resp.status_code == 200 + data = resp.json() + assert "access_token" in data + assert data["token_type"] == "bearer" + + +@pytest.mark.asyncio +async def test_login_wrong_password(client, registered_user): + """Login with wrong password → 401.""" + resp = await client.post("/api/v1/auth/login", json={ + "email": "testuser@chrysopedia.com", + "password": "wrongpassword", + }) + assert resp.status_code == 401 + + +@pytest.mark.asyncio +async def test_login_nonexistent_email(client): + """Login with an email that doesn't exist → 401.""" + resp = await client.post("/api/v1/auth/login", json={ + "email": "nobody@example.com", + "password": "somepass123", + }) + assert resp.status_code == 401 + + +# ── Profile (GET /me) ─────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_get_me_authenticated(client, auth_headers): + """GET /me with valid token → 200 + profile.""" + resp = await client.get("/api/v1/auth/me", headers=auth_headers) + assert resp.status_code == 200 + data = resp.json() + assert data["email"] == "testuser@chrysopedia.com" + assert data["display_name"] == "Test User" + + +@pytest.mark.asyncio +async def test_get_me_no_token(client, db_engine): + """GET /me without token → 401.""" + resp = await client.get("/api/v1/auth/me") + assert resp.status_code == 401 + + +@pytest.mark.asyncio +async def test_get_me_invalid_token(client, db_engine): + """GET /me with garbage token → 401.""" + resp = await client.get("/api/v1/auth/me", headers={ + "Authorization": "Bearer invalid.garbage.token", + }) + assert resp.status_code == 401 + + +@pytest.mark.asyncio +async def test_get_me_expired_token(client, db_engine, invite_code): + """GET /me with an expired JWT → 401.""" + from auth import create_access_token + + # Register a user first + resp = await client.post("/api/v1/auth/register", json={ + "email": "expired@example.com", + "password": "strongpass1", + "display_name": "Expired Token User", + "invite_code": invite_code, + }) + assert resp.status_code == 201 + user_id = resp.json()["id"] + + # Create a token that expires immediately + token = create_access_token(user_id, "creator", expires_minutes=-1) + resp = await client.get("/api/v1/auth/me", headers={ + "Authorization": f"Bearer {token}", + }) + assert resp.status_code == 401 + + +# ── Profile (PUT /me) ─────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_update_display_name(client, auth_headers): + """PUT /me updates display_name → 200 + new name.""" + resp = await client.put("/api/v1/auth/me", json={ + "display_name": "Updated Name", + }, headers=auth_headers) + assert resp.status_code == 200 + assert resp.json()["display_name"] == "Updated Name" + + # Verify persistence + resp2 = await client.get("/api/v1/auth/me", headers=auth_headers) + assert resp2.json()["display_name"] == "Updated Name" + + +@pytest.mark.asyncio +async def test_update_password(client, invite_code): + """PUT /me changes password → can login with new password.""" + # Register + await client.post("/api/v1/auth/register", json={ + "email": "pwchange@example.com", + "password": "oldpassword1", + "display_name": "PW User", + "invite_code": invite_code, + }) + # Login + login_resp = await client.post("/api/v1/auth/login", json={ + "email": "pwchange@example.com", + "password": "oldpassword1", + }) + token = login_resp.json()["access_token"] + headers = {"Authorization": f"Bearer {token}"} + + # Change password + resp = await client.put("/api/v1/auth/me", json={ + "current_password": "oldpassword1", + "new_password": "newpassword1", + }, headers=headers) + assert resp.status_code == 200 + + # Old password fails + resp_old = await client.post("/api/v1/auth/login", json={ + "email": "pwchange@example.com", + "password": "oldpassword1", + }) + assert resp_old.status_code == 401 + + # New password works + resp_new = await client.post("/api/v1/auth/login", json={ + "email": "pwchange@example.com", + "password": "newpassword1", + }) + assert resp_new.status_code == 200 + + +# ── Malformed inputs (422 validation) ─────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_register_missing_fields(client): + """Register with missing fields → 422.""" + resp = await client.post("/api/v1/auth/register", json={}) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_register_empty_password(client): + """Register with empty password → 422 (min_length=8).""" + resp = await client.post("/api/v1/auth/register", json={ + "email": "a@b.com", + "password": "", + "display_name": "X", + "invite_code": "CODE", + }) + assert resp.status_code == 422 + + +@pytest.mark.asyncio +async def test_login_empty_body(client): + """Login with empty body → 422.""" + resp = await client.post("/api/v1/auth/login", json={}) + assert resp.status_code == 422 + + +# ── Public endpoints unaffected ────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_public_techniques_no_auth(client, db_engine): + """GET /api/v1/techniques works without auth.""" + resp = await client.get("/api/v1/techniques") + # 200 even if empty — no 401/403 + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_public_creators_no_auth(client, db_engine): + """GET /api/v1/creators works without auth.""" + resp = await client.get("/api/v1/creators") + assert resp.status_code == 200