Skip to content

Commit f132dfa

Browse files
bbtfrliyang
andauthored
fix: redirect fix auth headers issue (#428)
Co-authored-by: liyang <[email protected]>
1 parent b237cc5 commit f132dfa

File tree

6 files changed

+57
-24
lines changed

6 files changed

+57
-24
lines changed

megfile/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,5 @@
6060
SFTP_MAX_RETRY_TIMES = int(
6161
os.getenv("MEGFILE_SFTP_MAX_RETRY_TIMES") or DEFAULT_MAX_RETRY_TIMES
6262
)
63+
64+
HTTP_AUTH_HEADERS = ("Authorization", "Www-Authenticate", "Cookie", "Cookie2")

megfile/lib/s3_prefetch_reader.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,11 @@ def _get_content_size(self):
7676
try:
7777
start, end = 0, self._block_size - 1
7878
first_index_response = self._fetch_response(start=start, end=end)
79-
content_size = int(first_index_response["ContentRange"].split("/")[-1])
79+
if "ContentRange" in first_index_response:
80+
content_size = int(first_index_response["ContentRange"].split("/")[-1])
81+
else:
82+
# usually when read a file only have one block
83+
content_size = int(first_index_response["ContentLength"])
8084
except S3InvalidRangeError:
8185
# usually when read a empty file
8286
# can use minio test empty file: https://hub.docker.com/r/minio/minio

megfile/s3_path.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,18 @@
77
from functools import cached_property, lru_cache, wraps
88
from logging import getLogger as get_logger
99
from typing import IO, Any, BinaryIO, Callable, Dict, Iterator, List, Optional, Tuple
10+
from urllib.parse import urlparse
1011

1112
import boto3
1213
import botocore
13-
from botocore.awsrequest import AWSResponse
14+
from botocore.awsrequest import AWSPreparedRequest, AWSResponse
1415

1516
from megfile.config import (
1617
DEFAULT_BLOCK_SIZE,
1718
DEFAULT_MAX_BLOCK_SIZE,
1819
DEFAULT_MIN_BLOCK_SIZE,
1920
GLOBAL_MAX_WORKERS,
21+
HTTP_AUTH_HEADERS,
2022
S3_CLIENT_CACHE_MODE,
2123
S3_MAX_RETRY_TIMES,
2224
)
@@ -76,6 +78,7 @@
7678
generate_cache_path,
7779
get_binary_mode,
7880
get_content_offset,
81+
is_domain_or_subdomain,
7982
is_readable,
8083
necessary_params,
8184
process_local,
@@ -162,24 +165,30 @@ def before_callback(operation_model, request_dict, request_context):
162165
retry_callback=retry_callback,
163166
)
164167

165-
def patch_send_request(send_request):
166-
def patched_send_request(request_dict, operation_model):
167-
http, parsed_response = send_request(request_dict, operation_model)
168+
def patch_send(send):
169+
def patched_send(request: AWSPreparedRequest) -> AWSResponse:
170+
response: AWSResponse = send(request)
168171
if (
169-
request_dict["method"] == "GET" # only support GET method for now
170-
and http.status_code in (301, 302, 307, 308)
171-
and "Location" in http.headers
172+
request.method == "GET" # only support GET method for now
173+
and response.status_code in (301, 302, 307, 308)
174+
and "Location" in response.headers
172175
):
173-
request_dict["url"] = http.headers["Location"]
174-
http, parsed_response = send_request(request_dict, operation_model)
175-
return http, parsed_response
176-
177-
return patched_send_request
176+
# Permit sending auth/cookie headers from "foo.com" to "sub.foo.com".
177+
# See also: https://go.dev/src/net/http/client.go#L980
178+
location = response.headers["Location"]
179+
ihost = urlparse(request.url).hostname
180+
dhost = urlparse(location).hostname
181+
if not is_domain_or_subdomain(dhost, ihost):
182+
for name in HTTP_AUTH_HEADERS:
183+
request.headers.pop(name, None)
184+
request.url = location
185+
response = send(request)
186+
return response
187+
188+
return patched_send
178189

179190
if redirect:
180-
client._endpoint._send_request = patch_send_request(
181-
client._endpoint._send_request
182-
)
191+
client._endpoint._send = patch_send(client._endpoint._send)
183192

184193
return client
185194

megfile/utils/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,3 +346,11 @@ def __get__( # pyre-ignore[14]
346346
val = self.func(cls)
347347
setattr(cls, self.attrname, val)
348348
return val
349+
350+
351+
def is_domain_or_subdomain(sub, parent):
352+
if sub == parent:
353+
return True
354+
if sub.endswith(f".{parent}"):
355+
return True
356+
return False

tests/test_s3.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,8 @@ def test_get_s3_client(mocker):
344344
aws_session_token=session_token,
345345
)
346346

347-
# assert _send_request is not patched
348-
assert "_send_request" not in client._endpoint.__dict__
347+
# assert _send is not patched
348+
assert "_send" not in client._endpoint.__dict__
349349

350350
client = s3.get_s3_client(cache_key="test")
351351
assert client is s3.get_s3_client(cache_key="test")
@@ -383,8 +383,8 @@ def test_get_s3_client_v2():
383383
== "virtual"
384384
)
385385

386-
# assert _send_request is patched
387-
assert "_send_request" in client._endpoint.__dict__
386+
# assert _send is patched
387+
assert "_send" in client._endpoint.__dict__
388388

389389

390390
def test_get_s3_client_from_env(mocker):
@@ -409,8 +409,8 @@ def test_get_s3_client_from_env(mocker):
409409
aws_session_token=session_token,
410410
)
411411

412-
# assert _send_request is patched
413-
assert "_send_request" in client._endpoint.__dict__
412+
# assert _send is patched
413+
assert "_send" in client._endpoint.__dict__
414414

415415

416416
def test_get_s3_client_with_config(mocker):
@@ -439,8 +439,8 @@ def __eq__(self, other):
439439
aws_session_token=session_token,
440440
)
441441

442-
# assert _send_request is patched
443-
assert "_send_request" in client._endpoint.__dict__
442+
# assert _send is patched
443+
assert "_send" in client._endpoint.__dict__
444444

445445

446446
def test_get_s3_session_threading(mocker):

tests/utils/test_init.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
cached_classproperty,
1010
combine,
1111
get_human_size,
12+
is_domain_or_subdomain,
1213
necessary_params,
1314
patch_rlimit,
1415
)
@@ -108,3 +109,12 @@ def test__is_pickle():
108109
fileObj.name = "test"
109110
fileObj.mode = "wb"
110111
assert _is_pickle(fileObj) is False
112+
113+
114+
def test_is_domain_or_subdomain():
115+
assert is_domain_or_subdomain("test1.com", "test2.com") is False
116+
assert is_domain_or_subdomain("test1.test.com", "test2.test.com") is False
117+
118+
assert is_domain_or_subdomain("test.com", "test.com") is True
119+
assert is_domain_or_subdomain("test1.test.com", "test.com") is True
120+
assert is_domain_or_subdomain("test.com", "test1.test.com") is False

0 commit comments

Comments
 (0)