Skip to content

Commit 4b7c793

Browse files
committed
feat: expose response headers via thread-safe contextvars
Implement a thread-safe and async-safe way to access response headers using Python's contextvars module. Changes: - Add ContextVar to store response headers isolated per execution context (thread/async task) - Add last_response_headers property to RestClient for accessing rate-limit and other response metadata - Update EmptyResponse to preserve headers from 204 responses - Comprehensive tests covering thread isolation, successive requests, and various HTTP methods Benefits: - Zero breaking changes - fully backwards compatible - Minimal code footprint - only rest.py modified - Thread-safe and async-safe via contextvars (Python 3.7+) - No per-class boilerplate (eliminates need for 32 duplicate properties) Usage: users = Users(domain="tenant.auth0.com", token="token") users.create({"email": "[email protected]"}) headers = users.client.last_response_headers remaining = int(headers.get("X-RateLimit-Remaining", 0))
1 parent 1233acd commit 4b7c793

File tree

2 files changed

+272
-4
lines changed

2 files changed

+272
-4
lines changed

auth0/rest.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import base64
44
import platform
55
import sys
6+
from contextvars import ContextVar
67
from json import dumps, loads
78
from random import randint
89
from time import sleep
@@ -19,6 +20,12 @@
1920

2021
UNKNOWN_ERROR = "a0.sdk.internal.unknown"
2122

23+
# Context variable to store response headers in a thread-safe and async-safe manner
24+
# Each execution context (thread or async task) gets its own isolated copy
25+
_response_headers: ContextVar[dict[str, str]] = ContextVar(
26+
"response_headers", default={}
27+
)
28+
2229

2330
class RestClientOptions:
2431
"""Configuration object for RestClient. Used for configuring
@@ -85,6 +92,9 @@ def __init__(
8592
self._metrics = {"retries": 0, "backoff": []}
8693
self._skip_sleep = False
8794

95+
# Initialize context variable for this client instance
96+
_response_headers.set({})
97+
8898
self.base_headers = {
8999
"Content-Type": "application/json",
90100
}
@@ -121,6 +131,26 @@ def __init__(
121131
self.telemetry = options.telemetry
122132
self.timeout = options.timeout
123133

134+
@property
135+
def last_response_headers(self) -> dict[str, str]:
136+
"""Get the headers from the most recent API response.
137+
138+
This property is thread-safe and async-safe, using context variables
139+
to isolate response headers per execution context (thread or async task).
140+
141+
Returns:
142+
dict[str, str]: Response headers including rate-limit information
143+
(X-RateLimit-Limit, X-RateLimit-Remaining, X-RateLimit-Reset).
144+
Returns an empty dict if no request has been made yet.
145+
146+
Example:
147+
>>> users = Users(domain="tenant.auth0.com", token="token")
148+
>>> users.create({"email": "[email protected]"})
149+
>>> headers = users.client.last_response_headers
150+
>>> remaining = int(headers.get("X-RateLimit-Remaining", 0))
151+
"""
152+
return _response_headers.get()
153+
124154
# Returns a hard cap for the maximum number of retries allowed (10)
125155
def MAX_REQUEST_RETRIES(self) -> int:
126156
return 10
@@ -262,11 +292,15 @@ def _calculate_wait(self, attempt: int) -> int:
262292
return wait
263293

264294
def _process_response(self, response: requests.Response) -> Any:
265-
return self._parse(response).content()
295+
parsed_response = self._parse(response)
296+
content = parsed_response.content()
297+
# Store headers in context variable for thread-safe/async-safe access
298+
_response_headers.set(dict(parsed_response._headers))
299+
return content
266300

267301
def _parse(self, response: requests.Response) -> Response:
268302
if not response.text:
269-
return EmptyResponse(response.status_code)
303+
return EmptyResponse(response.status_code, response.headers)
270304
try:
271305
return JsonResponse(response)
272306
except ValueError:
@@ -356,8 +390,8 @@ def _error_message(self) -> str:
356390

357391

358392
class EmptyResponse(Response):
359-
def __init__(self, status_code: int) -> None:
360-
super().__init__(status_code, "", {})
393+
def __init__(self, status_code: int, headers: Mapping[str, str] | None = None) -> None:
394+
super().__init__(status_code, "", headers or {})
361395

362396
def _error_code(self) -> str:
363397
return UNKNOWN_ERROR
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
"""Tests for context-var based response headers in RestClient.
2+
3+
Tests verify that headers are properly isolated across threads and async contexts.
4+
"""
5+
6+
import threading
7+
import time
8+
import unittest
9+
from unittest import mock
10+
11+
import responses
12+
13+
from auth0.rest import RestClient, RestClientOptions
14+
15+
16+
class TestRestClientHeadersContextVar(unittest.TestCase):
17+
"""Test that response headers are properly stored and accessed via contextvars."""
18+
19+
@responses.activate
20+
def test_headers_accessible_after_request(self):
21+
"""Test that headers are stored and accessible after a successful request."""
22+
responses.add(
23+
responses.GET,
24+
"https://example.com/api/test",
25+
json={"result": "ok"},
26+
status=200,
27+
headers={
28+
"X-RateLimit-Limit": "60",
29+
"X-RateLimit-Remaining": "59",
30+
"X-RateLimit-Reset": "1640000000",
31+
},
32+
)
33+
34+
client = RestClient(jwt="test-token")
35+
result = client.get("https://example.com/api/test")
36+
37+
self.assertEqual(result, {"result": "ok"})
38+
headers = client.last_response_headers
39+
self.assertEqual(headers.get("X-RateLimit-Limit"), "60")
40+
self.assertEqual(headers.get("X-RateLimit-Remaining"), "59")
41+
self.assertEqual(headers.get("X-RateLimit-Reset"), "1640000000")
42+
43+
@responses.activate
44+
def test_headers_on_204_response(self):
45+
"""Test that headers are preserved on 204 No Content responses."""
46+
responses.add(
47+
responses.DELETE,
48+
"https://example.com/api/resource/123",
49+
status=204,
50+
headers={
51+
"X-RateLimit-Limit": "30",
52+
"X-RateLimit-Remaining": "25",
53+
"X-RateLimit-Reset": "1640000100",
54+
},
55+
)
56+
57+
client = RestClient(jwt="test-token")
58+
result = client.delete("https://example.com/api/resource/123")
59+
60+
# 204 returns empty content
61+
self.assertEqual(result, "")
62+
# But headers should still be accessible
63+
headers = client.last_response_headers
64+
self.assertEqual(headers.get("X-RateLimit-Limit"), "30")
65+
self.assertEqual(headers.get("X-RateLimit-Remaining"), "25")
66+
67+
@responses.activate
68+
def test_headers_updated_on_successive_requests(self):
69+
"""Test that headers are updated with each new request."""
70+
# First request
71+
responses.add(
72+
responses.GET,
73+
"https://example.com/api/test1",
74+
json={"id": 1},
75+
status=200,
76+
headers={"X-RateLimit-Remaining": "59"},
77+
)
78+
79+
# Second request
80+
responses.add(
81+
responses.GET,
82+
"https://example.com/api/test2",
83+
json={"id": 2},
84+
status=200,
85+
headers={"X-RateLimit-Remaining": "58"},
86+
)
87+
88+
client = RestClient(jwt="test-token")
89+
90+
# First request
91+
client.get("https://example.com/api/test1")
92+
self.assertEqual(client.last_response_headers.get("X-RateLimit-Remaining"), "59")
93+
94+
# Second request should update headers
95+
client.get("https://example.com/api/test2")
96+
self.assertEqual(client.last_response_headers.get("X-RateLimit-Remaining"), "58")
97+
98+
@responses.activate
99+
def test_headers_empty_initially(self):
100+
"""Test that headers are empty before any request is made."""
101+
client = RestClient(jwt="test-token")
102+
headers = client.last_response_headers
103+
self.assertEqual(headers, {})
104+
105+
@responses.activate
106+
def test_thread_isolation(self):
107+
"""Test that response headers are isolated between threads.
108+
109+
This is the key thread-safety test: each thread should have its own
110+
response headers when using contextvars.
111+
"""
112+
results = {}
113+
errors = []
114+
115+
# Setup responses with different headers for each endpoint
116+
responses.add(
117+
responses.GET,
118+
"https://example.com/api/thread1",
119+
json={"thread": 1},
120+
status=200,
121+
headers={"X-RateLimit-Remaining": "100"},
122+
)
123+
124+
responses.add(
125+
responses.GET,
126+
"https://example.com/api/thread2",
127+
json={"thread": 2},
128+
status=200,
129+
headers={"X-RateLimit-Remaining": "200"},
130+
)
131+
132+
def thread_worker(thread_id: int, endpoint: str, remaining: str):
133+
"""Worker function for thread test."""
134+
try:
135+
client = RestClient(jwt="test-token")
136+
client.get(f"https://example.com/api/{endpoint}")
137+
# Each thread should see its own headers, not contaminated by other threads
138+
results[thread_id] = client.last_response_headers.get(
139+
"X-RateLimit-Remaining"
140+
)
141+
except Exception as e:
142+
errors.append(str(e))
143+
144+
# Start two threads that make requests simultaneously
145+
thread1 = threading.Thread(target=thread_worker, args=(1, "thread1", "100"))
146+
thread2 = threading.Thread(target=thread_worker, args=(2, "thread2", "200"))
147+
148+
thread1.start()
149+
thread2.start()
150+
151+
thread1.join()
152+
thread2.join()
153+
154+
# Verify no errors occurred
155+
self.assertEqual(errors, [], f"Errors in threads: {errors}")
156+
157+
# Verify each thread got the correct headers for its request
158+
self.assertEqual(results[1], "100", "Thread 1 should see its own headers")
159+
self.assertEqual(results[2], "200", "Thread 2 should see its own headers")
160+
161+
@responses.activate
162+
def test_headers_in_same_context_reflect_latest_request(self):
163+
"""Test that in the same execution context, headers reflect the latest request.
164+
165+
Contextvars are context-specific (thread or async task), not client-specific.
166+
When multiple clients make requests in the same context, the contextvar reflects
167+
the most recent response. For isolation per client, use different threads.
168+
"""
169+
responses.add(
170+
responses.GET,
171+
"https://example.com/api/request1",
172+
json={"request": 1},
173+
status=200,
174+
headers={"X-Request-ID": "request1"},
175+
)
176+
177+
responses.add(
178+
responses.GET,
179+
"https://example.com/api/request2",
180+
json={"request": 2},
181+
status=200,
182+
headers={"X-Request-ID": "request2"},
183+
)
184+
185+
client1 = RestClient(jwt="token1")
186+
client2 = RestClient(jwt="token2")
187+
188+
# First request
189+
client1.get("https://example.com/api/request1")
190+
self.assertEqual(client1.last_response_headers.get("X-Request-ID"), "request1")
191+
192+
# Second request in same context overwrites the contextvar
193+
client2.get("https://example.com/api/request2")
194+
# Both clients see the latest headers because they're in the same context
195+
self.assertEqual(client1.last_response_headers.get("X-Request-ID"), "request2")
196+
self.assertEqual(client2.last_response_headers.get("X-Request-ID"), "request2")
197+
198+
@responses.activate
199+
def test_post_request_headers(self):
200+
"""Test that headers are captured on POST requests."""
201+
responses.add(
202+
responses.POST,
203+
"https://example.com/api/create",
204+
json={"id": "new-id"},
205+
status=201,
206+
headers={"X-RateLimit-Remaining": "55"},
207+
)
208+
209+
client = RestClient(jwt="test-token")
210+
result = client.post("https://example.com/api/create", data={"name": "test"})
211+
212+
self.assertEqual(result["id"], "new-id")
213+
self.assertEqual(client.last_response_headers.get("X-RateLimit-Remaining"), "55")
214+
215+
@responses.activate
216+
def test_patch_request_headers(self):
217+
"""Test that headers are captured on PATCH requests."""
218+
responses.add(
219+
responses.PATCH,
220+
"https://example.com/api/update/123",
221+
json={"id": "123", "updated": True},
222+
status=200,
223+
headers={"X-RateLimit-Remaining": "54"},
224+
)
225+
226+
client = RestClient(jwt="test-token")
227+
result = client.patch("https://example.com/api/update/123", data={"name": "updated"})
228+
229+
self.assertEqual(result["updated"], True)
230+
self.assertEqual(client.last_response_headers.get("X-RateLimit-Remaining"), "54")
231+
232+
233+
if __name__ == "__main__":
234+
unittest.main()

0 commit comments

Comments
 (0)