Skip to content

Commit f542b6f

Browse files
bbtfrliyang
andauthored
support s3 proxy redirect (#426)
Co-authored-by: liyang <[email protected]>
1 parent 29120f3 commit f542b6f

File tree

2 files changed

+81
-30
lines changed

2 files changed

+81
-30
lines changed

megfile/s3_path.py

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@
122122
max_keys = 1000
123123

124124

125-
def _patch_make_request(client: botocore.client.BaseClient):
125+
def _patch_make_request(client: botocore.client.BaseClient, redirect: bool = False):
126126
def after_callback(result: Tuple[AWSResponse, dict], *args, **kwargs):
127127
if (
128128
not isinstance(result, tuple)
@@ -161,6 +161,26 @@ def before_callback(operation_model, request_dict, request_context):
161161
before_callback=before_callback,
162162
retry_callback=retry_callback,
163163
)
164+
165+
def patch_send_request(send_request):
166+
def patched_send_request(request_dict, operation_model):
167+
http, parsed_response = send_request(request_dict, operation_model)
168+
if (
169+
request_dict["method"] == "GET" # only support GET method for now
170+
and http.status_code in (301, 302, 307, 308)
171+
and "Location" in http.headers
172+
):
173+
request_dict["url"] = http.headers["Location"]
174+
http, parsed_response = send_request(request_dict, operation_model)
175+
return http, parsed_response
176+
177+
return patched_send_request
178+
179+
if redirect:
180+
client._endpoint._send_request = patch_send_request(
181+
client._endpoint._send_request
182+
)
183+
164184
return client
165185

166186

@@ -180,7 +200,11 @@ def parse_s3_url(s3_url: PathLike) -> Tuple[str, str]:
180200

181201

182202
def get_scoped_config(profile_name: Optional[str] = None) -> Dict:
183-
return get_s3_session(profile_name=profile_name)._session.get_scoped_config()
203+
try:
204+
session = get_s3_session(profile_name=profile_name)
205+
except botocore.exceptions.ProfileNotFound:
206+
session = get_s3_session()
207+
return session._session.get_scoped_config()
184208

185209

186210
@lru_cache()
@@ -224,25 +248,22 @@ def get_s3_session(profile_name=None) -> boto3.Session:
224248
)
225249

226250

251+
def get_env_var(env_name: str, profile_name=None):
252+
if profile_name:
253+
return os.getenv(f"{profile_name}__{env_name}".upper())
254+
return os.getenv(env_name.upper())
255+
256+
257+
def parse_boolean(value: Optional[str], default: bool = False) -> bool:
258+
if value is None:
259+
return default
260+
return value.lower() in ("true", "yes", "1")
261+
262+
227263
def get_access_token(profile_name=None):
228-
access_key_env_name = (
229-
f"{profile_name}__AWS_ACCESS_KEY_ID".upper()
230-
if profile_name
231-
else "AWS_ACCESS_KEY_ID"
232-
)
233-
secret_key_env_name = (
234-
f"{profile_name}__AWS_SECRET_ACCESS_KEY".upper()
235-
if profile_name
236-
else "AWS_SECRET_ACCESS_KEY"
237-
)
238-
session_token_env_name = (
239-
f"{profile_name}__AWS_SESSION_TOKEN".upper()
240-
if profile_name
241-
else "AWS_SESSION_TOKEN"
242-
)
243-
access_key = os.getenv(access_key_env_name)
244-
secret_key = os.getenv(secret_key_env_name)
245-
session_token = os.getenv(session_token_env_name)
264+
access_key = get_env_var("AWS_ACCESS_KEY_ID", profile_name=profile_name)
265+
secret_key = get_env_var("AWS_SECRET_ACCESS_KEY", profile_name=profile_name)
266+
session_token = get_env_var("AWS_SESSION_TOKEN", profile_name=profile_name)
246267
if access_key and secret_key:
247268
return access_key, secret_key, session_token
248269

@@ -289,10 +310,7 @@ def get_s3_client(
289310
connect_timeout=5, max_pool_connections=GLOBAL_MAX_WORKERS
290311
)
291312

292-
addressing_style_env_key = "AWS_S3_ADDRESSING_STYLE"
293-
if profile_name:
294-
addressing_style_env_key = f"{profile_name}__AWS_S3_ADDRESSING_STYLE".upper()
295-
addressing_style = os.environ.get(addressing_style_env_key)
313+
addressing_style = get_env_var("AWS_S3_ADDRESSING_STYLE", profile_name=profile_name)
296314
if addressing_style:
297315
config = config.merge(
298316
botocore.config.Config(s3={"addressing_style": addressing_style})
@@ -303,15 +321,25 @@ def get_s3_client(
303321
session = get_s3_session(profile_name=profile_name)
304322
except botocore.exceptions.ProfileNotFound:
305323
session = get_s3_session()
324+
325+
s3_config = get_scoped_config(profile_name=profile_name).get("s3", {})
326+
verify = get_env_var("AWS_S3_VERIFY", profile_name=profile_name)
327+
verify = verify or s3_config.get("verify")
328+
verify = parse_boolean(verify, default=True)
329+
redirect = get_env_var("AWS_S3_REDIRECT", profile_name=profile_name)
330+
redirect = redirect or s3_config.get("redirect")
331+
redirect = parse_boolean(redirect, default=False)
332+
306333
client = session.client(
307334
"s3",
308335
endpoint_url=get_endpoint_url(profile_name=profile_name),
336+
verify=verify,
309337
config=config,
310338
aws_access_key_id=access_key,
311339
aws_secret_access_key=secret_key,
312340
aws_session_token=session_token,
313341
)
314-
client = _patch_make_request(client)
342+
client = _patch_make_request(client, redirect=redirect)
315343
return client
316344

317345

tests/test_s3.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -331,18 +331,22 @@ def test_get_s3_client(mocker):
331331
mocker.patch("megfile.s3_path.get_scoped_config", return_value={})
332332
mocker.patch("megfile.s3_path.get_s3_session", return_value=mock_session)
333333

334-
s3.get_s3_client()
334+
client = s3.get_s3_client()
335335
access_key, secret_key, session_token = s3_path.get_access_token()
336336

337337
mock_session.client.assert_called_with(
338338
"s3",
339339
endpoint_url="https://s3.amazonaws.com",
340340
config=Any(),
341+
verify=True,
341342
aws_access_key_id=access_key,
342343
aws_secret_access_key=secret_key,
343344
aws_session_token=session_token,
344345
)
345346

347+
# assert _send_request is not patched
348+
assert "_send_request" not in client._endpoint.__dict__
349+
346350
client = s3.get_s3_client(cache_key="test")
347351
assert client is s3.get_s3_client(cache_key="test")
348352

@@ -354,6 +358,8 @@ def test_get_s3_client(mocker):
354358
"TEST__AWS_S3_ADDRESSING_STYLE": "auto",
355359
"AWS_ACCESS_KEY_ID": "test",
356360
"AWS_SECRET_ACCESS_KEY": "test",
361+
"AWS_S3_VERIFY": "false",
362+
"AWS_S3_REDIRECT": "true",
357363
},
358364
)
359365
def test_get_s3_client_v2():
@@ -377,48 +383,65 @@ def test_get_s3_client_v2():
377383
== "virtual"
378384
)
379385

386+
# assert _send_request is patched
387+
assert "_send_request" in client._endpoint.__dict__
388+
380389

381390
def test_get_s3_client_from_env(mocker):
382391
mock_session = mocker.Mock(spec=boto3.Session)
383392
mocker.patch("megfile.s3_path.get_scoped_config", return_value={})
384393
mocker.patch("megfile.s3_path.get_s3_session", return_value=mock_session)
385-
mocker.patch.dict(os.environ, {"OSS_ENDPOINT": "oss-endpoint"})
394+
mocker.patch.dict(
395+
os.environ,
396+
{"OSS_ENDPOINT": "oss-endpoint", "AWS_S3_VERIFY": "0", "AWS_S3_REDIRECT": "1"},
397+
)
386398

387-
s3.get_s3_client()
399+
client = s3.get_s3_client()
388400
access_key, secret_key, session_token = s3_path.get_access_token()
389401

390402
mock_session.client.assert_called_with(
391403
"s3",
392404
endpoint_url="oss-endpoint",
393405
config=Any(),
406+
verify=False,
394407
aws_access_key_id=access_key,
395408
aws_secret_access_key=secret_key,
396409
aws_session_token=session_token,
397410
)
398411

412+
# assert _send_request is patched
413+
assert "_send_request" in client._endpoint.__dict__
414+
399415

400416
def test_get_s3_client_with_config(mocker):
401417
mock_session = mocker.Mock(spec=boto3.Session)
402-
mocker.patch("megfile.s3_path.get_scoped_config", return_value={})
418+
mocker.patch(
419+
"megfile.s3_path.get_scoped_config",
420+
return_value={"s3": {"verify": "no", "redirect": "yes"}},
421+
)
403422
mocker.patch("megfile.s3_path.get_s3_session", return_value=mock_session)
404423

405424
class EQConfig(botocore.config.Config):
406425
def __eq__(self, other):
407426
return self._user_provided_options == other._user_provided_options
408427

409428
config = EQConfig(max_pool_connections=GLOBAL_MAX_WORKERS, connect_timeout=1)
410-
s3.get_s3_client(config)
429+
client = s3.get_s3_client(config)
411430
access_key, secret_key, session_token = s3_path.get_access_token()
412431

413432
mock_session.client.assert_called_with(
414433
"s3",
415434
endpoint_url="https://s3.amazonaws.com",
416435
config=config,
436+
verify=False,
417437
aws_access_key_id=access_key,
418438
aws_secret_access_key=secret_key,
419439
aws_session_token=session_token,
420440
)
421441

442+
# assert _send_request is patched
443+
assert "_send_request" in client._endpoint.__dict__
444+
422445

423446
def test_get_s3_session_threading(mocker):
424447
session_call = mocker.patch("boto3.Session")

0 commit comments

Comments
 (0)