- "backend/rate_limiter.py" - "backend/models.py" - "backend/routers/chat.py" - "backend/chat_service.py" - "backend/config.py" - "alembic/versions/031_add_chat_usage_log.py" GSD-Task: S04/T01
116 lines
3.7 KiB
Python
116 lines
3.7 KiB
Python
"""Redis sliding-window rate limiter using sorted sets.
|
|
|
|
Each rate limit key is a Redis sorted set where members are unique
|
|
request identifiers (timestamps with microseconds) and scores are
|
|
Unix timestamps. On each check, expired entries are pruned, the
|
|
current request is added, and the count determines whether the
|
|
request is allowed.
|
|
|
|
Fail-open: If Redis is unavailable, requests are allowed through
|
|
with a WARNING log.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import time
|
|
from dataclasses import dataclass
|
|
|
|
import redis.asyncio as aioredis
|
|
|
|
logger = logging.getLogger("chrysopedia.rate_limiter")
|
|
|
|
_KEY_PREFIX = "chrysopedia:ratelimit"
|
|
|
|
|
|
@dataclass
|
|
class RateLimitResult:
|
|
"""Result of a rate limit check."""
|
|
|
|
allowed: bool
|
|
remaining: int
|
|
retry_after: int # seconds until the window slides enough to allow a request; 0 if allowed
|
|
|
|
|
|
class RateLimiter:
|
|
"""Sliding-window rate limiter backed by Redis sorted sets.
|
|
|
|
Usage::
|
|
|
|
limiter = RateLimiter(redis)
|
|
result = await limiter.check_rate_limit("user:abc123", limit=30, window_seconds=3600)
|
|
if not result.allowed:
|
|
return 429, result.retry_after
|
|
"""
|
|
|
|
def __init__(self, redis: aioredis.Redis) -> None:
|
|
self._redis = redis
|
|
|
|
@staticmethod
|
|
def key(scope: str, identifier: str) -> str:
|
|
"""Build a namespaced Redis key for a rate limit bucket."""
|
|
return f"{_KEY_PREFIX}:{scope}:{identifier}"
|
|
|
|
async def check_rate_limit(
|
|
self,
|
|
key: str,
|
|
limit: int,
|
|
window_seconds: int = 3600,
|
|
) -> RateLimitResult:
|
|
"""Check whether a request is within the rate limit.
|
|
|
|
Uses a sorted set where:
|
|
- ZREMRANGEBYSCORE prunes entries older than the window
|
|
- ZCARD counts current entries
|
|
- ZADD adds the current request if under limit
|
|
|
|
Returns a RateLimitResult with allowed/remaining/retry_after.
|
|
On Redis errors, fails open (allowed=True).
|
|
"""
|
|
now = time.time()
|
|
window_start = now - window_seconds
|
|
|
|
try:
|
|
pipe = self._redis.pipeline(transaction=True)
|
|
# Remove expired entries
|
|
pipe.zremrangebyscore(key, "-inf", window_start)
|
|
# Count remaining entries
|
|
pipe.zcard(key)
|
|
results = await pipe.execute()
|
|
|
|
current_count: int = results[1]
|
|
|
|
if current_count >= limit:
|
|
# Over limit — calculate retry_after from oldest entry
|
|
oldest = await self._redis.zrange(key, 0, 0, withscores=True)
|
|
if oldest:
|
|
oldest_score = oldest[0][1]
|
|
retry_after = int(oldest_score + window_seconds - now) + 1
|
|
retry_after = max(retry_after, 1)
|
|
else:
|
|
retry_after = window_seconds
|
|
|
|
return RateLimitResult(
|
|
allowed=False,
|
|
remaining=0,
|
|
retry_after=retry_after,
|
|
)
|
|
|
|
# Under limit — add this request
|
|
member = f"{now}:{id(key)}" # unique member per call
|
|
await self._redis.zadd(key, {member: now})
|
|
# Set TTL on the key so it auto-expires after the window
|
|
await self._redis.expire(key, window_seconds + 60)
|
|
|
|
remaining = limit - current_count - 1
|
|
return RateLimitResult(
|
|
allowed=True,
|
|
remaining=max(remaining, 0),
|
|
retry_after=0,
|
|
)
|
|
|
|
except Exception:
|
|
logger.warning(
|
|
"rate_limit_redis_error key=%s — failing open", key, exc_info=True
|
|
)
|
|
return RateLimitResult(allowed=True, remaining=limit, retry_after=0)
|