Skip to content

Commit

Permalink
abort the whole table transaction if any updates in the transaction h…
Browse files Browse the repository at this point in the history
…as failed (#1246)

* abort the whole transaction if any update on the chain has failed

* Update tests/integration/test_writes/test_writes.py

Co-authored-by: Kevin Liu <[email protected]>

* Update tests/integration/test_writes/test_writes.py

Co-authored-by: Kevin Liu <[email protected]>

* add type:ignore to prevent lint error

---------

Co-authored-by: Yingjian Wu <[email protected]>
Co-authored-by: Kevin Liu <[email protected]>
  • Loading branch information
3 people authored Oct 29, 2024
1 parent 3f8cb17 commit fba79ba
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
11 changes: 8 additions & 3 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from dataclasses import dataclass
from functools import cached_property
from itertools import chain
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -33,6 +34,7 @@
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
)
Expand Down Expand Up @@ -237,9 +239,12 @@ def __enter__(self) -> Transaction:
"""Start a transaction to update the table."""
return self

def __exit__(self, _: Any, value: Any, traceback: Any) -> None:
"""Close and commit the transaction."""
self.commit_transaction()
def __exit__(
self, exctype: Optional[Type[BaseException]], excinst: Optional[BaseException], exctb: Optional[TracebackType]
) -> None:
"""Close and commit the transaction if no exceptions have been raised."""
if exctype is None and excinst is None and exctb is None:
self.commit_transaction()

def _apply(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequirement, ...] = ()) -> Transaction:
"""Check if the requirements are met, and applies the updates to the metadata."""
Expand Down
24 changes: 24 additions & 0 deletions tests/integration/test_writes/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,3 +1448,27 @@ def test_rewrite_manifest_after_partition_evolution(session_catalog: Catalog) ->
EqualTo("category", "A"),
),
)


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_abort_table_transaction_on_exception(
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
) -> None:
identifier = "default.table_test_abort_table_transaction_on_exception"
tbl = _create_table(session_catalog, identifier, properties={"format-version": format_version})

# Pre-populate some data
tbl.append(arrow_table_with_null)
table_size = len(arrow_table_with_null)
assert len(tbl.scan().to_pandas()) == table_size

# try to commit a transaction that raises exception at the middle
with pytest.raises(ValueError):
with tbl.transaction() as txn:
txn.append(arrow_table_with_null)
raise ValueError
txn.append(arrow_table_with_null) # type: ignore

# Validate the transaction is aborted and no partial update is applied
assert len(tbl.scan().to_pandas()) == table_size # type: ignore

0 comments on commit fba79ba

Please sign in to comment.