"""Fractafrag — JWT Authentication middleware and dependencies.""" from datetime import datetime, timedelta, timezone from uuid import UUID from typing import Optional from fastapi import Depends, HTTPException, status, Request, Response from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from jose import jwt, JWTError import bcrypt from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from app.config import get_settings from app.database import get_db from app.models import User from app.redis import get_redis settings = get_settings() bearer_scheme = HTTPBearer(auto_error=False) # ── Password Hashing ────────────────────────────────────── def hash_password(password: str) -> str: return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt(rounds=12)).decode("utf-8") def verify_password(plain: str, hashed: str) -> bool: return bcrypt.checkpw(plain.encode("utf-8"), hashed.encode("utf-8")) # ── JWT Token Management ────────────────────────────────── def create_access_token(user_id: UUID, username: str, role: str, tier: str) -> str: payload = { "sub": str(user_id), "username": username, "role": role, "tier": tier, "iat": datetime.now(timezone.utc), "exp": datetime.now(timezone.utc) + timedelta(minutes=settings.jwt_access_token_expire_minutes), } return jwt.encode(payload, settings.jwt_secret, algorithm=settings.jwt_algorithm) def create_refresh_token(user_id: UUID) -> str: payload = { "sub": str(user_id), "type": "refresh", "iat": datetime.now(timezone.utc), "exp": datetime.now(timezone.utc) + timedelta(days=settings.jwt_refresh_token_expire_days), } return jwt.encode(payload, settings.jwt_secret, algorithm=settings.jwt_algorithm) def decode_token(token: str) -> dict: try: return jwt.decode(token, settings.jwt_secret, algorithms=[settings.jwt_algorithm]) except JWTError: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token") # ── Refresh Token Blocklist (Redis) ─────────────────────── async def is_token_blocklisted(token: str) -> bool: redis = await get_redis() return await redis.exists(f"blocklist:{token}") async def blocklist_token(token: str, ttl_seconds: int): redis = await get_redis() await redis.setex(f"blocklist:{token}", ttl_seconds, "1") # ── FastAPI Dependencies ────────────────────────────────── async def get_current_user( credentials: Optional[HTTPAuthorizationCredentials] = Depends(bearer_scheme), db: AsyncSession = Depends(get_db), ) -> User: """Require authentication. Returns the current user. Supports: - JWT Bearer tokens (normal user auth) - Internal service token: 'Bearer internal:' from MCP/worker """ if credentials is None: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") token = credentials.credentials # Internal service auth — MCP server and workers use this to act as the system account if token.startswith("internal:"): from app.models.models import SYSTEM_USER_ID result = await db.execute(select(User).where(User.id == SYSTEM_USER_ID)) user = result.scalar_one_or_none() if user: return user raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="System user not found") payload = decode_token(token) if payload.get("type") == "refresh": raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Cannot use refresh token for API access") user_id = payload.get("sub") if not user_id: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload") result = await db.execute(select(User).where(User.id == UUID(user_id))) user = result.scalar_one_or_none() if user is None: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found") return user async def get_optional_user( credentials: Optional[HTTPAuthorizationCredentials] = Depends(bearer_scheme), db: AsyncSession = Depends(get_db), ) -> Optional[User]: """Optional authentication. Returns user or None for anonymous requests.""" if credentials is None: return None try: payload = decode_token(credentials.credentials) if payload.get("type") == "refresh": return None user_id = payload.get("sub") if not user_id: return None result = await db.execute(select(User).where(User.id == UUID(user_id))) return result.scalar_one_or_none() except HTTPException: return None def require_role(*roles: str): """Dependency factory: require user to have one of the specified roles.""" async def check_role(user: User = Depends(get_current_user)) -> User: if user.role not in roles: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient permissions") return user return check_role def require_tier(*tiers: str): """Dependency factory: require user to have one of the specified subscription tiers.""" async def check_tier(user: User = Depends(get_current_user)) -> User: if user.subscription_tier not in tiers: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"This feature requires one of: {', '.join(tiers)}" ) return user return check_tier