diff --git a/pyiceberg/catalog/sql.py b/pyiceberg/catalog/sql.py index 6c198767e7..ff7831d77f 100644 --- a/pyiceberg/catalog/sql.py +++ b/pyiceberg/catalog/sql.py @@ -60,7 +60,7 @@ from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec from pyiceberg.schema import Schema from pyiceberg.serializers import FromInputFile -from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table, update_table_metadata +from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table from pyiceberg.table.metadata import new_table_metadata from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder from pyiceberg.typedef import EMPTY_DICT, Identifier, Properties @@ -402,59 +402,83 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons identifier_tuple = self.identifier_to_tuple_without_catalog( tuple(table_request.identifier.namespace.root + [table_request.identifier.name]) ) - current_table = self.load_table(identifier_tuple) namespace_tuple = Catalog.namespace_from(identifier_tuple) namespace = Catalog.namespace_to_string(namespace_tuple) table_name = Catalog.table_name_from(identifier_tuple) - base_metadata = current_table.metadata - for requirement in table_request.requirements: - requirement.validate(base_metadata) - updated_metadata = update_table_metadata(base_metadata, table_request.updates) - if updated_metadata == base_metadata: - # no changes, do nothing - return CommitTableResponse(metadata=base_metadata, metadata_location=current_table.metadata_location) + current_table: Optional[Table] + try: + current_table = self.load_table(identifier_tuple) + except NoSuchTableError: + current_table = None - # write new metadata - new_metadata_version = self._parse_metadata_version(current_table.metadata_location) + 1 - new_metadata_location = self._get_metadata_location(current_table.metadata.location, new_metadata_version) - self._write_metadata(updated_metadata, current_table.io, new_metadata_location) + updated_staged_table = self._update_and_stage_table(current_table, table_request) + if current_table and updated_staged_table.metadata == current_table.metadata: + # no changes, do nothing + return CommitTableResponse(metadata=current_table.metadata, metadata_location=current_table.metadata_location) + self._write_metadata( + metadata=updated_staged_table.metadata, + io=updated_staged_table.io, + metadata_path=updated_staged_table.metadata_location, + ) with Session(self.engine) as session: - if self.engine.dialect.supports_sane_rowcount: - stmt = ( - update(IcebergTables) - .where( - IcebergTables.catalog_name == self.name, - IcebergTables.table_namespace == namespace, - IcebergTables.table_name == table_name, - IcebergTables.metadata_location == current_table.metadata_location, - ) - .values(metadata_location=new_metadata_location, previous_metadata_location=current_table.metadata_location) - ) - result = session.execute(stmt) - if result.rowcount < 1: - raise CommitFailedException(f"Table has been updated by another process: {namespace}.{table_name}") - else: - try: - tbl = ( - session.query(IcebergTables) - .with_for_update(of=IcebergTables) - .filter( + if current_table: + # table exists, update it + if self.engine.dialect.supports_sane_rowcount: + stmt = ( + update(IcebergTables) + .where( IcebergTables.catalog_name == self.name, IcebergTables.table_namespace == namespace, IcebergTables.table_name == table_name, IcebergTables.metadata_location == current_table.metadata_location, ) - .one() + .values( + metadata_location=updated_staged_table.metadata_location, + previous_metadata_location=current_table.metadata_location, + ) ) - tbl.metadata_location = new_metadata_location - tbl.previous_metadata_location = current_table.metadata_location - except NoResultFound as e: - raise CommitFailedException(f"Table has been updated by another process: {namespace}.{table_name}") from e - session.commit() + result = session.execute(stmt) + if result.rowcount < 1: + raise CommitFailedException(f"Table has been updated by another process: {namespace}.{table_name}") + else: + try: + tbl = ( + session.query(IcebergTables) + .with_for_update(of=IcebergTables) + .filter( + IcebergTables.catalog_name == self.name, + IcebergTables.table_namespace == namespace, + IcebergTables.table_name == table_name, + IcebergTables.metadata_location == current_table.metadata_location, + ) + .one() + ) + tbl.metadata_location = updated_staged_table.metadata_location + tbl.previous_metadata_location = current_table.metadata_location + except NoResultFound as e: + raise CommitFailedException(f"Table has been updated by another process: {namespace}.{table_name}") from e + session.commit() + else: + # table does not exist, create it + try: + session.add( + IcebergTables( + catalog_name=self.name, + table_namespace=namespace, + table_name=table_name, + metadata_location=updated_staged_table.metadata_location, + previous_metadata_location=None, + ) + ) + session.commit() + except IntegrityError as e: + raise TableAlreadyExistsError(f"Table {namespace}.{table_name} already exists") from e - return CommitTableResponse(metadata=updated_metadata, metadata_location=new_metadata_location) + return CommitTableResponse( + metadata=updated_staged_table.metadata, metadata_location=updated_staged_table.metadata_location + ) def _namespace_exists(self, identifier: Union[str, Identifier]) -> bool: namespace_tuple = Catalog.identifier_to_tuple(identifier) diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py index 6dc498233e..545916223a 100644 --- a/tests/catalog/test_sql.py +++ b/tests/catalog/test_sql.py @@ -1350,6 +1350,66 @@ def test_write_and_evolve(catalog: SqlCatalog, format_version: int) -> None: snapshot_update.append_data_file(data_file) +@pytest.mark.parametrize( + "catalog", + [ + lazy_fixture("catalog_memory"), + lazy_fixture("catalog_sqlite"), + lazy_fixture("catalog_sqlite_without_rowcount"), + ], +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_create_table_transaction(catalog: SqlCatalog, format_version: int) -> None: + identifier = f"default.arrow_create_table_transaction_{catalog.name}_{format_version}" + try: + catalog.create_namespace("default") + except NamespaceAlreadyExistsError: + pass + + try: + catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + pa_table = pa.Table.from_pydict( + { + "foo": ["a", None, "z"], + }, + schema=pa.schema([pa.field("foo", pa.string(), nullable=True)]), + ) + + pa_table_with_column = pa.Table.from_pydict( + { + "foo": ["a", None, "z"], + "bar": [19, None, 25], + }, + schema=pa.schema([ + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=True), + ]), + ) + + with catalog.create_table_transaction( + identifier=identifier, schema=pa_table.schema, properties={"format-version": str(format_version)} + ) as txn: + with txn.update_snapshot().fast_append() as snapshot_update: + for data_file in _dataframe_to_data_files(table_metadata=txn.table_metadata, df=pa_table, io=txn._table.io): + snapshot_update.append_data_file(data_file) + + with txn.update_schema() as schema_txn: + schema_txn.union_by_name(pa_table_with_column.schema) + + with txn.update_snapshot().fast_append() as snapshot_update: + for data_file in _dataframe_to_data_files( + table_metadata=txn.table_metadata, df=pa_table_with_column, io=txn._table.io + ): + snapshot_update.append_data_file(data_file) + + tbl = catalog.load_table(identifier=identifier) + assert tbl.format_version == format_version + assert len(tbl.scan().to_arrow()) == 6 + + @pytest.mark.parametrize( "catalog", [