|
| 1 | +import time |
| 2 | + |
| 3 | +from redis import Redis |
| 4 | + |
| 5 | +from .exceptions import ThrottlingMisconfigurationException |
| 6 | +from .redis_throttling_configuration import get_redis_throttling_configuration |
| 7 | +from .redis_throttling_configuration import RedisThrottlingConfiguration |
| 8 | + |
| 9 | + |
| 10 | +class RedisThrottlingClient: |
| 11 | + # Redis Lua scripts are atomic |
| 12 | + # Sliding window throttling. |
| 13 | + # Rejected requests aren't counted. |
| 14 | + THROTTLING_LUA = ''' |
| 15 | + local key = KEYS[1] |
| 16 | + local now = tonumber(ARGV[1]) |
| 17 | + local duration = tonumber(ARGV[2]) |
| 18 | + local max_requests = tonumber(ARGV[3]) |
| 19 | +
|
| 20 | + redis.call("ZREMRANGEBYSCORE", key, 0, now - duration) |
| 21 | + local count = redis.call("ZCARD", key) |
| 22 | +
|
| 23 | + if count >= max_requests then |
| 24 | + return 0 |
| 25 | + end |
| 26 | +
|
| 27 | + redis.call("ZADD", key, now, now) |
| 28 | + redis.call("EXPIRE", key, duration) |
| 29 | + return 1 |
| 30 | + ''' |
| 31 | + |
| 32 | + def __init__(self, configuration: RedisThrottlingConfiguration): |
| 33 | + self._redis_client = Redis( |
| 34 | + host=configuration.host, |
| 35 | + port=configuration.port, |
| 36 | + db=configuration.db, |
| 37 | + password=configuration.password, |
| 38 | + decode_responses=True, |
| 39 | + ) |
| 40 | + self._throttling_script = self._redis_client.register_script(self.THROTTLING_LUA) |
| 41 | + |
| 42 | + def is_request_allowed(self, key: str, duration: int, num_requests: int) -> bool: |
| 43 | + now = time.time() |
| 44 | + is_allowed = self._throttling_script( |
| 45 | + keys=[key], |
| 46 | + args=[now, duration, num_requests] |
| 47 | + ) |
| 48 | + return is_allowed == 1 |
| 49 | + |
| 50 | + def delete(self, key: str): |
| 51 | + self._redis_client.delete(key) |
| 52 | + |
| 53 | + |
| 54 | +_redis_throttling_client: RedisThrottlingClient | None = None |
| 55 | + |
| 56 | +def get_redis_throttling_client() -> RedisThrottlingClient: |
| 57 | + global _redis_throttling_client |
| 58 | + |
| 59 | + if _redis_throttling_client is None: |
| 60 | + configuration = get_redis_throttling_configuration() |
| 61 | + |
| 62 | + if configuration is None: |
| 63 | + raise ThrottlingMisconfigurationException('Configuration for Redis must be set before using the throttling') |
| 64 | + |
| 65 | + _redis_throttling_client = RedisThrottlingClient(configuration) |
| 66 | + |
| 67 | + return _redis_throttling_client |
0 commit comments