feat: Built Redis sliding-window rate limiter, ChatUsageLog model with…

- "backend/rate_limiter.py"
- "backend/models.py"
- "backend/routers/chat.py"
- "backend/chat_service.py"
- "backend/config.py"
- "alembic/versions/031_add_chat_usage_log.py"

GSD-Task: S04/T01
This commit is contained in:
jlightner 2026-04-04 13:36:29 +00:00
parent a0e228d5b4
commit 638477cc8e
6 changed files with 352 additions and 7 deletions

View file

@ -0,0 +1,40 @@
"""add_chat_usage_log
Revision ID: 031_chat_usage_log
Revises: 030_onboarding
Create Date: 2026-04-04
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers
revision = "031_chat_usage_log"
down_revision = "030_onboarding"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"chat_usage_log",
sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.func.gen_random_uuid()),
sa.Column("user_id", UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL"), nullable=True),
sa.Column("client_ip", sa.String(45), nullable=True),
sa.Column("creator_slug", sa.String(255), nullable=True),
sa.Column("query", sa.Text(), nullable=False),
sa.Column("prompt_tokens", sa.Integer(), nullable=False, server_default="0"),
sa.Column("completion_tokens", sa.Integer(), nullable=False, server_default="0"),
sa.Column("total_tokens", sa.Integer(), nullable=False, server_default="0"),
sa.Column("cascade_tier", sa.String(50), nullable=True),
sa.Column("model", sa.String(100), nullable=True),
sa.Column("latency_ms", sa.Float(), nullable=True),
sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()),
)
op.create_index("ix_chat_usage_log_created_at", "chat_usage_log", ["created_at"])
def downgrade() -> None:
op.drop_index("ix_chat_usage_log_created_at", table_name="chat_usage_log")
op.drop_table("chat_usage_log")

View file

@ -127,6 +127,46 @@ class ChatService:
voice_block = _build_personality_block(creator_name, profile, weight) voice_block = _build_personality_block(creator_name, profile, weight)
return system_prompt + "\n\n" + voice_block return system_prompt + "\n\n" + voice_block
async def _log_usage(
self,
db: AsyncSession,
user_id: Any | None,
client_ip: str | None,
creator_slug: str | None,
query: str,
usage: dict[str, int],
cascade_tier: str,
model: str,
latency_ms: float,
) -> None:
"""Insert a ChatUsageLog row. Non-blocking — errors logged, not raised."""
try:
from models import ChatUsageLog
log_entry = ChatUsageLog(
user_id=user_id,
client_ip=client_ip,
creator_slug=creator_slug,
query=query[:2000], # truncate very long queries
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
total_tokens=usage.get("total_tokens", 0),
cascade_tier=cascade_tier,
model=model,
latency_ms=latency_ms,
)
db.add(log_entry)
await db.commit()
except Exception:
logger.error(
"chat_usage_log_insert_error user=%s ip=%s",
user_id, client_ip, exc_info=True,
)
try:
await db.rollback()
except Exception:
pass
async def stream_response( async def stream_response(
self, self,
query: str, query: str,
@ -134,6 +174,8 @@ class ChatService:
creator: str | None = None, creator: str | None = None,
conversation_id: str | None = None, conversation_id: str | None = None,
personality_weight: float = 0.0, personality_weight: float = 0.0,
user_id: Any | None = None,
client_ip: str | None = None,
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
"""Yield SSE-formatted events for a chat query. """Yield SSE-formatted events for a chat query.
@ -201,17 +243,26 @@ class ChatService:
messages.append({"role": "user", "content": query}) messages.append({"role": "user", "content": query})
accumulated_response = "" accumulated_response = ""
usage_data: dict[str, int] | None = None
try: try:
stream = await self._openai.chat.completions.create( stream = await self._openai.chat.completions.create(
model=self.settings.llm_model, model=self.settings.llm_model,
messages=messages, messages=messages,
stream=True, stream=True,
stream_options={"include_usage": True},
temperature=temperature, temperature=temperature,
max_tokens=2048, max_tokens=2048,
) )
async for chunk in stream: async for chunk in stream:
# The final chunk with stream_options carries usage in chunk.usage
if hasattr(chunk, "usage") and chunk.usage is not None:
usage_data = {
"prompt_tokens": chunk.usage.prompt_tokens or 0,
"completion_tokens": chunk.usage.completion_tokens or 0,
"total_tokens": chunk.usage.total_tokens or 0,
}
choice = chunk.choices[0] if chunk.choices else None choice = chunk.choices[0] if chunk.choices else None
if choice and choice.delta and choice.delta.content: if choice and choice.delta and choice.delta.content:
text = choice.delta.content text = choice.delta.content
@ -227,11 +278,38 @@ class ChatService:
# ── 4. Save conversation history ──────────────────────────────── # ── 4. Save conversation history ────────────────────────────────
await self._save_history(conversation_id, history, query, accumulated_response) await self._save_history(conversation_id, history, query, accumulated_response)
# ── 5. Done event ─────────────────────────────────────────────── # ── 5. Log token usage ──────────────────────────────────────────
latency_ms = (time.monotonic() - start) * 1000 latency_ms = (time.monotonic() - start) * 1000
# Fallback: estimate tokens from character counts if stream_options not available
if usage_data is None:
prompt_chars = sum(len(m.get("content", "")) for m in messages)
est_prompt = prompt_chars // 4
est_completion = len(accumulated_response) // 4
usage_data = {
"prompt_tokens": est_prompt,
"completion_tokens": est_completion,
"total_tokens": est_prompt + est_completion,
}
logger.warning("chat_usage_estimated cid=%s (stream_options usage not available)", conversation_id)
await self._log_usage(
db=db,
user_id=user_id,
client_ip=client_ip,
creator_slug=creator,
query=query,
usage=usage_data,
cascade_tier=cascade_tier,
model=self.settings.llm_model,
latency_ms=latency_ms,
)
# ── 6. Done event ───────────────────────────────────────────────
logger.info( logger.info(
"chat_done query=%r creator=%r cascade_tier=%s source_count=%d latency_ms=%.1f cid=%s", "chat_done query=%r creator=%r cascade_tier=%s source_count=%d latency_ms=%.1f cid=%s tokens=%d",
query, creator, cascade_tier, len(sources), latency_ms, conversation_id, query, creator, cascade_tier, len(sources), latency_ms, conversation_id,
usage_data.get("total_tokens", 0),
) )
yield _sse("done", {"cascade_tier": cascade_tier, "conversation_id": conversation_id}) yield _sse("done", {"cascade_tier": cascade_tier, "conversation_id": conversation_id})

View file

@ -91,6 +91,11 @@ class Settings(BaseSettings):
smtp_from_address: str = "" smtp_from_address: str = ""
smtp_tls: bool = True smtp_tls: bool = True
# Rate limiting (per hour)
rate_limit_user_per_hour: int = 30
rate_limit_ip_per_hour: int = 10
rate_limit_creator_per_hour: int = 60
# Git commit SHA (set at Docker build time or via env var) # Git commit SHA (set at Docker build time or via env var)
git_commit_sha: str = "unknown" git_commit_sha: str = "unknown"

View file

@ -902,3 +902,31 @@ class GeneratedShort(Base):
# relationships # relationships
highlight_candidate: Mapped[HighlightCandidate] = sa_relationship() highlight_candidate: Mapped[HighlightCandidate] = sa_relationship()
# ── Chat Usage Tracking ──────────────────────────────────────────────────────
class ChatUsageLog(Base):
"""Per-request token usage log for chat completions.
Append-only table one row per chat request. Used for cost tracking,
rate limit analytics, and the admin usage dashboard.
"""
__tablename__ = "chat_usage_log"
id: Mapped[uuid.UUID] = _uuid_pk()
user_id: Mapped[uuid.UUID | None] = mapped_column(
ForeignKey("users.id", ondelete="SET NULL"), nullable=True,
)
client_ip: Mapped[str | None] = mapped_column(String(45), nullable=True)
creator_slug: Mapped[str | None] = mapped_column(String(255), nullable=True)
query: Mapped[str] = mapped_column(Text, nullable=False)
prompt_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
completion_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
total_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
cascade_tier: Mapped[str | None] = mapped_column(String(50), nullable=True)
model: Mapped[str | None] = mapped_column(String(100), nullable=True)
latency_ms: Mapped[float | None] = mapped_column(Float, nullable=True)
created_at: Mapped[datetime] = mapped_column(
default=_now, server_default=func.now(), index=True,
)

116
backend/rate_limiter.py Normal file
View file

@ -0,0 +1,116 @@
"""Redis sliding-window rate limiter using sorted sets.
Each rate limit key is a Redis sorted set where members are unique
request identifiers (timestamps with microseconds) and scores are
Unix timestamps. On each check, expired entries are pruned, the
current request is added, and the count determines whether the
request is allowed.
Fail-open: If Redis is unavailable, requests are allowed through
with a WARNING log.
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass
import redis.asyncio as aioredis
logger = logging.getLogger("chrysopedia.rate_limiter")
_KEY_PREFIX = "chrysopedia:ratelimit"
@dataclass
class RateLimitResult:
"""Result of a rate limit check."""
allowed: bool
remaining: int
retry_after: int # seconds until the window slides enough to allow a request; 0 if allowed
class RateLimiter:
"""Sliding-window rate limiter backed by Redis sorted sets.
Usage::
limiter = RateLimiter(redis)
result = await limiter.check_rate_limit("user:abc123", limit=30, window_seconds=3600)
if not result.allowed:
return 429, result.retry_after
"""
def __init__(self, redis: aioredis.Redis) -> None:
self._redis = redis
@staticmethod
def key(scope: str, identifier: str) -> str:
"""Build a namespaced Redis key for a rate limit bucket."""
return f"{_KEY_PREFIX}:{scope}:{identifier}"
async def check_rate_limit(
self,
key: str,
limit: int,
window_seconds: int = 3600,
) -> RateLimitResult:
"""Check whether a request is within the rate limit.
Uses a sorted set where:
- ZREMRANGEBYSCORE prunes entries older than the window
- ZCARD counts current entries
- ZADD adds the current request if under limit
Returns a RateLimitResult with allowed/remaining/retry_after.
On Redis errors, fails open (allowed=True).
"""
now = time.time()
window_start = now - window_seconds
try:
pipe = self._redis.pipeline(transaction=True)
# Remove expired entries
pipe.zremrangebyscore(key, "-inf", window_start)
# Count remaining entries
pipe.zcard(key)
results = await pipe.execute()
current_count: int = results[1]
if current_count >= limit:
# Over limit — calculate retry_after from oldest entry
oldest = await self._redis.zrange(key, 0, 0, withscores=True)
if oldest:
oldest_score = oldest[0][1]
retry_after = int(oldest_score + window_seconds - now) + 1
retry_after = max(retry_after, 1)
else:
retry_after = window_seconds
return RateLimitResult(
allowed=False,
remaining=0,
retry_after=retry_after,
)
# Under limit — add this request
member = f"{now}:{id(key)}" # unique member per call
await self._redis.zadd(key, {member: now})
# Set TTL on the key so it auto-expires after the window
await self._redis.expire(key, window_seconds + 60)
remaining = limit - current_count - 1
return RateLimitResult(
allowed=True,
remaining=max(remaining, 0),
retry_after=0,
)
except Exception:
logger.warning(
"rate_limit_redis_error key=%s — failing open", key, exc_info=True
)
return RateLimitResult(allowed=True, remaining=limit, retry_after=0)

View file

@ -2,20 +2,25 @@
Accepts a query and optional creator filter, returns a Server-Sent Events Accepts a query and optional creator filter, returns a Server-Sent Events
stream with sources, token, done, and error events. stream with sources, token, done, and error events.
Rate limiting: per-user (authenticated), per-IP (anonymous), and per-creator.
""" """
from __future__ import annotations from __future__ import annotations
import logging import logging
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends, Request
from fastapi.responses import StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from auth import get_optional_user
from chat_service import ChatService from chat_service import ChatService
from config import Settings, get_settings from config import Settings, get_settings
from database import get_session from database import get_session
from models import User
from rate_limiter import RateLimiter
from redis_client import get_redis from redis_client import get_redis
logger = logging.getLogger("chrysopedia.chat.router") logger = logging.getLogger("chrysopedia.chat.router")
@ -32,23 +37,94 @@ class ChatRequest(BaseModel):
personality_weight: float = Field(default=0.0, ge=0.0, le=1.0) personality_weight: float = Field(default=0.0, ge=0.0, le=1.0)
@router.post("") def _get_client_ip(request: Request) -> str:
"""Extract client IP, preferring X-Forwarded-For behind a reverse proxy."""
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
return request.client.host if request.client else "unknown"
@router.post("", response_model=None)
async def chat( async def chat(
body: ChatRequest, body: ChatRequest,
request: Request,
db: AsyncSession = Depends(get_session), db: AsyncSession = Depends(get_session),
settings: Settings = Depends(get_settings), settings: Settings = Depends(get_settings),
) -> StreamingResponse: user: User | None = Depends(get_optional_user),
):
"""Stream a chat response as Server-Sent Events. """Stream a chat response as Server-Sent Events.
Rate limits are checked before processing:
- Authenticated users: ``rate_limit_user_per_hour`` requests/hour
- Anonymous (IP-based): ``rate_limit_ip_per_hour`` requests/hour
- Per-creator (if creator filter set): ``rate_limit_creator_per_hour`` requests/hour
SSE protocol: SSE protocol:
- ``event: sources`` citation metadata array (sent first) - ``event: sources`` citation metadata array (sent first)
- ``event: token`` streamed text chunk (repeated) - ``event: token`` streamed text chunk (repeated)
- ``event: done`` completion metadata with cascade_tier, conversation_id - ``event: done`` completion metadata with cascade_tier, conversation_id
- ``event: error`` error message (on failure) - ``event: error`` error message (on failure)
""" """
logger.info("chat_request query=%r creator=%r cid=%r weight=%.2f", body.query, body.creator, body.conversation_id, body.personality_weight) client_ip = _get_client_ip(request)
user_id = user.id if user else None
logger.info(
"chat_request query=%r creator=%r cid=%r weight=%.2f user=%s ip=%s",
body.query, body.creator, body.conversation_id,
body.personality_weight, user_id, client_ip,
)
redis = await get_redis() redis = await get_redis()
# ── Rate limiting ───────────────────────────────────────────────────
limiter = RateLimiter(redis)
# User-based limit (authenticated) or IP-based limit (anonymous)
if user_id:
identity_key = RateLimiter.key("user", str(user_id))
identity_limit = settings.rate_limit_user_per_hour
else:
identity_key = RateLimiter.key("ip", client_ip)
identity_limit = settings.rate_limit_ip_per_hour
result = await limiter.check_rate_limit(identity_key, identity_limit, window_seconds=3600)
if not result.allowed:
scope = "user" if user_id else "ip"
logger.warning(
"rate_limit_exceeded scope=%s key=%s remaining=%d retry_after=%d",
scope, identity_key, result.remaining, result.retry_after,
)
return JSONResponse(
status_code=429,
content={
"error": "Rate limit exceeded",
"retry_after": result.retry_after,
},
headers={"Retry-After": str(result.retry_after)},
)
# Per-creator limit (if creator filter is provided)
if body.creator:
creator_key = RateLimiter.key("creator", body.creator)
creator_result = await limiter.check_rate_limit(
creator_key, settings.rate_limit_creator_per_hour, window_seconds=3600,
)
if not creator_result.allowed:
logger.warning(
"rate_limit_exceeded scope=creator key=%s retry_after=%d",
creator_key, creator_result.retry_after,
)
return JSONResponse(
status_code=429,
content={
"error": "Creator rate limit exceeded",
"retry_after": creator_result.retry_after,
},
headers={"Retry-After": str(creator_result.retry_after)},
)
# ── Stream response ─────────────────────────────────────────────────
service = ChatService(settings, redis=redis) service = ChatService(settings, redis=redis)
return StreamingResponse( return StreamingResponse(
@ -58,6 +134,8 @@ async def chat(
creator=body.creator, creator=body.creator,
conversation_id=body.conversation_id, conversation_id=body.conversation_id,
personality_weight=body.personality_weight, personality_weight=body.personality_weight,
user_id=user_id,
client_ip=client_ip,
), ),
media_type="text/event-stream", media_type="text/event-stream",
headers={ headers={