promptlooper/backend/tests/test_auth.py

238 lines
8.7 KiB
Python

"""Tests for backend/auth.py — JWT, API key, setup flow, and auth dependency."""
import os
from datetime import timedelta
from unittest.mock import patch
import pytest
from fastapi import FastAPI, Depends
from fastapi.testclient import TestClient
@pytest.fixture(autouse=True)
def _isolate_settings(tmp_path):
"""Ensure tests use a temp SQLite DB and no Redis."""
env = {
"DATABASE_URL": f"sqlite:///{tmp_path / 'test.db'}",
"REDIS_URL": "",
"DATA_DIR": str(tmp_path),
"JWT_SECRET": "test-secret-key-for-jwt-signing",
"API_KEY": "test-api-key-12345",
}
with patch.dict(os.environ, env, clear=False):
import config
new_settings = config.Settings(_env_file=None)
config.settings = new_settings
import main
main.settings = new_settings
main._init_db()
main._init_redis()
from models import Base
Base.metadata.create_all(bind=main.engine)
# Also patch auth module's settings reference
import auth
auth.settings = new_settings
yield
@pytest.fixture
def db_session():
from main import get_db
gen = get_db()
session = next(gen)
yield session
try:
next(gen)
except StopIteration:
pass
# ---------------------------------------------------------------------------
# Password hashing
# ---------------------------------------------------------------------------
class TestPasswordHashing:
def test_hash_and_verify(self):
from auth import hash_password, verify_password
hashed = hash_password("my-secret-password")
assert hashed != "my-secret-password"
assert verify_password("my-secret-password", hashed)
def test_wrong_password_fails(self):
from auth import hash_password, verify_password
hashed = hash_password("correct-password")
assert not verify_password("wrong-password", hashed)
# ---------------------------------------------------------------------------
# JWT
# ---------------------------------------------------------------------------
class TestJWT:
def test_create_and_decode_token(self):
from auth import create_access_token, decode_access_token
token = create_access_token("user-123")
assert decode_access_token(token) == "user-123"
def test_expired_token_raises(self):
from auth import create_access_token, decode_access_token
token = create_access_token("user-123", expires_delta=timedelta(seconds=-1))
with pytest.raises(Exception) as exc_info:
decode_access_token(token)
assert exc_info.value.status_code == 401
def test_invalid_token_raises(self):
from auth import decode_access_token
with pytest.raises(Exception) as exc_info:
decode_access_token("not-a-valid-token")
assert exc_info.value.status_code == 401
def test_token_without_sub_raises(self):
from jose import jwt
import config
token = jwt.encode({"foo": "bar"}, config.settings.jwt_secret, algorithm="HS256")
from auth import decode_access_token
with pytest.raises(Exception) as exc_info:
decode_access_token(token)
assert exc_info.value.status_code == 401
# ---------------------------------------------------------------------------
# First-boot setup
# ---------------------------------------------------------------------------
class TestSetup:
def test_needs_setup_true_when_no_users(self, db_session):
from auth import needs_setup
assert needs_setup(db_session) is True
def test_create_admin_succeeds(self, db_session):
from auth import create_admin, needs_setup
user = create_admin(db_session, "admin", "password123")
assert user.username == "admin"
assert user.is_admin is True
assert needs_setup(db_session) is False
def test_create_admin_twice_raises_409(self, db_session):
from auth import create_admin
create_admin(db_session, "admin", "password123")
with pytest.raises(Exception) as exc_info:
create_admin(db_session, "admin2", "password456")
assert exc_info.value.status_code == 409
def test_admin_password_is_hashed(self, db_session):
from auth import create_admin
user = create_admin(db_session, "admin", "password123")
assert user.password_hash != "password123"
assert user.password_hash.startswith("$2b$")
# ---------------------------------------------------------------------------
# Authenticate user (login)
# ---------------------------------------------------------------------------
class TestAuthenticateUser:
def test_valid_credentials(self, db_session):
from auth import create_admin, authenticate_user
create_admin(db_session, "admin", "password123")
user = authenticate_user(db_session, "admin", "password123")
assert user.username == "admin"
def test_wrong_password_raises_401(self, db_session):
from auth import create_admin, authenticate_user
create_admin(db_session, "admin", "password123")
with pytest.raises(Exception) as exc_info:
authenticate_user(db_session, "admin", "wrong")
assert exc_info.value.status_code == 401
def test_unknown_user_raises_401(self, db_session):
from auth import authenticate_user
with pytest.raises(Exception) as exc_info:
authenticate_user(db_session, "nonexistent", "password")
assert exc_info.value.status_code == 401
# ---------------------------------------------------------------------------
# get_current_user dependency (integration via test app)
# ---------------------------------------------------------------------------
@pytest.fixture
def auth_app():
"""Create a minimal FastAPI app with a protected endpoint for testing auth."""
from auth import get_current_user
from schemas import UserResponse
test_app = FastAPI()
@test_app.get("/protected")
def protected(user=Depends(get_current_user)):
return {"user_id": str(user.id), "username": user.username}
return test_app
@pytest.fixture
def auth_client(auth_app):
return TestClient(auth_app)
class TestGetCurrentUser:
def test_no_auth_returns_401(self, auth_client):
resp = auth_client.get("/protected")
assert resp.status_code == 401
assert "Missing authentication" in resp.json()["detail"]
def test_invalid_bearer_format_returns_401(self, auth_client):
resp = auth_client.get("/protected", headers={"Authorization": "NotBearer token"})
assert resp.status_code == 401
def test_jwt_auth_succeeds(self, auth_client, db_session):
from auth import create_admin, create_access_token
user = create_admin(db_session, "admin", "password123")
token = create_access_token(str(user.id))
resp = auth_client.get("/protected", headers={"Authorization": f"Bearer {token}"})
assert resp.status_code == 200
assert resp.json()["username"] == "admin"
def test_jwt_for_deleted_user_returns_401(self, auth_client, db_session):
from auth import create_access_token
import uuid
token = create_access_token(str(uuid.uuid4()))
resp = auth_client.get("/protected", headers={"Authorization": f"Bearer {token}"})
assert resp.status_code == 401
def test_api_key_auth_succeeds(self, auth_client, db_session):
from auth import create_admin
create_admin(db_session, "admin", "password123")
resp = auth_client.get("/protected", headers={"X-Api-Key": "test-api-key-12345"})
assert resp.status_code == 200
assert resp.json()["username"] == "admin"
def test_wrong_api_key_returns_401(self, auth_client):
resp = auth_client.get("/protected", headers={"X-Api-Key": "wrong-key"})
assert resp.status_code == 401
def test_api_key_without_admin_returns_401(self, auth_client):
# No admin user created yet
resp = auth_client.get("/protected", headers={"X-Api-Key": "test-api-key-12345"})
assert resp.status_code == 401
def test_api_key_disabled_when_not_configured(self, auth_client, db_session):
"""When API_KEY is not set in config, API key auth should fail."""
from auth import create_admin
import config, auth
create_admin(db_session, "admin", "password123")
old_key = config.settings.api_key
config.settings.api_key = None
auth.settings = config.settings
try:
resp = auth_client.get("/protected", headers={"X-Api-Key": "test-api-key-12345"})
assert resp.status_code == 401
finally:
config.settings.api_key = old_key
auth.settings = config.settings