@@ -153,12 +153,18 @@ class MatrixFederationRequest:
153
153
"""Query arguments.
154
154
"""
155
155
156
- txn_id : Optional [ str ] = None
157
- """Unique ID for this request (for logging)
156
+ txn_id : str = attr . ib ( init = False )
157
+ """Unique ID for this request (for logging), this is autogenerated.
158
158
"""
159
159
160
- uri : bytes = attr .ib (init = False )
161
- """The URI of this request
160
+ uri : bytes = b""
161
+ """The URI of this request, usually generated from the above information.
162
+ """
163
+
164
+ _generate_uri : bool = True
165
+ """True to automatically generate the uri field based on the above information.
166
+
167
+ Set to False if manually configuring the URI.
162
168
"""
163
169
164
170
def __attrs_post_init__ (self ) -> None :
@@ -168,22 +174,23 @@ def __attrs_post_init__(self) -> None:
168
174
169
175
object .__setattr__ (self , "txn_id" , txn_id )
170
176
171
- destination_bytes = self .destination .encode ("ascii" )
172
- path_bytes = self .path .encode ("ascii" )
173
- query_bytes = encode_query_args (self .query )
174
-
175
- # The object is frozen so we can pre-compute this.
176
- uri = urllib .parse .urlunparse (
177
- (
178
- b"matrix-federation" ,
179
- destination_bytes ,
180
- path_bytes ,
181
- None ,
182
- query_bytes ,
183
- b"" ,
177
+ if self ._generate_uri :
178
+ destination_bytes = self .destination .encode ("ascii" )
179
+ path_bytes = self .path .encode ("ascii" )
180
+ query_bytes = encode_query_args (self .query )
181
+
182
+ # The object is frozen so we can pre-compute this.
183
+ uri = urllib .parse .urlunparse (
184
+ (
185
+ b"matrix-federation" ,
186
+ destination_bytes ,
187
+ path_bytes ,
188
+ None ,
189
+ query_bytes ,
190
+ b"" ,
191
+ )
184
192
)
185
- )
186
- object .__setattr__ (self , "uri" , uri )
193
+ object .__setattr__ (self , "uri" , uri )
187
194
188
195
def get_json (self ) -> Optional [JsonDict ]:
189
196
if self .json_callback :
@@ -513,6 +520,7 @@ async def _send_request(
513
520
ignore_backoff : bool = False ,
514
521
backoff_on_404 : bool = False ,
515
522
backoff_on_all_error_codes : bool = False ,
523
+ follow_redirects : bool = False ,
516
524
) -> IResponse :
517
525
"""
518
526
Sends a request to the given server.
@@ -555,6 +563,9 @@ async def _send_request(
555
563
backoff_on_404: Back off if we get a 404
556
564
backoff_on_all_error_codes: Back off if we get any error response
557
565
566
+ follow_redirects: True to follow the Location header of 307/308 redirect
567
+ responses. This does not recurse.
568
+
558
569
Returns:
559
570
Resolves with the HTTP response object on success.
560
571
@@ -714,6 +725,26 @@ async def _send_request(
714
725
response .code ,
715
726
response_phrase ,
716
727
)
728
+ elif (
729
+ response .code in (307 , 308 )
730
+ and follow_redirects
731
+ and response .headers .hasHeader ("Location" )
732
+ ):
733
+ # The Location header *might* be relative so resolve it.
734
+ location = response .headers .getRawHeaders (b"Location" )[0 ]
735
+ new_uri = urllib .parse .urljoin (request .uri , location )
736
+
737
+ return await self ._send_request (
738
+ attr .evolve (request , uri = new_uri , generate_uri = False ),
739
+ retry_on_dns_fail ,
740
+ timeout ,
741
+ long_retries ,
742
+ ignore_backoff ,
743
+ backoff_on_404 ,
744
+ backoff_on_all_error_codes ,
745
+ # Do not continue following redirects.
746
+ follow_redirects = False ,
747
+ )
717
748
else :
718
749
logger .info (
719
750
"{%s} [%s] Got response headers: %d %s" ,
@@ -1383,6 +1414,7 @@ async def get_file(
1383
1414
retry_on_dns_fail : bool = True ,
1384
1415
max_size : Optional [int ] = None ,
1385
1416
ignore_backoff : bool = False ,
1417
+ follow_redirects : bool = False ,
1386
1418
) -> Tuple [int , Dict [bytes , List [bytes ]]]:
1387
1419
"""GETs a file from a given homeserver
1388
1420
Args:
@@ -1392,6 +1424,8 @@ async def get_file(
1392
1424
args: Optional dictionary used to create the query string.
1393
1425
ignore_backoff: true to ignore the historical backoff data
1394
1426
and try the request anyway.
1427
+ follow_redirects: True to follow the Location header of 307/308 redirect
1428
+ responses. This does not recurse.
1395
1429
1396
1430
Returns:
1397
1431
Resolves with an (int,dict) tuple of
@@ -1412,7 +1446,10 @@ async def get_file(
1412
1446
)
1413
1447
1414
1448
response = await self ._send_request (
1415
- request , retry_on_dns_fail = retry_on_dns_fail , ignore_backoff = ignore_backoff
1449
+ request ,
1450
+ retry_on_dns_fail = retry_on_dns_fail ,
1451
+ ignore_backoff = ignore_backoff ,
1452
+ follow_redirects = follow_redirects ,
1416
1453
)
1417
1454
1418
1455
headers = dict (response .headers .getAllRawHeaders ())
0 commit comments