Skip to content

Commit

Permalink
Fix tests for Python <= 3.9
Browse files Browse the repository at this point in the history
  • Loading branch information
anupam-saini committed Feb 5, 2024
1 parent a740c5b commit 4daac0d
Showing 1 changed file with 33 additions and 28 deletions.
61 changes: 33 additions & 28 deletions pyiceberg/catalog/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,19 @@ class Endpoints:
NAMESPACE_SEPARATOR = b"\x1F".decode(UTF8)


def _retry_hook(retry_state: RetryCallState) -> None:
rest_catalog: RestCatalog = retry_state.args[0]
rest_catalog._refresh_token() # pylint: disable=protected-access


_RETRY_ARGS = {
"retry": retry_if_exception_type(AuthorizationExpiredError),
"stop": stop_after_attempt(2),
"before": _retry_hook,
"reraise": True,
}


class TableResponse(IcebergBaseModel):
metadata_location: str = Field(alias="metadata-location")
metadata: TableMetadata
Expand Down Expand Up @@ -212,11 +225,6 @@ def __init__(self, name: str, **properties: str):
self._fetch_config()
self._session = self._create_session()

@staticmethod
def _retry_hook(retry_state: RetryCallState) -> None:
rest_catalog: RestCatalog = retry_state.args[0]
rest_catalog._refresh_token() # pylint: disable=protected-access

def _create_session(self) -> Session:
"""Create a request session with provided catalog configuration."""
session = Session()
Expand All @@ -231,13 +239,7 @@ def _create_session(self) -> Session:
elif ssl_client_cert := ssl_client.get(CERT):
session.cert = ssl_client_cert

# If we have credentials, but not a token, we want to fetch a token
if TOKEN not in self.properties and CREDENTIAL in self.properties:
self.properties[TOKEN] = self._fetch_access_token(session, self.properties[CREDENTIAL])

# Set Auth token for subsequent calls in the session
if token := self.properties.get(TOKEN):
session.headers[AUTHORIZATION_HEADER] = f"{BEARER_PREFIX} {token}"
self._refresh_token(session, self.properties.get(TOKEN))

# Set HTTP headers
session.headers["Content-type"] = "application/json"
Expand Down Expand Up @@ -444,16 +446,18 @@ def _response_to_table(self, identifier_tuple: Tuple[str, ...], table_response:
catalog=self,
)

def _refresh_token(self) -> None:
session: Session = self._session
# If we have credentials, fetch a new token
if CREDENTIAL in self.properties:
def _refresh_token(self, session: Optional[Session] = None, new_token: Optional[str] = None) -> None:
session = session or self._session
if new_token is not None:
self.properties[TOKEN] = new_token
elif CREDENTIAL in self.properties:
self.properties[TOKEN] = self._fetch_access_token(session, self.properties[CREDENTIAL])

# Set Auth token for subsequent calls in the session
if token := self.properties.get(TOKEN):
session.headers[AUTHORIZATION_HEADER] = f"{BEARER_PREFIX} {token}"

@retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True)
@retry(**_RETRY_ARGS)
def create_table(
self,
identifier: Union[str, Identifier],
Expand Down Expand Up @@ -488,7 +492,7 @@ def create_table(
table_response = TableResponse(**response.json())
return self._response_to_table(self.identifier_to_tuple(identifier), table_response)

@retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True)
@retry(**_RETRY_ARGS)
def register_table(self, identifier: Union[str, Identifier], metadata_location: str) -> Table:
"""Register a new table using existing metadata.
Expand Down Expand Up @@ -520,7 +524,7 @@ def register_table(self, identifier: Union[str, Identifier], metadata_location:
table_response = TableResponse(**response.json())
return self._response_to_table(self.identifier_to_tuple(identifier), table_response)

@retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True)
@retry(**_RETRY_ARGS)
def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]:
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace_concat = NAMESPACE_SEPARATOR.join(namespace_tuple)
Expand All @@ -531,7 +535,7 @@ def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]:
self._handle_non_200_response(exc, {404: NoSuchNamespaceError})
return [(*table.namespace, table.name) for table in ListTablesResponse(**response.json()).identifiers]

@retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True)
@retry(**_RETRY_ARGS)
def load_table(self, identifier: Union[str, Identifier]) -> Table:
identifier_tuple = self.identifier_to_tuple_without_catalog(identifier)
response = self._session.get(
Expand All @@ -545,7 +549,7 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table:
table_response = TableResponse(**response.json())
return self._response_to_table(identifier_tuple, table_response)

@retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True)
@retry(**_RETRY_ARGS)
def drop_table(self, identifier: Union[str, Identifier], purge_requested: bool = False) -> None:
identifier_tuple = self.identifier_to_tuple_without_catalog(identifier)
response = self._session.delete(
Expand All @@ -558,11 +562,11 @@ def drop_table(self, identifier: Union[str, Identifier], purge_requested: bool =
except HTTPError as exc:
self._handle_non_200_response(exc, {404: NoSuchTableError})

@retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True)
@retry(**_RETRY_ARGS)
def purge_table(self, identifier: Union[str, Identifier]) -> None:
self.drop_table(identifier=identifier, purge_requested=True)

@retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True)
@retry(**_RETRY_ARGS)
def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: Union[str, Identifier]) -> Table:
from_identifier_tuple = self.identifier_to_tuple_without_catalog(from_identifier)
payload = {
Expand All @@ -577,6 +581,7 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U

return self.load_table(to_identifier)

@retry(**_RETRY_ARGS)
def _commit_table(self, table_request: CommitTableRequest) -> CommitTableResponse:
"""Update the table.
Expand Down Expand Up @@ -607,7 +612,7 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons
)
return CommitTableResponse(**response.json())

@retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True)
@retry(**_RETRY_ARGS)
def create_namespace(self, namespace: Union[str, Identifier], properties: Properties = EMPTY_DICT) -> None:
namespace_tuple = self._check_valid_namespace_identifier(namespace)
payload = {"namespace": namespace_tuple, "properties": properties}
Expand All @@ -617,7 +622,7 @@ def create_namespace(self, namespace: Union[str, Identifier], properties: Proper
except HTTPError as exc:
self._handle_non_200_response(exc, {404: NoSuchNamespaceError, 409: NamespaceAlreadyExistsError})

@retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True)
@retry(**_RETRY_ARGS)
def drop_namespace(self, namespace: Union[str, Identifier]) -> None:
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
Expand All @@ -627,7 +632,7 @@ def drop_namespace(self, namespace: Union[str, Identifier]) -> None:
except HTTPError as exc:
self._handle_non_200_response(exc, {404: NoSuchNamespaceError})

@retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True)
@retry(**_RETRY_ARGS)
def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]:
namespace_tuple = self.identifier_to_tuple(namespace)
response = self._session.get(
Expand All @@ -645,7 +650,7 @@ def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identi
namespaces = ListNamespaceResponse(**response.json())
return [namespace_tuple + child_namespace for child_namespace in namespaces.namespaces]

@retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True)
@retry(**_RETRY_ARGS)
def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Properties:
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
Expand All @@ -657,7 +662,7 @@ def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Proper

return NamespaceResponse(**response.json()).properties

@retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True)
@retry(**_RETRY_ARGS)
def update_namespace_properties(
self, namespace: Union[str, Identifier], removals: Optional[Set[str]] = None, updates: Properties = EMPTY_DICT
) -> PropertiesUpdateSummary:
Expand Down

0 comments on commit 4daac0d

Please sign in to comment.