test: Implemented auth API router with register/login/me/update-profile…
- "backend/routers/auth.py" - "backend/main.py" - "backend/auth.py" - "backend/requirements.txt" - "backend/tests/conftest.py" - "backend/tests/test_auth.py" GSD-Task: S02/T02
This commit is contained in:
parent
ae62c09881
commit
f4020251b9
6 changed files with 535 additions and 7 deletions
|
|
@ -6,10 +6,10 @@ import uuid
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
|
import bcrypt
|
||||||
import jwt
|
import jwt
|
||||||
from fastapi import Depends, HTTPException, status
|
from fastapi import Depends, HTTPException, status
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from passlib.context import CryptContext
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
|
@ -19,17 +19,15 @@ from models import User, UserRole
|
||||||
|
|
||||||
# ── Password hashing ─────────────────────────────────────────────────────────
|
# ── Password hashing ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
_pwd_ctx = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
||||||
|
|
||||||
|
|
||||||
def hash_password(plain: str) -> str:
|
def hash_password(plain: str) -> str:
|
||||||
"""Hash a plaintext password with bcrypt."""
|
"""Hash a plaintext password with bcrypt."""
|
||||||
return _pwd_ctx.hash(plain)
|
return bcrypt.hashpw(plain.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def verify_password(plain: str, hashed: str) -> bool:
|
def verify_password(plain: str, hashed: str) -> bool:
|
||||||
"""Verify a plaintext password against a bcrypt hash."""
|
"""Verify a plaintext password against a bcrypt hash."""
|
||||||
return _pwd_ctx.verify(plain, hashed)
|
return bcrypt.checkpw(plain.encode("utf-8"), hashed.encode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
# ── JWT ──────────────────────────────────────────────────────────────────────
|
# ── JWT ──────────────────────────────────────────────────────────────────────
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
from config import get_settings
|
from config import get_settings
|
||||||
from routers import creators, health, ingest, pipeline, reports, search, stats, techniques, topics, videos
|
from routers import auth, creators, health, ingest, pipeline, reports, search, stats, techniques, topics, videos
|
||||||
|
|
||||||
|
|
||||||
def _setup_logging() -> None:
|
def _setup_logging() -> None:
|
||||||
|
|
@ -78,6 +78,7 @@ app.add_middleware(
|
||||||
app.include_router(health.router)
|
app.include_router(health.router)
|
||||||
|
|
||||||
# Versioned API
|
# Versioned API
|
||||||
|
app.include_router(auth.router, prefix="/api/v1")
|
||||||
app.include_router(creators.router, prefix="/api/v1")
|
app.include_router(creators.router, prefix="/api/v1")
|
||||||
app.include_router(ingest.router, prefix="/api/v1")
|
app.include_router(ingest.router, prefix="/api/v1")
|
||||||
app.include_router(pipeline.router, prefix="/api/v1")
|
app.include_router(pipeline.router, prefix="/api/v1")
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ pyyaml>=6.0,<7.0
|
||||||
psycopg2-binary>=2.9,<3.0
|
psycopg2-binary>=2.9,<3.0
|
||||||
watchdog>=4.0,<5.0
|
watchdog>=4.0,<5.0
|
||||||
PyJWT>=2.8,<3.0
|
PyJWT>=2.8,<3.0
|
||||||
passlib[bcrypt]>=1.7,<2.0
|
bcrypt>=4.0,<6.0
|
||||||
# Test dependencies
|
# Test dependencies
|
||||||
pytest>=8.0,<10.0
|
pytest>=8.0,<10.0
|
||||||
pytest-asyncio>=0.24,<1.0
|
pytest-asyncio>=0.24,<1.0
|
||||||
|
|
|
||||||
168
backend/routers/auth.py
Normal file
168
backend/routers/auth.py
Normal file
|
|
@ -0,0 +1,168 @@
|
||||||
|
"""Auth router — registration, login, profile management."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from auth import (
|
||||||
|
create_access_token,
|
||||||
|
get_current_user,
|
||||||
|
hash_password,
|
||||||
|
verify_password,
|
||||||
|
)
|
||||||
|
from database import get_session
|
||||||
|
from models import Creator, InviteCode, User
|
||||||
|
from schemas import (
|
||||||
|
LoginRequest,
|
||||||
|
RegisterRequest,
|
||||||
|
TokenResponse,
|
||||||
|
UpdateProfileRequest,
|
||||||
|
UserResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger("chrysopedia.auth")
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Registration ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def register(
|
||||||
|
body: RegisterRequest,
|
||||||
|
session: Annotated[AsyncSession, Depends(get_session)],
|
||||||
|
):
|
||||||
|
"""Register a new user with a valid invite code."""
|
||||||
|
# 1. Validate invite code
|
||||||
|
result = await session.execute(
|
||||||
|
select(InviteCode).where(InviteCode.code == body.invite_code)
|
||||||
|
)
|
||||||
|
invite = result.scalar_one_or_none()
|
||||||
|
if invite is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid invite code")
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
|
if invite.expires_at is not None and invite.expires_at < now:
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invite code has expired")
|
||||||
|
|
||||||
|
if invite.uses_remaining <= 0:
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invite code exhausted")
|
||||||
|
|
||||||
|
# 2. Check email uniqueness
|
||||||
|
existing = await session.execute(select(User).where(User.email == body.email))
|
||||||
|
if existing.scalar_one_or_none() is not None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Email already registered")
|
||||||
|
|
||||||
|
# 3. Optionally resolve creator_id from slug
|
||||||
|
creator_id = None
|
||||||
|
if body.creator_slug:
|
||||||
|
creator_result = await session.execute(
|
||||||
|
select(Creator).where(Creator.slug == body.creator_slug)
|
||||||
|
)
|
||||||
|
creator = creator_result.scalar_one_or_none()
|
||||||
|
if creator is not None:
|
||||||
|
creator_id = creator.id
|
||||||
|
|
||||||
|
# 4. Create user
|
||||||
|
user = User(
|
||||||
|
email=body.email,
|
||||||
|
hashed_password=hash_password(body.password),
|
||||||
|
display_name=body.display_name,
|
||||||
|
creator_id=creator_id,
|
||||||
|
)
|
||||||
|
session.add(user)
|
||||||
|
|
||||||
|
# 5. Decrement invite code uses
|
||||||
|
invite.uses_remaining -= 1
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(user)
|
||||||
|
|
||||||
|
logger.info("User registered: %s (email=%s)", user.id, user.email)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
# ── Login ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login", response_model=TokenResponse)
|
||||||
|
async def login(
|
||||||
|
body: LoginRequest,
|
||||||
|
session: Annotated[AsyncSession, Depends(get_session)],
|
||||||
|
):
|
||||||
|
"""Authenticate with email + password, return JWT."""
|
||||||
|
result = await session.execute(select(User).where(User.email == body.email))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if user is None or not verify_password(body.password, user.hashed_password):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid email or password",
|
||||||
|
)
|
||||||
|
|
||||||
|
token = create_access_token(user.id, user.role.value)
|
||||||
|
logger.info("User logged in: %s", user.id)
|
||||||
|
return TokenResponse(access_token=token)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Profile ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me", response_model=UserResponse)
|
||||||
|
async def get_profile(
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
):
|
||||||
|
"""Return the current user's profile."""
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/me", response_model=UserResponse)
|
||||||
|
async def update_profile(
|
||||||
|
body: UpdateProfileRequest,
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
session: Annotated[AsyncSession, Depends(get_session)],
|
||||||
|
):
|
||||||
|
"""Update the current user's display name and/or password."""
|
||||||
|
if body.display_name is not None:
|
||||||
|
current_user.display_name = body.display_name
|
||||||
|
|
||||||
|
if body.new_password is not None:
|
||||||
|
if body.current_password is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Current password required to set new password",
|
||||||
|
)
|
||||||
|
if not verify_password(body.current_password, current_user.hashed_password):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Current password is incorrect",
|
||||||
|
)
|
||||||
|
current_user.hashed_password = hash_password(body.new_password)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(current_user)
|
||||||
|
|
||||||
|
logger.info("Profile updated: %s", current_user.id)
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
# ── Seed ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def seed_invite_codes(session: AsyncSession) -> None:
|
||||||
|
"""Create default invite code if none exist. Call from lifespan or CLI."""
|
||||||
|
result = await session.execute(select(InviteCode))
|
||||||
|
if result.scalar_one_or_none() is None:
|
||||||
|
session.add(InviteCode(
|
||||||
|
code="CHRYSOPEDIA-ALPHA-2026",
|
||||||
|
uses_remaining=100,
|
||||||
|
))
|
||||||
|
await session.commit()
|
||||||
|
logger.info("Seeded default invite code: CHRYSOPEDIA-ALPHA-2026")
|
||||||
|
|
@ -34,9 +34,12 @@ from main import app # noqa: E402
|
||||||
from models import ( # noqa: E402
|
from models import ( # noqa: E402
|
||||||
ContentType,
|
ContentType,
|
||||||
Creator,
|
Creator,
|
||||||
|
InviteCode,
|
||||||
ProcessingStatus,
|
ProcessingStatus,
|
||||||
SourceVideo,
|
SourceVideo,
|
||||||
TranscriptSegment,
|
TranscriptSegment,
|
||||||
|
User,
|
||||||
|
UserRole,
|
||||||
)
|
)
|
||||||
|
|
||||||
TEST_DATABASE_URL = os.getenv(
|
TEST_DATABASE_URL = os.getenv(
|
||||||
|
|
@ -190,3 +193,47 @@ def pre_ingested_video(sync_engine):
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ── Auth fixtures ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_TEST_INVITE_CODE = "TEST-INVITE-2026"
|
||||||
|
_TEST_EMAIL = "testuser@chrysopedia.com"
|
||||||
|
_TEST_PASSWORD = "securepass123"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture()
|
||||||
|
async def invite_code(db_engine):
|
||||||
|
"""Create a test invite code in the DB and return the code string."""
|
||||||
|
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||||
|
async with factory() as session:
|
||||||
|
code = InviteCode(code=_TEST_INVITE_CODE, uses_remaining=10)
|
||||||
|
session.add(code)
|
||||||
|
await session.commit()
|
||||||
|
return _TEST_INVITE_CODE
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture()
|
||||||
|
async def registered_user(client, invite_code):
|
||||||
|
"""Register a user via the API and return the response dict."""
|
||||||
|
resp = await client.post("/api/v1/auth/register", json={
|
||||||
|
"email": _TEST_EMAIL,
|
||||||
|
"password": _TEST_PASSWORD,
|
||||||
|
"display_name": "Test User",
|
||||||
|
"invite_code": invite_code,
|
||||||
|
})
|
||||||
|
assert resp.status_code == 201
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture()
|
||||||
|
async def auth_headers(client, registered_user):
|
||||||
|
"""Log in and return Authorization headers dict."""
|
||||||
|
resp = await client.post("/api/v1/auth/login", json={
|
||||||
|
"email": _TEST_EMAIL,
|
||||||
|
"password": _TEST_PASSWORD,
|
||||||
|
})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
token = resp.json()["access_token"]
|
||||||
|
return {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
|
|
||||||
314
backend/tests/test_auth.py
Normal file
314
backend/tests/test_auth.py
Normal file
|
|
@ -0,0 +1,314 @@
|
||||||
|
"""Integration tests for the auth router — registration, login, profile."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
|
from models import InviteCode, User
|
||||||
|
|
||||||
|
|
||||||
|
# ── Registration ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_valid(client, invite_code):
|
||||||
|
"""Register with a valid invite code → 201 + user created."""
|
||||||
|
resp = await client.post("/api/v1/auth/register", json={
|
||||||
|
"email": "newuser@example.com",
|
||||||
|
"password": "strongpass1",
|
||||||
|
"display_name": "New User",
|
||||||
|
"invite_code": invite_code,
|
||||||
|
})
|
||||||
|
assert resp.status_code == 201
|
||||||
|
data = resp.json()
|
||||||
|
assert data["email"] == "newuser@example.com"
|
||||||
|
assert data["display_name"] == "New User"
|
||||||
|
assert data["role"] == "creator"
|
||||||
|
assert "id" in data
|
||||||
|
# Password not leaked
|
||||||
|
assert "hashed_password" not in data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_invalid_invite_code(client, invite_code):
|
||||||
|
"""Register with a wrong invite code → 403."""
|
||||||
|
resp = await client.post("/api/v1/auth/register", json={
|
||||||
|
"email": "bad@example.com",
|
||||||
|
"password": "strongpass1",
|
||||||
|
"display_name": "Bad",
|
||||||
|
"invite_code": "WRONG-CODE",
|
||||||
|
})
|
||||||
|
assert resp.status_code == 403
|
||||||
|
assert "Invalid invite code" in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_expired_invite_code(client, db_engine):
|
||||||
|
"""Register with an expired invite code → 403."""
|
||||||
|
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||||
|
async with factory() as session:
|
||||||
|
past = (datetime.now(timezone.utc) - timedelta(days=1)).replace(tzinfo=None)
|
||||||
|
session.add(InviteCode(code="EXPIRED-CODE", uses_remaining=10, expires_at=past))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
resp = await client.post("/api/v1/auth/register", json={
|
||||||
|
"email": "exp@example.com",
|
||||||
|
"password": "strongpass1",
|
||||||
|
"display_name": "Expired",
|
||||||
|
"invite_code": "EXPIRED-CODE",
|
||||||
|
})
|
||||||
|
assert resp.status_code == 403
|
||||||
|
assert "expired" in resp.json()["detail"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_exhausted_invite_code(client, db_engine):
|
||||||
|
"""Register with an invite code that has uses_remaining=0 → 403."""
|
||||||
|
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||||
|
async with factory() as session:
|
||||||
|
session.add(InviteCode(code="EXHAUSTED", uses_remaining=0))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
resp = await client.post("/api/v1/auth/register", json={
|
||||||
|
"email": "nope@example.com",
|
||||||
|
"password": "strongpass1",
|
||||||
|
"display_name": "Nope",
|
||||||
|
"invite_code": "EXHAUSTED",
|
||||||
|
})
|
||||||
|
assert resp.status_code == 403
|
||||||
|
assert "exhausted" in resp.json()["detail"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_invite_code_decrements(client, db_engine):
|
||||||
|
"""Invite code uses_remaining decrements after registration."""
|
||||||
|
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||||
|
async with factory() as session:
|
||||||
|
session.add(InviteCode(code="SINGLE-USE", uses_remaining=1))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# First registration succeeds
|
||||||
|
resp = await client.post("/api/v1/auth/register", json={
|
||||||
|
"email": "first@example.com",
|
||||||
|
"password": "strongpass1",
|
||||||
|
"display_name": "First",
|
||||||
|
"invite_code": "SINGLE-USE",
|
||||||
|
})
|
||||||
|
assert resp.status_code == 201
|
||||||
|
|
||||||
|
# Second registration with same code fails (exhausted)
|
||||||
|
resp = await client.post("/api/v1/auth/register", json={
|
||||||
|
"email": "second@example.com",
|
||||||
|
"password": "strongpass1",
|
||||||
|
"display_name": "Second",
|
||||||
|
"invite_code": "SINGLE-USE",
|
||||||
|
})
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_duplicate_email(client, invite_code, registered_user):
|
||||||
|
"""Register with an already-used email → 409."""
|
||||||
|
resp = await client.post("/api/v1/auth/register", json={
|
||||||
|
"email": "testuser@chrysopedia.com",
|
||||||
|
"password": "anotherpass1",
|
||||||
|
"display_name": "Dup",
|
||||||
|
"invite_code": invite_code,
|
||||||
|
})
|
||||||
|
assert resp.status_code == 409
|
||||||
|
assert "already registered" in resp.json()["detail"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Login ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_success(client, registered_user):
|
||||||
|
"""Login with correct credentials → 200 + JWT."""
|
||||||
|
resp = await client.post("/api/v1/auth/login", json={
|
||||||
|
"email": "testuser@chrysopedia.com",
|
||||||
|
"password": "securepass123",
|
||||||
|
})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert "access_token" in data
|
||||||
|
assert data["token_type"] == "bearer"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_wrong_password(client, registered_user):
|
||||||
|
"""Login with wrong password → 401."""
|
||||||
|
resp = await client.post("/api/v1/auth/login", json={
|
||||||
|
"email": "testuser@chrysopedia.com",
|
||||||
|
"password": "wrongpassword",
|
||||||
|
})
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_nonexistent_email(client):
|
||||||
|
"""Login with an email that doesn't exist → 401."""
|
||||||
|
resp = await client.post("/api/v1/auth/login", json={
|
||||||
|
"email": "nobody@example.com",
|
||||||
|
"password": "somepass123",
|
||||||
|
})
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
# ── Profile (GET /me) ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_me_authenticated(client, auth_headers):
|
||||||
|
"""GET /me with valid token → 200 + profile."""
|
||||||
|
resp = await client.get("/api/v1/auth/me", headers=auth_headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["email"] == "testuser@chrysopedia.com"
|
||||||
|
assert data["display_name"] == "Test User"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_me_no_token(client, db_engine):
|
||||||
|
"""GET /me without token → 401."""
|
||||||
|
resp = await client.get("/api/v1/auth/me")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_me_invalid_token(client, db_engine):
|
||||||
|
"""GET /me with garbage token → 401."""
|
||||||
|
resp = await client.get("/api/v1/auth/me", headers={
|
||||||
|
"Authorization": "Bearer invalid.garbage.token",
|
||||||
|
})
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_me_expired_token(client, db_engine, invite_code):
|
||||||
|
"""GET /me with an expired JWT → 401."""
|
||||||
|
from auth import create_access_token
|
||||||
|
|
||||||
|
# Register a user first
|
||||||
|
resp = await client.post("/api/v1/auth/register", json={
|
||||||
|
"email": "expired@example.com",
|
||||||
|
"password": "strongpass1",
|
||||||
|
"display_name": "Expired Token User",
|
||||||
|
"invite_code": invite_code,
|
||||||
|
})
|
||||||
|
assert resp.status_code == 201
|
||||||
|
user_id = resp.json()["id"]
|
||||||
|
|
||||||
|
# Create a token that expires immediately
|
||||||
|
token = create_access_token(user_id, "creator", expires_minutes=-1)
|
||||||
|
resp = await client.get("/api/v1/auth/me", headers={
|
||||||
|
"Authorization": f"Bearer {token}",
|
||||||
|
})
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
# ── Profile (PUT /me) ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_display_name(client, auth_headers):
|
||||||
|
"""PUT /me updates display_name → 200 + new name."""
|
||||||
|
resp = await client.put("/api/v1/auth/me", json={
|
||||||
|
"display_name": "Updated Name",
|
||||||
|
}, headers=auth_headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["display_name"] == "Updated Name"
|
||||||
|
|
||||||
|
# Verify persistence
|
||||||
|
resp2 = await client.get("/api/v1/auth/me", headers=auth_headers)
|
||||||
|
assert resp2.json()["display_name"] == "Updated Name"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_password(client, invite_code):
|
||||||
|
"""PUT /me changes password → can login with new password."""
|
||||||
|
# Register
|
||||||
|
await client.post("/api/v1/auth/register", json={
|
||||||
|
"email": "pwchange@example.com",
|
||||||
|
"password": "oldpassword1",
|
||||||
|
"display_name": "PW User",
|
||||||
|
"invite_code": invite_code,
|
||||||
|
})
|
||||||
|
# Login
|
||||||
|
login_resp = await client.post("/api/v1/auth/login", json={
|
||||||
|
"email": "pwchange@example.com",
|
||||||
|
"password": "oldpassword1",
|
||||||
|
})
|
||||||
|
token = login_resp.json()["access_token"]
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
# Change password
|
||||||
|
resp = await client.put("/api/v1/auth/me", json={
|
||||||
|
"current_password": "oldpassword1",
|
||||||
|
"new_password": "newpassword1",
|
||||||
|
}, headers=headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
# Old password fails
|
||||||
|
resp_old = await client.post("/api/v1/auth/login", json={
|
||||||
|
"email": "pwchange@example.com",
|
||||||
|
"password": "oldpassword1",
|
||||||
|
})
|
||||||
|
assert resp_old.status_code == 401
|
||||||
|
|
||||||
|
# New password works
|
||||||
|
resp_new = await client.post("/api/v1/auth/login", json={
|
||||||
|
"email": "pwchange@example.com",
|
||||||
|
"password": "newpassword1",
|
||||||
|
})
|
||||||
|
assert resp_new.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
# ── Malformed inputs (422 validation) ───────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_missing_fields(client):
|
||||||
|
"""Register with missing fields → 422."""
|
||||||
|
resp = await client.post("/api/v1/auth/register", json={})
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_empty_password(client):
|
||||||
|
"""Register with empty password → 422 (min_length=8)."""
|
||||||
|
resp = await client.post("/api/v1/auth/register", json={
|
||||||
|
"email": "a@b.com",
|
||||||
|
"password": "",
|
||||||
|
"display_name": "X",
|
||||||
|
"invite_code": "CODE",
|
||||||
|
})
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_empty_body(client):
|
||||||
|
"""Login with empty body → 422."""
|
||||||
|
resp = await client.post("/api/v1/auth/login", json={})
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
# ── Public endpoints unaffected ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_public_techniques_no_auth(client, db_engine):
|
||||||
|
"""GET /api/v1/techniques works without auth."""
|
||||||
|
resp = await client.get("/api/v1/techniques")
|
||||||
|
# 200 even if empty — no 401/403
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_public_creators_no_auth(client, db_engine):
|
||||||
|
"""GET /api/v1/creators works without auth."""
|
||||||
|
resp = await client.get("/api/v1/creators")
|
||||||
|
assert resp.status_code == 200
|
||||||
Loading…
Add table
Reference in a new issue