diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 8e94e02ea3..c852aee565 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -238,7 +238,7 @@ def __exit__( ) -> None: """Close and commit the transaction, or handle exceptions.""" # Only commit the full transaction, if there is no exception in all updates on the chain - if exctb is None: + if exctype is None and excinst is None and exctb is None: self.commit_transaction() def _apply(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequirement, ...] = ()) -> Transaction: diff --git a/tests/catalog/test_base.py b/tests/catalog/test_base.py index 4ff7d82f07..4faed0416f 100644 --- a/tests/catalog/test_base.py +++ b/tests/catalog/test_base.py @@ -41,17 +41,14 @@ NoSuchTableError, TableAlreadyExistsError, ) -from pyiceberg.expressions import BooleanExpression -from pyiceberg.io import WAREHOUSE, FileIO, load_file_io +from pyiceberg.io import WAREHOUSE, load_file_io from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec from pyiceberg.schema import Schema from pyiceberg.table import ( - ALWAYS_TRUE, CommitTableResponse, Table, - Transaction, ) -from pyiceberg.table.metadata import TableMetadata, new_table_metadata +from pyiceberg.table.metadata import new_table_metadata from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder from pyiceberg.table.update import ( AddSchemaUpdate, @@ -268,42 +265,6 @@ def drop_view(self, identifier: Union[str, Identifier]) -> None: raise NotImplementedError -class TransactionThrowExceptionInOverwrite(Transaction): - def __init__(self, table: Table): - super().__init__(table) - - # Override the default overwrite to simulate exception during append - def overwrite( - self, - df: pa.Table, - overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE, - snapshot_properties: Dict[str, str] = EMPTY_DICT, - ) -> None: - self.delete(delete_filter=overwrite_filter, snapshot_properties=snapshot_properties) - raise Exception("Fail Append Commit Exception") - - -class TableThrowExceptionInOverwrite(Table): - def __init__(self, identifier: Identifier, metadata: TableMetadata, metadata_location: str, io: FileIO, catalog: Catalog): - # Call the constructor of the parent class - super().__init__(identifier, metadata, metadata_location, io, catalog) - - def transaction(self) -> Transaction: - return TransactionThrowExceptionInOverwrite(self) - - -def given_catalog_has_a_table_throw_exception_in_overwrite( - catalog: InMemoryCatalog, properties: Properties = EMPTY_DICT -) -> TableThrowExceptionInOverwrite: - table = catalog.create_table( - identifier=TEST_TABLE_IDENTIFIER, - schema=TEST_TABLE_SCHEMA, - partition_spec=TEST_TABLE_PARTITION_SPEC, - properties=properties or TEST_TABLE_PROPERTIES, - ) - return TableThrowExceptionInOverwrite(table.identifier, table.metadata, table.metadata_location, table.io, table.catalog) - - @pytest.fixture def catalog(tmp_path: PosixPath) -> InMemoryCatalog: return InMemoryCatalog("test.in_memory.catalog", **{WAREHOUSE: tmp_path.absolute().as_posix(), "test.key": "test.value"}) @@ -808,24 +769,23 @@ def test_table_properties_raise_for_none_value(catalog: InMemoryCatalog) -> None def test_table_overwrite_with_exception(catalog: InMemoryCatalog) -> None: - given_table = given_catalog_has_a_table_throw_exception_in_overwrite(catalog) + tbl = given_catalog_has_a_table(catalog) # Populate some initial data data = pa.Table.from_pylist( [{"x": 1, "y": 2, "z": 3}, {"x": 4, "y": 5, "z": 6}], schema=TEST_TABLE_SCHEMA.as_arrow(), ) - given_table.append(data) + tbl.append(data) # Data to overwrite data = pa.Table.from_pylist( - [{"x": 7, "y": 8, "z": 9}], + [{"x": 7, "y": 8, "z": 9}, {"x": 7, "y": 8, "z": 9}, {"x": 7, "y": 8, "z": 9}], schema=TEST_TABLE_SCHEMA.as_arrow(), ) - # Since overwrite has an exception, we should fail the whole overwrite transaction - try: - given_table.overwrite(data) - except Exception as e: - assert str(e) == "Fail Append Commit Exception", f"Expected 'Fail Append Commit Exception', but got '{str(e)}'" + with pytest.raises(ValueError): + with tbl.transaction() as txn: + txn.overwrite(data) + raise ValueError - assert len(given_table.scan().to_arrow()) == 2 + assert len(tbl.scan().to_pandas()) == 2 # type: ignore