promptlooper/backend/auth.py

154 lines
5.6 KiB
Python

"""PromptLooper authentication — JWT tokens, API keys, first-boot setup."""
import uuid as _uuid
from datetime import datetime, timedelta, timezone
from typing import Generator
from fastapi import Depends, HTTPException, Header, status
from jose import JWTError, jwt
from passlib.context import CryptContext
from sqlalchemy.orm import Session
from config import settings
from models import User
# ---------------------------------------------------------------------------
# Password hashing
# ---------------------------------------------------------------------------
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def hash_password(password: str) -> str:
return pwd_context.hash(password)
def verify_password(plain: str, hashed: str) -> bool:
return pwd_context.verify(plain, hashed)
# ---------------------------------------------------------------------------
# JWT
# ---------------------------------------------------------------------------
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 # 24 hours
def create_access_token(user_id: str, *, expires_delta: timedelta | None = None) -> str:
expire = datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
payload = {"sub": user_id, "exp": expire}
return jwt.encode(payload, settings.jwt_secret, algorithm=ALGORITHM)
def decode_access_token(token: str) -> str:
"""Return the user_id (sub) from a valid JWT, or raise."""
try:
payload = jwt.decode(token, settings.jwt_secret, algorithms=[ALGORITHM])
user_id: str | None = payload.get("sub")
if user_id is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
return user_id
except JWTError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
# ---------------------------------------------------------------------------
# First-boot setup
# ---------------------------------------------------------------------------
def needs_setup(db: Session) -> bool:
"""Return True if no users exist yet (first-boot state)."""
return db.query(User).count() == 0
def create_admin(db: Session, username: str, password: str) -> User:
"""Create the first admin user. Raises if users already exist."""
if not needs_setup(db):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Admin account already exists",
)
user = User(
username=username,
password_hash=hash_password(password),
is_admin=True,
)
db.add(user)
db.commit()
db.refresh(user)
return user
# ---------------------------------------------------------------------------
# Authenticate (login)
# ---------------------------------------------------------------------------
def authenticate_user(db: Session, username: str, password: str) -> User:
"""Verify credentials and return the User, or raise 401."""
user = db.query(User).filter(User.username == username).first()
if user is None or not verify_password(password, user.password_hash):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
return user
# ---------------------------------------------------------------------------
# Database session dependency (local to avoid circular import with main.py)
# ---------------------------------------------------------------------------
def _get_db() -> Generator[Session, None, None]:
"""Yield a DB session. Imported lazily from main to avoid circular import."""
from main import get_db
yield from get_db()
# ---------------------------------------------------------------------------
# Dependency: get current user (JWT or API key)
# ---------------------------------------------------------------------------
def get_current_user(
authorization: str | None = Header(None),
x_api_key: str | None = Header(None),
db: Session = Depends(_get_db),
) -> User:
"""FastAPI dependency — resolve the current user from JWT Bearer token or API key.
Priority:
1. X-Api-Key header — matched against settings.api_key (grants first admin).
2. Authorization: Bearer <jwt> — decoded to get user_id.
"""
# --- API key path ---
if x_api_key is not None:
if settings.api_key is None or x_api_key != settings.api_key:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key")
# API key grants the first admin user
admin = db.query(User).filter(User.is_admin.is_(True)).first()
if admin is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="No admin user exists")
return admin
# --- JWT path ---
if authorization is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing authentication",
headers={"WWW-Authenticate": "Bearer"},
)
scheme, _, token = authorization.partition(" ")
if scheme.lower() != "bearer" or not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authorization header",
headers={"WWW-Authenticate": "Bearer"},
)
user_id_str = decode_access_token(token)
try:
user_id = _uuid.UUID(user_id_str)
except ValueError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
return user