diff --git a/pyiceberg/catalog/rest.py b/pyiceberg/catalog/rest.py index df255017e0..6a75328cae 100644 --- a/pyiceberg/catalog/rest.py +++ b/pyiceberg/catalog/rest.py @@ -119,6 +119,19 @@ class Endpoints: NAMESPACE_SEPARATOR = b"\x1F".decode(UTF8) +def _retry_hook(retry_state: RetryCallState) -> None: + rest_catalog: RestCatalog = retry_state.args[0] + rest_catalog._refresh_token() # pylint: disable=protected-access + + +_RETRY_ARGS = { + "retry": retry_if_exception_type(AuthorizationExpiredError), + "stop": stop_after_attempt(2), + "before": _retry_hook, + "reraise": True, +} + + class TableResponse(IcebergBaseModel): metadata_location: str = Field(alias="metadata-location") metadata: TableMetadata @@ -212,11 +225,6 @@ def __init__(self, name: str, **properties: str): self._fetch_config() self._session = self._create_session() - @staticmethod - def _retry_hook(retry_state: RetryCallState) -> None: - rest_catalog: RestCatalog = retry_state.args[0] - rest_catalog._refresh_token() # pylint: disable=protected-access - def _create_session(self) -> Session: """Create a request session with provided catalog configuration.""" session = Session() @@ -231,13 +239,7 @@ def _create_session(self) -> Session: elif ssl_client_cert := ssl_client.get(CERT): session.cert = ssl_client_cert - # If we have credentials, but not a token, we want to fetch a token - if TOKEN not in self.properties and CREDENTIAL in self.properties: - self.properties[TOKEN] = self._fetch_access_token(session, self.properties[CREDENTIAL]) - - # Set Auth token for subsequent calls in the session - if token := self.properties.get(TOKEN): - session.headers[AUTHORIZATION_HEADER] = f"{BEARER_PREFIX} {token}" + self._refresh_token(session, self.properties.get(TOKEN)) # Set HTTP headers session.headers["Content-type"] = "application/json" @@ -444,16 +446,18 @@ def _response_to_table(self, identifier_tuple: Tuple[str, ...], table_response: catalog=self, ) - def _refresh_token(self) -> None: - session: Session = self._session - # If we have credentials, fetch a new token - if CREDENTIAL in self.properties: + def _refresh_token(self, session: Optional[Session] = None, new_token: Optional[str] = None) -> None: + session = session or self._session + if new_token is not None: + self.properties[TOKEN] = new_token + elif CREDENTIAL in self.properties: self.properties[TOKEN] = self._fetch_access_token(session, self.properties[CREDENTIAL]) + # Set Auth token for subsequent calls in the session if token := self.properties.get(TOKEN): session.headers[AUTHORIZATION_HEADER] = f"{BEARER_PREFIX} {token}" - @retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True) + @retry(**_RETRY_ARGS) def create_table( self, identifier: Union[str, Identifier], @@ -488,7 +492,7 @@ def create_table( table_response = TableResponse(**response.json()) return self._response_to_table(self.identifier_to_tuple(identifier), table_response) - @retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True) + @retry(**_RETRY_ARGS) def register_table(self, identifier: Union[str, Identifier], metadata_location: str) -> Table: """Register a new table using existing metadata. @@ -520,7 +524,7 @@ def register_table(self, identifier: Union[str, Identifier], metadata_location: table_response = TableResponse(**response.json()) return self._response_to_table(self.identifier_to_tuple(identifier), table_response) - @retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True) + @retry(**_RETRY_ARGS) def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]: namespace_tuple = self._check_valid_namespace_identifier(namespace) namespace_concat = NAMESPACE_SEPARATOR.join(namespace_tuple) @@ -531,7 +535,7 @@ def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]: self._handle_non_200_response(exc, {404: NoSuchNamespaceError}) return [(*table.namespace, table.name) for table in ListTablesResponse(**response.json()).identifiers] - @retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True) + @retry(**_RETRY_ARGS) def load_table(self, identifier: Union[str, Identifier]) -> Table: identifier_tuple = self.identifier_to_tuple_without_catalog(identifier) response = self._session.get( @@ -545,7 +549,7 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table: table_response = TableResponse(**response.json()) return self._response_to_table(identifier_tuple, table_response) - @retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True) + @retry(**_RETRY_ARGS) def drop_table(self, identifier: Union[str, Identifier], purge_requested: bool = False) -> None: identifier_tuple = self.identifier_to_tuple_without_catalog(identifier) response = self._session.delete( @@ -558,11 +562,11 @@ def drop_table(self, identifier: Union[str, Identifier], purge_requested: bool = except HTTPError as exc: self._handle_non_200_response(exc, {404: NoSuchTableError}) - @retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True) + @retry(**_RETRY_ARGS) def purge_table(self, identifier: Union[str, Identifier]) -> None: self.drop_table(identifier=identifier, purge_requested=True) - @retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True) + @retry(**_RETRY_ARGS) def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: Union[str, Identifier]) -> Table: from_identifier_tuple = self.identifier_to_tuple_without_catalog(from_identifier) payload = { @@ -577,6 +581,7 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U return self.load_table(to_identifier) + @retry(**_RETRY_ARGS) def _commit_table(self, table_request: CommitTableRequest) -> CommitTableResponse: """Update the table. @@ -607,7 +612,7 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons ) return CommitTableResponse(**response.json()) - @retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True) + @retry(**_RETRY_ARGS) def create_namespace(self, namespace: Union[str, Identifier], properties: Properties = EMPTY_DICT) -> None: namespace_tuple = self._check_valid_namespace_identifier(namespace) payload = {"namespace": namespace_tuple, "properties": properties} @@ -617,7 +622,7 @@ def create_namespace(self, namespace: Union[str, Identifier], properties: Proper except HTTPError as exc: self._handle_non_200_response(exc, {404: NoSuchNamespaceError, 409: NamespaceAlreadyExistsError}) - @retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True) + @retry(**_RETRY_ARGS) def drop_namespace(self, namespace: Union[str, Identifier]) -> None: namespace_tuple = self._check_valid_namespace_identifier(namespace) namespace = NAMESPACE_SEPARATOR.join(namespace_tuple) @@ -627,7 +632,7 @@ def drop_namespace(self, namespace: Union[str, Identifier]) -> None: except HTTPError as exc: self._handle_non_200_response(exc, {404: NoSuchNamespaceError}) - @retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True) + @retry(**_RETRY_ARGS) def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]: namespace_tuple = self.identifier_to_tuple(namespace) response = self._session.get( @@ -645,7 +650,7 @@ def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identi namespaces = ListNamespaceResponse(**response.json()) return [namespace_tuple + child_namespace for child_namespace in namespaces.namespaces] - @retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True) + @retry(**_RETRY_ARGS) def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Properties: namespace_tuple = self._check_valid_namespace_identifier(namespace) namespace = NAMESPACE_SEPARATOR.join(namespace_tuple) @@ -657,7 +662,7 @@ def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Proper return NamespaceResponse(**response.json()).properties - @retry(retry=retry_if_exception_type(AuthorizationExpiredError), stop=stop_after_attempt(2), before=_retry_hook, reraise=True) + @retry(**_RETRY_ARGS) def update_namespace_properties( self, namespace: Union[str, Identifier], removals: Optional[Set[str]] = None, updates: Properties = EMPTY_DICT ) -> PropertiesUpdateSummary: