|
7 | 7 | from functools import cached_property, lru_cache, wraps |
8 | 8 | from logging import getLogger as get_logger |
9 | 9 | from typing import IO, Any, BinaryIO, Callable, Dict, Iterator, List, Optional, Tuple |
| 10 | +from urllib.parse import urlparse |
10 | 11 |
|
11 | 12 | import boto3 |
12 | 13 | import botocore |
13 | | -from botocore.awsrequest import AWSResponse |
| 14 | +from botocore.awsrequest import AWSPreparedRequest, AWSResponse |
14 | 15 |
|
15 | 16 | from megfile.config import ( |
16 | 17 | DEFAULT_BLOCK_SIZE, |
17 | 18 | DEFAULT_MAX_BLOCK_SIZE, |
18 | 19 | DEFAULT_MIN_BLOCK_SIZE, |
19 | 20 | GLOBAL_MAX_WORKERS, |
| 21 | + HTTP_AUTH_HEADERS, |
20 | 22 | S3_CLIENT_CACHE_MODE, |
21 | 23 | S3_MAX_RETRY_TIMES, |
22 | 24 | ) |
|
76 | 78 | generate_cache_path, |
77 | 79 | get_binary_mode, |
78 | 80 | get_content_offset, |
| 81 | + is_domain_or_subdomain, |
79 | 82 | is_readable, |
80 | 83 | necessary_params, |
81 | 84 | process_local, |
@@ -162,24 +165,30 @@ def before_callback(operation_model, request_dict, request_context): |
162 | 165 | retry_callback=retry_callback, |
163 | 166 | ) |
164 | 167 |
|
165 | | - def patch_send_request(send_request): |
166 | | - def patched_send_request(request_dict, operation_model): |
167 | | - http, parsed_response = send_request(request_dict, operation_model) |
| 168 | + def patch_send(send): |
| 169 | + def patched_send(request: AWSPreparedRequest) -> AWSResponse: |
| 170 | + response: AWSResponse = send(request) |
168 | 171 | if ( |
169 | | - request_dict["method"] == "GET" # only support GET method for now |
170 | | - and http.status_code in (301, 302, 307, 308) |
171 | | - and "Location" in http.headers |
| 172 | + request.method == "GET" # only support GET method for now |
| 173 | + and response.status_code in (301, 302, 307, 308) |
| 174 | + and "Location" in response.headers |
172 | 175 | ): |
173 | | - request_dict["url"] = http.headers["Location"] |
174 | | - http, parsed_response = send_request(request_dict, operation_model) |
175 | | - return http, parsed_response |
176 | | - |
177 | | - return patched_send_request |
| 176 | + # Permit sending auth/cookie headers from "foo.com" to "sub.foo.com". |
| 177 | + # See also: https://go.dev/src/net/http/client.go#L980 |
| 178 | + location = response.headers["Location"] |
| 179 | + ihost = urlparse(request.url).hostname |
| 180 | + dhost = urlparse(location).hostname |
| 181 | + if not is_domain_or_subdomain(dhost, ihost): |
| 182 | + for name in HTTP_AUTH_HEADERS: |
| 183 | + request.headers.pop(name, None) |
| 184 | + request.url = location |
| 185 | + response = send(request) |
| 186 | + return response |
| 187 | + |
| 188 | + return patched_send |
178 | 189 |
|
179 | 190 | if redirect: |
180 | | - client._endpoint._send_request = patch_send_request( |
181 | | - client._endpoint._send_request |
182 | | - ) |
| 191 | + client._endpoint._send = patch_send(client._endpoint._send) |
183 | 192 |
|
184 | 193 | return client |
185 | 194 |
|
|
0 commit comments