From 15348695c4c6eb3c8b9c2b628105874fb0f054fd Mon Sep 17 00:00:00 2001 From: HonahX Date: Tue, 16 Apr 2024 23:46:34 -0700 Subject: [PATCH] support createTableTransaction in hive --- pyiceberg/catalog/__init__.py | 2 +- pyiceberg/catalog/hive.py | 91 +++++++++++++------- tests/integration/test_writes/test_writes.py | 9 +- 3 files changed, 68 insertions(+), 34 deletions(-) diff --git a/pyiceberg/catalog/__init__.py b/pyiceberg/catalog/__init__.py index f104aa94da..f8de9a58d1 100644 --- a/pyiceberg/catalog/__init__.py +++ b/pyiceberg/catalog/__init__.py @@ -719,7 +719,7 @@ def _create_staged_table( metadata = new_table_metadata( location=location, schema=schema, partition_spec=partition_spec, sort_order=sort_order, properties=properties ) - io = load_file_io(properties=self.properties, location=metadata_location) + io = self._load_file_io(properties=properties, location=metadata_location) return StagedTable( identifier=(self.name, database_name, table_name), metadata=metadata, diff --git a/pyiceberg/catalog/hive.py b/pyiceberg/catalog/hive.py index 804b1105cc..bcb99d6a60 100644 --- a/pyiceberg/catalog/hive.py +++ b/pyiceberg/catalog/hive.py @@ -74,8 +74,15 @@ from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec from pyiceberg.schema import Schema, SchemaVisitor, visit from pyiceberg.serializers import FromInputFile -from pyiceberg.table import CommitTableRequest, CommitTableResponse, PropertyUtil, Table, TableProperties, update_table_metadata -from pyiceberg.table.metadata import new_table_metadata +from pyiceberg.table import ( + CommitTableRequest, + CommitTableResponse, + PropertyUtil, + StagedTable, + Table, + TableProperties, + update_table_metadata, +) from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder from pyiceberg.typedef import EMPTY_DICT, Identifier, Properties from pyiceberg.types import ( @@ -266,6 +273,26 @@ def _convert_hive_into_iceberg(self, table: HiveTable, io: FileIO) -> Table: catalog=self, ) + def _convert_iceberg_into_hive(self, table: Table) -> HiveTable: + identifier_tuple = self.identifier_to_tuple_without_catalog(table.identifier) + database_name, table_name = self.identifier_to_database_and_table(identifier_tuple, NoSuchTableError) + current_time_millis = int(time.time() * 1000) + + return HiveTable( + dbName=database_name, + tableName=table_name, + owner=table.properties[OWNER] if table.properties and OWNER in table.properties else getpass.getuser(), + createTime=current_time_millis // 1000, + lastAccessTime=current_time_millis // 1000, + sd=_construct_hive_storage_descriptor( + table.schema(), + table.location(), + PropertyUtil.property_as_bool(self.properties, HIVE2_COMPATIBLE, HIVE2_COMPATIBLE_DEFAULT), + ), + tableType=EXTERNAL_TABLE, + parameters=_construct_parameters(table.metadata_location), + ) + def create_table( self, identifier: Union[str, Identifier], @@ -292,37 +319,19 @@ def create_table( AlreadyExistsError: If a table with the name already exists. ValueError: If the identifier is invalid. """ - schema: Schema = self._convert_schema_if_needed(schema) # type: ignore - properties = {**DEFAULT_PROPERTIES, **properties} - database_name, table_name = self.identifier_to_database_and_table(identifier) - current_time_millis = int(time.time() * 1000) - - location = self._resolve_table_location(location, database_name, table_name) - - metadata_location = self._get_metadata_location(location=location) - metadata = new_table_metadata( - location=location, + staged_table = self._create_staged_table( + identifier=identifier, schema=schema, + location=location, partition_spec=partition_spec, sort_order=sort_order, properties=properties, ) - io = load_file_io({**self.properties, **properties}, location=location) - self._write_metadata(metadata, io, metadata_location) + database_name, table_name = self.identifier_to_database_and_table(identifier) - tbl = HiveTable( - dbName=database_name, - tableName=table_name, - owner=properties[OWNER] if properties and OWNER in properties else getpass.getuser(), - createTime=current_time_millis // 1000, - lastAccessTime=current_time_millis // 1000, - sd=_construct_hive_storage_descriptor( - schema, location, PropertyUtil.property_as_bool(self.properties, HIVE2_COMPATIBLE, HIVE2_COMPATIBLE_DEFAULT) - ), - tableType=EXTERNAL_TABLE, - parameters=_construct_parameters(metadata_location), - ) + self._write_metadata(staged_table.metadata, staged_table.io, staged_table.metadata_location) + tbl = self._convert_iceberg_into_hive(staged_table) try: with self._client as open_client: open_client.create_table(tbl) @@ -330,7 +339,7 @@ def create_table( except AlreadyExistsException as e: raise TableAlreadyExistsError(f"Table {database_name}.{table_name} already exists") from e - return self._convert_hive_into_iceberg(hive_table, io) + return self._convert_hive_into_iceberg(hive_table, staged_table.io) def register_table(self, identifier: Union[str, Identifier], metadata_location: str) -> Table: """Register a new table using existing metadata. @@ -404,8 +413,32 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons metadata_location=new_metadata_location, previous_metadata_location=current_table.metadata_location ) open_client.alter_table(dbname=database_name, tbl_name=table_name, new_tbl=hive_table) - except NoSuchObjectException as e: - raise NoSuchTableError(f"Table does not exist: {table_name}") from e + except NoSuchObjectException: + updated_metadata = update_table_metadata( + base_metadata=self._empty_table_metadata(), updates=table_request.updates, enforce_validation=True + ) + new_metadata_version = 0 + new_metadata_location = self._get_metadata_location(updated_metadata.location, new_metadata_version) + io = self._load_file_io(updated_metadata.properties, new_metadata_location) + self._write_metadata( + updated_metadata, + io, + new_metadata_location, + ) + + tbl = self._convert_iceberg_into_hive( + StagedTable( + identifier=(self.name, database_name, table_name), + metadata=updated_metadata, + metadata_location=new_metadata_location, + io=io, + catalog=self, + ) + ) + try: + open_client.create_table(tbl) + except AlreadyExistsException as e: + raise TableAlreadyExistsError(f"Table {database_name}.{table_name} already exists") from e finally: open_client.unlock(UnlockRequest(lockid=lock.lockid)) diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 775a6f9d42..951932ecef 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -592,7 +592,8 @@ def test_write_and_evolve(session_catalog: Catalog, format_version: int) -> None @pytest.mark.integration @pytest.mark.parametrize("format_version", [2]) -def test_create_table_transaction(session_catalog: Catalog, format_version: int) -> None: +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('session_catalog_hive'), pytest.lazy_fixture('session_catalog')]) +def test_create_table_transaction(catalog: Catalog, format_version: int) -> None: if format_version == 1: pytest.skip( "There is a bug in the REST catalog (maybe server side) that prevents create and commit a staged version 1 table" @@ -601,7 +602,7 @@ def test_create_table_transaction(session_catalog: Catalog, format_version: int) identifier = f"default.arrow_create_table_transaction{format_version}" try: - session_catalog.drop_table(identifier=identifier) + catalog.drop_table(identifier=identifier) except NoSuchTableError: pass @@ -623,7 +624,7 @@ def test_create_table_transaction(session_catalog: Catalog, format_version: int) ]), ) - with session_catalog.create_table_transaction( + with catalog.create_table_transaction( identifier=identifier, schema=pa_table.schema, properties={"format-version": str(format_version)} ) as txn: with txn.update_snapshot().fast_append() as snapshot_update: @@ -639,7 +640,7 @@ def test_create_table_transaction(session_catalog: Catalog, format_version: int) ): snapshot_update.append_data_file(data_file) - tbl = session_catalog.load_table(identifier=identifier) + tbl = catalog.load_table(identifier=identifier) assert tbl.format_version == format_version assert len(tbl.scan().to_arrow()) == 6