From 638477cc8edf705bea2118f70e403eee4bc3b640 Mon Sep 17 00:00:00 2001 From: jlightner Date: Sat, 4 Apr 2026 13:36:29 +0000 Subject: [PATCH] =?UTF-8?q?feat:=20Built=20Redis=20sliding-window=20rate?= =?UTF-8?q?=20limiter,=20ChatUsageLog=20model=20with=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - "backend/rate_limiter.py" - "backend/models.py" - "backend/routers/chat.py" - "backend/chat_service.py" - "backend/config.py" - "alembic/versions/031_add_chat_usage_log.py" GSD-Task: S04/T01 --- alembic/versions/031_add_chat_usage_log.py | 40 +++++++ backend/chat_service.py | 82 ++++++++++++++- backend/config.py | 5 + backend/models.py | 28 +++++ backend/rate_limiter.py | 116 +++++++++++++++++++++ backend/routers/chat.py | 88 +++++++++++++++- 6 files changed, 352 insertions(+), 7 deletions(-) create mode 100644 alembic/versions/031_add_chat_usage_log.py create mode 100644 backend/rate_limiter.py diff --git a/alembic/versions/031_add_chat_usage_log.py b/alembic/versions/031_add_chat_usage_log.py new file mode 100644 index 0000000..7b71772 --- /dev/null +++ b/alembic/versions/031_add_chat_usage_log.py @@ -0,0 +1,40 @@ +"""add_chat_usage_log + +Revision ID: 031_chat_usage_log +Revises: 030_onboarding +Create Date: 2026-04-04 +""" + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID + +# revision identifiers +revision = "031_chat_usage_log" +down_revision = "030_onboarding" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "chat_usage_log", + sa.Column("id", UUID(as_uuid=True), primary_key=True, server_default=sa.func.gen_random_uuid()), + sa.Column("user_id", UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="SET NULL"), nullable=True), + sa.Column("client_ip", sa.String(45), nullable=True), + sa.Column("creator_slug", sa.String(255), nullable=True), + sa.Column("query", sa.Text(), nullable=False), + sa.Column("prompt_tokens", sa.Integer(), nullable=False, server_default="0"), + sa.Column("completion_tokens", sa.Integer(), nullable=False, server_default="0"), + sa.Column("total_tokens", sa.Integer(), nullable=False, server_default="0"), + sa.Column("cascade_tier", sa.String(50), nullable=True), + sa.Column("model", sa.String(100), nullable=True), + sa.Column("latency_ms", sa.Float(), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()), + ) + op.create_index("ix_chat_usage_log_created_at", "chat_usage_log", ["created_at"]) + + +def downgrade() -> None: + op.drop_index("ix_chat_usage_log_created_at", table_name="chat_usage_log") + op.drop_table("chat_usage_log") diff --git a/backend/chat_service.py b/backend/chat_service.py index e791d8e..6e27a93 100644 --- a/backend/chat_service.py +++ b/backend/chat_service.py @@ -127,6 +127,46 @@ class ChatService: voice_block = _build_personality_block(creator_name, profile, weight) return system_prompt + "\n\n" + voice_block + async def _log_usage( + self, + db: AsyncSession, + user_id: Any | None, + client_ip: str | None, + creator_slug: str | None, + query: str, + usage: dict[str, int], + cascade_tier: str, + model: str, + latency_ms: float, + ) -> None: + """Insert a ChatUsageLog row. Non-blocking — errors logged, not raised.""" + try: + from models import ChatUsageLog + + log_entry = ChatUsageLog( + user_id=user_id, + client_ip=client_ip, + creator_slug=creator_slug, + query=query[:2000], # truncate very long queries + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + cascade_tier=cascade_tier, + model=model, + latency_ms=latency_ms, + ) + db.add(log_entry) + await db.commit() + except Exception: + logger.error( + "chat_usage_log_insert_error user=%s ip=%s", + user_id, client_ip, exc_info=True, + ) + try: + await db.rollback() + except Exception: + pass + async def stream_response( self, query: str, @@ -134,6 +174,8 @@ class ChatService: creator: str | None = None, conversation_id: str | None = None, personality_weight: float = 0.0, + user_id: Any | None = None, + client_ip: str | None = None, ) -> AsyncIterator[str]: """Yield SSE-formatted events for a chat query. @@ -201,17 +243,26 @@ class ChatService: messages.append({"role": "user", "content": query}) accumulated_response = "" + usage_data: dict[str, int] | None = None try: stream = await self._openai.chat.completions.create( model=self.settings.llm_model, messages=messages, stream=True, + stream_options={"include_usage": True}, temperature=temperature, max_tokens=2048, ) async for chunk in stream: + # The final chunk with stream_options carries usage in chunk.usage + if hasattr(chunk, "usage") and chunk.usage is not None: + usage_data = { + "prompt_tokens": chunk.usage.prompt_tokens or 0, + "completion_tokens": chunk.usage.completion_tokens or 0, + "total_tokens": chunk.usage.total_tokens or 0, + } choice = chunk.choices[0] if chunk.choices else None if choice and choice.delta and choice.delta.content: text = choice.delta.content @@ -227,11 +278,38 @@ class ChatService: # ── 4. Save conversation history ──────────────────────────────── await self._save_history(conversation_id, history, query, accumulated_response) - # ── 5. Done event ─────────────────────────────────────────────── + # ── 5. Log token usage ────────────────────────────────────────── latency_ms = (time.monotonic() - start) * 1000 + + # Fallback: estimate tokens from character counts if stream_options not available + if usage_data is None: + prompt_chars = sum(len(m.get("content", "")) for m in messages) + est_prompt = prompt_chars // 4 + est_completion = len(accumulated_response) // 4 + usage_data = { + "prompt_tokens": est_prompt, + "completion_tokens": est_completion, + "total_tokens": est_prompt + est_completion, + } + logger.warning("chat_usage_estimated cid=%s (stream_options usage not available)", conversation_id) + + await self._log_usage( + db=db, + user_id=user_id, + client_ip=client_ip, + creator_slug=creator, + query=query, + usage=usage_data, + cascade_tier=cascade_tier, + model=self.settings.llm_model, + latency_ms=latency_ms, + ) + + # ── 6. Done event ─────────────────────────────────────────────── logger.info( - "chat_done query=%r creator=%r cascade_tier=%s source_count=%d latency_ms=%.1f cid=%s", + "chat_done query=%r creator=%r cascade_tier=%s source_count=%d latency_ms=%.1f cid=%s tokens=%d", query, creator, cascade_tier, len(sources), latency_ms, conversation_id, + usage_data.get("total_tokens", 0), ) yield _sse("done", {"cascade_tier": cascade_tier, "conversation_id": conversation_id}) diff --git a/backend/config.py b/backend/config.py index aad490c..2b0f8a8 100644 --- a/backend/config.py +++ b/backend/config.py @@ -91,6 +91,11 @@ class Settings(BaseSettings): smtp_from_address: str = "" smtp_tls: bool = True + # Rate limiting (per hour) + rate_limit_user_per_hour: int = 30 + rate_limit_ip_per_hour: int = 10 + rate_limit_creator_per_hour: int = 60 + # Git commit SHA (set at Docker build time or via env var) git_commit_sha: str = "unknown" diff --git a/backend/models.py b/backend/models.py index 4fbd236..7b6f28e 100644 --- a/backend/models.py +++ b/backend/models.py @@ -902,3 +902,31 @@ class GeneratedShort(Base): # relationships highlight_candidate: Mapped[HighlightCandidate] = sa_relationship() + + +# ── Chat Usage Tracking ────────────────────────────────────────────────────── + +class ChatUsageLog(Base): + """Per-request token usage log for chat completions. + + Append-only table — one row per chat request. Used for cost tracking, + rate limit analytics, and the admin usage dashboard. + """ + __tablename__ = "chat_usage_log" + + id: Mapped[uuid.UUID] = _uuid_pk() + user_id: Mapped[uuid.UUID | None] = mapped_column( + ForeignKey("users.id", ondelete="SET NULL"), nullable=True, + ) + client_ip: Mapped[str | None] = mapped_column(String(45), nullable=True) + creator_slug: Mapped[str | None] = mapped_column(String(255), nullable=True) + query: Mapped[str] = mapped_column(Text, nullable=False) + prompt_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + completion_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + total_tokens: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + cascade_tier: Mapped[str | None] = mapped_column(String(50), nullable=True) + model: Mapped[str | None] = mapped_column(String(100), nullable=True) + latency_ms: Mapped[float | None] = mapped_column(Float, nullable=True) + created_at: Mapped[datetime] = mapped_column( + default=_now, server_default=func.now(), index=True, + ) diff --git a/backend/rate_limiter.py b/backend/rate_limiter.py new file mode 100644 index 0000000..968d043 --- /dev/null +++ b/backend/rate_limiter.py @@ -0,0 +1,116 @@ +"""Redis sliding-window rate limiter using sorted sets. + +Each rate limit key is a Redis sorted set where members are unique +request identifiers (timestamps with microseconds) and scores are +Unix timestamps. On each check, expired entries are pruned, the +current request is added, and the count determines whether the +request is allowed. + +Fail-open: If Redis is unavailable, requests are allowed through +with a WARNING log. +""" + +from __future__ import annotations + +import logging +import time +from dataclasses import dataclass + +import redis.asyncio as aioredis + +logger = logging.getLogger("chrysopedia.rate_limiter") + +_KEY_PREFIX = "chrysopedia:ratelimit" + + +@dataclass +class RateLimitResult: + """Result of a rate limit check.""" + + allowed: bool + remaining: int + retry_after: int # seconds until the window slides enough to allow a request; 0 if allowed + + +class RateLimiter: + """Sliding-window rate limiter backed by Redis sorted sets. + + Usage:: + + limiter = RateLimiter(redis) + result = await limiter.check_rate_limit("user:abc123", limit=30, window_seconds=3600) + if not result.allowed: + return 429, result.retry_after + """ + + def __init__(self, redis: aioredis.Redis) -> None: + self._redis = redis + + @staticmethod + def key(scope: str, identifier: str) -> str: + """Build a namespaced Redis key for a rate limit bucket.""" + return f"{_KEY_PREFIX}:{scope}:{identifier}" + + async def check_rate_limit( + self, + key: str, + limit: int, + window_seconds: int = 3600, + ) -> RateLimitResult: + """Check whether a request is within the rate limit. + + Uses a sorted set where: + - ZREMRANGEBYSCORE prunes entries older than the window + - ZCARD counts current entries + - ZADD adds the current request if under limit + + Returns a RateLimitResult with allowed/remaining/retry_after. + On Redis errors, fails open (allowed=True). + """ + now = time.time() + window_start = now - window_seconds + + try: + pipe = self._redis.pipeline(transaction=True) + # Remove expired entries + pipe.zremrangebyscore(key, "-inf", window_start) + # Count remaining entries + pipe.zcard(key) + results = await pipe.execute() + + current_count: int = results[1] + + if current_count >= limit: + # Over limit — calculate retry_after from oldest entry + oldest = await self._redis.zrange(key, 0, 0, withscores=True) + if oldest: + oldest_score = oldest[0][1] + retry_after = int(oldest_score + window_seconds - now) + 1 + retry_after = max(retry_after, 1) + else: + retry_after = window_seconds + + return RateLimitResult( + allowed=False, + remaining=0, + retry_after=retry_after, + ) + + # Under limit — add this request + member = f"{now}:{id(key)}" # unique member per call + await self._redis.zadd(key, {member: now}) + # Set TTL on the key so it auto-expires after the window + await self._redis.expire(key, window_seconds + 60) + + remaining = limit - current_count - 1 + return RateLimitResult( + allowed=True, + remaining=max(remaining, 0), + retry_after=0, + ) + + except Exception: + logger.warning( + "rate_limit_redis_error key=%s — failing open", key, exc_info=True + ) + return RateLimitResult(allowed=True, remaining=limit, retry_after=0) diff --git a/backend/routers/chat.py b/backend/routers/chat.py index 55f93d6..cdb27b5 100644 --- a/backend/routers/chat.py +++ b/backend/routers/chat.py @@ -2,20 +2,25 @@ Accepts a query and optional creator filter, returns a Server-Sent Events stream with sources, token, done, and error events. + +Rate limiting: per-user (authenticated), per-IP (anonymous), and per-creator. """ from __future__ import annotations import logging -from fastapi import APIRouter, Depends -from fastapi.responses import StreamingResponse +from fastapi import APIRouter, Depends, Request +from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession +from auth import get_optional_user from chat_service import ChatService from config import Settings, get_settings from database import get_session +from models import User +from rate_limiter import RateLimiter from redis_client import get_redis logger = logging.getLogger("chrysopedia.chat.router") @@ -32,23 +37,94 @@ class ChatRequest(BaseModel): personality_weight: float = Field(default=0.0, ge=0.0, le=1.0) -@router.post("") +def _get_client_ip(request: Request) -> str: + """Extract client IP, preferring X-Forwarded-For behind a reverse proxy.""" + forwarded = request.headers.get("x-forwarded-for") + if forwarded: + return forwarded.split(",")[0].strip() + return request.client.host if request.client else "unknown" + + +@router.post("", response_model=None) async def chat( body: ChatRequest, + request: Request, db: AsyncSession = Depends(get_session), settings: Settings = Depends(get_settings), -) -> StreamingResponse: + user: User | None = Depends(get_optional_user), +): """Stream a chat response as Server-Sent Events. + Rate limits are checked before processing: + - Authenticated users: ``rate_limit_user_per_hour`` requests/hour + - Anonymous (IP-based): ``rate_limit_ip_per_hour`` requests/hour + - Per-creator (if creator filter set): ``rate_limit_creator_per_hour`` requests/hour + SSE protocol: - ``event: sources`` — citation metadata array (sent first) - ``event: token`` — streamed text chunk (repeated) - ``event: done`` — completion metadata with cascade_tier, conversation_id - ``event: error`` — error message (on failure) """ - logger.info("chat_request query=%r creator=%r cid=%r weight=%.2f", body.query, body.creator, body.conversation_id, body.personality_weight) + client_ip = _get_client_ip(request) + user_id = user.id if user else None + + logger.info( + "chat_request query=%r creator=%r cid=%r weight=%.2f user=%s ip=%s", + body.query, body.creator, body.conversation_id, + body.personality_weight, user_id, client_ip, + ) redis = await get_redis() + + # ── Rate limiting ─────────────────────────────────────────────────── + limiter = RateLimiter(redis) + + # User-based limit (authenticated) or IP-based limit (anonymous) + if user_id: + identity_key = RateLimiter.key("user", str(user_id)) + identity_limit = settings.rate_limit_user_per_hour + else: + identity_key = RateLimiter.key("ip", client_ip) + identity_limit = settings.rate_limit_ip_per_hour + + result = await limiter.check_rate_limit(identity_key, identity_limit, window_seconds=3600) + if not result.allowed: + scope = "user" if user_id else "ip" + logger.warning( + "rate_limit_exceeded scope=%s key=%s remaining=%d retry_after=%d", + scope, identity_key, result.remaining, result.retry_after, + ) + return JSONResponse( + status_code=429, + content={ + "error": "Rate limit exceeded", + "retry_after": result.retry_after, + }, + headers={"Retry-After": str(result.retry_after)}, + ) + + # Per-creator limit (if creator filter is provided) + if body.creator: + creator_key = RateLimiter.key("creator", body.creator) + creator_result = await limiter.check_rate_limit( + creator_key, settings.rate_limit_creator_per_hour, window_seconds=3600, + ) + if not creator_result.allowed: + logger.warning( + "rate_limit_exceeded scope=creator key=%s retry_after=%d", + creator_key, creator_result.retry_after, + ) + return JSONResponse( + status_code=429, + content={ + "error": "Creator rate limit exceeded", + "retry_after": creator_result.retry_after, + }, + headers={"Retry-After": str(creator_result.retry_after)}, + ) + + # ── Stream response ───────────────────────────────────────────────── service = ChatService(settings, redis=redis) return StreamingResponse( @@ -58,6 +134,8 @@ async def chat( creator=body.creator, conversation_id=body.conversation_id, personality_weight=body.personality_weight, + user_id=user_id, + client_ip=client_ip, ), media_type="text/event-stream", headers={