diff --git a/pyiceberg/catalog/__init__.py b/pyiceberg/catalog/__init__.py index db83658f1f..4e1f80910d 100644 --- a/pyiceberg/catalog/__init__.py +++ b/pyiceberg/catalog/__init__.py @@ -45,6 +45,8 @@ from pyiceberg.table import ( CommitTableRequest, CommitTableResponse, + CreateTableTransaction, + StagedTable, Table, TableMetadata, ) @@ -288,6 +290,62 @@ def __init__(self, name: str, **properties: str): def _load_file_io(self, properties: Properties = EMPTY_DICT, location: Optional[str] = None) -> FileIO: return load_file_io({**self.properties, **properties}, location) + def _create_staged_table( + self, + identifier: Union[str, Identifier], + schema: Union[Schema, "pa.Schema"], + location: Optional[str] = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: SortOrder = UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> StagedTable: + """Create a table and return the table instance without committing the changes. + + Args: + identifier (str | Identifier): Table identifier. + schema (Schema): Table's schema. + location (str | None): Location for the table. Optional Argument. + partition_spec (PartitionSpec): PartitionSpec for the table. + sort_order (SortOrder): SortOrder for the table. + properties (Properties): Table properties that can be a string based dictionary. + + Returns: + Table: the created table instance. + + Raises: + TableAlreadyExistsError: If a table with the name already exists. + """ + raise NotImplementedError + + def create_table_transaction( + self, + identifier: Union[str, Identifier], + schema: Union[Schema, "pa.Schema"], + location: Optional[str] = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: SortOrder = UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> CreateTableTransaction: + """Create a CreateTableTransaction. + + Args: + identifier (str | Identifier): Table identifier. + schema (Schema): Table's schema. + location (str | None): Location for the table. Optional Argument. + partition_spec (PartitionSpec): PartitionSpec for the table. + sort_order (SortOrder): SortOrder for the table. + properties (Properties): Table properties that can be a string based dictionary. + + Returns: + CreateTableTransaction: createTableTransaction instance. + + Raises: + TableAlreadyExistsError: If a table with the name already exists. + """ + return CreateTableTransaction( + self._create_staged_table(identifier, schema, location, partition_spec, sort_order, properties) + ) + @abstractmethod def create_table( self, diff --git a/pyiceberg/catalog/glue.py b/pyiceberg/catalog/glue.py index 089a30ba61..06be06f397 100644 --- a/pyiceberg/catalog/glue.py +++ b/pyiceberg/catalog/glue.py @@ -64,7 +64,14 @@ from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec from pyiceberg.schema import Schema, SchemaVisitor, visit from pyiceberg.serializers import FromInputFile -from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table, update_table_metadata +from pyiceberg.table import ( + CommitTableRequest, + CommitTableResponse, + StagedTable, + Table, + construct_initial_table_metadata, + update_table_metadata, +) from pyiceberg.table.metadata import TableMetadata, new_table_metadata from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder from pyiceberg.typedef import EMPTY_DICT @@ -325,12 +332,39 @@ def _update_glue_table(self, database_name: str, table_name: str, table_input: T f"Cannot commit {database_name}.{table_name} because Glue detected concurrent update to table version {version_id}" ) from e - def _get_glue_table(self, database_name: str, table_name: str) -> TableTypeDef: + def _get_glue_table(self, database_name: str, table_name: str) -> Optional[TableTypeDef]: try: load_table_response = self.glue.get_table(DatabaseName=database_name, Name=table_name) return load_table_response["Table"] - except self.glue.exceptions.EntityNotFoundException as e: - raise NoSuchTableError(f"Table does not exist: {database_name}.{table_name}") from e + except self.glue.exceptions.EntityNotFoundException: + return None + + def _create_staged_table( + self, + identifier: Union[str, Identifier], + schema: Union[Schema, "pa.Schema"], + location: Optional[str] = None, + partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, + sort_order: SortOrder = UNSORTED_SORT_ORDER, + properties: Properties = EMPTY_DICT, + ) -> StagedTable: + schema: Schema = self._convert_schema_if_needed(schema) # type: ignore + + database_name, table_name = self.identifier_to_database_and_table(identifier) + + location = self._resolve_table_location(location, database_name, table_name) + metadata_location = self._get_metadata_location(location=location) + metadata = new_table_metadata( + location=location, schema=schema, partition_spec=partition_spec, sort_order=sort_order, properties=properties + ) + io = load_file_io(properties=self.properties, location=metadata_location) + return StagedTable( + identifier=(self.name, database_name, table_name), + metadata=metadata, + metadata_location=metadata_location, + io=io, + catalog=self, + ) def create_table( self, @@ -412,45 +446,68 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons database_name, table_name = self.identifier_to_database_and_table(identifier_tuple) current_glue_table = self._get_glue_table(database_name=database_name, table_name=table_name) - glue_table_version_id = current_glue_table.get("VersionId") - if not glue_table_version_id: - raise CommitFailedException(f"Cannot commit {database_name}.{table_name} because Glue table version id is missing") - current_table = self._convert_glue_to_iceberg(glue_table=current_glue_table) - base_metadata = current_table.metadata - - # Validate the update requirements - for requirement in table_request.requirements: - requirement.validate(base_metadata) - - updated_metadata = update_table_metadata(base_metadata, table_request.updates) - if updated_metadata == base_metadata: - # no changes, do nothing - return CommitTableResponse(metadata=base_metadata, metadata_location=current_table.metadata_location) - - # write new metadata - new_metadata_version = self._parse_metadata_version(current_table.metadata_location) + 1 - new_metadata_location = self._get_metadata_location(current_table.metadata.location, new_metadata_version) - self._write_metadata(updated_metadata, current_table.io, new_metadata_location) - - update_table_input = _construct_table_input( - table_name=table_name, - metadata_location=new_metadata_location, - properties=current_table.properties, - metadata=updated_metadata, - glue_table=current_glue_table, - prev_metadata_location=current_table.metadata_location, - ) + if current_glue_table is not None: + # Update the table + glue_table_version_id = current_glue_table.get("VersionId") + if not glue_table_version_id: + raise CommitFailedException( + f"Cannot commit {database_name}.{table_name} because Glue table version id is missing" + ) + current_table = self._convert_glue_to_iceberg(glue_table=current_glue_table) + base_metadata = current_table.metadata + + # Validate the update requirements + for requirement in table_request.requirements: + requirement.validate(base_metadata) + + updated_metadata = update_table_metadata(base_metadata, table_request.updates) + if updated_metadata == base_metadata: + # no changes, do nothing + return CommitTableResponse(metadata=base_metadata, metadata_location=current_table.metadata_location) + + # write new metadata + new_metadata_version = self._parse_metadata_version(current_table.metadata_location) + 1 + new_metadata_location = self._get_metadata_location(current_table.metadata.location, new_metadata_version) + self._write_metadata(updated_metadata, current_table.io, new_metadata_location) + + update_table_input = _construct_table_input( + table_name=table_name, + metadata_location=new_metadata_location, + properties=current_table.properties, + metadata=updated_metadata, + glue_table=current_glue_table, + prev_metadata_location=current_table.metadata_location, + ) - # Pass `version_id` to implement optimistic locking: it ensures updates are rejected if concurrent - # modifications occur. See more details at https://iceberg.apache.org/docs/latest/aws/#optimistic-locking - self._update_glue_table( - database_name=database_name, - table_name=table_name, - table_input=update_table_input, - version_id=glue_table_version_id, - ) + # Pass `version_id` to implement optimistic locking: it ensures updates are rejected if concurrent + # modifications occur. See more details at https://iceberg.apache.org/docs/latest/aws/#optimistic-locking + self._update_glue_table( + database_name=database_name, + table_name=table_name, + table_input=update_table_input, + version_id=glue_table_version_id, + ) + + return CommitTableResponse(metadata=updated_metadata, metadata_location=new_metadata_location) + else: + # Create the table + updated_metadata = construct_initial_table_metadata(table_request.updates) + new_metadata_version = 0 + new_metadata_location = self._get_metadata_location(updated_metadata.location, new_metadata_version) + self._write_metadata( + updated_metadata, self._load_file_io(updated_metadata.properties, new_metadata_location), new_metadata_location + ) + + create_table_input = _construct_table_input( + table_name=table_name, + metadata_location=new_metadata_location, + properties=updated_metadata.properties, + metadata=updated_metadata, + ) + + self._create_glue_table(database_name=database_name, table_name=table_name, table_input=create_table_input) - return CommitTableResponse(metadata=updated_metadata, metadata_location=new_metadata_location) + return CommitTableResponse(metadata=updated_metadata, metadata_location=new_metadata_location) def load_table(self, identifier: Union[str, Identifier]) -> Table: """Load the table's metadata and returns the table instance. diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 1a4183c914..7808af2e8d 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -24,7 +24,7 @@ from copy import copy from dataclasses import dataclass from enum import Enum -from functools import cached_property, singledispatch +from functools import cached_property, singledispatch, singledispatchmethod from itertools import chain from typing import ( TYPE_CHECKING, @@ -72,6 +72,7 @@ from pyiceberg.partitioning import ( INITIAL_PARTITION_SPEC_ID, PARTITION_FIELD_ID_START, + UNPARTITIONED_PARTITION_SPEC, IdentityTransform, PartitionField, PartitionSpec, @@ -93,6 +94,8 @@ SUPPORTED_TABLE_FORMAT_VERSION, TableMetadata, TableMetadataUtil, + TableMetadataV1, + TableMetadataV2, ) from pyiceberg.table.name_mapping import ( NameMapping, @@ -108,7 +111,7 @@ Summary, update_snapshot_summaries, ) -from pyiceberg.table.sorting import SortOrder +from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder from pyiceberg.transforms import TimeTransform, Transform, VoidTransform from pyiceberg.typedef import ( EMPTY_DICT, @@ -138,7 +141,6 @@ from pyiceberg.catalog import Catalog - ALWAYS_TRUE = AlwaysTrue() TABLE_ROOT_ID = -1 @@ -373,7 +375,45 @@ def commit_transaction(self) -> Table: return self._table +class CreateTableTransaction(Transaction): + @staticmethod + def create_changes(table_metadata: TableMetadata) -> Tuple[TableUpdate, ...]: + changes = [ + AssignUUIDUpdate(uuid=table_metadata.table_uuid), + UpgradeFormatVersionUpdate(format_version=table_metadata.format_version), + ] + + schema: Schema = table_metadata.schema() + changes.append(AddSchemaUpdate(schema_=schema, last_column_id=schema.highest_field_id)) + changes.append(SetCurrentSchemaUpdate(schema_id=-1)) + + spec: PartitionSpec = table_metadata.spec() + if spec.is_unpartitioned(): + changes.append(AddPartitionSpecUpdate(spec=UNPARTITIONED_PARTITION_SPEC)) + else: + changes.append(AddPartitionSpecUpdate(spec=spec)) + changes.append(SetDefaultSpecUpdate(spec_id=-1)) + + sort_order: Optional[SortOrder] = table_metadata.sort_order_by_id(table_metadata.default_sort_order_id) + if sort_order is None or sort_order.is_unsorted: + changes.append(AddSortOrderUpdate(sort_order=UNSORTED_SORT_ORDER)) + else: + changes.append(AddSortOrderUpdate(sort_order=sort_order)) + changes.append(SetDefaultSortOrderUpdate(sort_order_id=-1)) + + changes.append(SetLocationUpdate(location=table_metadata.location)) + changes.append(SetPropertiesUpdate(updates=table_metadata.properties)) + + return tuple(changes) + + def __init__(self, table: StagedTable): + super().__init__(table, autocommit=False) + self._requirements = (AssertCreate(),) + self._updates = self.create_changes(table.metadata) + + class TableUpdateAction(Enum): + assign_uuid = "assign-uuid" upgrade_format_version = "upgrade-format-version" add_schema = "add-schema" set_current_schema = "set-current-schema" @@ -394,6 +434,11 @@ class TableUpdate(IcebergBaseModel): action: TableUpdateAction +class AssignUUIDUpdate(TableUpdate): + action: TableUpdateAction = TableUpdateAction.assign_uuid + uuid: uuid.UUID + + class UpgradeFormatVersionUpdate(TableUpdate): action: TableUpdateAction = TableUpdateAction.upgrade_format_version format_version: int = Field(alias="format-version") @@ -522,6 +567,15 @@ def _apply_table_update(update: TableUpdate, base_metadata: TableMetadata, conte raise NotImplementedError(f"Unsupported table update: {update}") +@_apply_table_update.register(AssignUUIDUpdate) +def _(update: AssignUUIDUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: + if update.uuid == base_metadata.table_uuid: + return base_metadata + + context.add_update(update) + return base_metadata.model_copy(update={"table_uuid": update.uuid}) + + @_apply_table_update.register(UpgradeFormatVersionUpdate) def _(update: UpgradeFormatVersionUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: if update.format_version > SUPPORTED_TABLE_FORMAT_VERSION: @@ -753,6 +807,143 @@ def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpda return new_metadata.model_copy(deep=True) +class InitialTableMetadataConstructor: + table_uuid: uuid.UUID + format_version: int + schema: Schema + current_schema_id: int + spec: PartitionSpec + default_spec_id: int + sort_order: SortOrder + default_sort_order_id: int + location: str + properties: Dict[str, str] + + @singledispatchmethod + def apply_table_update(self, update: TableUpdate) -> None: + raise NotImplementedError(f"Table Update {update} should not be part of initial table metadata construction") + + @apply_table_update.register(AssignUUIDUpdate) + def _(self, update: AssignUUIDUpdate) -> None: + self.table_uuid = update.uuid + + @apply_table_update.register(UpgradeFormatVersionUpdate) + def _(self, update: UpgradeFormatVersionUpdate) -> None: + self.format_version = update.format_version + + @apply_table_update.register(AddSchemaUpdate) + def _(self, update: AddSchemaUpdate) -> None: + self.schema = update.schema_ + + @apply_table_update.register(SetCurrentSchemaUpdate) + def _(self, update: SetCurrentSchemaUpdate) -> None: + if update.schema_id == -1: + if self.schema is None: + raise ValueError("No schema has been added") + self.current_schema_id = self.schema.schema_id + else: + self.current_schema_id = update.schema_id + + @apply_table_update.register(AddPartitionSpecUpdate) + def _(self, update: AddPartitionSpecUpdate) -> None: + self.spec = update.spec + + @apply_table_update.register(SetDefaultSpecUpdate) + def _(self, update: SetDefaultSpecUpdate) -> None: + if update.spec_id == -1: + if self.spec is None: + raise ValueError("No partition spec has been added") + self.default_spec_id = self.spec.spec_id + else: + self.default_spec_id = update.spec_id + + @apply_table_update.register(AddSortOrderUpdate) + def _(self, update: AddSortOrderUpdate) -> None: + self.sort_order = update.sort_order + + @apply_table_update.register(SetDefaultSortOrderUpdate) + def _(self, update: SetDefaultSortOrderUpdate) -> None: + if update.sort_order_id == -1: + if self.sort_order is None: + raise ValueError("No sort order has been added") + self.default_sort_order_id = self.sort_order.order_id + else: + self.default_sort_order_id = update.sort_order_id + + @apply_table_update.register(SetLocationUpdate) + def _(self, update: SetLocationUpdate) -> None: + self.location = update.location + + @apply_table_update.register(SetPropertiesUpdate) + def _(self, update: SetPropertiesUpdate) -> None: + self.properties = update.updates + + def ready_to_construct(self) -> bool: + """Return true if all fields are set. + + Note fields may not be able to get from getattr if not set + """ + return all( + hasattr(self, field) and getattr(self, field) is not None + for field in [ + "table_uuid", + "format_version", + "schema", + "current_schema_id", + "spec", + "default_spec_id", + "sort_order", + "default_sort_order_id", + "location", + "properties", + ] + ) + + def construct_initial_metadata(self) -> TableMetadata: + if self.format_version == 1: + return TableMetadataV1( + location=self.location, + last_column_id=self.schema.highest_field_id, + current_schema_id=self.current_schema_id, + schema=self.schema, + partition_spec=[field.model_dump() for field in self.spec.fields], + partition_specs=[self.spec], + default_spec_id=self.default_spec_id, + sort_order=[self.sort_order], + default_sort_order_id=self.default_sort_order_id, + properties=self.properties, + last_partition_id=self.spec.last_assigned_field_id, + table_uuid=self.table_uuid, + ) + + return TableMetadataV2( + location=self.location, + schemas=[self.schema], + last_column_id=self.schema.highest_field_id, + current_schema_id=self.current_schema_id, + partition_specs=[self.spec], + default_spec_id=self.default_spec_id, + sort_orders=[self.sort_order], + default_sort_order_id=self.default_sort_order_id, + properties=self.properties, + last_partition_id=self.spec.last_assigned_field_id, + table_uuid=self.table_uuid, + ) + + +def construct_initial_table_metadata(updates: Tuple[TableUpdate, ...]) -> TableMetadata: + initial_create_update_index = 0 + initial_metadata_constructor = InitialTableMetadataConstructor() + for update in updates: + initial_metadata_constructor.apply_table_update(update) + initial_create_update_index += 1 + if initial_metadata_constructor.ready_to_construct(): + break + base_metadata = initial_metadata_constructor.construct_initial_metadata() + updated_metadata = update_table_metadata(base_metadata, updates[initial_create_update_index:]) + return updated_metadata + + class TableRequirement(IcebergBaseModel): type: str @@ -1211,6 +1402,25 @@ def from_metadata(cls, metadata_location: str, properties: Properties = EMPTY_DI ) +class StagedTable(Table): + def refresh(self) -> Table: + raise ValueError("Cannot refresh a staged table") + + def scan( + self, + row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE, + selected_fields: Tuple[str, ...] = ("*",), + case_sensitive: bool = True, + snapshot_id: Optional[int] = None, + options: Properties = EMPTY_DICT, + limit: Optional[int] = None, + ) -> DataScan: + raise ValueError("Cannot scan a staged table") + + def to_daft(self) -> daft.DataFrame: + raise ValueError("Cannot convert a staged table to a Daft DataFrame") + + def _parse_row_filter(expr: Union[str, BooleanExpression]) -> BooleanExpression: """Accept an expression in the form of a BooleanExpression or a string. @@ -1654,7 +1864,8 @@ def union_by_name(self, new_schema: Union[Schema, "pa.Schema"]) -> UpdateSchema: visit_with_partner( Catalog._convert_schema_if_needed(new_schema), -1, - UnionByNameVisitor(update_schema=self, existing_schema=self._schema, case_sensitive=self._case_sensitive), # type: ignore + UnionByNameVisitor(update_schema=self, existing_schema=self._schema, case_sensitive=self._case_sensitive), + # type: ignore PartnerIdByNameAccessor(partner_schema=self._schema, case_sensitive=self._case_sensitive), ) return self diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index 931b0cfe0a..dd56776189 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -449,6 +449,8 @@ def construct_refs(cls, table_metadata: TableMetadata) -> TableMetadata: TableMetadata = Annotated[Union[TableMetadataV1, TableMetadataV2], Field(discriminator="format_version")] +EMPTY_METADATA = TableMetadataV1(location="", schema=Schema(), last_column_id=-1, partition_spec=[]) + def new_table_metadata( schema: Schema, diff --git a/tests/catalog/test_glue.py b/tests/catalog/test_glue.py index 6e0196c1a2..5f633c30f2 100644 --- a/tests/catalog/test_glue.py +++ b/tests/catalog/test_glue.py @@ -671,3 +671,31 @@ def test_commit_table_properties( updated_table_metadata = table.metadata assert test_catalog._parse_metadata_version(table.metadata_location) == 1 assert updated_table_metadata.properties == {"test_a": "test_aa", "test_c": "test_c"} + + +@mock_aws +def test_create_table_transaction( + _glue: boto3.client, + _bucket_initialize: None, + moto_endpoint_url: str, + table_schema_nested: Schema, + database_name: str, + table_name: str, +) -> None: + catalog_name = "glue" + identifier = (database_name, table_name) + test_catalog = GlueCatalog(catalog_name, **{"s3.endpoint": moto_endpoint_url, "warehouse": f"s3://{BUCKET_NAME}"}) + test_catalog.create_namespace(namespace=database_name) + + with test_catalog.create_table_transaction(identifier, table_schema_nested) as txn: + with txn.update_schema() as update_schema: + update_schema.add_column(path="b", field_type=IntegerType()) + + txn.set_properties(test_a="test_aa", test_b="test_b", test_c="test_c") + + table = test_catalog.load_table(identifier) + + assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location) + assert test_catalog._parse_metadata_version(table.metadata_location) == 0 + assert table.schema().find_field("b").field_type == IntegerType() + assert table.properties == {"test_a": "test_aa", "test_b": "test_b", "test_c": "test_c"}