Skip to content

Replace Deprecated (Current) OAuth2 Handling with AuthManager Implementation LegacyOAuth2AuthManager #1981

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 49 additions & 140 deletions pyiceberg/catalog/rest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,18 @@
# specific language governing permissions and limitations
# under the License.
from enum import Enum
from json import JSONDecodeError
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
Optional,
Set,
Tuple,
Type,
Union,
)

from pydantic import Field, ValidationError, field_validator
from pydantic import Field, field_validator
from requests import HTTPError, Session
from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt

Expand All @@ -41,22 +38,18 @@
Catalog,
PropertiesUpdateSummary,
)
from pyiceberg.catalog.rest.auth import AuthManager, AuthManagerAdapter, AuthManagerFactory, LegacyOAuth2AuthManager
from pyiceberg.catalog.rest.response import _handle_non_200_response
from pyiceberg.exceptions import (
AuthorizationExpiredError,
BadRequestError,
CommitFailedException,
CommitStateUnknownException,
ForbiddenError,
NamespaceAlreadyExistsError,
NamespaceNotEmptyError,
NoSuchIdentifierError,
NoSuchNamespaceError,
NoSuchTableError,
NoSuchViewError,
OAuthError,
RESTError,
ServerError,
ServiceUnavailableError,
TableAlreadyExistsError,
UnauthorizedError,
)
Expand Down Expand Up @@ -181,15 +174,6 @@ class RegisterTableRequest(IcebergBaseModel):
metadata_location: str = Field(..., alias="metadata-location")


class TokenResponse(IcebergBaseModel):
access_token: str = Field()
token_type: str = Field()
expires_in: Optional[int] = Field(default=None)
issued_token_type: Optional[str] = Field(default=None)
refresh_token: Optional[str] = Field(default=None)
scope: Optional[str] = Field(default=None)


class ConfigResponse(IcebergBaseModel):
defaults: Properties = Field()
overrides: Properties = Field()
Expand Down Expand Up @@ -228,24 +212,6 @@ class ListViewsResponse(IcebergBaseModel):
identifiers: List[ListViewResponseEntry] = Field()


class ErrorResponseMessage(IcebergBaseModel):
message: str = Field()
type: str = Field()
code: int = Field()


class ErrorResponse(IcebergBaseModel):
error: ErrorResponseMessage = Field()


class OAuthErrorResponse(IcebergBaseModel):
error: Literal[
"invalid_request", "invalid_client", "invalid_grant", "unauthorized_client", "unsupported_grant_type", "invalid_scope"
]
error_description: Optional[str] = None
error_uri: Optional[str] = None


class RestCatalog(Catalog):
uri: str
_session: Session
Expand Down Expand Up @@ -278,8 +244,7 @@ def _create_session(self) -> Session:
elif ssl_client_cert := ssl_client.get(CERT):
session.cert = ssl_client_cert

self._refresh_token(session, self.properties.get(TOKEN))

session.auth = AuthManagerAdapter(self._create_legacy_oauth2_auth_manager(session))
# Set HTTP headers
self._config_headers(session)

Expand All @@ -289,6 +254,26 @@ def _create_session(self) -> Session:

return session

def _create_legacy_oauth2_auth_manager(self, session: Session) -> AuthManager:
"""Create the LegacyOAuth2AuthManager by fetching required properties.

This will be removed in PyIceberg 1.0
"""
client_credentials = self.properties.get(CREDENTIAL)
# We want to call `self.auth_url` only when we are using CREDENTIAL
# with the legacy OAUTH2 flow as it will raise a DeprecationWarning
auth_url = self.auth_url if client_credentials is not None else None

auth_config = {
"session": session,
"auth_url": auth_url,
"credential": client_credentials,
"initial_token": self.properties.get(TOKEN),
"optional_oauth_params": self._extract_optional_oauth_params(),
}

return AuthManagerFactory.create("legacyoauth2", auth_config)

def _check_valid_namespace_identifier(self, identifier: Union[str, Identifier]) -> Identifier:
"""Check if the identifier has at least one element."""
identifier_tuple = Catalog.identifier_to_tuple(identifier)
Expand Down Expand Up @@ -351,27 +336,6 @@ def _extract_optional_oauth_params(self) -> Dict[str, str]:

return optional_oauth_param

def _fetch_access_token(self, session: Session, credential: str) -> str:
if SEMICOLON in credential:
client_id, client_secret = credential.split(SEMICOLON)
else:
client_id, client_secret = None, credential

data = {GRANT_TYPE: CLIENT_CREDENTIALS, CLIENT_ID: client_id, CLIENT_SECRET: client_secret}

optional_oauth_params = self._extract_optional_oauth_params()
data.update(optional_oauth_params)

response = session.post(
url=self.auth_url, data=data, headers={**session.headers, "Content-type": "application/x-www-form-urlencoded"}
)
try:
response.raise_for_status()
except HTTPError as exc:
self._handle_non_200_response(exc, {400: OAuthError, 401: OAuthError})

return TokenResponse.model_validate_json(response.text).access_token

def _fetch_config(self) -> None:
params = {}
if warehouse_location := self.properties.get(WAREHOUSE_LOCATION):
Expand All @@ -382,7 +346,7 @@ def _fetch_config(self) -> None:
try:
response.raise_for_status()
except HTTPError as exc:
self._handle_non_200_response(exc, {})
_handle_non_200_response(exc, {})
config_response = ConfigResponse.model_validate_json(response.text)

config = config_response.defaults
Expand Down Expand Up @@ -412,58 +376,6 @@ def _split_identifier_for_json(self, identifier: Union[str, Identifier]) -> Dict
identifier_tuple = self._identifier_to_validated_tuple(identifier)
return {"namespace": identifier_tuple[:-1], "name": identifier_tuple[-1]}

def _handle_non_200_response(self, exc: HTTPError, error_handler: Dict[int, Type[Exception]]) -> None:
exception: Type[Exception]

if exc.response is None:
raise ValueError("Did not receive a response")

code = exc.response.status_code
if code in error_handler:
exception = error_handler[code]
elif code == 400:
exception = BadRequestError
elif code == 401:
exception = UnauthorizedError
elif code == 403:
exception = ForbiddenError
elif code == 422:
exception = RESTError
elif code == 419:
exception = AuthorizationExpiredError
elif code == 501:
exception = NotImplementedError
elif code == 503:
exception = ServiceUnavailableError
elif 500 <= code < 600:
exception = ServerError
else:
exception = RESTError

try:
if exception == OAuthError:
# The OAuthErrorResponse has a different format
error = OAuthErrorResponse.model_validate_json(exc.response.text)
response = str(error.error)
if description := error.error_description:
response += f": {description}"
if uri := error.error_uri:
response += f" ({uri})"
else:
error = ErrorResponse.model_validate_json(exc.response.text).error
response = f"{error.type}: {error.message}"
except JSONDecodeError:
# In the case we don't have a proper response
response = f"RESTError {exc.response.status_code}: Could not decode json payload: {exc.response.text}"
except ValidationError as e:
# In the case we don't have a proper response
errs = ", ".join(err["msg"] for err in e.errors())
response = (
f"RESTError {exc.response.status_code}: Received unexpected JSON Payload: {exc.response.text}, errors: {errs}"
)

raise exception(response) from exc

def _init_sigv4(self, session: Session) -> None:
from urllib import parse

Expand Down Expand Up @@ -533,16 +445,13 @@ def _response_to_staged_table(self, identifier_tuple: Tuple[str, ...], table_res
catalog=self,
)

def _refresh_token(self, session: Optional[Session] = None, initial_token: Optional[str] = None) -> None:
session = session or self._session
if initial_token is not None:
self.properties[TOKEN] = initial_token
elif CREDENTIAL in self.properties:
self.properties[TOKEN] = self._fetch_access_token(session, self.properties[CREDENTIAL])

# Set Auth token for subsequent calls in the session
if token := self.properties.get(TOKEN):
session.headers[AUTHORIZATION_HEADER] = f"{BEARER_PREFIX} {token}"
def _refresh_token(self) -> None:
# Reactive token refresh is atypical - we should proactively refresh tokens in a separate thread
# instead of retrying on Auth Exceptions. Keeping refresh behavior for the LegacyOAuth2AuthManager
# for backward compatibility
auth_manager = self._session.auth.auth_manager # type: ignore[union-attr]
if isinstance(auth_manager, LegacyOAuth2AuthManager):
auth_manager._refresh_token()

def _config_headers(self, session: Session) -> None:
header_properties = get_header_properties(self.properties)
Expand Down Expand Up @@ -587,7 +496,7 @@ def _create_table(
try:
response.raise_for_status()
except HTTPError as exc:
self._handle_non_200_response(exc, {409: TableAlreadyExistsError})
_handle_non_200_response(exc, {409: TableAlreadyExistsError})
return TableResponse.model_validate_json(response.text)

@retry(**_RETRY_ARGS)
Expand Down Expand Up @@ -660,7 +569,7 @@ def register_table(self, identifier: Union[str, Identifier], metadata_location:
try:
response.raise_for_status()
except HTTPError as exc:
self._handle_non_200_response(exc, {409: TableAlreadyExistsError})
_handle_non_200_response(exc, {409: TableAlreadyExistsError})

table_response = TableResponse.model_validate_json(response.text)
return self._response_to_table(self.identifier_to_tuple(identifier), table_response)
Expand All @@ -673,7 +582,7 @@ def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]:
try:
response.raise_for_status()
except HTTPError as exc:
self._handle_non_200_response(exc, {404: NoSuchNamespaceError})
_handle_non_200_response(exc, {404: NoSuchNamespaceError})
return [(*table.namespace, table.name) for table in ListTablesResponse.model_validate_json(response.text).identifiers]

@retry(**_RETRY_ARGS)
Expand All @@ -682,7 +591,7 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table:
try:
response.raise_for_status()
except HTTPError as exc:
self._handle_non_200_response(exc, {404: NoSuchTableError})
_handle_non_200_response(exc, {404: NoSuchTableError})

table_response = TableResponse.model_validate_json(response.text)
return self._response_to_table(self.identifier_to_tuple(identifier), table_response)
Expand All @@ -695,7 +604,7 @@ def drop_table(self, identifier: Union[str, Identifier], purge_requested: bool =
try:
response.raise_for_status()
except HTTPError as exc:
self._handle_non_200_response(exc, {404: NoSuchTableError})
_handle_non_200_response(exc, {404: NoSuchTableError})

@retry(**_RETRY_ARGS)
def purge_table(self, identifier: Union[str, Identifier]) -> None:
Expand All @@ -711,7 +620,7 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U
try:
response.raise_for_status()
except HTTPError as exc:
self._handle_non_200_response(exc, {404: NoSuchTableError, 409: TableAlreadyExistsError})
_handle_non_200_response(exc, {404: NoSuchTableError, 409: TableAlreadyExistsError})

return self.load_table(to_identifier)

Expand All @@ -734,7 +643,7 @@ def list_views(self, namespace: Union[str, Identifier]) -> List[Identifier]:
try:
response.raise_for_status()
except HTTPError as exc:
self._handle_non_200_response(exc, {404: NoSuchNamespaceError})
_handle_non_200_response(exc, {404: NoSuchNamespaceError})
return [(*view.namespace, view.name) for view in ListViewsResponse.model_validate_json(response.text).identifiers]

@retry(**_RETRY_ARGS)
Expand Down Expand Up @@ -772,7 +681,7 @@ def commit_table(
try:
response.raise_for_status()
except HTTPError as exc:
self._handle_non_200_response(
_handle_non_200_response(
exc,
{
409: CommitFailedException,
Expand All @@ -791,7 +700,7 @@ def create_namespace(self, namespace: Union[str, Identifier], properties: Proper
try:
response.raise_for_status()
except HTTPError as exc:
self._handle_non_200_response(exc, {409: NamespaceAlreadyExistsError})
_handle_non_200_response(exc, {409: NamespaceAlreadyExistsError})

@retry(**_RETRY_ARGS)
def drop_namespace(self, namespace: Union[str, Identifier]) -> None:
Expand All @@ -801,7 +710,7 @@ def drop_namespace(self, namespace: Union[str, Identifier]) -> None:
try:
response.raise_for_status()
except HTTPError as exc:
self._handle_non_200_response(exc, {404: NoSuchNamespaceError, 409: NamespaceNotEmptyError})
_handle_non_200_response(exc, {404: NoSuchNamespaceError, 409: NamespaceNotEmptyError})

@retry(**_RETRY_ARGS)
def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]:
Expand All @@ -816,7 +725,7 @@ def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identi
try:
response.raise_for_status()
except HTTPError as exc:
self._handle_non_200_response(exc, {404: NoSuchNamespaceError})
_handle_non_200_response(exc, {404: NoSuchNamespaceError})

return ListNamespaceResponse.model_validate_json(response.text).namespaces

Expand All @@ -828,7 +737,7 @@ def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Proper
try:
response.raise_for_status()
except HTTPError as exc:
self._handle_non_200_response(exc, {404: NoSuchNamespaceError})
_handle_non_200_response(exc, {404: NoSuchNamespaceError})

return NamespaceResponse.model_validate_json(response.text).properties

Expand All @@ -843,7 +752,7 @@ def update_namespace_properties(
try:
response.raise_for_status()
except HTTPError as exc:
self._handle_non_200_response(exc, {404: NoSuchNamespaceError})
_handle_non_200_response(exc, {404: NoSuchNamespaceError})
parsed_response = UpdateNamespacePropertiesResponse.model_validate_json(response.text)
return PropertiesUpdateSummary(
removed=parsed_response.removed,
Expand All @@ -865,7 +774,7 @@ def namespace_exists(self, namespace: Union[str, Identifier]) -> bool:
try:
response.raise_for_status()
except HTTPError as exc:
self._handle_non_200_response(exc, {})
_handle_non_200_response(exc, {})

return False

Expand All @@ -891,7 +800,7 @@ def table_exists(self, identifier: Union[str, Identifier]) -> bool:
try:
response.raise_for_status()
except HTTPError as exc:
self._handle_non_200_response(exc, {})
_handle_non_200_response(exc, {})

return False

Expand All @@ -916,7 +825,7 @@ def view_exists(self, identifier: Union[str, Identifier]) -> bool:
try:
response.raise_for_status()
except HTTPError as exc:
self._handle_non_200_response(exc, {})
_handle_non_200_response(exc, {})

return False

Expand All @@ -928,4 +837,4 @@ def drop_view(self, identifier: Union[str]) -> None:
try:
response.raise_for_status()
except HTTPError as exc:
self._handle_non_200_response(exc, {404: NoSuchViewError})
_handle_non_200_response(exc, {404: NoSuchViewError})
Loading