diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index f94d9c8a14..264afd8971 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -23,6 +23,7 @@ from dataclasses import dataclass from functools import cached_property from itertools import chain +from types import TracebackType from typing import ( TYPE_CHECKING, Any, @@ -33,6 +34,7 @@ Optional, Set, Tuple, + Type, TypeVar, Union, ) @@ -237,9 +239,12 @@ def __enter__(self) -> Transaction: """Start a transaction to update the table.""" return self - def __exit__(self, _: Any, value: Any, traceback: Any) -> None: - """Close and commit the transaction.""" - self.commit_transaction() + def __exit__( + self, exctype: Optional[Type[BaseException]], excinst: Optional[BaseException], exctb: Optional[TracebackType] + ) -> None: + """Close and commit the transaction if no exceptions have been raised.""" + if exctype is None and excinst is None and exctb is None: + self.commit_transaction() def _apply(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequirement, ...] = ()) -> Transaction: """Check if the requirements are met, and applies the updates to the metadata.""" diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index fc2746c614..9cccb542d6 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -1448,3 +1448,27 @@ def test_rewrite_manifest_after_partition_evolution(session_catalog: Catalog) -> EqualTo("category", "A"), ), ) + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_abort_table_transaction_on_exception( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = "default.table_test_abort_table_transaction_on_exception" + tbl = _create_table(session_catalog, identifier, properties={"format-version": format_version}) + + # Pre-populate some data + tbl.append(arrow_table_with_null) + table_size = len(arrow_table_with_null) + assert len(tbl.scan().to_pandas()) == table_size + + # try to commit a transaction that raises exception at the middle + with pytest.raises(ValueError): + with tbl.transaction() as txn: + txn.append(arrow_table_with_null) + raise ValueError + txn.append(arrow_table_with_null) # type: ignore + + # Validate the transaction is aborted and no partial update is applied + assert len(tbl.scan().to_pandas()) == table_size # type: ignore