feat: Added write_mode support to impersonation tokens with conditional…
- "backend/auth.py" - "backend/models.py" - "backend/routers/admin.py" - "backend/tests/test_impersonation.py" GSD-Task: S07/T01
This commit is contained in:
parent
5be499d0ad
commit
5a39850a35
4 changed files with 256 additions and 8 deletions
|
|
@ -63,11 +63,14 @@ def create_impersonation_token(
|
||||||
admin_user_id: uuid.UUID | str,
|
admin_user_id: uuid.UUID | str,
|
||||||
target_user_id: uuid.UUID | str,
|
target_user_id: uuid.UUID | str,
|
||||||
target_role: str,
|
target_role: str,
|
||||||
|
*,
|
||||||
|
write_mode: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a scoped JWT for admin impersonation.
|
"""Create a scoped JWT for admin impersonation.
|
||||||
|
|
||||||
The token has sub=target_user_id so get_current_user loads the target,
|
The token has sub=target_user_id so get_current_user loads the target,
|
||||||
plus original_user_id so the system knows it's impersonation.
|
plus original_user_id so the system knows it's impersonation.
|
||||||
|
When write_mode is True, the token allows write operations.
|
||||||
"""
|
"""
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
|
|
@ -79,6 +82,8 @@ def create_impersonation_token(
|
||||||
"iat": now,
|
"iat": now,
|
||||||
"exp": now + timedelta(minutes=_IMPERSONATION_EXPIRE_MINUTES),
|
"exp": now + timedelta(minutes=_IMPERSONATION_EXPIRE_MINUTES),
|
||||||
}
|
}
|
||||||
|
if write_mode:
|
||||||
|
payload["write_mode"] = True
|
||||||
return jwt.encode(payload, settings.app_secret_key, algorithm=_ALGORITHM)
|
return jwt.encode(payload, settings.app_secret_key, algorithm=_ALGORITHM)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -127,6 +132,7 @@ async def get_current_user(
|
||||||
)
|
)
|
||||||
# Attach impersonation metadata (non-column runtime attribute)
|
# Attach impersonation metadata (non-column runtime attribute)
|
||||||
user._impersonating_admin_id = payload.get("original_user_id") # type: ignore[attr-defined]
|
user._impersonating_admin_id = payload.get("original_user_id") # type: ignore[attr-defined]
|
||||||
|
user._impersonation_write_mode = payload.get("write_mode", False) # type: ignore[attr-defined]
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -149,11 +155,17 @@ def require_role(required_role: UserRole):
|
||||||
async def reject_impersonation(
|
async def reject_impersonation(
|
||||||
current_user: Annotated[User, Depends(get_current_user)],
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
) -> User:
|
) -> User:
|
||||||
"""Dependency that blocks write operations during impersonation."""
|
"""Dependency that blocks write operations during impersonation.
|
||||||
|
|
||||||
|
If the impersonation token was issued with write_mode=True,
|
||||||
|
writes are permitted.
|
||||||
|
"""
|
||||||
admin_id = getattr(current_user, "_impersonating_admin_id", None)
|
admin_id = getattr(current_user, "_impersonating_admin_id", None)
|
||||||
if admin_id is not None:
|
if admin_id is not None:
|
||||||
raise HTTPException(
|
write_mode = getattr(current_user, "_impersonation_write_mode", False)
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
if not write_mode:
|
||||||
detail="Write operations are not allowed during impersonation",
|
raise HTTPException(
|
||||||
)
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Write operations are not allowed during impersonation",
|
||||||
|
)
|
||||||
return current_user
|
return current_user
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ from sqlalchemy import (
|
||||||
Text,
|
Text,
|
||||||
UniqueConstraint,
|
UniqueConstraint,
|
||||||
func,
|
func,
|
||||||
|
text,
|
||||||
)
|
)
|
||||||
from sqlalchemy.dialects.postgresql import ARRAY, JSONB, UUID
|
from sqlalchemy.dialects.postgresql import ARRAY, JSONB, UUID
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
@ -691,6 +692,9 @@ class ImpersonationLog(Base):
|
||||||
action: Mapped[str] = mapped_column(
|
action: Mapped[str] = mapped_column(
|
||||||
String(10), nullable=False, doc="'start' or 'stop'"
|
String(10), nullable=False, doc="'start' or 'stop'"
|
||||||
)
|
)
|
||||||
|
write_mode: Mapped[bool] = mapped_column(
|
||||||
|
default=False, server_default=text("false"),
|
||||||
|
)
|
||||||
ip_address: Mapped[str | None] = mapped_column(String(45), nullable=True)
|
ip_address: Mapped[str | None] = mapped_column(String(45), nullable=True)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
default=_now, server_default=func.now()
|
default=_now, server_default=func.now()
|
||||||
|
|
|
||||||
|
|
@ -3,13 +3,15 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import aliased
|
||||||
|
|
||||||
from auth import (
|
from auth import (
|
||||||
create_impersonation_token,
|
create_impersonation_token,
|
||||||
|
|
@ -52,6 +54,20 @@ class StopImpersonateResponse(BaseModel):
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class StartImpersonationRequest(BaseModel):
|
||||||
|
write_mode: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class ImpersonationLogItem(BaseModel):
|
||||||
|
id: str
|
||||||
|
admin_name: str
|
||||||
|
target_name: str
|
||||||
|
action: str
|
||||||
|
write_mode: bool
|
||||||
|
ip_address: str | None
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────────────
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -97,8 +113,12 @@ async def start_impersonation(
|
||||||
request: Request,
|
request: Request,
|
||||||
admin: Annotated[User, Depends(_require_admin)],
|
admin: Annotated[User, Depends(_require_admin)],
|
||||||
session: Annotated[AsyncSession, Depends(get_session)],
|
session: Annotated[AsyncSession, Depends(get_session)],
|
||||||
|
body: StartImpersonationRequest | None = None,
|
||||||
):
|
):
|
||||||
"""Start impersonating a user. Admin only. Returns a scoped JWT."""
|
"""Start impersonating a user. Admin only. Returns a scoped JWT."""
|
||||||
|
if body is None:
|
||||||
|
body = StartImpersonationRequest()
|
||||||
|
|
||||||
# Cannot impersonate yourself
|
# Cannot impersonate yourself
|
||||||
if admin.id == user_id:
|
if admin.id == user_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -120,6 +140,7 @@ async def start_impersonation(
|
||||||
admin_user_id=admin.id,
|
admin_user_id=admin.id,
|
||||||
target_user_id=target.id,
|
target_user_id=target.id,
|
||||||
target_role=target.role.value,
|
target_role=target.role.value,
|
||||||
|
write_mode=body.write_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Audit log
|
# Audit log
|
||||||
|
|
@ -127,13 +148,14 @@ async def start_impersonation(
|
||||||
admin_user_id=admin.id,
|
admin_user_id=admin.id,
|
||||||
target_user_id=target.id,
|
target_user_id=target.id,
|
||||||
action="start",
|
action="start",
|
||||||
|
write_mode=body.write_mode,
|
||||||
ip_address=_client_ip(request),
|
ip_address=_client_ip(request),
|
||||||
))
|
))
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Impersonation started: admin=%s target=%s",
|
"Impersonation started: admin=%s target=%s write_mode=%s",
|
||||||
admin.id, target.id,
|
admin.id, target.id, body.write_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImpersonateResponse(
|
return ImpersonateResponse(
|
||||||
|
|
@ -178,3 +200,39 @@ async def stop_impersonation(
|
||||||
)
|
)
|
||||||
|
|
||||||
return StopImpersonateResponse(message="Impersonation ended")
|
return StopImpersonateResponse(message="Impersonation ended")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/impersonation-log", response_model=list[ImpersonationLogItem])
|
||||||
|
async def get_impersonation_log(
|
||||||
|
_admin: Annotated[User, Depends(_require_admin)],
|
||||||
|
session: Annotated[AsyncSession, Depends(get_session)],
|
||||||
|
page: int = Query(1, ge=1),
|
||||||
|
page_size: int = Query(50, ge=1, le=200),
|
||||||
|
):
|
||||||
|
"""Paginated impersonation audit log. Admin only."""
|
||||||
|
AdminUser = aliased(User, name="admin_user")
|
||||||
|
TargetUser = aliased(User, name="target_user")
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(ImpersonationLog, AdminUser.display_name, TargetUser.display_name)
|
||||||
|
.join(AdminUser, ImpersonationLog.admin_user_id == AdminUser.id)
|
||||||
|
.join(TargetUser, ImpersonationLog.target_user_id == TargetUser.id)
|
||||||
|
.order_by(ImpersonationLog.created_at.desc())
|
||||||
|
.offset((page - 1) * page_size)
|
||||||
|
.limit(page_size)
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
rows = result.all()
|
||||||
|
|
||||||
|
return [
|
||||||
|
ImpersonationLogItem(
|
||||||
|
id=str(log.id),
|
||||||
|
admin_name=admin_name,
|
||||||
|
target_name=target_name,
|
||||||
|
action=log.action,
|
||||||
|
write_mode=log.write_mode,
|
||||||
|
ip_address=log.ip_address,
|
||||||
|
created_at=log.created_at,
|
||||||
|
)
|
||||||
|
for log, admin_name, target_name in rows
|
||||||
|
]
|
||||||
|
|
|
||||||
174
backend/tests/test_impersonation.py
Normal file
174
backend/tests/test_impersonation.py
Normal file
|
|
@ -0,0 +1,174 @@
|
||||||
|
"""Integration tests for impersonation write-mode and audit log."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from httpx import AsyncClient
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
|
from models import InviteCode, User, UserRole
|
||||||
|
|
||||||
|
# Re-use fixtures from conftest: db_engine, client, admin_auth
|
||||||
|
|
||||||
|
|
||||||
|
_TARGET_EMAIL = "impersonate-target@chrysopedia.com"
|
||||||
|
_TARGET_PASSWORD = "targetpass123"
|
||||||
|
_TARGET_INVITE = "IMP-TARGET-INV"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture()
|
||||||
|
async def target_user(client: AsyncClient, db_engine):
|
||||||
|
"""Register a regular user to be the impersonation target."""
|
||||||
|
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||||
|
async with factory() as session:
|
||||||
|
code = InviteCode(code=_TARGET_INVITE, uses_remaining=10)
|
||||||
|
session.add(code)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
resp = await client.post("/api/v1/auth/register", json={
|
||||||
|
"email": _TARGET_EMAIL,
|
||||||
|
"password": _TARGET_PASSWORD,
|
||||||
|
"display_name": "Target User",
|
||||||
|
"invite_code": _TARGET_INVITE,
|
||||||
|
})
|
||||||
|
assert resp.status_code == 201
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Write-mode tests ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_impersonation_without_write_mode_blocks_writes(
|
||||||
|
client: AsyncClient, admin_auth, target_user,
|
||||||
|
):
|
||||||
|
"""Read-only impersonation (default) should 403 on PUT /auth/me."""
|
||||||
|
# Start impersonation without write_mode
|
||||||
|
resp = await client.post(
|
||||||
|
f"/api/v1/admin/impersonate/{target_user['id']}",
|
||||||
|
headers=admin_auth["headers"],
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
imp_token = resp.json()["access_token"]
|
||||||
|
imp_headers = {"Authorization": f"Bearer {imp_token}"}
|
||||||
|
|
||||||
|
# Attempt a write operation — should be blocked
|
||||||
|
resp = await client.put(
|
||||||
|
"/api/v1/auth/me",
|
||||||
|
headers=imp_headers,
|
||||||
|
json={"display_name": "Hacked Name"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 403
|
||||||
|
assert "impersonation" in resp.json()["detail"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_impersonation_with_write_mode_allows_writes(
|
||||||
|
client: AsyncClient, admin_auth, target_user,
|
||||||
|
):
|
||||||
|
"""Write-mode impersonation should not 403 on PUT /auth/me."""
|
||||||
|
# Start impersonation WITH write_mode
|
||||||
|
resp = await client.post(
|
||||||
|
f"/api/v1/admin/impersonate/{target_user['id']}",
|
||||||
|
headers=admin_auth["headers"],
|
||||||
|
json={"write_mode": True},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
imp_token = resp.json()["access_token"]
|
||||||
|
imp_headers = {"Authorization": f"Bearer {imp_token}"}
|
||||||
|
|
||||||
|
# Attempt a write — should NOT get 403 from reject_impersonation
|
||||||
|
resp = await client.put(
|
||||||
|
"/api/v1/auth/me",
|
||||||
|
headers=imp_headers,
|
||||||
|
json={"display_name": "Updated Via WriteMode"},
|
||||||
|
)
|
||||||
|
# Should succeed (200) or at least not be a 403
|
||||||
|
assert resp.status_code != 403
|
||||||
|
# Verify the update actually took effect
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["display_name"] == "Updated Via WriteMode"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Audit log endpoint tests ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_impersonation_log_returns_entries(
|
||||||
|
client: AsyncClient, admin_auth, target_user,
|
||||||
|
):
|
||||||
|
"""GET /admin/impersonation-log returns log entries with names."""
|
||||||
|
# Create some log entries by starting impersonation
|
||||||
|
resp = await client.post(
|
||||||
|
f"/api/v1/admin/impersonate/{target_user['id']}",
|
||||||
|
headers=admin_auth["headers"],
|
||||||
|
json={"write_mode": True},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
# Fetch the log
|
||||||
|
resp = await client.get(
|
||||||
|
"/api/v1/admin/impersonation-log",
|
||||||
|
headers=admin_auth["headers"],
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
logs = resp.json()
|
||||||
|
assert len(logs) >= 1
|
||||||
|
|
||||||
|
entry = logs[0]
|
||||||
|
assert entry["admin_name"] == "Admin User"
|
||||||
|
assert entry["target_name"] == "Target User"
|
||||||
|
assert entry["action"] == "start"
|
||||||
|
assert entry["write_mode"] is True
|
||||||
|
assert "id" in entry
|
||||||
|
assert "created_at" in entry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_impersonation_log_non_admin_forbidden(
|
||||||
|
client: AsyncClient, target_user,
|
||||||
|
):
|
||||||
|
"""Non-admin users cannot access the impersonation log."""
|
||||||
|
# Login as the target (regular) user
|
||||||
|
resp = await client.post("/api/v1/auth/login", json={
|
||||||
|
"email": _TARGET_EMAIL,
|
||||||
|
"password": _TARGET_PASSWORD,
|
||||||
|
})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
user_headers = {"Authorization": f"Bearer {resp.json()['access_token']}"}
|
||||||
|
|
||||||
|
resp = await client.get(
|
||||||
|
"/api/v1/admin/impersonation-log",
|
||||||
|
headers=user_headers,
|
||||||
|
)
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_impersonation_log_pagination(
|
||||||
|
client: AsyncClient, admin_auth, target_user,
|
||||||
|
):
|
||||||
|
"""Verify pagination params work on impersonation-log."""
|
||||||
|
# Create two entries
|
||||||
|
for _ in range(2):
|
||||||
|
await client.post(
|
||||||
|
f"/api/v1/admin/impersonate/{target_user['id']}",
|
||||||
|
headers=admin_auth["headers"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fetch page 1, page_size=1
|
||||||
|
resp = await client.get(
|
||||||
|
"/api/v1/admin/impersonation-log",
|
||||||
|
headers=admin_auth["headers"],
|
||||||
|
params={"page": 1, "page_size": 1},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert len(resp.json()) == 1
|
||||||
|
|
||||||
|
# Fetch page 2
|
||||||
|
resp = await client.get(
|
||||||
|
"/api/v1/admin/impersonation-log",
|
||||||
|
headers=admin_auth["headers"],
|
||||||
|
params={"page": 2, "page_size": 1},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert len(resp.json()) == 1
|
||||||
Loading…
Add table
Reference in a new issue