From 19bf07ca746fd203ea6c8e476ba8abfbba8ffa79 Mon Sep 17 00:00:00 2001 From: HonahX Date: Tue, 30 Apr 2024 23:20:32 -0700 Subject: [PATCH 1/3] create table transaction for sql --- pyiceberg/catalog/sql.py | 106 ++++++++++++++++++++++++-------------- tests/catalog/test_sql.py | 60 +++++++++++++++++++++ 2 files changed, 126 insertions(+), 40 deletions(-) diff --git a/pyiceberg/catalog/sql.py b/pyiceberg/catalog/sql.py index 978109b2a3..9dc17a4d85 100644 --- a/pyiceberg/catalog/sql.py +++ b/pyiceberg/catalog/sql.py @@ -59,7 +59,7 @@ from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec from pyiceberg.schema import Schema from pyiceberg.serializers import FromInputFile -from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table, update_table_metadata +from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table from pyiceberg.table.metadata import new_table_metadata from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder from pyiceberg.typedef import EMPTY_DICT, Identifier, Properties @@ -376,57 +376,83 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons identifier_tuple = self.identifier_to_tuple_without_catalog( tuple(table_request.identifier.namespace.root + [table_request.identifier.name]) ) - current_table = self.load_table(identifier_tuple) database_name, table_name = self.identifier_to_database_and_table(identifier_tuple, NoSuchTableError) - base_metadata = current_table.metadata - for requirement in table_request.requirements: - requirement.validate(base_metadata) - updated_metadata = update_table_metadata(base_metadata, table_request.updates) - if updated_metadata == base_metadata: - # no changes, do nothing - return CommitTableResponse(metadata=base_metadata, metadata_location=current_table.metadata_location) + current_table: Optional[Table] + try: + current_table = self.load_table(identifier_tuple) + except NoSuchTableError: + current_table = None - # write new metadata - new_metadata_version = self._parse_metadata_version(current_table.metadata_location) + 1 - new_metadata_location = self._get_metadata_location(current_table.metadata.location, new_metadata_version) - self._write_metadata(updated_metadata, current_table.io, new_metadata_location) + updated_staged_table = self._update_and_stage_table(current_table, table_request) + if current_table and updated_staged_table.metadata == current_table.metadata: + # no changes, do nothing + return CommitTableResponse(metadata=current_table.metadata, metadata_location=current_table.metadata_location) + self._write_metadata( + metadata=updated_staged_table.metadata, + io=updated_staged_table.io, + metadata_path=updated_staged_table.metadata_location, + ) with Session(self.engine) as session: - if self.engine.dialect.supports_sane_rowcount: - stmt = ( - update(IcebergTables) - .where( - IcebergTables.catalog_name == self.name, - IcebergTables.table_namespace == database_name, - IcebergTables.table_name == table_name, - IcebergTables.metadata_location == current_table.metadata_location, - ) - .values(metadata_location=new_metadata_location, previous_metadata_location=current_table.metadata_location) - ) - result = session.execute(stmt) - if result.rowcount < 1: - raise CommitFailedException(f"Table has been updated by another process: {database_name}.{table_name}") - else: - try: - tbl = ( - session.query(IcebergTables) - .with_for_update(of=IcebergTables) - .filter( + if current_table: + # table exists, update it + if self.engine.dialect.supports_sane_rowcount: + stmt = ( + update(IcebergTables) + .where( IcebergTables.catalog_name == self.name, IcebergTables.table_namespace == database_name, IcebergTables.table_name == table_name, IcebergTables.metadata_location == current_table.metadata_location, ) - .one() + .values( + metadata_location=updated_staged_table.metadata_location, + previous_metadata_location=current_table.metadata_location, + ) ) - tbl.metadata_location = new_metadata_location - tbl.previous_metadata_location = current_table.metadata_location - except NoResultFound as e: - raise CommitFailedException(f"Table has been updated by another process: {database_name}.{table_name}") from e - session.commit() + result = session.execute(stmt) + if result.rowcount < 1: + raise CommitFailedException(f"Table has been updated by another process: {database_name}.{table_name}") + else: + try: + tbl = ( + session.query(IcebergTables) + .with_for_update(of=IcebergTables) + .filter( + IcebergTables.catalog_name == self.name, + IcebergTables.table_namespace == database_name, + IcebergTables.table_name == table_name, + IcebergTables.metadata_location == current_table.metadata_location, + ) + .one() + ) + tbl.metadata_location = updated_staged_table.metadata_location + tbl.previous_metadata_location = current_table.metadata_location + except NoResultFound as e: + raise CommitFailedException( + f"Table has been updated by another process: {database_name}.{table_name}" + ) from e + session.commit() + else: + # table does not exist, create it + try: + session.add( + IcebergTables( + catalog_name=self.name, + table_namespace=database_name, + table_name=table_name, + metadata_location=updated_staged_table.metadata_location, + previous_metadata_location=None, + ) + ) + session.commit() + except IntegrityError as e: + raise TableAlreadyExistsError(f"Table {database_name}.{table_name} already exists") from e - return CommitTableResponse(metadata=updated_metadata, metadata_location=new_metadata_location) + return CommitTableResponse( + metadata=updated_staged_table.metadata, metadata_location=updated_staged_table.metadata_location + ) def _namespace_exists(self, identifier: Union[str, Identifier]) -> bool: namespace = self.identifier_to_database(identifier) diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py index 99b8550602..9f2f370413 100644 --- a/tests/catalog/test_sql.py +++ b/tests/catalog/test_sql.py @@ -948,6 +948,66 @@ def test_write_and_evolve(catalog: SqlCatalog, format_version: int) -> None: snapshot_update.append_data_file(data_file) +@pytest.mark.parametrize( + 'catalog', + [ + lazy_fixture('catalog_memory'), + lazy_fixture('catalog_sqlite'), + lazy_fixture('catalog_sqlite_without_rowcount'), + ], +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_create_table_transaction(catalog: SqlCatalog, format_version: int) -> None: + identifier = f"default.arrow_create_table_transaction_{catalog.name}_{format_version}" + try: + catalog.create_namespace("default") + except NamespaceAlreadyExistsError: + pass + + try: + catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + pa_table = pa.Table.from_pydict( + { + 'foo': ['a', None, 'z'], + }, + schema=pa.schema([pa.field("foo", pa.string(), nullable=True)]), + ) + + pa_table_with_column = pa.Table.from_pydict( + { + 'foo': ['a', None, 'z'], + 'bar': [19, None, 25], + }, + schema=pa.schema([ + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=True), + ]), + ) + + with catalog.create_table_transaction( + identifier=identifier, schema=pa_table.schema, properties={"format-version": str(format_version)} + ) as txn: + with txn.update_snapshot().fast_append() as snapshot_update: + for data_file in _dataframe_to_data_files(table_metadata=txn.table_metadata, df=pa_table, io=txn._table.io): + snapshot_update.append_data_file(data_file) + + with txn.update_schema() as schema_txn: + schema_txn.union_by_name(pa_table_with_column.schema) + + with txn.update_snapshot().fast_append() as snapshot_update: + for data_file in _dataframe_to_data_files( + table_metadata=txn.table_metadata, df=pa_table_with_column, io=txn._table.io + ): + snapshot_update.append_data_file(data_file) + + tbl = catalog.load_table(identifier=identifier) + assert tbl.format_version == format_version + assert len(tbl.scan().to_arrow()) == 6 + + @pytest.mark.parametrize( 'catalog', [ From 436beeb7b17ccfa85abf8510b1b64819c1f645f2 Mon Sep 17 00:00:00 2001 From: HonahX Date: Thu, 30 May 2024 00:13:41 -0700 Subject: [PATCH 2/3] fix merge conflict --- pyiceberg/catalog/sql.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/pyiceberg/catalog/sql.py b/pyiceberg/catalog/sql.py index 6f70ba607d..ff7831d77f 100644 --- a/pyiceberg/catalog/sql.py +++ b/pyiceberg/catalog/sql.py @@ -402,7 +402,6 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons identifier_tuple = self.identifier_to_tuple_without_catalog( tuple(table_request.identifier.namespace.root + [table_request.identifier.name]) ) - current_table = self.load_table(identifier_tuple) namespace_tuple = Catalog.namespace_from(identifier_tuple) namespace = Catalog.namespace_to_string(namespace_tuple) table_name = Catalog.table_name_from(identifier_tuple) @@ -431,7 +430,7 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons update(IcebergTables) .where( IcebergTables.catalog_name == self.name, - IcebergTables.table_namespace == database_name, + IcebergTables.table_namespace == namespace, IcebergTables.table_name == table_name, IcebergTables.metadata_location == current_table.metadata_location, ) @@ -442,7 +441,7 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons ) result = session.execute(stmt) if result.rowcount < 1: - raise CommitFailedException(f"Table has been updated by another process: {database_name}.{table_name}") + raise CommitFailedException(f"Table has been updated by another process: {namespace}.{table_name}") else: try: tbl = ( @@ -450,7 +449,7 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons .with_for_update(of=IcebergTables) .filter( IcebergTables.catalog_name == self.name, - IcebergTables.table_namespace == database_name, + IcebergTables.table_namespace == namespace, IcebergTables.table_name == table_name, IcebergTables.metadata_location == current_table.metadata_location, ) @@ -459,9 +458,7 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons tbl.metadata_location = updated_staged_table.metadata_location tbl.previous_metadata_location = current_table.metadata_location except NoResultFound as e: - raise CommitFailedException( - f"Table has been updated by another process: {database_name}.{table_name}" - ) from e + raise CommitFailedException(f"Table has been updated by another process: {namespace}.{table_name}") from e session.commit() else: # table does not exist, create it @@ -469,7 +466,7 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons session.add( IcebergTables( catalog_name=self.name, - table_namespace=database_name, + table_namespace=namespace, table_name=table_name, metadata_location=updated_staged_table.metadata_location, previous_metadata_location=None, @@ -477,7 +474,7 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons ) session.commit() except IntegrityError as e: - raise TableAlreadyExistsError(f"Table {database_name}.{table_name} already exists") from e + raise TableAlreadyExistsError(f"Table {namespace}.{table_name} already exists") from e return CommitTableResponse( metadata=updated_staged_table.metadata, metadata_location=updated_staged_table.metadata_location From d7f7b561042ea4eb0169d00e8b0771a12e357be3 Mon Sep 17 00:00:00 2001 From: HonahX Date: Thu, 30 May 2024 23:05:44 -0700 Subject: [PATCH 3/3] fix lint --- tests/catalog/test_sql.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py index 810d094e4d..545916223a 100644 --- a/tests/catalog/test_sql.py +++ b/tests/catalog/test_sql.py @@ -1373,15 +1373,15 @@ def test_create_table_transaction(catalog: SqlCatalog, format_version: int) -> N pa_table = pa.Table.from_pydict( { - 'foo': ['a', None, 'z'], + "foo": ["a", None, "z"], }, schema=pa.schema([pa.field("foo", pa.string(), nullable=True)]), ) pa_table_with_column = pa.Table.from_pydict( { - 'foo': ['a', None, 'z'], - 'bar': [19, None, 25], + "foo": ["a", None, "z"], + "bar": [19, None, 25], }, schema=pa.schema([ pa.field("foo", pa.string(), nullable=True), @@ -1411,11 +1411,11 @@ def test_create_table_transaction(catalog: SqlCatalog, format_version: int) -> N @pytest.mark.parametrize( - 'catalog', + "catalog", [ - lazy_fixture('catalog_memory'), - lazy_fixture('catalog_sqlite'), - lazy_fixture('catalog_sqlite_without_rowcount'), + lazy_fixture("catalog_memory"), + lazy_fixture("catalog_sqlite"), + lazy_fixture("catalog_sqlite_without_rowcount"), ], ) @pytest.mark.parametrize(