feat: Built Redis sliding-window rate limiter, ChatUsageLog model with…
- "backend/rate_limiter.py" - "backend/models.py" - "backend/routers/chat.py" - "backend/chat_service.py" - "backend/config.py" - "alembic/versions/031_add_chat_usage_log.py" GSD-Task: S04/T01
This commit is contained in:
parent
a0e228d5b4
commit
638477cc8e
6 changed files with 352 additions and 7 deletions
40
alembic/versions/031_add_chat_usage_log.py
Normal file
40
alembic/versions/031_add_chat_usage_log.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
"""add_chat_usage_log
|
||||
|
||||
Revision ID: 031_chat_usage_log
|
||||
Revises: 030_onboarding
|
||||
Create Date: 2026-04-04
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# revision identifiers
|
||||
revision = "031_chat_usage_log"
|
||||
down_revision = "030_onboarding"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"chat_usage_log",
|
||||
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.func.gen_random_uuid()),
|
||||
sa.Column("user_id", UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL"), nullable=True),
|
||||
sa.Column("client_ip", sa.String(45), nullable=True),
|
||||
sa.Column("creator_slug", sa.String(255), nullable=True),
|
||||
sa.Column("query", sa.Text(), nullable=False),
|
||||
sa.Column("prompt_tokens", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("completion_tokens", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("total_tokens", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("cascade_tier", sa.String(50), nullable=True),
|
||||
sa.Column("model", sa.String(100), nullable=True),
|
||||
sa.Column("latency_ms", sa.Float(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()),
|
||||
)
|
||||
op.create_index("ix_chat_usage_log_created_at", "chat_usage_log", ["created_at"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_chat_usage_log_created_at", table_name="chat_usage_log")
|
||||
op.drop_table("chat_usage_log")
|
||||
|
|
@ -127,6 +127,46 @@ class ChatService:
|
|||
voice_block = _build_personality_block(creator_name, profile, weight)
|
||||
return system_prompt + "\n\n" + voice_block
|
||||
|
||||
async def _log_usage(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user_id: Any | None,
|
||||
client_ip: str | None,
|
||||
creator_slug: str | None,
|
||||
query: str,
|
||||
usage: dict[str, int],
|
||||
cascade_tier: str,
|
||||
model: str,
|
||||
latency_ms: float,
|
||||
) -> None:
|
||||
"""Insert a ChatUsageLog row. Non-blocking — errors logged, not raised."""
|
||||
try:
|
||||
from models import ChatUsageLog
|
||||
|
||||
log_entry = ChatUsageLog(
|
||||
user_id=user_id,
|
||||
client_ip=client_ip,
|
||||
creator_slug=creator_slug,
|
||||
query=query[:2000], # truncate very long queries
|
||||
prompt_tokens=usage.get("prompt_tokens", 0),
|
||||
completion_tokens=usage.get("completion_tokens", 0),
|
||||
total_tokens=usage.get("total_tokens", 0),
|
||||
cascade_tier=cascade_tier,
|
||||
model=model,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
db.add(log_entry)
|
||||
await db.commit()
|
||||
except Exception:
|
||||
logger.error(
|
||||
"chat_usage_log_insert_error user=%s ip=%s",
|
||||
user_id, client_ip, exc_info=True,
|
||||
)
|
||||
try:
|
||||
await db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def stream_response(
|
||||
self,
|
||||
query: str,
|
||||
|
|
@ -134,6 +174,8 @@ class ChatService:
|
|||
creator: str | None = None,
|
||||
conversation_id: str | None = None,
|
||||
personality_weight: float = 0.0,
|
||||
user_id: Any | None = None,
|
||||
client_ip: str | None = None,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Yield SSE-formatted events for a chat query.
|
||||
|
||||
|
|
@ -201,17 +243,26 @@ class ChatService:
|
|||
messages.append({"role": "user", "content": query})
|
||||
|
||||
accumulated_response = ""
|
||||
usage_data: dict[str, int] | None = None
|
||||
|
||||
try:
|
||||
stream = await self._openai.chat.completions.create(
|
||||
model=self.settings.llm_model,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
temperature=temperature,
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
# The final chunk with stream_options carries usage in chunk.usage
|
||||
if hasattr(chunk, "usage") and chunk.usage is not None:
|
||||
usage_data = {
|
||||
"prompt_tokens": chunk.usage.prompt_tokens or 0,
|
||||
"completion_tokens": chunk.usage.completion_tokens or 0,
|
||||
"total_tokens": chunk.usage.total_tokens or 0,
|
||||
}
|
||||
choice = chunk.choices[0] if chunk.choices else None
|
||||
if choice and choice.delta and choice.delta.content:
|
||||
text = choice.delta.content
|
||||
|
|
@ -227,11 +278,38 @@ class ChatService:
|
|||
# ── 4. Save conversation history ────────────────────────────────
|
||||
await self._save_history(conversation_id, history, query, accumulated_response)
|
||||
|
||||
# ── 5. Done event ───────────────────────────────────────────────
|
||||
# ── 5. Log token usage ──────────────────────────────────────────
|
||||
latency_ms = (time.monotonic() - start) * 1000
|
||||
|
||||
# Fallback: estimate tokens from character counts if stream_options not available
|
||||
if usage_data is None:
|
||||
prompt_chars = sum(len(m.get("content", "")) for m in messages)
|
||||
est_prompt = prompt_chars // 4
|
||||
est_completion = len(accumulated_response) // 4
|
||||
usage_data = {
|
||||
"prompt_tokens": est_prompt,
|
||||
"completion_tokens": est_completion,
|
||||
"total_tokens": est_prompt + est_completion,
|
||||
}
|
||||
logger.warning("chat_usage_estimated cid=%s (stream_options usage not available)", conversation_id)
|
||||
|
||||
await self._log_usage(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
client_ip=client_ip,
|
||||
creator_slug=creator,
|
||||
query=query,
|
||||
usage=usage_data,
|
||||
cascade_tier=cascade_tier,
|
||||
model=self.settings.llm_model,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
# ── 6. Done event ───────────────────────────────────────────────
|
||||
logger.info(
|
||||
"chat_done query=%r creator=%r cascade_tier=%s source_count=%d latency_ms=%.1f cid=%s",
|
||||
"chat_done query=%r creator=%r cascade_tier=%s source_count=%d latency_ms=%.1f cid=%s tokens=%d",
|
||||
query, creator, cascade_tier, len(sources), latency_ms, conversation_id,
|
||||
usage_data.get("total_tokens", 0),
|
||||
)
|
||||
yield _sse("done", {"cascade_tier": cascade_tier, "conversation_id": conversation_id})
|
||||
|
||||
|
|
|
|||
|
|
@ -91,6 +91,11 @@ class Settings(BaseSettings):
|
|||
smtp_from_address: str = ""
|
||||
smtp_tls: bool = True
|
||||
|
||||
# Rate limiting (per hour)
|
||||
rate_limit_user_per_hour: int = 30
|
||||
rate_limit_ip_per_hour: int = 10
|
||||
rate_limit_creator_per_hour: int = 60
|
||||
|
||||
# Git commit SHA (set at Docker build time or via env var)
|
||||
git_commit_sha: str = "unknown"
|
||||
|
||||
|
|
|
|||
|
|
@ -902,3 +902,31 @@ class GeneratedShort(Base):
|
|||
|
||||
# relationships
|
||||
highlight_candidate: Mapped[HighlightCandidate] = sa_relationship()
|
||||
|
||||
|
||||
# ── Chat Usage Tracking ──────────────────────────────────────────────────────
|
||||
|
||||
class ChatUsageLog(Base):
|
||||
"""Per-request token usage log for chat completions.
|
||||
|
||||
Append-only table — one row per chat request. Used for cost tracking,
|
||||
rate limit analytics, and the admin usage dashboard.
|
||||
"""
|
||||
__tablename__ = "chat_usage_log"
|
||||
|
||||
id: Mapped[uuid.UUID] = _uuid_pk()
|
||||
user_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
ForeignKey("users.id", ondelete="SET NULL"), nullable=True,
|
||||
)
|
||||
client_ip: Mapped[str | None] = mapped_column(String(45), nullable=True)
|
||||
creator_slug: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
query: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
prompt_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
completion_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
total_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
cascade_tier: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
model: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
latency_ms: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
default=_now, server_default=func.now(), index=True,
|
||||
)
|
||||
|
|
|
|||
116
backend/rate_limiter.py
Normal file
116
backend/rate_limiter.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
"""Redis sliding-window rate limiter using sorted sets.
|
||||
|
||||
Each rate limit key is a Redis sorted set where members are unique
|
||||
request identifiers (timestamps with microseconds) and scores are
|
||||
Unix timestamps. On each check, expired entries are pruned, the
|
||||
current request is added, and the count determines whether the
|
||||
request is allowed.
|
||||
|
||||
Fail-open: If Redis is unavailable, requests are allowed through
|
||||
with a WARNING log.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
logger = logging.getLogger("chrysopedia.rate_limiter")
|
||||
|
||||
_KEY_PREFIX = "chrysopedia:ratelimit"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitResult:
|
||||
"""Result of a rate limit check."""
|
||||
|
||||
allowed: bool
|
||||
remaining: int
|
||||
retry_after: int # seconds until the window slides enough to allow a request; 0 if allowed
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Sliding-window rate limiter backed by Redis sorted sets.
|
||||
|
||||
Usage::
|
||||
|
||||
limiter = RateLimiter(redis)
|
||||
result = await limiter.check_rate_limit("user:abc123", limit=30, window_seconds=3600)
|
||||
if not result.allowed:
|
||||
return 429, result.retry_after
|
||||
"""
|
||||
|
||||
def __init__(self, redis: aioredis.Redis) -> None:
|
||||
self._redis = redis
|
||||
|
||||
@staticmethod
|
||||
def key(scope: str, identifier: str) -> str:
|
||||
"""Build a namespaced Redis key for a rate limit bucket."""
|
||||
return f"{_KEY_PREFIX}:{scope}:{identifier}"
|
||||
|
||||
async def check_rate_limit(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
window_seconds: int = 3600,
|
||||
) -> RateLimitResult:
|
||||
"""Check whether a request is within the rate limit.
|
||||
|
||||
Uses a sorted set where:
|
||||
- ZREMRANGEBYSCORE prunes entries older than the window
|
||||
- ZCARD counts current entries
|
||||
- ZADD adds the current request if under limit
|
||||
|
||||
Returns a RateLimitResult with allowed/remaining/retry_after.
|
||||
On Redis errors, fails open (allowed=True).
|
||||
"""
|
||||
now = time.time()
|
||||
window_start = now - window_seconds
|
||||
|
||||
try:
|
||||
pipe = self._redis.pipeline(transaction=True)
|
||||
# Remove expired entries
|
||||
pipe.zremrangebyscore(key, "-inf", window_start)
|
||||
# Count remaining entries
|
||||
pipe.zcard(key)
|
||||
results = await pipe.execute()
|
||||
|
||||
current_count: int = results[1]
|
||||
|
||||
if current_count >= limit:
|
||||
# Over limit — calculate retry_after from oldest entry
|
||||
oldest = await self._redis.zrange(key, 0, 0, withscores=True)
|
||||
if oldest:
|
||||
oldest_score = oldest[0][1]
|
||||
retry_after = int(oldest_score + window_seconds - now) + 1
|
||||
retry_after = max(retry_after, 1)
|
||||
else:
|
||||
retry_after = window_seconds
|
||||
|
||||
return RateLimitResult(
|
||||
allowed=False,
|
||||
remaining=0,
|
||||
retry_after=retry_after,
|
||||
)
|
||||
|
||||
# Under limit — add this request
|
||||
member = f"{now}:{id(key)}" # unique member per call
|
||||
await self._redis.zadd(key, {member: now})
|
||||
# Set TTL on the key so it auto-expires after the window
|
||||
await self._redis.expire(key, window_seconds + 60)
|
||||
|
||||
remaining = limit - current_count - 1
|
||||
return RateLimitResult(
|
||||
allowed=True,
|
||||
remaining=max(remaining, 0),
|
||||
retry_after=0,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"rate_limit_redis_error key=%s — failing open", key, exc_info=True
|
||||
)
|
||||
return RateLimitResult(allowed=True, remaining=limit, retry_after=0)
|
||||
|
|
@ -2,20 +2,25 @@
|
|||
|
||||
Accepts a query and optional creator filter, returns a Server-Sent Events
|
||||
stream with sources, token, done, and error events.
|
||||
|
||||
Rate limiting: per-user (authenticated), per-IP (anonymous), and per-creator.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from auth import get_optional_user
|
||||
from chat_service import ChatService
|
||||
from config import Settings, get_settings
|
||||
from database import get_session
|
||||
from models import User
|
||||
from rate_limiter import RateLimiter
|
||||
from redis_client import get_redis
|
||||
|
||||
logger = logging.getLogger("chrysopedia.chat.router")
|
||||
|
|
@ -32,23 +37,94 @@ class ChatRequest(BaseModel):
|
|||
personality_weight: float = Field(default=0.0, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
@router.post("")
|
||||
def _get_client_ip(request: Request) -> str:
|
||||
"""Extract client IP, preferring X-Forwarded-For behind a reverse proxy."""
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
|
||||
@router.post("", response_model=None)
|
||||
async def chat(
|
||||
body: ChatRequest,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
settings: Settings = Depends(get_settings),
|
||||
) -> StreamingResponse:
|
||||
user: User | None = Depends(get_optional_user),
|
||||
):
|
||||
"""Stream a chat response as Server-Sent Events.
|
||||
|
||||
Rate limits are checked before processing:
|
||||
- Authenticated users: ``rate_limit_user_per_hour`` requests/hour
|
||||
- Anonymous (IP-based): ``rate_limit_ip_per_hour`` requests/hour
|
||||
- Per-creator (if creator filter set): ``rate_limit_creator_per_hour`` requests/hour
|
||||
|
||||
SSE protocol:
|
||||
- ``event: sources`` — citation metadata array (sent first)
|
||||
- ``event: token`` — streamed text chunk (repeated)
|
||||
- ``event: done`` — completion metadata with cascade_tier, conversation_id
|
||||
- ``event: error`` — error message (on failure)
|
||||
"""
|
||||
logger.info("chat_request query=%r creator=%r cid=%r weight=%.2f", body.query, body.creator, body.conversation_id, body.personality_weight)
|
||||
client_ip = _get_client_ip(request)
|
||||
user_id = user.id if user else None
|
||||
|
||||
logger.info(
|
||||
"chat_request query=%r creator=%r cid=%r weight=%.2f user=%s ip=%s",
|
||||
body.query, body.creator, body.conversation_id,
|
||||
body.personality_weight, user_id, client_ip,
|
||||
)
|
||||
|
||||
redis = await get_redis()
|
||||
|
||||
# ── Rate limiting ───────────────────────────────────────────────────
|
||||
limiter = RateLimiter(redis)
|
||||
|
||||
# User-based limit (authenticated) or IP-based limit (anonymous)
|
||||
if user_id:
|
||||
identity_key = RateLimiter.key("user", str(user_id))
|
||||
identity_limit = settings.rate_limit_user_per_hour
|
||||
else:
|
||||
identity_key = RateLimiter.key("ip", client_ip)
|
||||
identity_limit = settings.rate_limit_ip_per_hour
|
||||
|
||||
result = await limiter.check_rate_limit(identity_key, identity_limit, window_seconds=3600)
|
||||
if not result.allowed:
|
||||
scope = "user" if user_id else "ip"
|
||||
logger.warning(
|
||||
"rate_limit_exceeded scope=%s key=%s remaining=%d retry_after=%d",
|
||||
scope, identity_key, result.remaining, result.retry_after,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": "Rate limit exceeded",
|
||||
"retry_after": result.retry_after,
|
||||
},
|
||||
headers={"Retry-After": str(result.retry_after)},
|
||||
)
|
||||
|
||||
# Per-creator limit (if creator filter is provided)
|
||||
if body.creator:
|
||||
creator_key = RateLimiter.key("creator", body.creator)
|
||||
creator_result = await limiter.check_rate_limit(
|
||||
creator_key, settings.rate_limit_creator_per_hour, window_seconds=3600,
|
||||
)
|
||||
if not creator_result.allowed:
|
||||
logger.warning(
|
||||
"rate_limit_exceeded scope=creator key=%s retry_after=%d",
|
||||
creator_key, creator_result.retry_after,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": "Creator rate limit exceeded",
|
||||
"retry_after": creator_result.retry_after,
|
||||
},
|
||||
headers={"Retry-After": str(creator_result.retry_after)},
|
||||
)
|
||||
|
||||
# ── Stream response ─────────────────────────────────────────────────
|
||||
service = ChatService(settings, redis=redis)
|
||||
|
||||
return StreamingResponse(
|
||||
|
|
@ -58,6 +134,8 @@ async def chat(
|
|||
creator=body.creator,
|
||||
conversation_id=body.conversation_id,
|
||||
personality_weight=body.personality_weight,
|
||||
user_id=user_id,
|
||||
client_ip=client_ip,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue