feat: Added write_mode support to impersonation tokens with conditional…

- "backend/auth.py"
- "backend/models.py"
- "backend/routers/admin.py"
- "backend/tests/test_impersonation.py"

GSD-Task: S07/T01
This commit is contained in:
jlightner 2026-04-04 06:24:04 +00:00
parent 5be499d0ad
commit 5a39850a35
4 changed files with 256 additions and 8 deletions

View file

@ -63,11 +63,14 @@ def create_impersonation_token(
admin_user_id: uuid.UUID | str,
target_user_id: uuid.UUID | str,
target_role: str,
*,
write_mode: bool = False,
) -> str:
"""Create a scoped JWT for admin impersonation.
The token has sub=target_user_id so get_current_user loads the target,
plus original_user_id so the system knows it's impersonation.
When write_mode is True, the token allows write operations.
"""
settings = get_settings()
now = datetime.now(timezone.utc)
@ -79,6 +82,8 @@ def create_impersonation_token(
"iat": now,
"exp": now + timedelta(minutes=_IMPERSONATION_EXPIRE_MINUTES),
}
if write_mode:
payload["write_mode"] = True
return jwt.encode(payload, settings.app_secret_key, algorithm=_ALGORITHM)
@ -127,6 +132,7 @@ async def get_current_user(
)
# Attach impersonation metadata (non-column runtime attribute)
user._impersonating_admin_id = payload.get("original_user_id") # type: ignore[attr-defined]
user._impersonation_write_mode = payload.get("write_mode", False) # type: ignore[attr-defined]
return user
@ -149,9 +155,15 @@ def require_role(required_role: UserRole):
async def reject_impersonation(
current_user: Annotated[User, Depends(get_current_user)],
) -> User:
"""Dependency that blocks write operations during impersonation."""
"""Dependency that blocks write operations during impersonation.
If the impersonation token was issued with write_mode=True,
writes are permitted.
"""
admin_id = getattr(current_user, "_impersonating_admin_id", None)
if admin_id is not None:
write_mode = getattr(current_user, "_impersonation_write_mode", False)
if not write_mode:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Write operations are not allowed during impersonation",

View file

@ -21,6 +21,7 @@ from sqlalchemy import (
Text,
UniqueConstraint,
func,
text,
)
from sqlalchemy.dialects.postgresql import ARRAY, JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column
@ -691,6 +692,9 @@ class ImpersonationLog(Base):
action: Mapped[str] = mapped_column(
String(10), nullable=False, doc="'start' or 'stop'"
)
write_mode: Mapped[bool] = mapped_column(
default=False, server_default=text("false"),
)
ip_address: Mapped[str | None] = mapped_column(String(45), nullable=True)
created_at: Mapped[datetime] = mapped_column(
default=_now, server_default=func.now()

View file

@ -3,13 +3,15 @@
from __future__ import annotations
import logging
from datetime import datetime
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import aliased
from auth import (
create_impersonation_token,
@ -52,6 +54,20 @@ class StopImpersonateResponse(BaseModel):
message: str
class StartImpersonationRequest(BaseModel):
write_mode: bool = False
class ImpersonationLogItem(BaseModel):
id: str
admin_name: str
target_name: str
action: str
write_mode: bool
ip_address: str | None
created_at: datetime
# ── Helpers ──────────────────────────────────────────────────────────────────
@ -97,8 +113,12 @@ async def start_impersonation(
request: Request,
admin: Annotated[User, Depends(_require_admin)],
session: Annotated[AsyncSession, Depends(get_session)],
body: StartImpersonationRequest | None = None,
):
"""Start impersonating a user. Admin only. Returns a scoped JWT."""
if body is None:
body = StartImpersonationRequest()
# Cannot impersonate yourself
if admin.id == user_id:
raise HTTPException(
@ -120,6 +140,7 @@ async def start_impersonation(
admin_user_id=admin.id,
target_user_id=target.id,
target_role=target.role.value,
write_mode=body.write_mode,
)
# Audit log
@ -127,13 +148,14 @@ async def start_impersonation(
admin_user_id=admin.id,
target_user_id=target.id,
action="start",
write_mode=body.write_mode,
ip_address=_client_ip(request),
))
await session.commit()
logger.info(
"Impersonation started: admin=%s target=%s",
admin.id, target.id,
"Impersonation started: admin=%s target=%s write_mode=%s",
admin.id, target.id, body.write_mode,
)
return ImpersonateResponse(
@ -178,3 +200,39 @@ async def stop_impersonation(
)
return StopImpersonateResponse(message="Impersonation ended")
@router.get("/impersonation-log", response_model=list[ImpersonationLogItem])
async def get_impersonation_log(
_admin: Annotated[User, Depends(_require_admin)],
session: Annotated[AsyncSession, Depends(get_session)],
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=200),
):
"""Paginated impersonation audit log. Admin only."""
AdminUser = aliased(User, name="admin_user")
TargetUser = aliased(User, name="target_user")
stmt = (
select(ImpersonationLog, AdminUser.display_name, TargetUser.display_name)
.join(AdminUser, ImpersonationLog.admin_user_id == AdminUser.id)
.join(TargetUser, ImpersonationLog.target_user_id == TargetUser.id)
.order_by(ImpersonationLog.created_at.desc())
.offset((page - 1) * page_size)
.limit(page_size)
)
result = await session.execute(stmt)
rows = result.all()
return [
ImpersonationLogItem(
id=str(log.id),
admin_name=admin_name,
target_name=target_name,
action=log.action,
write_mode=log.write_mode,
ip_address=log.ip_address,
created_at=log.created_at,
)
for log, admin_name, target_name in rows
]

View file

@ -0,0 +1,174 @@
"""Integration tests for impersonation write-mode and audit log."""
import pytest
import pytest_asyncio
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from models import InviteCode, User, UserRole
# Re-use fixtures from conftest: db_engine, client, admin_auth
_TARGET_EMAIL = "impersonate-target@chrysopedia.com"
_TARGET_PASSWORD = "targetpass123"
_TARGET_INVITE = "IMP-TARGET-INV"
@pytest_asyncio.fixture()
async def target_user(client: AsyncClient, db_engine):
"""Register a regular user to be the impersonation target."""
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
async with factory() as session:
code = InviteCode(code=_TARGET_INVITE, uses_remaining=10)
session.add(code)
await session.commit()
resp = await client.post("/api/v1/auth/register", json={
"email": _TARGET_EMAIL,
"password": _TARGET_PASSWORD,
"display_name": "Target User",
"invite_code": _TARGET_INVITE,
})
assert resp.status_code == 201
return resp.json()
# ── Write-mode tests ────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_impersonation_without_write_mode_blocks_writes(
client: AsyncClient, admin_auth, target_user,
):
"""Read-only impersonation (default) should 403 on PUT /auth/me."""
# Start impersonation without write_mode
resp = await client.post(
f"/api/v1/admin/impersonate/{target_user['id']}",
headers=admin_auth["headers"],
)
assert resp.status_code == 200
imp_token = resp.json()["access_token"]
imp_headers = {"Authorization": f"Bearer {imp_token}"}
# Attempt a write operation — should be blocked
resp = await client.put(
"/api/v1/auth/me",
headers=imp_headers,
json={"display_name": "Hacked Name"},
)
assert resp.status_code == 403
assert "impersonation" in resp.json()["detail"].lower()
@pytest.mark.asyncio
async def test_impersonation_with_write_mode_allows_writes(
client: AsyncClient, admin_auth, target_user,
):
"""Write-mode impersonation should not 403 on PUT /auth/me."""
# Start impersonation WITH write_mode
resp = await client.post(
f"/api/v1/admin/impersonate/{target_user['id']}",
headers=admin_auth["headers"],
json={"write_mode": True},
)
assert resp.status_code == 200
imp_token = resp.json()["access_token"]
imp_headers = {"Authorization": f"Bearer {imp_token}"}
# Attempt a write — should NOT get 403 from reject_impersonation
resp = await client.put(
"/api/v1/auth/me",
headers=imp_headers,
json={"display_name": "Updated Via WriteMode"},
)
# Should succeed (200) or at least not be a 403
assert resp.status_code != 403
# Verify the update actually took effect
assert resp.status_code == 200
assert resp.json()["display_name"] == "Updated Via WriteMode"
# ── Audit log endpoint tests ────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_impersonation_log_returns_entries(
client: AsyncClient, admin_auth, target_user,
):
"""GET /admin/impersonation-log returns log entries with names."""
# Create some log entries by starting impersonation
resp = await client.post(
f"/api/v1/admin/impersonate/{target_user['id']}",
headers=admin_auth["headers"],
json={"write_mode": True},
)
assert resp.status_code == 200
# Fetch the log
resp = await client.get(
"/api/v1/admin/impersonation-log",
headers=admin_auth["headers"],
)
assert resp.status_code == 200
logs = resp.json()
assert len(logs) >= 1
entry = logs[0]
assert entry["admin_name"] == "Admin User"
assert entry["target_name"] == "Target User"
assert entry["action"] == "start"
assert entry["write_mode"] is True
assert "id" in entry
assert "created_at" in entry
@pytest.mark.asyncio
async def test_impersonation_log_non_admin_forbidden(
client: AsyncClient, target_user,
):
"""Non-admin users cannot access the impersonation log."""
# Login as the target (regular) user
resp = await client.post("/api/v1/auth/login", json={
"email": _TARGET_EMAIL,
"password": _TARGET_PASSWORD,
})
assert resp.status_code == 200
user_headers = {"Authorization": f"Bearer {resp.json()['access_token']}"}
resp = await client.get(
"/api/v1/admin/impersonation-log",
headers=user_headers,
)
assert resp.status_code == 403
@pytest.mark.asyncio
async def test_impersonation_log_pagination(
client: AsyncClient, admin_auth, target_user,
):
"""Verify pagination params work on impersonation-log."""
# Create two entries
for _ in range(2):
await client.post(
f"/api/v1/admin/impersonate/{target_user['id']}",
headers=admin_auth["headers"],
)
# Fetch page 1, page_size=1
resp = await client.get(
"/api/v1/admin/impersonation-log",
headers=admin_auth["headers"],
params={"page": 1, "page_size": 1},
)
assert resp.status_code == 200
assert len(resp.json()) == 1
# Fetch page 2
resp = await client.get(
"/api/v1/admin/impersonation-log",
headers=admin_auth["headers"],
params={"page": 2, "page_size": 1},
)
assert resp.status_code == 200
assert len(resp.json()) == 1