From 90e50575f55a12193e2b4dafee0a3a4bf01fd78c Mon Sep 17 00:00:00 2001 From: Howie Wang Date: Wed, 3 Apr 2024 12:23:02 -0700 Subject: [PATCH] Disallow default header to be overwritten (#577) Co-authored-by: Hongyi Wang --- pyiceberg/catalog/rest.py | 4 ++-- tests/catalog/test_rest.py | 37 ++++++++++++++++++++++++++++--------- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/pyiceberg/catalog/rest.py b/pyiceberg/catalog/rest.py index 9f0d054493..81a9b09f87 100644 --- a/pyiceberg/catalog/rest.py +++ b/pyiceberg/catalog/rest.py @@ -480,12 +480,12 @@ def _refresh_token(self, session: Optional[Session] = None, initial_token: Optio session.headers[AUTHORIZATION_HEADER] = f"{BEARER_PREFIX} {token}" def _config_headers(self, session: Session) -> None: + header_properties = self._extract_headers_from_properties() + session.headers.update(header_properties) session.headers["Content-type"] = "application/json" session.headers["X-Client-Version"] = ICEBERG_REST_SPEC_VERSION session.headers["User-Agent"] = f"PyIceberg/{__version__}" session.headers["X-Iceberg-Access-Delegation"] = "vended-credentials" - header_properties = self._extract_headers_from_properties() - session.headers.update(header_properties) def _extract_headers_from_properties(self) -> Dict[str, str]: return {key[len(HEADER_PREFIX) :]: value for key, value in self.properties.items() if key.startswith(HEADER_PREFIX)} diff --git a/tests/catalog/test_rest.py b/tests/catalog/test_rest.py index 4956fffe6c..15ddb01b25 100644 --- a/tests/catalog/test_rest.py +++ b/tests/catalog/test_rest.py @@ -277,23 +277,35 @@ def test_properties_sets_headers(requests_mock: Mocker) -> None: ) catalog = RestCatalog( - "rest", uri=TEST_URI, warehouse="s3://some-bucket", **{"header.Content-Type": "application/vnd.api+json"} + "rest", + uri=TEST_URI, + warehouse="s3://some-bucket", + **{"header.Content-Type": "application/vnd.api+json", "header.Customized-Header": "some/value"}, ) assert ( - catalog._session.headers.get("Content-type") == "application/vnd.api+json" - ), "Expected 'Content-Type' header to be 'application/vnd.api+json'" - + catalog._session.headers.get("Content-type") == "application/json" + ), "Expected 'Content-Type' default header not to be overwritten" assert ( - requests_mock.last_request.headers["Content-type"] == "application/vnd.api+json" + requests_mock.last_request.headers["Content-type"] == "application/json" ), "Config request did not include expected 'Content-Type' header" + assert ( + catalog._session.headers.get("Customized-Header") == "some/value" + ), "Expected 'Customized-Header' header to be 'some/value'" + assert ( + requests_mock.last_request.headers["Customized-Header"] == "some/value" + ), "Config request did not include expected 'Customized-Header' header" + def test_config_sets_headers(requests_mock: Mocker) -> None: namespace = "leden" requests_mock.get( f"{TEST_URI}v1/config", - json={"defaults": {"header.Content-Type": "application/vnd.api+json"}, "overrides": {}}, + json={ + "defaults": {"header.Content-Type": "application/vnd.api+json", "header.Customized-Header": "some/value"}, + "overrides": {}, + }, status_code=200, ) requests_mock.post(f"{TEST_URI}v1/namespaces", json={"namespace": [namespace], "properties": {}}, status_code=200) @@ -301,12 +313,19 @@ def test_config_sets_headers(requests_mock: Mocker) -> None: catalog.create_namespace(namespace) assert ( - catalog._session.headers.get("Content-type") == "application/vnd.api+json" - ), "Expected 'Content-Type' header to be 'application/vnd.api+json'" + catalog._session.headers.get("Content-type") == "application/json" + ), "Expected 'Content-Type' default header not to be overwritten" assert ( - requests_mock.last_request.headers["Content-type"] == "application/vnd.api+json" + requests_mock.last_request.headers["Content-type"] == "application/json" ), "Create namespace request did not include expected 'Content-Type' header" + assert ( + catalog._session.headers.get("Customized-Header") == "some/value" + ), "Expected 'Customized-Header' header to be 'some/value'" + assert ( + requests_mock.last_request.headers["Customized-Header"] == "some/value" + ), "Create namespace request did not include expected 'Customized-Header' header" + def test_token_400(rest_mock: Mocker) -> None: rest_mock.post(