"""Redis sliding-window rate limiter using sorted sets. Each rate limit key is a Redis sorted set where members are unique request identifiers (timestamps with microseconds) and scores are Unix timestamps. On each check, expired entries are pruned, the current request is added, and the count determines whether the request is allowed. Fail-open: If Redis is unavailable, requests are allowed through with a WARNING log. """ from __future__ import annotations import logging import time from dataclasses import dataclass import redis.asyncio as aioredis logger = logging.getLogger("chrysopedia.rate_limiter") _KEY_PREFIX = "chrysopedia:ratelimit" @dataclass class RateLimitResult: """Result of a rate limit check.""" allowed: bool remaining: int retry_after: int # seconds until the window slides enough to allow a request; 0 if allowed class RateLimiter: """Sliding-window rate limiter backed by Redis sorted sets. Usage:: limiter = RateLimiter(redis) result = await limiter.check_rate_limit("user:abc123", limit=30, window_seconds=3600) if not result.allowed: return 429, result.retry_after """ def __init__(self, redis: aioredis.Redis) -> None: self._redis = redis @staticmethod def key(scope: str, identifier: str) -> str: """Build a namespaced Redis key for a rate limit bucket.""" return f"{_KEY_PREFIX}:{scope}:{identifier}" async def check_rate_limit( self, key: str, limit: int, window_seconds: int = 3600, ) -> RateLimitResult: """Check whether a request is within the rate limit. Uses a sorted set where: - ZREMRANGEBYSCORE prunes entries older than the window - ZCARD counts current entries - ZADD adds the current request if under limit Returns a RateLimitResult with allowed/remaining/retry_after. On Redis errors, fails open (allowed=True). """ now = time.time() window_start = now - window_seconds try: pipe = self._redis.pipeline(transaction=True) # Remove expired entries pipe.zremrangebyscore(key, "-inf", window_start) # Count remaining entries pipe.zcard(key) results = await pipe.execute() current_count: int = results[1] if current_count >= limit: # Over limit — calculate retry_after from oldest entry oldest = await self._redis.zrange(key, 0, 0, withscores=True) if oldest: oldest_score = oldest[0][1] retry_after = int(oldest_score + window_seconds - now) + 1 retry_after = max(retry_after, 1) else: retry_after = window_seconds return RateLimitResult( allowed=False, remaining=0, retry_after=retry_after, ) # Under limit — add this request member = f"{now}:{id(key)}" # unique member per call await self._redis.zadd(key, {member: now}) # Set TTL on the key so it auto-expires after the window await self._redis.expire(key, window_seconds + 60) remaining = limit - current_count - 1 return RateLimitResult( allowed=True, remaining=max(remaining, 0), retry_after=0, ) except Exception: logger.warning( "rate_limit_redis_error key=%s — failing open", key, exc_info=True ) return RateLimitResult(allowed=True, remaining=limit, retry_after=0)