diff --git a/mkdocs/docs/configuration.md b/mkdocs/docs/configuration.md index 5afd6d805b..93b198c328 100644 --- a/mkdocs/docs/configuration.md +++ b/mkdocs/docs/configuration.md @@ -170,6 +170,8 @@ catalog: | credential | t-1234:secret | Credential to use for OAuth2 credential flow when initializing the catalog | | token | FEW23.DFSDF.FSDF | Bearer token value to use for `Authorization` header | | scope | openid offline corpds:ds:profile | Desired scope of the requested security token (default : catalog) | +| resource | rest_catalog.iceberg.com | URI for the target resource or service | +| audience | rest_catalog | Logical name of target resource or service | | rest.sigv4-enabled | true | Sign requests to the REST Server using AWS SigV4 protocol | | rest.signing-region | us-east-1 | The region to use when SigV4 signing a request | | rest.signing-name | execute-api | The service signing name to use when SigV4 signing a request | diff --git a/pyiceberg/catalog/rest.py b/pyiceberg/catalog/rest.py index e7f0ddd899..79fc37a398 100644 --- a/pyiceberg/catalog/rest.py +++ b/pyiceberg/catalog/rest.py @@ -105,6 +105,8 @@ class Endpoints: CREDENTIAL = "credential" GRANT_TYPE = "grant_type" SCOPE = "scope" +AUDIENCE = "audience" +RESOURCE = "resource" TOKEN_EXCHANGE = "urn:ietf:params:oauth:grant-type:token-exchange" SEMICOLON = ":" KEY = "key" @@ -289,16 +291,26 @@ def auth_url(self) -> str: else: return self.url(Endpoints.get_token, prefixed=False) + def _extract_optional_oauth_params(self) -> Dict[str, str]: + optional_oauth_param = {SCOPE: self.properties.get(SCOPE) or CATALOG_SCOPE} + set_of_optional_params = {AUDIENCE, RESOURCE} + for param in set_of_optional_params: + if param_value := self.properties.get(param): + optional_oauth_param[param] = param_value + + return optional_oauth_param + def _fetch_access_token(self, session: Session, credential: str) -> str: if SEMICOLON in credential: client_id, client_secret = credential.split(SEMICOLON) else: client_id, client_secret = None, credential - # take scope from properties or use default CATALOG_SCOPE - scope = self.properties.get(SCOPE) or CATALOG_SCOPE + data = {GRANT_TYPE: CLIENT_CREDENTIALS, CLIENT_ID: client_id, CLIENT_SECRET: client_secret} + + optional_oauth_params = self._extract_optional_oauth_params() + data.update(optional_oauth_params) - data = {GRANT_TYPE: CLIENT_CREDENTIALS, CLIENT_ID: client_id, CLIENT_SECRET: client_secret, SCOPE: scope} response = session.post( url=self.auth_url, data=data, headers={**session.headers, "Content-type": "application/x-www-form-urlencoded"} ) diff --git a/tests/catalog/test_rest.py b/tests/catalog/test_rest.py index 20fdbfa4ea..51bc286267 100644 --- a/tests/catalog/test_rest.py +++ b/tests/catalog/test_rest.py @@ -47,6 +47,9 @@ TEST_AUTH_URL = "https://auth-endpoint/" TEST_TOKEN = "some_jwt_token" TEST_SCOPE = "openid_offline_corpds_ds_profile" +TEST_AUDIENCE = "test_audience" +TEST_RESOURCE = "test_resource" + TEST_HEADERS = { "Content-type": "application/json", "X-Client-Version": "0.14.1", @@ -137,6 +140,48 @@ def test_token_200_without_optional_fields(rest_mock: Mocker) -> None: ) +def test_token_with_optional_oauth_params(rest_mock: Mocker) -> None: + mock_request = rest_mock.post( + f"{TEST_URI}v1/oauth/tokens", + json={ + "access_token": TEST_TOKEN, + "token_type": "Bearer", + "expires_in": 86400, + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + }, + status_code=200, + request_headers=OAUTH_TEST_HEADERS, + ) + assert ( + RestCatalog( + "rest", uri=TEST_URI, credential=TEST_CREDENTIALS, audience=TEST_AUDIENCE, resource=TEST_RESOURCE + )._session.headers["Authorization"] + == f"Bearer {TEST_TOKEN}" + ) + assert TEST_AUDIENCE in mock_request.last_request.text + assert TEST_RESOURCE in mock_request.last_request.text + + +def test_token_with_optional_oauth_params_as_empty(rest_mock: Mocker) -> None: + mock_request = rest_mock.post( + f"{TEST_URI}v1/oauth/tokens", + json={ + "access_token": TEST_TOKEN, + "token_type": "Bearer", + "expires_in": 86400, + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + }, + status_code=200, + request_headers=OAUTH_TEST_HEADERS, + ) + assert ( + RestCatalog("rest", uri=TEST_URI, credential=TEST_CREDENTIALS, audience="", resource="")._session.headers["Authorization"] + == f"Bearer {TEST_TOKEN}" + ) + assert TEST_AUDIENCE not in mock_request.last_request.text + assert TEST_RESOURCE not in mock_request.last_request.text + + def test_token_with_default_scope(rest_mock: Mocker) -> None: mock_request = rest_mock.post( f"{TEST_URI}v1/oauth/tokens",