diff --git a/pyiceberg/catalog/rest/__init__.py b/pyiceberg/catalog/rest/__init__.py index 18d875ea65..f2e1989613 100644 --- a/pyiceberg/catalog/rest/__init__.py +++ b/pyiceberg/catalog/rest/__init__.py @@ -15,21 +15,18 @@ # specific language governing permissions and limitations # under the License. from enum import Enum -from json import JSONDecodeError from typing import ( TYPE_CHECKING, Any, Dict, List, - Literal, Optional, Set, Tuple, - Type, Union, ) -from pydantic import Field, ValidationError, field_validator +from pydantic import Field, field_validator from requests import HTTPError, Session from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt @@ -41,22 +38,18 @@ Catalog, PropertiesUpdateSummary, ) +from pyiceberg.catalog.rest.auth import AuthManager, AuthManagerAdapter, AuthManagerFactory, LegacyOAuth2AuthManager +from pyiceberg.catalog.rest.response import _handle_non_200_response from pyiceberg.exceptions import ( AuthorizationExpiredError, - BadRequestError, CommitFailedException, CommitStateUnknownException, - ForbiddenError, NamespaceAlreadyExistsError, NamespaceNotEmptyError, NoSuchIdentifierError, NoSuchNamespaceError, NoSuchTableError, NoSuchViewError, - OAuthError, - RESTError, - ServerError, - ServiceUnavailableError, TableAlreadyExistsError, UnauthorizedError, ) @@ -181,15 +174,6 @@ class RegisterTableRequest(IcebergBaseModel): metadata_location: str = Field(..., alias="metadata-location") -class TokenResponse(IcebergBaseModel): - access_token: str = Field() - token_type: str = Field() - expires_in: Optional[int] = Field(default=None) - issued_token_type: Optional[str] = Field(default=None) - refresh_token: Optional[str] = Field(default=None) - scope: Optional[str] = Field(default=None) - - class ConfigResponse(IcebergBaseModel): defaults: Properties = Field() overrides: Properties = Field() @@ -228,24 +212,6 @@ class ListViewsResponse(IcebergBaseModel): identifiers: List[ListViewResponseEntry] = Field() -class ErrorResponseMessage(IcebergBaseModel): - message: str = Field() - type: str = Field() - code: int = Field() - - -class ErrorResponse(IcebergBaseModel): - error: ErrorResponseMessage = Field() - - -class OAuthErrorResponse(IcebergBaseModel): - error: Literal[ - "invalid_request", "invalid_client", "invalid_grant", "unauthorized_client", "unsupported_grant_type", "invalid_scope" - ] - error_description: Optional[str] = None - error_uri: Optional[str] = None - - class RestCatalog(Catalog): uri: str _session: Session @@ -278,8 +244,7 @@ def _create_session(self) -> Session: elif ssl_client_cert := ssl_client.get(CERT): session.cert = ssl_client_cert - self._refresh_token(session, self.properties.get(TOKEN)) - + session.auth = AuthManagerAdapter(self._create_legacy_oauth2_auth_manager(session)) # Set HTTP headers self._config_headers(session) @@ -289,6 +254,26 @@ def _create_session(self) -> Session: return session + def _create_legacy_oauth2_auth_manager(self, session: Session) -> AuthManager: + """Create the LegacyOAuth2AuthManager by fetching required properties. + + This will be removed in PyIceberg 1.0 + """ + client_credentials = self.properties.get(CREDENTIAL) + # We want to call `self.auth_url` only when we are using CREDENTIAL + # with the legacy OAUTH2 flow as it will raise a DeprecationWarning + auth_url = self.auth_url if client_credentials is not None else None + + auth_config = { + "session": session, + "auth_url": auth_url, + "credential": client_credentials, + "initial_token": self.properties.get(TOKEN), + "optional_oauth_params": self._extract_optional_oauth_params(), + } + + return AuthManagerFactory.create("legacyoauth2", auth_config) + def _check_valid_namespace_identifier(self, identifier: Union[str, Identifier]) -> Identifier: """Check if the identifier has at least one element.""" identifier_tuple = Catalog.identifier_to_tuple(identifier) @@ -351,27 +336,6 @@ def _extract_optional_oauth_params(self) -> Dict[str, str]: return optional_oauth_param - def _fetch_access_token(self, session: Session, credential: str) -> str: - if SEMICOLON in credential: - client_id, client_secret = credential.split(SEMICOLON) - else: - client_id, client_secret = None, credential - - data = {GRANT_TYPE: CLIENT_CREDENTIALS, CLIENT_ID: client_id, CLIENT_SECRET: client_secret} - - optional_oauth_params = self._extract_optional_oauth_params() - data.update(optional_oauth_params) - - response = session.post( - url=self.auth_url, data=data, headers={**session.headers, "Content-type": "application/x-www-form-urlencoded"} - ) - try: - response.raise_for_status() - except HTTPError as exc: - self._handle_non_200_response(exc, {400: OAuthError, 401: OAuthError}) - - return TokenResponse.model_validate_json(response.text).access_token - def _fetch_config(self) -> None: params = {} if warehouse_location := self.properties.get(WAREHOUSE_LOCATION): @@ -382,7 +346,7 @@ def _fetch_config(self) -> None: try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {}) + _handle_non_200_response(exc, {}) config_response = ConfigResponse.model_validate_json(response.text) config = config_response.defaults @@ -412,58 +376,6 @@ def _split_identifier_for_json(self, identifier: Union[str, Identifier]) -> Dict identifier_tuple = self._identifier_to_validated_tuple(identifier) return {"namespace": identifier_tuple[:-1], "name": identifier_tuple[-1]} - def _handle_non_200_response(self, exc: HTTPError, error_handler: Dict[int, Type[Exception]]) -> None: - exception: Type[Exception] - - if exc.response is None: - raise ValueError("Did not receive a response") - - code = exc.response.status_code - if code in error_handler: - exception = error_handler[code] - elif code == 400: - exception = BadRequestError - elif code == 401: - exception = UnauthorizedError - elif code == 403: - exception = ForbiddenError - elif code == 422: - exception = RESTError - elif code == 419: - exception = AuthorizationExpiredError - elif code == 501: - exception = NotImplementedError - elif code == 503: - exception = ServiceUnavailableError - elif 500 <= code < 600: - exception = ServerError - else: - exception = RESTError - - try: - if exception == OAuthError: - # The OAuthErrorResponse has a different format - error = OAuthErrorResponse.model_validate_json(exc.response.text) - response = str(error.error) - if description := error.error_description: - response += f": {description}" - if uri := error.error_uri: - response += f" ({uri})" - else: - error = ErrorResponse.model_validate_json(exc.response.text).error - response = f"{error.type}: {error.message}" - except JSONDecodeError: - # In the case we don't have a proper response - response = f"RESTError {exc.response.status_code}: Could not decode json payload: {exc.response.text}" - except ValidationError as e: - # In the case we don't have a proper response - errs = ", ".join(err["msg"] for err in e.errors()) - response = ( - f"RESTError {exc.response.status_code}: Received unexpected JSON Payload: {exc.response.text}, errors: {errs}" - ) - - raise exception(response) from exc - def _init_sigv4(self, session: Session) -> None: from urllib import parse @@ -533,16 +445,13 @@ def _response_to_staged_table(self, identifier_tuple: Tuple[str, ...], table_res catalog=self, ) - def _refresh_token(self, session: Optional[Session] = None, initial_token: Optional[str] = None) -> None: - session = session or self._session - if initial_token is not None: - self.properties[TOKEN] = initial_token - elif CREDENTIAL in self.properties: - self.properties[TOKEN] = self._fetch_access_token(session, self.properties[CREDENTIAL]) - - # Set Auth token for subsequent calls in the session - if token := self.properties.get(TOKEN): - session.headers[AUTHORIZATION_HEADER] = f"{BEARER_PREFIX} {token}" + def _refresh_token(self) -> None: + # Reactive token refresh is atypical - we should proactively refresh tokens in a separate thread + # instead of retrying on Auth Exceptions. Keeping refresh behavior for the LegacyOAuth2AuthManager + # for backward compatibility + auth_manager = self._session.auth.auth_manager # type: ignore[union-attr] + if isinstance(auth_manager, LegacyOAuth2AuthManager): + auth_manager._refresh_token() def _config_headers(self, session: Session) -> None: header_properties = get_header_properties(self.properties) @@ -587,7 +496,7 @@ def _create_table( try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {409: TableAlreadyExistsError}) + _handle_non_200_response(exc, {409: TableAlreadyExistsError}) return TableResponse.model_validate_json(response.text) @retry(**_RETRY_ARGS) @@ -660,7 +569,7 @@ def register_table(self, identifier: Union[str, Identifier], metadata_location: try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {409: TableAlreadyExistsError}) + _handle_non_200_response(exc, {409: TableAlreadyExistsError}) table_response = TableResponse.model_validate_json(response.text) return self._response_to_table(self.identifier_to_tuple(identifier), table_response) @@ -673,7 +582,7 @@ def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]: try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {404: NoSuchNamespaceError}) + _handle_non_200_response(exc, {404: NoSuchNamespaceError}) return [(*table.namespace, table.name) for table in ListTablesResponse.model_validate_json(response.text).identifiers] @retry(**_RETRY_ARGS) @@ -682,7 +591,7 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table: try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {404: NoSuchTableError}) + _handle_non_200_response(exc, {404: NoSuchTableError}) table_response = TableResponse.model_validate_json(response.text) return self._response_to_table(self.identifier_to_tuple(identifier), table_response) @@ -695,7 +604,7 @@ def drop_table(self, identifier: Union[str, Identifier], purge_requested: bool = try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {404: NoSuchTableError}) + _handle_non_200_response(exc, {404: NoSuchTableError}) @retry(**_RETRY_ARGS) def purge_table(self, identifier: Union[str, Identifier]) -> None: @@ -711,7 +620,7 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {404: NoSuchTableError, 409: TableAlreadyExistsError}) + _handle_non_200_response(exc, {404: NoSuchTableError, 409: TableAlreadyExistsError}) return self.load_table(to_identifier) @@ -734,7 +643,7 @@ def list_views(self, namespace: Union[str, Identifier]) -> List[Identifier]: try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {404: NoSuchNamespaceError}) + _handle_non_200_response(exc, {404: NoSuchNamespaceError}) return [(*view.namespace, view.name) for view in ListViewsResponse.model_validate_json(response.text).identifiers] @retry(**_RETRY_ARGS) @@ -772,7 +681,7 @@ def commit_table( try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response( + _handle_non_200_response( exc, { 409: CommitFailedException, @@ -791,7 +700,7 @@ def create_namespace(self, namespace: Union[str, Identifier], properties: Proper try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {409: NamespaceAlreadyExistsError}) + _handle_non_200_response(exc, {409: NamespaceAlreadyExistsError}) @retry(**_RETRY_ARGS) def drop_namespace(self, namespace: Union[str, Identifier]) -> None: @@ -801,7 +710,7 @@ def drop_namespace(self, namespace: Union[str, Identifier]) -> None: try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {404: NoSuchNamespaceError, 409: NamespaceNotEmptyError}) + _handle_non_200_response(exc, {404: NoSuchNamespaceError, 409: NamespaceNotEmptyError}) @retry(**_RETRY_ARGS) def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]: @@ -816,7 +725,7 @@ def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identi try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {404: NoSuchNamespaceError}) + _handle_non_200_response(exc, {404: NoSuchNamespaceError}) return ListNamespaceResponse.model_validate_json(response.text).namespaces @@ -828,7 +737,7 @@ def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Proper try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {404: NoSuchNamespaceError}) + _handle_non_200_response(exc, {404: NoSuchNamespaceError}) return NamespaceResponse.model_validate_json(response.text).properties @@ -843,7 +752,7 @@ def update_namespace_properties( try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {404: NoSuchNamespaceError}) + _handle_non_200_response(exc, {404: NoSuchNamespaceError}) parsed_response = UpdateNamespacePropertiesResponse.model_validate_json(response.text) return PropertiesUpdateSummary( removed=parsed_response.removed, @@ -865,7 +774,7 @@ def namespace_exists(self, namespace: Union[str, Identifier]) -> bool: try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {}) + _handle_non_200_response(exc, {}) return False @@ -891,7 +800,7 @@ def table_exists(self, identifier: Union[str, Identifier]) -> bool: try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {}) + _handle_non_200_response(exc, {}) return False @@ -916,7 +825,7 @@ def view_exists(self, identifier: Union[str, Identifier]) -> bool: try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {}) + _handle_non_200_response(exc, {}) return False @@ -928,4 +837,4 @@ def drop_view(self, identifier: Union[str]) -> None: try: response.raise_for_status() except HTTPError as exc: - self._handle_non_200_response(exc, {404: NoSuchViewError}) + _handle_non_200_response(exc, {404: NoSuchViewError}) diff --git a/pyiceberg/catalog/rest/auth.py b/pyiceberg/catalog/rest/auth.py index 041a8a4cd1..89395f1158 100644 --- a/pyiceberg/catalog/rest/auth.py +++ b/pyiceberg/catalog/rest/auth.py @@ -16,12 +16,18 @@ # under the License. import base64 +import importlib from abc import ABC, abstractmethod -from typing import Optional +from typing import Any, Dict, Optional, Type -from requests import PreparedRequest +from requests import HTTPError, PreparedRequest, Session from requests.auth import AuthBase +from pyiceberg.catalog.rest.response import TokenResponse, _handle_non_200_response +from pyiceberg.exceptions import OAuthError + +COLON = ":" + class AuthManager(ABC): """ @@ -49,6 +55,60 @@ def auth_header(self) -> str: return f"Basic {self._token}" +class LegacyOAuth2AuthManager(AuthManager): + _session: Session + _auth_url: Optional[str] + _token: Optional[str] + _credential: Optional[str] + _optional_oauth_params: Optional[Dict[str, str]] + + def __init__( + self, + session: Session, + auth_url: Optional[str] = None, + credential: Optional[str] = None, + initial_token: Optional[str] = None, + optional_oauth_params: Optional[Dict[str, str]] = None, + ): + self._session = session + self._auth_url = auth_url + self._token = initial_token + self._credential = credential + self._optional_oauth_params = optional_oauth_params + self._refresh_token() + + def _fetch_access_token(self, credential: str) -> str: + if COLON in credential: + client_id, client_secret = credential.split(COLON) + else: + client_id, client_secret = None, credential + + data = {"grant_type": "client_credentials", "client_id": client_id, "client_secret": client_secret} + + if self._optional_oauth_params: + data.update(self._optional_oauth_params) + + if self._auth_url is None: + raise ValueError("Cannot fetch access token from undefined auth_url") + + response = self._session.post( + url=self._auth_url, data=data, headers={**self._session.headers, "Content-type": "application/x-www-form-urlencoded"} + ) + try: + response.raise_for_status() + except HTTPError as exc: + _handle_non_200_response(exc, {400: OAuthError, 401: OAuthError}) + + return TokenResponse.model_validate_json(response.text).access_token + + def _refresh_token(self) -> None: + if self._credential is not None: + self._token = self._fetch_access_token(self._credential) + + def auth_header(self) -> str: + return f"Bearer {self._token}" + + class AuthManagerAdapter(AuthBase): """A `requests.auth.AuthBase` adapter that integrates an `AuthManager` into a `requests.Session` to automatically attach the appropriate Authorization header to every request. @@ -80,3 +140,50 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest: if auth_header := self.auth_manager.auth_header(): request.headers["Authorization"] = auth_header return request + + +class AuthManagerFactory: + _registry: Dict[str, Type["AuthManager"]] = {} + + @classmethod + def register(cls, name: str, auth_manager_class: Type["AuthManager"]) -> None: + """ + Register a string name to a known AuthManager class. + + Args: + name (str): unique name like 'oauth2' to register the AuthManager with + auth_manager_class (Type["AuthManager"]): Implementation of AuthManager + + Returns: + None + """ + cls._registry[name] = auth_manager_class + + @classmethod + def create(cls, class_or_name: str, config: Dict[str, Any]) -> AuthManager: + """ + Create an AuthManager by name or fully-qualified class path. + + Args: + class_or_name (str): Either a name like 'oauth2' or a full class path like 'my.module.CustomAuthManager' + config (Dict[str, Any]): Configuration passed to the AuthManager constructor + + Returns: + AuthManager: An instantiated AuthManager subclass + """ + if class_or_name in cls._registry: + manager_cls = cls._registry[class_or_name] + else: + try: + module_path, class_name = class_or_name.rsplit(".", 1) + module = importlib.import_module(module_path) + manager_cls = getattr(module, class_name) + except Exception as err: + raise ValueError(f"Could not load AuthManager class for '{class_or_name}'") from err + + return manager_cls(**config) + + +AuthManagerFactory.register("noop", NoopAuthManager) +AuthManagerFactory.register("basic", BasicAuthManager) +AuthManagerFactory.register("legacyoauth2", LegacyOAuth2AuthManager) diff --git a/pyiceberg/catalog/rest/response.py b/pyiceberg/catalog/rest/response.py new file mode 100644 index 0000000000..8f23af8c35 --- /dev/null +++ b/pyiceberg/catalog/rest/response.py @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from json import JSONDecodeError +from typing import Dict, Literal, Optional, Type + +from pydantic import Field, ValidationError +from requests import HTTPError + +from pyiceberg.exceptions import ( + AuthorizationExpiredError, + BadRequestError, + ForbiddenError, + OAuthError, + RESTError, + ServerError, + ServiceUnavailableError, + UnauthorizedError, +) +from pyiceberg.typedef import IcebergBaseModel + + +class TokenResponse(IcebergBaseModel): + access_token: str = Field() + token_type: str = Field() + expires_in: Optional[int] = Field(default=None) + issued_token_type: Optional[str] = Field(default=None) + refresh_token: Optional[str] = Field(default=None) + scope: Optional[str] = Field(default=None) + + +class ErrorResponseMessage(IcebergBaseModel): + message: str = Field() + type: str = Field() + code: int = Field() + + +class ErrorResponse(IcebergBaseModel): + error: ErrorResponseMessage = Field() + + +class OAuthErrorResponse(IcebergBaseModel): + error: Literal[ + "invalid_request", "invalid_client", "invalid_grant", "unauthorized_client", "unsupported_grant_type", "invalid_scope" + ] + error_description: Optional[str] = None + error_uri: Optional[str] = None + + +def _handle_non_200_response(exc: HTTPError, error_handler: Dict[int, Type[Exception]]) -> None: + exception: Type[Exception] + + if exc.response is None: + raise ValueError("Did not receive a response") + + code = exc.response.status_code + if code in error_handler: + exception = error_handler[code] + elif code == 400: + exception = BadRequestError + elif code == 401: + exception = UnauthorizedError + elif code == 403: + exception = ForbiddenError + elif code == 422: + exception = RESTError + elif code == 419: + exception = AuthorizationExpiredError + elif code == 501: + exception = NotImplementedError + elif code == 503: + exception = ServiceUnavailableError + elif 500 <= code < 600: + exception = ServerError + else: + exception = RESTError + + try: + if exception == OAuthError: + # The OAuthErrorResponse has a different format + error = OAuthErrorResponse.model_validate_json(exc.response.text) + response = str(error.error) + if description := error.error_description: + response += f": {description}" + if uri := error.error_uri: + response += f" ({uri})" + else: + error = ErrorResponse.model_validate_json(exc.response.text).error + response = f"{error.type}: {error.message}" + except JSONDecodeError: + # In the case we don't have a proper response + response = f"RESTError {exc.response.status_code}: Could not decode json payload: {exc.response.text}" + except ValidationError as e: + # In the case we don't have a proper response + errs = ", ".join(err["msg"] for err in e.errors()) + response = f"RESTError {exc.response.status_code}: Received unexpected JSON Payload: {exc.response.text}, errors: {errs}" + + raise exception(response) from exc diff --git a/tests/catalog/test_rest.py b/tests/catalog/test_rest.py index b9c88d2fc4..9682e6afbf 100644 --- a/tests/catalog/test_rest.py +++ b/tests/catalog/test_rest.py @@ -620,6 +620,10 @@ def test_list_namespaces_token_expired_success_on_retries(rest_mock: Mocker, sta status_code=200, ) catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN, credential=TEST_CREDENTIALS) + # LegacyOAuth2AuthManager is created twice through `_create_session()` + # which results in the token being refreshed twice when the RestCatalog is initialized. + assert tokens.call_count == 2 + assert catalog.list_namespaces() == [ ("default",), ("examples",), @@ -627,7 +631,7 @@ def test_list_namespaces_token_expired_success_on_retries(rest_mock: Mocker, sta ("system",), ] assert namespaces.call_count == 2 - assert tokens.call_count == 1 + assert tokens.call_count == 3 assert catalog.list_namespaces() == [ ("default",), @@ -636,7 +640,7 @@ def test_list_namespaces_token_expired_success_on_retries(rest_mock: Mocker, sta ("system",), ] assert namespaces.call_count == 3 - assert tokens.call_count == 1 + assert tokens.call_count == 3 def test_create_namespace_200(rest_mock: Mocker) -> None: