Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Yingjian Wu committed Oct 27, 2024
1 parent 4eae64a commit 34ca959
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 51 deletions.
2 changes: 1 addition & 1 deletion pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def __exit__(
) -> None:
"""Close and commit the transaction, or handle exceptions."""
# Only commit the full transaction, if there is no exception in all updates on the chain
if exctb is None:
if exctype is None and excinst is None and exctb is None:
self.commit_transaction()

def _apply(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequirement, ...] = ()) -> Transaction:
Expand Down
60 changes: 10 additions & 50 deletions tests/catalog/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,14 @@
NoSuchTableError,
TableAlreadyExistsError,
)
from pyiceberg.expressions import BooleanExpression
from pyiceberg.io import WAREHOUSE, FileIO, load_file_io
from pyiceberg.io import WAREHOUSE, load_file_io
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table import (
ALWAYS_TRUE,
CommitTableResponse,
Table,
Transaction,
)
from pyiceberg.table.metadata import TableMetadata, new_table_metadata
from pyiceberg.table.metadata import new_table_metadata
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
from pyiceberg.table.update import (
AddSchemaUpdate,
Expand Down Expand Up @@ -268,42 +265,6 @@ def drop_view(self, identifier: Union[str, Identifier]) -> None:
raise NotImplementedError


class TransactionThrowExceptionInOverwrite(Transaction):
def __init__(self, table: Table):
super().__init__(table)

# Override the default overwrite to simulate exception during append
def overwrite(
self,
df: pa.Table,
overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
snapshot_properties: Dict[str, str] = EMPTY_DICT,
) -> None:
self.delete(delete_filter=overwrite_filter, snapshot_properties=snapshot_properties)
raise Exception("Fail Append Commit Exception")


class TableThrowExceptionInOverwrite(Table):
def __init__(self, identifier: Identifier, metadata: TableMetadata, metadata_location: str, io: FileIO, catalog: Catalog):
# Call the constructor of the parent class
super().__init__(identifier, metadata, metadata_location, io, catalog)

def transaction(self) -> Transaction:
return TransactionThrowExceptionInOverwrite(self)


def given_catalog_has_a_table_throw_exception_in_overwrite(
catalog: InMemoryCatalog, properties: Properties = EMPTY_DICT
) -> TableThrowExceptionInOverwrite:
table = catalog.create_table(
identifier=TEST_TABLE_IDENTIFIER,
schema=TEST_TABLE_SCHEMA,
partition_spec=TEST_TABLE_PARTITION_SPEC,
properties=properties or TEST_TABLE_PROPERTIES,
)
return TableThrowExceptionInOverwrite(table.identifier, table.metadata, table.metadata_location, table.io, table.catalog)


@pytest.fixture
def catalog(tmp_path: PosixPath) -> InMemoryCatalog:
return InMemoryCatalog("test.in_memory.catalog", **{WAREHOUSE: tmp_path.absolute().as_posix(), "test.key": "test.value"})
Expand Down Expand Up @@ -808,24 +769,23 @@ def test_table_properties_raise_for_none_value(catalog: InMemoryCatalog) -> None


def test_table_overwrite_with_exception(catalog: InMemoryCatalog) -> None:
given_table = given_catalog_has_a_table_throw_exception_in_overwrite(catalog)
tbl = given_catalog_has_a_table(catalog)
# Populate some initial data
data = pa.Table.from_pylist(
[{"x": 1, "y": 2, "z": 3}, {"x": 4, "y": 5, "z": 6}],
schema=TEST_TABLE_SCHEMA.as_arrow(),
)
given_table.append(data)
tbl.append(data)

# Data to overwrite
data = pa.Table.from_pylist(
[{"x": 7, "y": 8, "z": 9}],
[{"x": 7, "y": 8, "z": 9}, {"x": 7, "y": 8, "z": 9}, {"x": 7, "y": 8, "z": 9}],
schema=TEST_TABLE_SCHEMA.as_arrow(),
)

# Since overwrite has an exception, we should fail the whole overwrite transaction
try:
given_table.overwrite(data)
except Exception as e:
assert str(e) == "Fail Append Commit Exception", f"Expected 'Fail Append Commit Exception', but got '{str(e)}'"
with pytest.raises(ValueError):
with tbl.transaction() as txn:
txn.overwrite(data)
raise ValueError

assert len(given_table.scan().to_arrow()) == 2
assert len(tbl.scan().to_pandas()) == 2 # type: ignore

0 comments on commit 34ca959

Please sign in to comment.