chrysopedia/backend/rate_limiter.py
jlightner a5d3af55ca feat: Built Redis sliding-window rate limiter, ChatUsageLog model with…
- "backend/rate_limiter.py"
- "backend/models.py"
- "backend/routers/chat.py"
- "backend/chat_service.py"
- "backend/config.py"
- "alembic/versions/031_add_chat_usage_log.py"

GSD-Task: S04/T01
2026-04-04 13:36:29 +00:00

116 lines
3.7 KiB
Python

"""Redis sliding-window rate limiter using sorted sets.
Each rate limit key is a Redis sorted set where members are unique
request identifiers (timestamps with microseconds) and scores are
Unix timestamps. On each check, expired entries are pruned, the
current request is added, and the count determines whether the
request is allowed.
Fail-open: If Redis is unavailable, requests are allowed through
with a WARNING log.
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass
import redis.asyncio as aioredis
logger = logging.getLogger("chrysopedia.rate_limiter")
_KEY_PREFIX = "chrysopedia:ratelimit"
@dataclass
class RateLimitResult:
"""Result of a rate limit check."""
allowed: bool
remaining: int
retry_after: int # seconds until the window slides enough to allow a request; 0 if allowed
class RateLimiter:
"""Sliding-window rate limiter backed by Redis sorted sets.
Usage::
limiter = RateLimiter(redis)
result = await limiter.check_rate_limit("user:abc123", limit=30, window_seconds=3600)
if not result.allowed:
return 429, result.retry_after
"""
def __init__(self, redis: aioredis.Redis) -> None:
self._redis = redis
@staticmethod
def key(scope: str, identifier: str) -> str:
"""Build a namespaced Redis key for a rate limit bucket."""
return f"{_KEY_PREFIX}:{scope}:{identifier}"
async def check_rate_limit(
self,
key: str,
limit: int,
window_seconds: int = 3600,
) -> RateLimitResult:
"""Check whether a request is within the rate limit.
Uses a sorted set where:
- ZREMRANGEBYSCORE prunes entries older than the window
- ZCARD counts current entries
- ZADD adds the current request if under limit
Returns a RateLimitResult with allowed/remaining/retry_after.
On Redis errors, fails open (allowed=True).
"""
now = time.time()
window_start = now - window_seconds
try:
pipe = self._redis.pipeline(transaction=True)
# Remove expired entries
pipe.zremrangebyscore(key, "-inf", window_start)
# Count remaining entries
pipe.zcard(key)
results = await pipe.execute()
current_count: int = results[1]
if current_count >= limit:
# Over limit — calculate retry_after from oldest entry
oldest = await self._redis.zrange(key, 0, 0, withscores=True)
if oldest:
oldest_score = oldest[0][1]
retry_after = int(oldest_score + window_seconds - now) + 1
retry_after = max(retry_after, 1)
else:
retry_after = window_seconds
return RateLimitResult(
allowed=False,
remaining=0,
retry_after=retry_after,
)
# Under limit — add this request
member = f"{now}:{id(key)}" # unique member per call
await self._redis.zadd(key, {member: now})
# Set TTL on the key so it auto-expires after the window
await self._redis.expire(key, window_seconds + 60)
remaining = limit - current_count - 1
return RateLimitResult(
allowed=True,
remaining=max(remaining, 0),
retry_after=0,
)
except Exception:
logger.warning(
"rate_limit_redis_error key=%s — failing open", key, exc_info=True
)
return RateLimitResult(allowed=True, remaining=limit, retry_after=0)