Skip to content

Commit 0bef723

Browse files
authored
Merge pull request #148 from JWCook/per-host-buckets
Fix per-host rate-limiting for Redis and Postgres backends
2 parents faacfff + af6a29e commit 0bef723

File tree

9 files changed

+115
-66
lines changed

9 files changed

+115
-66
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ repos:
88
- id: mixed-line-ending
99
- id: trailing-whitespace
1010
- repo: https://github.com/astral-sh/ruff-pre-commit
11-
rev: v0.15.8
11+
rev: v0.15.9
1212
hooks:
1313
- id: ruff
1414
- id: ruff-format

HISTORY.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
# History
22

3+
## 0.10.0 (Unreleased)
4+
* Fix per-host rate-limiting for Redis and Postgres backends
5+
* If both `per_host=True` and a `bucket_name` is specified, use `bucket_name` as a bucket prefix
6+
* Add warning if a custom Limiter object is passed with `per_host=True` and no HostBucketFactory
7+
38
## 0.9.3 (2026-04-02)
4-
* Fix compatibility with `RedisBucket`
5-
* Fix compatibility with `PostgresBucket`
9+
* Fix bucket initialization for `RedisBucket` and `PostgresBucket`
610
* Use built-in support for pickling `Limiter` from pyrate-limiter 4.1.0
711

812
## 0.9.2 (2026-02-27)

README.md

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,16 +123,20 @@ The following parameters are available for the most common rate limit intervals:
123123
<!-- TODO: Section explaining burst rate limit -->
124124

125125
### Advanced Settings
126-
If you need to define more complex rate limits, you can create a `Limiter` object instead:
126+
If you need to define more complex rate limits, you can use `Limiter` directly.
127+
Note that it must be used with `HostBucketFactory` if you want per-host rate-limiting.
128+
127129
```python
128-
from pyrate_limiter import Duration, RequestRate, Limiter
129-
from requests_ratelimiter import LimiterSession
130+
from pyrate_limiter import Duration, Rate, Limiter
131+
from requests_ratelimiter import LimiterSession, HostBucketFactory
130132

131-
nanocentury_rate = RequestRate(10, Duration.SECOND * 3.156)
132-
fortnight_rate = RequestRate(1000, Duration.DAY * 14)
133-
trimonthly_rate = RequestRate(10000, Duration.MONTH * 3)
134-
limiter = Limiter(nanocentury_rate, fortnight_rate, trimonthly_rate)
133+
nanocentury_rate = Rate(10, Duration.SECOND * 3.156)
134+
fortnight_rate = Rate(1000, Duration.DAY * 14)
135+
trimonthly_rate = Rate(10000, Duration.DAY * 90)
135136

137+
# This factory object is required for per-host rate-limiting
138+
factory = HostBucketFactory(rates=[nanocentury_rate, fortnight_rate, trimonthly_rate])
139+
limiter = Limiter(factory)
136140
session = LimiterSession(limiter=limiter)
137141
```
138142

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = 'requests-ratelimiter'
3-
version = '0.9.3'
3+
version = '0.10.0'
44
description = 'Rate-limiting for the requests library'
55
authors = [{name = 'Jordan Cook'}]
66
license = 'MIT'

requests_ratelimiter/buckets.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
from typing import Dict, Optional, Type
23

34
from pyrate_limiter import InMemoryBucket, PostgresBucket, Rate, RedisBucket, SQLiteBucket
@@ -33,53 +34,54 @@ def get(self, item: RateItem) -> AbstractBucket:
3334
"""Get or create a bucket for the given item name"""
3435
if item.name not in self.buckets:
3536
# Create new bucket for this name
36-
bucket = self._create_bucket()
37+
bucket = self._create_bucket(item.name)
3738
self.schedule_leak(bucket)
3839
self.buckets[item.name] = bucket
3940

4041
return self.buckets[item.name]
4142

42-
def _create_bucket(self) -> AbstractBucket:
43-
"""Create a new bucket instance with the configured bucket class"""
43+
def _create_bucket(self, name: str) -> AbstractBucket:
44+
"""Create a new bucket instance, and handle per-host naming for each supported backend"""
4445
if self.bucket_class == InMemoryBucket:
4546
return InMemoryBucket(self.rates)
4647
elif self.bucket_class == SQLiteBucket:
47-
kwargs = prepare_sqlite_kwargs(self.bucket_init_kwargs, self.bucket_name)
48+
kwargs = prepare_sqlite_kwargs(self.bucket_init_kwargs, name)
4849
return SQLiteBucket.init_from_file(rates=self.rates, **kwargs)
4950
elif self.bucket_class == RedisBucket:
5051
kwargs = self.bucket_init_kwargs.copy()
51-
bucket_key = kwargs.pop('bucket_key', self.bucket_name or 'default')
5252
redis = kwargs.pop('redis')
53-
return RedisBucket.init(rates=self.rates, redis=redis, bucket_key=bucket_key)
53+
bucket_key = kwargs.pop('bucket_key', _sanitize_name(name))
54+
return RedisBucket.init(rates=self.rates, redis=redis, bucket_key=bucket_key, **kwargs)
5455
elif self.bucket_class == PostgresBucket:
5556
kwargs = self.bucket_init_kwargs.copy()
5657
pool = kwargs.pop('pool')
57-
table = kwargs.pop('table', self.bucket_name or 'default')
58+
table = kwargs.pop('table', _sanitize_name(name))
5859
return PostgresBucket(pool=pool, table=table, rates=self.rates)
5960
else:
60-
# Generic bucket creation - pass rates as first arg
6161
return self.bucket_class(self.rates, **self.bucket_init_kwargs)
6262

6363
def __getitem__(self, name: str) -> AbstractBucket:
6464
"""Dict-like access for backward compatibility with _fill_bucket() method"""
6565
if name not in self.buckets:
66-
# Create bucket on access
6766
temp_item = RateItem(name, 0, 1)
6867
return self.get(temp_item)
6968
return self.buckets[name]
7069

7170

7271
def prepare_sqlite_kwargs(bucket_kwargs: Dict, bucket_name: Optional[str] = None) -> Dict:
73-
"""Prepare SQLiteBucket kwargs for v4 compatibility"""
7472
kwargs = bucket_kwargs.copy()
7573
if 'path' in kwargs:
7674
kwargs['db_path'] = str(kwargs.pop('path'))
7775

7876
# If bucket_name is specified, use it as the table name to ensure separation
7977
# This allows multiple sessions with different bucket_names to share a db file
8078
if bucket_name and 'table' not in kwargs:
81-
kwargs['table'] = f'bucket_{bucket_name}'
79+
kwargs['table'] = f'bucket_{_sanitize_name(bucket_name)}'
8280

8381
# Filter to only supported parameters for SQLiteBucket.init_from_file
8482
supported_params = {'table', 'db_path', 'create_new_table', 'use_file_lock'}
8583
return {k: v for k, v in kwargs.items() if k in supported_params}
84+
85+
86+
def _sanitize_name(name: str) -> str:
87+
return re.sub(r'[^a-zA-Z0-9_]', '_', name)

requests_ratelimiter/requests_ratelimiter.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ def __init__(
6666
if limiter:
6767
self.limiter = limiter
6868
self._custom_limiter = True
69+
if per_host and not isinstance(limiter.bucket_factory, HostBucketFactory):
70+
logger.warning(
71+
'Custom limiter does not use HostBucketFactory; per-host rate limiting will '
72+
'not work. Use HostBucketFactory to enable per-host rate limiting.'
73+
)
6974
else:
7075
factory = HostBucketFactory(
7176
rates=rates,
@@ -101,10 +106,11 @@ def send(self, request: PreparedRequest, **kwargs) -> Response:
101106

102107
def _bucket_name(self, request):
103108
"""Get a bucket name for the given request"""
104-
if self.bucket_name:
109+
if self.per_host:
110+
host = urlparse(request.url).netloc
111+
return f'{self.bucket_name}:{host}' if self.bucket_name else host
112+
elif self.bucket_name:
105113
return self.bucket_name
106-
elif self.per_host:
107-
return urlparse(request.url).netloc
108114
else:
109115
return self._default_bucket
110116

@@ -188,8 +194,9 @@ class LimiterSession(LimiterMixin, Session):
188194
:py:class:`~pyrate_limiter.buckets.redis_bucket.RedisBucket`
189195
bucket_kwargs: Bucket backend keyword arguments
190196
limiter: An existing Limiter object to use instead of the above params
191-
per_host: Track request rate limits separately for each host
192197
limit_statuses: Alternative HTTP status codes that indicate a rate limit was exceeded
198+
per_host: Track request rate limits separately for each host
199+
bucket_name: Override default bucket name. In per-host mode, this sets the bucket prefix.
193200
"""
194201

195202
__attrs__ = Session.__attrs__ + [

test/test_buckets.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,14 @@ def test_separate_buckets_per_name():
5252

5353

5454
@pytest.mark.parametrize(
55-
'extra_init_kwargs, expected_key',
55+
'extra_init_kwargs, identity, expected_key',
5656
[
57-
({'bucket_key': 'test_bucket'}, 'test_bucket'),
58-
({}, 'default'),
57+
({}, 'api.example.com', 'api_example_com'),
58+
({'bucket_key': 'override'}, 'api.example.com', 'override'),
59+
({}, 'myapp:api.example.com', 'myapp_api_example_com'),
5960
],
60-
ids=['explicit_key', 'default_key'],
6161
)
62-
def test_redis_bucket(extra_init_kwargs, expected_key):
62+
def test_redis_bucket(extra_init_kwargs, identity, expected_key):
6363
mock_redis = MagicMock()
6464
mock_redis.script_load.return_value = 'fake_sha1'
6565
factory = HostBucketFactory(
@@ -69,7 +69,7 @@ def test_redis_bucket(extra_init_kwargs, expected_key):
6969
)
7070

7171
with patch.object(RedisBucket, 'init', wraps=RedisBucket.init) as mock_init:
72-
factory._create_bucket()
72+
factory._create_bucket(identity)
7373

7474
mock_init.assert_called_once_with(
7575
rates=factory.rates, redis=mock_redis, bucket_key=expected_key
@@ -82,20 +82,19 @@ def _postgres_init_stub(self, pool, table, rates):
8282

8383

8484
@pytest.mark.parametrize(
85-
'extra_init_kwargs, factory_kwargs, expected_table',
85+
'extra_init_kwargs, identity, expected_table',
8686
[
87-
({'table': 'test_table'}, {}, 'test_table'),
88-
({}, {'bucket_name': 'my_table'}, 'my_table'),
87+
({}, 'api.example.com', 'api_example_com'),
88+
({'table': 'override'}, 'api.example.com', 'override'),
89+
({}, 'myapp:api.example.com', 'myapp_api_example_com'),
8990
],
90-
ids=['explicit_table', 'default_from_bucket_name'],
9191
)
92-
def test_postgres_bucket(extra_init_kwargs, factory_kwargs, expected_table):
92+
def test_postgres_bucket(extra_init_kwargs, identity, expected_table):
9393
mock_pool = MagicMock()
9494
factory = HostBucketFactory(
9595
rates=[Rate(5, 1000)],
9696
bucket_class=PostgresBucket,
9797
bucket_init_kwargs={'pool': mock_pool, **extra_init_kwargs},
98-
**factory_kwargs,
9998
)
10099

101100
with patch.object(
@@ -105,7 +104,7 @@ def test_postgres_bucket(extra_init_kwargs, factory_kwargs, expected_table):
105104
return_value=None,
106105
side_effect=_postgres_init_stub,
107106
) as mock_init:
108-
factory._create_bucket()
107+
factory._create_bucket(identity)
109108

110109
_, kwargs = mock_init.call_args
111110
assert kwargs == {'pool': mock_pool, 'table': expected_table, 'rates': factory.rates}
@@ -119,7 +118,7 @@ class CustomBucket(InMemoryBucket):
119118
rates=[Rate(5, 1000)],
120119
bucket_class=CustomBucket,
121120
)
122-
bucket = factory._create_bucket()
121+
bucket = factory._create_bucket('test_host')
123122
assert isinstance(bucket, CustomBucket)
124123

125124

test/test_requests_ratelimiter.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from requests import PreparedRequest, Session
1919
from requests_cache import CacheMixin
2020

21-
from requests_ratelimiter import LimiterMixin, LimiterSession
21+
from requests_ratelimiter import HostBucketFactory, LimiterMixin, LimiterSession
2222
from requests_ratelimiter.requests_ratelimiter import _convert_rate, _get_valid_kwargs
2323
from test.conftest import (
2424
MOCKED_URL,
@@ -48,7 +48,6 @@ def __init__(self, *args, flag: bool = False, **kwargs):
4848
lambda: LimiterSession(per_second=5),
4949
lambda: CustomSession(per_second=5, flag=True),
5050
],
51-
ids=['LimiterSession', 'CustomSession'],
5251
)
5352
def test_rate_limit_enforcement(mock_sleep, session_factory):
5453
session = mount_mock_adapter(session_factory())
@@ -92,6 +91,40 @@ def test_custom_limiter(mock_sleep):
9291
assert mock_sleep.called is True
9392

9493

94+
@patch_sleep
95+
def test_custom_limiter__per_host(mock_sleep):
96+
factory = HostBucketFactory(rates=[Rate(5, Duration.SECOND)])
97+
limiter = Limiter(factory)
98+
session = get_mock_session(limiter=limiter, per_host=True)
99+
100+
for _ in range(5):
101+
session.get(MOCKED_URL)
102+
assert mock_sleep.called is False
103+
104+
# A different host should not be affected
105+
session.get(MOCKED_URL_ALT_HOST)
106+
assert mock_sleep.called is False
107+
108+
session.get(MOCKED_URL)
109+
assert mock_sleep.called is True
110+
111+
112+
@pytest.mark.parametrize(
113+
'session_kwargs, expect_warning',
114+
[
115+
({'per_host': True}, True),
116+
({'per_host': False}, False),
117+
],
118+
)
119+
def test_custom_limiter__per_host_warning(caplog, session_kwargs, expect_warning):
120+
"""Warn when a custom limiter without HostBucketFactory is used with per_host=True"""
121+
bucket = InMemoryBucket([Rate(5, Duration.SECOND)])
122+
limiter = Limiter(bucket)
123+
with caplog.at_level('WARNING', logger='requests_ratelimiter'):
124+
LimiterSession(limiter=limiter, **session_kwargs)
125+
assert ('HostBucketFactory' in caplog.text) == expect_warning
126+
127+
95128
@patch_sleep
96129
@pytest.mark.parametrize(
97130
'url, session_kwargs, expect_sleep',
@@ -299,7 +332,7 @@ def test_close_before_any_request_and_idempotent():
299332
def test_fill_bucket_with_custom_limiter(mock_sleep):
300333
bucket = InMemoryBucket([Rate(5, Duration.SECOND)])
301334
limiter = Limiter(bucket)
302-
session = get_mock_session(limiter=limiter)
335+
session = get_mock_session(limiter=limiter, per_host=False)
303336
session.get(MOCKED_URL_429)
304337
session.get(MOCKED_URL_429)
305338
assert mock_sleep.called is True
@@ -335,11 +368,11 @@ class MyBucket(InMemoryBucket):
335368
session.close()
336369

337370

338-
def test_bucket_name_overrides_per_host():
339-
session = LimiterSession(per_second=5, bucket_name='fixed', per_host=True)
371+
def test_bucket_name_prefixes_per_host():
372+
session = LimiterSession(per_second=5, bucket_name='myapp', per_host=True)
340373
req = PreparedRequest()
341374
req.url = MOCKED_URL
342-
assert session._bucket_name(req) == 'fixed'
375+
assert session._bucket_name(req) == 'myapp:requests-ratelimiter.com'
343376

344377

345378
def test_max_delay_logs_warning(caplog):
@@ -365,7 +398,6 @@ def test_burst_allows_consecutive_requests(mock_sleep):
365398
(InMemoryBucket, {}), # InMemoryBucket
366399
(SQLiteBucket, None), # SQLiteBucket (will use fixture to provide kwargs)
367400
],
368-
ids=['in_memory', 'sqlite'],
369401
)
370402
def test_pickling(mock_sleep, bucket_class, bucket_kwargs, tmp_path):
371403
if bucket_class == SQLiteBucket:
@@ -434,18 +466,19 @@ def test_no_rate_limits_no_limiter():
434466

435467

436468
@pytest.mark.parametrize(
437-
'url, expected_name',
469+
'url, bucket_name, expected_name',
438470
[
439-
('http+mock://example.com/path', 'example.com'),
440-
('http+mock://example.com:8080/path', 'example.com:8080'),
441-
('http+mock://192.168.1.1/path', '192.168.1.1'),
442-
('http+mock://[::1]/path', '[::1]'),
443-
('http+mock://[::1]:8080/path', '[::1]:8080'),
471+
('http+mock://example.com/path', None, 'example.com'),
472+
('http+mock://example.com:8080/path', None, 'example.com:8080'),
473+
('http+mock://192.168.1.1/path', None, '192.168.1.1'),
474+
('http+mock://[::1]/path', None, '[::1]'),
475+
('http+mock://[::1]:8080/path', None, '[::1]:8080'),
476+
('http+mock://example.com/path', 'myapp', 'myapp:example.com'),
444477
],
445478
)
446-
def test_bucket_name_from_url(url, expected_name):
479+
def test_bucket_name_from_url(url, bucket_name, expected_name):
447480
"""per_host bucket names are derived from URL netloc, including ports and IPs"""
448-
session = LimiterSession(per_second=5, per_host=True)
481+
session = LimiterSession(per_second=5, per_host=True, bucket_name=bucket_name)
449482
req = PreparedRequest()
450483
req.url = url
451484
assert session._bucket_name(req) == expected_name
@@ -454,7 +487,7 @@ def test_bucket_name_from_url(url, expected_name):
454487
def test_custom_limiter_close_does_not_stop_factory():
455488
bucket = InMemoryBucket([Rate(5, Duration.SECOND)])
456489
limiter = Limiter(bucket)
457-
session = get_mock_session(limiter=limiter)
490+
session = get_mock_session(limiter=limiter, per_host=False)
458491
session.get(MOCKED_URL)
459492
session.close()
460493

0 commit comments

Comments
 (0)