Skip to content

Commit

Permalink
Remove initial_change when CreateTableTransaction apply table updat…
Browse files Browse the repository at this point in the history
…es on an empty metadata (apache#1219)

* make table metadata without validaiton

* update deletes test

* remove info

* add deprecation message

* revert lib version updates

* remove initial_changes usage in code

* move test to integration

* fix typo

* update error string
  • Loading branch information
HonahX authored and sungwy committed Dec 7, 2024
1 parent df8f516 commit 8526880
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 21 deletions.
2 changes: 1 addition & 1 deletion pyiceberg/catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,4 +1011,4 @@ def _empty_table_metadata() -> TableMetadata:
Returns:
TableMetadata: An empty TableMetadata instance.
"""
return TableMetadataV1(location="", last_column_id=-1, schema=Schema())
return TableMetadataV1.model_construct(last_column_id=-1, schema=Schema())
10 changes: 5 additions & 5 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,22 +703,22 @@ def _initial_changes(self, table_metadata: TableMetadata) -> None:

schema: Schema = table_metadata.schema()
self._updates += (
AddSchemaUpdate(schema_=schema, last_column_id=schema.highest_field_id, initial_change=True),
AddSchemaUpdate(schema_=schema, last_column_id=schema.highest_field_id),
SetCurrentSchemaUpdate(schema_id=-1),
)

spec: PartitionSpec = table_metadata.spec()
if spec.is_unpartitioned():
self._updates += (AddPartitionSpecUpdate(spec=UNPARTITIONED_PARTITION_SPEC, initial_change=True),)
self._updates += (AddPartitionSpecUpdate(spec=UNPARTITIONED_PARTITION_SPEC),)
else:
self._updates += (AddPartitionSpecUpdate(spec=spec, initial_change=True),)
self._updates += (AddPartitionSpecUpdate(spec=spec),)
self._updates += (SetDefaultSpecUpdate(spec_id=-1),)

sort_order: Optional[SortOrder] = table_metadata.sort_order_by_id(table_metadata.default_sort_order_id)
if sort_order is None or sort_order.is_unsorted:
self._updates += (AddSortOrderUpdate(sort_order=UNSORTED_SORT_ORDER, initial_change=True),)
self._updates += (AddSortOrderUpdate(sort_order=UNSORTED_SORT_ORDER),)
else:
self._updates += (AddSortOrderUpdate(sort_order=sort_order, initial_change=True),)
self._updates += (AddSortOrderUpdate(sort_order=sort_order),)
self._updates += (SetDefaultSortOrderUpdate(sort_order_id=-1),)

self._updates += (
Expand Down
16 changes: 16 additions & 0 deletions pyiceberg/table/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,5 +587,21 @@ def parse_obj(data: Dict[str, Any]) -> TableMetadata:
else:
raise ValidationError(f"Unknown format version: {format_version}")

@staticmethod
def _construct_without_validation(table_metadata: TableMetadata) -> TableMetadata:
"""Construct table metadata from an existing table without performing validation.
This method is useful during a sequence of table updates when the model needs to be re-constructed but is not yet ready for validation.
"""
if table_metadata.format_version is None:
raise ValidationError(f"Missing format-version in TableMetadata: {table_metadata}")

if table_metadata.format_version == 1:
return TableMetadataV1.model_construct(**dict(table_metadata))
elif table_metadata.format_version == 2:
return TableMetadataV2.model_construct(**dict(table_metadata))
else:
raise ValidationError(f"Unknown format version: {table_metadata.format_version}")


TableMetadata = Annotated[Union[TableMetadataV1, TableMetadataV2], Field(discriminator="format_version")] # type: ignore
39 changes: 28 additions & 11 deletions pyiceberg/table/update/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import uuid
from abc import ABC, abstractmethod
from copy import copy
from datetime import datetime
from functools import singledispatch
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Literal, Optional, Tuple, TypeVar, Union
Expand All @@ -45,6 +44,7 @@
transform_dict_value_to_str,
)
from pyiceberg.utils.datetime import datetime_to_millis
from pyiceberg.utils.deprecated import deprecation_notice
from pyiceberg.utils.properties import property_as_int

if TYPE_CHECKING:
Expand Down Expand Up @@ -90,7 +90,13 @@ class AddSchemaUpdate(IcebergBaseModel):
# This field is required: https://github.com/apache/iceberg/pull/7445
last_column_id: int = Field(alias="last-column-id")

initial_change: bool = Field(default=False, exclude=True)
initial_change: bool = Field(
default=False,
exclude=True,
deprecated=deprecation_notice(
deprecated_in="0.8.0", removed_in="0.9.0", help_message="CreateTableTransaction can work without this field"
),
)


class SetCurrentSchemaUpdate(IcebergBaseModel):
Expand All @@ -104,7 +110,13 @@ class AddPartitionSpecUpdate(IcebergBaseModel):
action: Literal["add-spec"] = Field(default="add-spec")
spec: PartitionSpec

initial_change: bool = Field(default=False, exclude=True)
initial_change: bool = Field(
default=False,
exclude=True,
deprecated=deprecation_notice(
deprecated_in="0.8.0", removed_in="0.9.0", help_message="CreateTableTransaction can work without this field"
),
)


class SetDefaultSpecUpdate(IcebergBaseModel):
Expand All @@ -118,7 +130,13 @@ class AddSortOrderUpdate(IcebergBaseModel):
action: Literal["add-sort-order"] = Field(default="add-sort-order")
sort_order: SortOrder = Field(alias="sort-order")

initial_change: bool = Field(default=False, exclude=True)
initial_change: bool = Field(
default=False,
exclude=True,
deprecated=deprecation_notice(
deprecated_in="0.8.0", removed_in="0.9.0", help_message="CreateTableTransaction can work without this field"
),
)


class SetDefaultSortOrderUpdate(IcebergBaseModel):
Expand Down Expand Up @@ -267,11 +285,10 @@ def _(
elif update.format_version == base_metadata.format_version:
return base_metadata

updated_metadata_data = copy(base_metadata.model_dump())
updated_metadata_data["format-version"] = update.format_version
updated_metadata = base_metadata.model_copy(update={"format_version": update.format_version})

context.add_update(update)
return TableMetadataUtil.parse_obj(updated_metadata_data)
return TableMetadataUtil._construct_without_validation(updated_metadata)


@_apply_table_update.register(SetPropertiesUpdate)
Expand Down Expand Up @@ -306,7 +323,7 @@ def _(update: AddSchemaUpdate, base_metadata: TableMetadata, context: _TableMeta

metadata_updates: Dict[str, Any] = {
"last_column_id": update.last_column_id,
"schemas": [update.schema_] if update.initial_change else base_metadata.schemas + [update.schema_],
"schemas": base_metadata.schemas + [update.schema_],
}

context.add_update(update)
Expand Down Expand Up @@ -336,11 +353,11 @@ def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: _Ta
@_apply_table_update.register(AddPartitionSpecUpdate)
def _(update: AddPartitionSpecUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
for spec in base_metadata.partition_specs:
if spec.spec_id == update.spec.spec_id and not update.initial_change:
if spec.spec_id == update.spec.spec_id:
raise ValueError(f"Partition spec with id {spec.spec_id} already exists: {spec}")

metadata_updates: Dict[str, Any] = {
"partition_specs": [update.spec] if update.initial_change else base_metadata.partition_specs + [update.spec],
"partition_specs": base_metadata.partition_specs + [update.spec],
"last_partition_id": max(
max([field.field_id for field in update.spec.fields], default=0),
base_metadata.last_partition_id or PARTITION_FIELD_ID_START - 1,
Expand Down Expand Up @@ -448,7 +465,7 @@ def _(update: AddSortOrderUpdate, base_metadata: TableMetadata, context: _TableM
context.add_update(update)
return base_metadata.model_copy(
update={
"sort_orders": [update.sort_order] if update.initial_change else base_metadata.sort_orders + [update.sort_order],
"sort_orders": base_metadata.sort_orders + [update.sort_order],
}
)

Expand Down
9 changes: 6 additions & 3 deletions pyiceberg/utils/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,17 @@ def new_func(*args: Any, **kwargs: Any) -> Any:
return decorator


def deprecation_notice(deprecated_in: str, removed_in: str, help_message: Optional[str]) -> str:
"""Return a deprecation notice."""
return f"Deprecated in {deprecated_in}, will be removed in {removed_in}. {help_message}"


def deprecation_message(deprecated_in: str, removed_in: str, help_message: Optional[str]) -> None:
"""Mark properties or behaviors as deprecated.
Adding this will result in a warning being emitted.
"""
message = f"Deprecated in {deprecated_in}, will be removed in {removed_in}. {help_message}"

_deprecation_warning(message)
_deprecation_warning(deprecation_notice(deprecated_in, removed_in, help_message))


def _deprecation_warning(message: str) -> None:
Expand Down
59 changes: 58 additions & 1 deletion tests/integration/test_writes/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table import TableProperties
from pyiceberg.table.sorting import SortDirection, SortField, SortOrder
from pyiceberg.transforms import DayTransform, HourTransform, IdentityTransform
from pyiceberg.types import (
DateType,
Expand Down Expand Up @@ -738,7 +739,7 @@ def test_write_and_evolve(session_catalog: Catalog, format_version: int) -> None
def test_create_table_transaction(catalog: Catalog, format_version: int) -> None:
if format_version == 1 and isinstance(catalog, RestCatalog):
pytest.skip(
"There is a bug in the REST catalog (maybe server side) that prevents create and commit a staged version 1 table"
"There is a bug in the REST catalog image (https://github.com/apache/iceberg/issues/8756) that prevents create and commit a staged version 1 table"
)

identifier = f"default.arrow_create_table_transaction_{catalog.name}_{format_version}"
Expand Down Expand Up @@ -787,6 +788,62 @@ def test_create_table_transaction(catalog: Catalog, format_version: int) -> None
assert len(tbl.scan().to_arrow()) == 6


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_create_table_with_non_default_values(catalog: Catalog, table_schema_with_all_types: Schema, format_version: int) -> None:
if format_version == 1 and isinstance(catalog, RestCatalog):
pytest.skip(
"There is a bug in the REST catalog image (https://github.com/apache/iceberg/issues/8756) that prevents create and commit a staged version 1 table"
)

identifier = f"default.arrow_create_table_transaction_with_non_default_values_{catalog.name}_{format_version}"
identifier_ref = f"default.arrow_create_table_transaction_with_non_default_values_ref_{catalog.name}_{format_version}"

try:
catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

try:
catalog.drop_table(identifier=identifier_ref)
except NoSuchTableError:
pass

iceberg_spec = PartitionSpec(*[
PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="integer_partition")
])

sort_order = SortOrder(*[SortField(source_id=2, transform=IdentityTransform(), direction=SortDirection.ASC)])

txn = catalog.create_table_transaction(
identifier=identifier,
schema=table_schema_with_all_types,
partition_spec=iceberg_spec,
sort_order=sort_order,
properties={"format-version": format_version},
)
txn.commit_transaction()

tbl = catalog.load_table(identifier)

tbl_ref = catalog.create_table(
identifier=identifier_ref,
schema=table_schema_with_all_types,
partition_spec=iceberg_spec,
sort_order=sort_order,
properties={"format-version": format_version},
)

assert tbl.format_version == tbl_ref.format_version
assert tbl.schema() == tbl_ref.schema()
assert tbl.schemas() == tbl_ref.schemas()
assert tbl.spec() == tbl_ref.spec()
assert tbl.specs() == tbl_ref.specs()
assert tbl.sort_order() == tbl_ref.sort_order()
assert tbl.sort_orders() == tbl_ref.sort_orders()


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_table_properties_int_value(
Expand Down

0 comments on commit 8526880

Please sign in to comment.