1818from requests import PreparedRequest , Session
1919from requests_cache import CacheMixin
2020
21- from requests_ratelimiter import LimiterMixin , LimiterSession
21+ from requests_ratelimiter import HostBucketFactory , LimiterMixin , LimiterSession
2222from requests_ratelimiter .requests_ratelimiter import _convert_rate , _get_valid_kwargs
2323from test .conftest import (
2424 MOCKED_URL ,
@@ -48,7 +48,6 @@ def __init__(self, *args, flag: bool = False, **kwargs):
4848 lambda : LimiterSession (per_second = 5 ),
4949 lambda : CustomSession (per_second = 5 , flag = True ),
5050 ],
51- ids = ['LimiterSession' , 'CustomSession' ],
5251)
5352def test_rate_limit_enforcement (mock_sleep , session_factory ):
5453 session = mount_mock_adapter (session_factory ())
@@ -92,6 +91,40 @@ def test_custom_limiter(mock_sleep):
9291 assert mock_sleep .called is True
9392
9493
94+ @patch_sleep
95+ def test_custom_limiter__per_host (mock_sleep ):
96+ factory = HostBucketFactory (rates = [Rate (5 , Duration .SECOND )])
97+ limiter = Limiter (factory )
98+ session = get_mock_session (limiter = limiter , per_host = True )
99+
100+ for _ in range (5 ):
101+ session .get (MOCKED_URL )
102+ assert mock_sleep .called is False
103+
104+ # A different host should not be affected
105+ session .get (MOCKED_URL_ALT_HOST )
106+ assert mock_sleep .called is False
107+
108+ session .get (MOCKED_URL )
109+ assert mock_sleep .called is True
110+
111+
112+ @pytest .mark .parametrize (
113+ 'session_kwargs, expect_warning' ,
114+ [
115+ ({'per_host' : True }, True ),
116+ ({'per_host' : False }, False ),
117+ ],
118+ )
119+ def test_custom_limiter__per_host_warning (caplog , session_kwargs , expect_warning ):
120+ """Warn when a custom limiter without HostBucketFactory is used with per_host=True"""
121+ bucket = InMemoryBucket ([Rate (5 , Duration .SECOND )])
122+ limiter = Limiter (bucket )
123+ with caplog .at_level ('WARNING' , logger = 'requests_ratelimiter' ):
124+ LimiterSession (limiter = limiter , ** session_kwargs )
125+ assert ('HostBucketFactory' in caplog .text ) == expect_warning
126+
127+
95128@patch_sleep
96129@pytest .mark .parametrize (
97130 'url, session_kwargs, expect_sleep' ,
@@ -299,7 +332,7 @@ def test_close_before_any_request_and_idempotent():
299332def test_fill_bucket_with_custom_limiter (mock_sleep ):
300333 bucket = InMemoryBucket ([Rate (5 , Duration .SECOND )])
301334 limiter = Limiter (bucket )
302- session = get_mock_session (limiter = limiter )
335+ session = get_mock_session (limiter = limiter , per_host = False )
303336 session .get (MOCKED_URL_429 )
304337 session .get (MOCKED_URL_429 )
305338 assert mock_sleep .called is True
@@ -335,11 +368,11 @@ class MyBucket(InMemoryBucket):
335368 session .close ()
336369
337370
338- def test_bucket_name_overrides_per_host ():
339- session = LimiterSession (per_second = 5 , bucket_name = 'fixed ' , per_host = True )
371+ def test_bucket_name_prefixes_per_host ():
372+ session = LimiterSession (per_second = 5 , bucket_name = 'myapp ' , per_host = True )
340373 req = PreparedRequest ()
341374 req .url = MOCKED_URL
342- assert session ._bucket_name (req ) == 'fixed '
375+ assert session ._bucket_name (req ) == 'myapp:requests-ratelimiter.com '
343376
344377
345378def test_max_delay_logs_warning (caplog ):
@@ -365,7 +398,6 @@ def test_burst_allows_consecutive_requests(mock_sleep):
365398 (InMemoryBucket , {}), # InMemoryBucket
366399 (SQLiteBucket , None ), # SQLiteBucket (will use fixture to provide kwargs)
367400 ],
368- ids = ['in_memory' , 'sqlite' ],
369401)
370402def test_pickling (mock_sleep , bucket_class , bucket_kwargs , tmp_path ):
371403 if bucket_class == SQLiteBucket :
@@ -434,18 +466,19 @@ def test_no_rate_limits_no_limiter():
434466
435467
436468@pytest .mark .parametrize (
437- 'url, expected_name' ,
469+ 'url, bucket_name, expected_name' ,
438470 [
439- ('http+mock://example.com/path' , 'example.com' ),
440- ('http+mock://example.com:8080/path' , 'example.com:8080' ),
441- ('http+mock://192.168.1.1/path' , '192.168.1.1' ),
442- ('http+mock://[::1]/path' , '[::1]' ),
443- ('http+mock://[::1]:8080/path' , '[::1]:8080' ),
471+ ('http+mock://example.com/path' , None , 'example.com' ),
472+ ('http+mock://example.com:8080/path' , None , 'example.com:8080' ),
473+ ('http+mock://192.168.1.1/path' , None , '192.168.1.1' ),
474+ ('http+mock://[::1]/path' , None , '[::1]' ),
475+ ('http+mock://[::1]:8080/path' , None , '[::1]:8080' ),
476+ ('http+mock://example.com/path' , 'myapp' , 'myapp:example.com' ),
444477 ],
445478)
446- def test_bucket_name_from_url (url , expected_name ):
479+ def test_bucket_name_from_url (url , bucket_name , expected_name ):
447480 """per_host bucket names are derived from URL netloc, including ports and IPs"""
448- session = LimiterSession (per_second = 5 , per_host = True )
481+ session = LimiterSession (per_second = 5 , per_host = True , bucket_name = bucket_name )
449482 req = PreparedRequest ()
450483 req .url = url
451484 assert session ._bucket_name (req ) == expected_name
@@ -454,7 +487,7 @@ def test_bucket_name_from_url(url, expected_name):
454487def test_custom_limiter_close_does_not_stop_factory ():
455488 bucket = InMemoryBucket ([Rate (5 , Duration .SECOND )])
456489 limiter = Limiter (bucket )
457- session = get_mock_session (limiter = limiter )
490+ session = get_mock_session (limiter = limiter , per_host = False )
458491 session .get (MOCKED_URL )
459492 session .close ()
460493
0 commit comments