Spaces:
Build error
Build error
| import time | |
| from typing import Optional, Dict | |
| from open_webui.env import REDIS_KEY_PREFIX | |
| class RateLimiter: | |
| """ | |
| General-purpose rate limiter using Redis with a rolling window strategy. | |
| Falls back to in-memory storage if Redis is not available. | |
| """ | |
| # In-memory fallback storage | |
| _memory_store: Dict[str, Dict[int, int]] = {} | |
| def __init__( | |
| self, | |
| redis_client, | |
| limit: int, | |
| window: int, | |
| bucket_size: int = 60, | |
| enabled: bool = True, | |
| ): | |
| """ | |
| :param redis_client: Redis client instance or None | |
| :param limit: Max allowed events in the window | |
| :param window: Time window in seconds | |
| :param bucket_size: Bucket resolution | |
| :param enabled: Turn on/off rate limiting globally | |
| """ | |
| self.r = redis_client | |
| self.limit = limit | |
| self.window = window | |
| self.bucket_size = bucket_size | |
| self.num_buckets = window // bucket_size | |
| self.enabled = enabled | |
| def _bucket_key(self, key: str, bucket_index: int) -> str: | |
| return f"{REDIS_KEY_PREFIX}:ratelimit:{key.lower()}:{bucket_index}" | |
| def _current_bucket(self) -> int: | |
| return int(time.time()) // self.bucket_size | |
| def _redis_available(self) -> bool: | |
| return self.r is not None | |
| def is_limited(self, key: str) -> bool: | |
| """ | |
| Main rate-limit check. | |
| Gracefully handles missing or failing Redis. | |
| """ | |
| if not self.enabled: | |
| return False | |
| if self._redis_available(): | |
| try: | |
| return self._is_limited_redis(key) | |
| except Exception: | |
| return self._is_limited_memory(key) | |
| else: | |
| return self._is_limited_memory(key) | |
| def get_count(self, key: str) -> int: | |
| if not self.enabled: | |
| return 0 | |
| if self._redis_available(): | |
| try: | |
| return self._get_count_redis(key) | |
| except Exception: | |
| return self._get_count_memory(key) | |
| else: | |
| return self._get_count_memory(key) | |
| def remaining(self, key: str) -> int: | |
| used = self.get_count(key) | |
| return max(0, self.limit - used) | |
| def _is_limited_redis(self, key: str) -> bool: | |
| now_bucket = self._current_bucket() | |
| bucket_key = self._bucket_key(key, now_bucket) | |
| attempts = self.r.incr(bucket_key) | |
| if attempts == 1: | |
| self.r.expire(bucket_key, self.window + self.bucket_size) | |
| # Collect buckets | |
| buckets = [ | |
| self._bucket_key(key, now_bucket - i) for i in range(self.num_buckets + 1) | |
| ] | |
| counts = self.r.mget(buckets) | |
| total = sum(int(c) for c in counts if c) | |
| return total > self.limit | |
| def _get_count_redis(self, key: str) -> int: | |
| now_bucket = self._current_bucket() | |
| buckets = [ | |
| self._bucket_key(key, now_bucket - i) for i in range(self.num_buckets + 1) | |
| ] | |
| counts = self.r.mget(buckets) | |
| return sum(int(c) for c in counts if c) | |
| def _is_limited_memory(self, key: str) -> bool: | |
| now_bucket = self._current_bucket() | |
| # Init storage | |
| if key not in self._memory_store: | |
| self._memory_store[key] = {} | |
| store = self._memory_store[key] | |
| # Increment bucket | |
| store[now_bucket] = store.get(now_bucket, 0) + 1 | |
| # Drop expired buckets | |
| min_bucket = now_bucket - self.num_buckets | |
| expired = [b for b in store if b < min_bucket] | |
| for b in expired: | |
| del store[b] | |
| # Count totals | |
| total = sum(store.values()) | |
| return total > self.limit | |
| def _get_count_memory(self, key: str) -> int: | |
| now_bucket = self._current_bucket() | |
| if key not in self._memory_store: | |
| return 0 | |
| store = self._memory_store[key] | |
| min_bucket = now_bucket - self.num_buckets | |
| # Remove expired | |
| expired = [b for b in store if b < min_bucket] | |
| for b in expired: | |
| del store[b] | |
| return sum(store.values()) | |