Skip to content

Commit

Permalink
Table commit retries based on table properties
Browse files Browse the repository at this point in the history
  • Loading branch information
Buktoria committed Jan 31, 2024
1 parent 7f7bb03 commit edd1aad
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 2 deletions.
16 changes: 15 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

74 changes: 73 additions & 1 deletion pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from copy import copy
from dataclasses import dataclass
from enum import Enum
from functools import cached_property, singledispatch
from functools import cached_property, partial, singledispatch
from itertools import chain
from typing import (
TYPE_CHECKING,
Expand All @@ -43,6 +43,7 @@

from pydantic import Field, SerializeAsAny
from sortedcontainers import SortedList
from tenacity import RetryError, Retrying, retry_if_exception_type, stop_after_attempt, stop_after_delay, wait_exponential
from typing_extensions import Annotated

from pyiceberg.exceptions import CommitFailedException, ResolveError, ValidationError
Expand Down Expand Up @@ -791,6 +792,76 @@ class CommitTableResponse(IcebergBaseModel):
metadata_location: str = Field(alias="metadata-location")


class TableCommitRetry:
"""Decorator for building the table commit retry controller."""

num_retries = "commit.retry.num-retries"
num_retries_default: int = 4
min_wait_ms = "commit.retry.min-wait-ms"
min_wait_ms_default: int = 100
max_wait_ms = "commit.retry.max-wait-ms"
max_wait_ms_default: int = 60000 # 1 min
total_timeout_ms = "commit.retry.total-timeout-ms"
total_timeout_ms_default: int = 1800000 # 30 mins

def __init__(self, func: Callable[..., Any], properties_attribute: str = "properties") -> None:
self.properties_attr: str = properties_attribute
self.func: Callable[..., Any] = func
self.loaded_properties: Properties = {}

def __get__(self, instance: Any, owner: Any) -> Callable[..., Any]:
"""Return the __call__ method with the instance caller."""
return partial(self.__call__, instance)

def __call__(self, instance: Any, *args: Any, **kwargs: Any) -> Any:
"""Run function with the retrying controller on the caller instance."""
self.loaded_properties = getattr(instance, self.properties_attr)
try:
for attempt in self.build_retry_controller():
with attempt:
result = self.func(instance, *args, **kwargs)
except RetryError as err:
raise Exception from err.reraise()
else:
return result

@property
def table_properties(self) -> Properties:
"""Get the table properties from the instance that is calling this decorator."""
return self.loaded_properties

def build_retry_controller(self) -> Retrying:
"""Build the retry controller."""
return Retrying(
stop=(
stop_after_attempt(self.get_config(self.num_retries, self.num_retries_default))
| stop_after_delay(
datetime.timedelta(milliseconds=self.get_config(self.total_timeout_ms, self.total_timeout_ms_default))
)
),
wait=wait_exponential(min=self.get_config(self.min_wait_ms, self.min_wait_ms_default) / 1000.0),
retry=retry_if_exception_type(CommitFailedException),
)

def get_config(self, config: str, default: int) -> int:
"""Get config out of the properties."""
return self.to_int(self.table_properties.get(config, ""), default)

@staticmethod
def to_int(v: str, default: int) -> int:
"""Convert str value to int, otherwise return a default."""
try:
return int(v)
except (ValueError, TypeError):
pass
return default


def table_commit_retry(properties_attribute: str) -> Callable[..., TableCommitRetry]:
"""Decorate TableCommitRetry to capture the `properties_attribute`."""
return partial(TableCommitRetry, properties_attribute=properties_attribute)


class Table:
identifier: Identifier = Field()
metadata: TableMetadata
Expand Down Expand Up @@ -994,6 +1065,7 @@ def refs(self) -> Dict[str, SnapshotRef]:
"""Return the snapshot references in the table."""
return self.metadata.refs

@table_commit_retry("properties")
def _do_commit(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequirement, ...]) -> None:
response = self.catalog._commit_table( # pylint: disable=W0212
CommitTableRequest(
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ adlfs = { version = ">=2023.1.0,<2024.2.0", optional = true }
gcsfs = { version = ">=2023.1.0,<2024.1.0", optional = true }
psycopg2-binary = { version = ">=2.9.6", optional = true }
sqlalchemy = { version = "^2.0.18", optional = true }
tenacity = "8.2.3"

[tool.poetry.dev-dependencies]
pytest = "7.4.4"
Expand Down Expand Up @@ -295,6 +296,10 @@ ignore_missing_imports = true
module = "setuptools.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "tenacity.*"
ignore_missing_imports = true

[tool.coverage.run]
source = ['pyiceberg/']

Expand Down
73 changes: 73 additions & 0 deletions tests/table/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
_generate_snapshot_id,
_match_deletes_to_data_file,
_TableMetadataUpdateContext,
table_commit_retry,
update_table_metadata,
)
from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadataUtil, TableMetadataV2
Expand All @@ -77,6 +78,7 @@
SortOrder,
)
from pyiceberg.transforms import BucketTransform, IdentityTransform
from pyiceberg.typedef import Properties
from pyiceberg.types import (
BinaryType,
BooleanType,
Expand Down Expand Up @@ -982,3 +984,74 @@ def test_correct_schema() -> None:
_ = t.scan(snapshot_id=-1).projection()

assert "Snapshot not found: -1" in str(exc_info.value)


def test_non_commit_failure_retry() -> None:
class CustomException(Exception):
pass

class TestTableCommitRetiesCustomError:
def __init__(self) -> None:
self.count: int = 0
self.properties: Properties = {
"commit.retry.num-retries": "3",
"commit.retry.max-wait-ms": "0",
"commit.retry.min-wait-ms": "0",
}

@table_commit_retry("properties")
def my_function(self) -> None:
self.count += 1
raise CustomException

test_table_commits_retry = TestTableCommitRetiesCustomError()

with pytest.raises(
CustomException,
):
test_table_commits_retry.my_function()
assert test_table_commits_retry.count == 1


def test_custom_retry_commit_config() -> None:
class TestTableCommitReties:
def __init__(self) -> None:
self.count: int = 0
self.properties: Properties = {
"commit.retry.num-retries": "3",
"commit.retry.max-wait-ms": "0",
"commit.retry.min-wait-ms": "0",
}

@table_commit_retry("properties")
def my_function(self) -> None:
self.count += 1
raise CommitFailedException

test_table_commits_retry = TestTableCommitReties()

with pytest.raises(CommitFailedException):
test_table_commits_retry.my_function()
assert test_table_commits_retry.count == 3


def test_invalid_commit_retry_config() -> None:
class TestTableCommitReties:
def __init__(self) -> None:
self.count: int = 0
self.properties: Properties = {
"commit.retry.num-retries": "I AM INVALID",
"commit.retry.max-wait-ms": "0",
"commit.retry.min-wait-ms": "0",
}

@table_commit_retry("properties")
def my_function(self) -> None:
self.count += 1
raise CommitFailedException

test_table_commits_retry = TestTableCommitReties()

with pytest.raises(CommitFailedException):
test_table_commits_retry.my_function()
assert test_table_commits_retry.count == 4

0 comments on commit edd1aad

Please sign in to comment.