From 12b94589177dcf10113af2994eb63c3c344c8ce6 Mon Sep 17 00:00:00 2001 From: HonahX Date: Wed, 24 Apr 2024 01:02:28 -0700 Subject: [PATCH] support CreateTableTransaction for SQL --- pyiceberg/catalog/sql.py | 104 +++++++++++++++++++++++++------------- tests/catalog/test_sql.py | 60 ++++++++++++++++++++++ 2 files changed, 129 insertions(+), 35 deletions(-) diff --git a/pyiceberg/catalog/sql.py b/pyiceberg/catalog/sql.py index 978109b2a3..18d66c6017 100644 --- a/pyiceberg/catalog/sql.py +++ b/pyiceberg/catalog/sql.py @@ -376,55 +376,89 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons identifier_tuple = self.identifier_to_tuple_without_catalog( tuple(table_request.identifier.namespace.root + [table_request.identifier.name]) ) - current_table = self.load_table(identifier_tuple) database_name, table_name = self.identifier_to_database_and_table(identifier_tuple, NoSuchTableError) - base_metadata = current_table.metadata + + current_table: Optional[Table] + try: + current_table = self.load_table(identifier_tuple) + except NoSuchTableError: + current_table = None + for requirement in table_request.requirements: - requirement.validate(base_metadata) + requirement.validate(current_table.metadata if current_table else None) - updated_metadata = update_table_metadata(base_metadata, table_request.updates) - if updated_metadata == base_metadata: + updated_metadata = update_table_metadata( + base_metadata=current_table.metadata if current_table else self._empty_table_metadata(), + updates=table_request.updates, + enforce_validation=current_table is None, + ) + if current_table and updated_metadata == current_table.metadata: # no changes, do nothing - return CommitTableResponse(metadata=base_metadata, metadata_location=current_table.metadata_location) + return CommitTableResponse(metadata=current_table.metadata, metadata_location=current_table.metadata_location) # write new metadata - new_metadata_version = self._parse_metadata_version(current_table.metadata_location) + 1 - new_metadata_location = self._get_metadata_location(current_table.metadata.location, new_metadata_version) - self._write_metadata(updated_metadata, current_table.io, new_metadata_location) + new_metadata_version = self._parse_metadata_version(current_table.metadata_location) + 1 if current_table else 0 + new_metadata_location = self._get_metadata_location(updated_metadata.location, new_metadata_version) + self._write_metadata( + metadata=updated_metadata, + io=self._load_file_io(updated_metadata.properties, new_metadata_location), + metadata_path=new_metadata_location, + ) with Session(self.engine) as session: - if self.engine.dialect.supports_sane_rowcount: - stmt = ( - update(IcebergTables) - .where( - IcebergTables.catalog_name == self.name, - IcebergTables.table_namespace == database_name, - IcebergTables.table_name == table_name, - IcebergTables.metadata_location == current_table.metadata_location, - ) - .values(metadata_location=new_metadata_location, previous_metadata_location=current_table.metadata_location) - ) - result = session.execute(stmt) - if result.rowcount < 1: - raise CommitFailedException(f"Table has been updated by another process: {database_name}.{table_name}") - else: - try: - tbl = ( - session.query(IcebergTables) - .with_for_update(of=IcebergTables) - .filter( + if current_table: + # table exists, update it + if self.engine.dialect.supports_sane_rowcount: + stmt = ( + update(IcebergTables) + .where( IcebergTables.catalog_name == self.name, IcebergTables.table_namespace == database_name, IcebergTables.table_name == table_name, IcebergTables.metadata_location == current_table.metadata_location, ) - .one() + .values( + metadata_location=new_metadata_location, previous_metadata_location=current_table.metadata_location + ) ) - tbl.metadata_location = new_metadata_location - tbl.previous_metadata_location = current_table.metadata_location - except NoResultFound as e: - raise CommitFailedException(f"Table has been updated by another process: {database_name}.{table_name}") from e - session.commit() + result = session.execute(stmt) + if result.rowcount < 1: + raise CommitFailedException(f"Table has been updated by another process: {database_name}.{table_name}") + else: + try: + tbl = ( + session.query(IcebergTables) + .with_for_update(of=IcebergTables) + .filter( + IcebergTables.catalog_name == self.name, + IcebergTables.table_namespace == database_name, + IcebergTables.table_name == table_name, + IcebergTables.metadata_location == current_table.metadata_location, + ) + .one() + ) + tbl.metadata_location = new_metadata_location + tbl.previous_metadata_location = current_table.metadata_location + except NoResultFound as e: + raise CommitFailedException( + f"Table has been updated by another process: {database_name}.{table_name}" + ) from e + session.commit() + else: + # table does not exist, create it + try: + session.add( + IcebergTables( + catalog_name=self.name, + table_namespace=database_name, + table_name=table_name, + metadata_location=new_metadata_location, + previous_metadata_location=None, + ) + ) + session.commit() + except IntegrityError as e: + raise TableAlreadyExistsError(f"Table {database_name}.{table_name} already exists") from e return CommitTableResponse(metadata=updated_metadata, metadata_location=new_metadata_location) diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py index 99b8550602..9f2f370413 100644 --- a/tests/catalog/test_sql.py +++ b/tests/catalog/test_sql.py @@ -948,6 +948,66 @@ def test_write_and_evolve(catalog: SqlCatalog, format_version: int) -> None: snapshot_update.append_data_file(data_file) +@pytest.mark.parametrize( + 'catalog', + [ + lazy_fixture('catalog_memory'), + lazy_fixture('catalog_sqlite'), + lazy_fixture('catalog_sqlite_without_rowcount'), + ], +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_create_table_transaction(catalog: SqlCatalog, format_version: int) -> None: + identifier = f"default.arrow_create_table_transaction_{catalog.name}_{format_version}" + try: + catalog.create_namespace("default") + except NamespaceAlreadyExistsError: + pass + + try: + catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + pa_table = pa.Table.from_pydict( + { + 'foo': ['a', None, 'z'], + }, + schema=pa.schema([pa.field("foo", pa.string(), nullable=True)]), + ) + + pa_table_with_column = pa.Table.from_pydict( + { + 'foo': ['a', None, 'z'], + 'bar': [19, None, 25], + }, + schema=pa.schema([ + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=True), + ]), + ) + + with catalog.create_table_transaction( + identifier=identifier, schema=pa_table.schema, properties={"format-version": str(format_version)} + ) as txn: + with txn.update_snapshot().fast_append() as snapshot_update: + for data_file in _dataframe_to_data_files(table_metadata=txn.table_metadata, df=pa_table, io=txn._table.io): + snapshot_update.append_data_file(data_file) + + with txn.update_schema() as schema_txn: + schema_txn.union_by_name(pa_table_with_column.schema) + + with txn.update_snapshot().fast_append() as snapshot_update: + for data_file in _dataframe_to_data_files( + table_metadata=txn.table_metadata, df=pa_table_with_column, io=txn._table.io + ): + snapshot_update.append_data_file(data_file) + + tbl = catalog.load_table(identifier=identifier) + assert tbl.format_version == format_version + assert len(tbl.scan().to_arrow()) == 6 + + @pytest.mark.parametrize( 'catalog', [