From 952d7c0c47593f25913f67aa6155817db9ea1ead Mon Sep 17 00:00:00 2001 From: Adrian Qin <147659252+jqin61@users.noreply.github.com> Date: Thu, 19 Dec 2024 07:53:32 -0500 Subject: [PATCH] Add Support for Dynamic Overwrite (#931) --- mkdocs/docs/api.md | 121 ++++++ pyiceberg/expressions/literals.py | 4 + pyiceberg/io/pyarrow.py | 1 - pyiceberg/table/__init__.py | 135 +++++- .../test_writes/test_partitioned_writes.py | 401 +++++++++++++++++- 5 files changed, 633 insertions(+), 29 deletions(-) diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index 250f7ad72b..7aa4159016 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -353,6 +353,127 @@ lat: [[52.371807,37.773972,53.11254],[53.21917]] long: [[4.896029,-122.431297,6.0989],[6.56667]] ``` +### Partial overwrites + +When using the `overwrite` API, you can use an `overwrite_filter` to delete data that matches the filter before appending new data into the table. + +For example, with an iceberg table created as: + +```python +from pyiceberg.catalog import load_catalog + +catalog = load_catalog("default") + +from pyiceberg.schema import Schema +from pyiceberg.types import NestedField, StringType, DoubleType + +schema = Schema( + NestedField(1, "city", StringType(), required=False), + NestedField(2, "lat", DoubleType(), required=False), + NestedField(3, "long", DoubleType(), required=False), +) + +tbl = catalog.create_table("default.cities", schema=schema) +``` + +And with initial data populating the table: + +```python +import pyarrow as pa +df = pa.Table.from_pylist( + [ + {"city": "Amsterdam", "lat": 52.371807, "long": 4.896029}, + {"city": "San Francisco", "lat": 37.773972, "long": -122.431297}, + {"city": "Drachten", "lat": 53.11254, "long": 6.0989}, + {"city": "Paris", "lat": 48.864716, "long": 2.349014}, + ], +) +tbl.append(df) +``` + +You can overwrite the record of `Paris` with a record of `New York`: + +```python +from pyiceberg.expressions import EqualTo +df = pa.Table.from_pylist( + [ + {"city": "New York", "lat": 40.7128, "long": 74.0060}, + ] +) +tbl.overwrite(df, overwrite_filter=EqualTo('city', "Paris")) +``` + +This produces the following result with `tbl.scan().to_arrow()`: + +```python +pyarrow.Table +city: large_string +lat: double +long: double +---- +city: [["New York"],["Amsterdam","San Francisco","Drachten"]] +lat: [[40.7128],[52.371807,37.773972,53.11254]] +long: [[74.006],[4.896029,-122.431297,6.0989]] +``` + +If the PyIceberg table is partitioned, you can use `tbl.dynamic_partition_overwrite(df)` to replace the existing partitions with new ones provided in the dataframe. The partitions to be replaced are detected automatically from the provided arrow table. +For example, with an iceberg table with a partition specified on `"city"` field: + +```python +from pyiceberg.schema import Schema +from pyiceberg.types import DoubleType, NestedField, StringType + +schema = Schema( + NestedField(1, "city", StringType(), required=False), + NestedField(2, "lat", DoubleType(), required=False), + NestedField(3, "long", DoubleType(), required=False), +) + +tbl = catalog.create_table( + "default.cities", + schema=schema, + partition_spec=PartitionSpec(PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="city_identity")) +) +``` + +And we want to overwrite the data for the partition of `"Paris"`: + +```python +import pyarrow as pa + +df = pa.Table.from_pylist( + [ + {"city": "Amsterdam", "lat": 52.371807, "long": 4.896029}, + {"city": "San Francisco", "lat": 37.773972, "long": -122.431297}, + {"city": "Drachten", "lat": 53.11254, "long": 6.0989}, + {"city": "Paris", "lat": -48.864716, "long": -2.349014}, + ], +) +tbl.append(df) +``` + +Then we can call `dynamic_partition_overwrite` with this arrow table: + +```python +df_corrected = pa.Table.from_pylist([ + {"city": "Paris", "lat": 48.864716, "long": 2.349014} +]) +tbl.dynamic_partition_overwrite(df_corrected) +``` + +This produces the following result with `tbl.scan().to_arrow()`: + +```python +pyarrow.Table +city: large_string +lat: double +long: double +---- +city: [["Paris"],["Amsterdam"],["Drachten"],["San Francisco"]] +lat: [[48.864716],[52.371807],[53.11254],[37.773972]] +long: [[2.349014],[4.896029],[6.0989],[-122.431297]] +``` + ## Inspecting tables To explore the table metadata, tables can be inspected. diff --git a/pyiceberg/expressions/literals.py b/pyiceberg/expressions/literals.py index d9f66ae24a..d1c170d0dd 100644 --- a/pyiceberg/expressions/literals.py +++ b/pyiceberg/expressions/literals.py @@ -311,6 +311,10 @@ def _(self, _: TimeType) -> Literal[int]: def _(self, _: TimestampType) -> Literal[int]: return TimestampLiteral(self.value) + @to.register(TimestamptzType) + def _(self, _: TimestamptzType) -> Literal[int]: + return TimestampLiteral(self.value) + @to.register(DecimalType) def _(self, type_var: DecimalType) -> Literal[Decimal]: unscaled = Decimal(self.value) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 7956a83242..9847ec5a1c 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -2519,7 +2519,6 @@ def _check_pyarrow_schema_compatible( raise ValueError( f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)." ) from e - _check_schema_compatible(requested_schema, provided_schema) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 766ffba685..164d347796 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -44,10 +44,14 @@ import pyiceberg.expressions.parser as parser from pyiceberg.expressions import ( + AlwaysFalse, AlwaysTrue, And, BooleanExpression, EqualTo, + IsNull, + Or, + Reference, ) from pyiceberg.expressions.visitors import ( _InclusiveMetricsEvaluator, @@ -117,6 +121,7 @@ _OverwriteFiles, ) from pyiceberg.table.update.spec import UpdateSpec +from pyiceberg.transforms import IdentityTransform from pyiceberg.typedef import ( EMPTY_DICT, IcebergBaseModel, @@ -344,6 +349,48 @@ def _set_ref_snapshot( return updates, requirements + def _build_partition_predicate(self, partition_records: Set[Record]) -> BooleanExpression: + """Build a filter predicate matching any of the input partition records. + + Args: + partition_records: A set of partition records to match + Returns: + A predicate matching any of the input partition records. + """ + partition_spec = self.table_metadata.spec() + schema = self.table_metadata.schema() + partition_fields = [schema.find_field(field.source_id).name for field in partition_spec.fields] + + expr: BooleanExpression = AlwaysFalse() + for partition_record in partition_records: + match_partition_expression: BooleanExpression = AlwaysTrue() + + for pos, partition_field in enumerate(partition_fields): + predicate = ( + EqualTo(Reference(partition_field), partition_record[pos]) + if partition_record[pos] is not None + else IsNull(Reference(partition_field)) + ) + match_partition_expression = And(match_partition_expression, predicate) + expr = Or(expr, match_partition_expression) + return expr + + def _append_snapshot_producer(self, snapshot_properties: Dict[str, str]) -> _FastAppendFiles: + """Determine the append type based on table properties. + + Args: + snapshot_properties: Custom properties to be added to the snapshot summary + Returns: + Either a fast-append or a merge-append snapshot producer. + """ + manifest_merge_enabled = property_as_bool( + self.table_metadata.properties, + TableProperties.MANIFEST_MERGE_ENABLED, + TableProperties.MANIFEST_MERGE_ENABLED_DEFAULT, + ) + update_snapshot = self.update_snapshot(snapshot_properties=snapshot_properties) + return update_snapshot.merge_append() if manifest_merge_enabled else update_snapshot.fast_append() + def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive: bool = True) -> UpdateSchema: """Create a new UpdateSchema to alter the columns of this table. @@ -398,15 +445,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) self.table_metadata.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us ) - manifest_merge_enabled = property_as_bool( - self.table_metadata.properties, - TableProperties.MANIFEST_MERGE_ENABLED, - TableProperties.MANIFEST_MERGE_ENABLED_DEFAULT, - ) - update_snapshot = self.update_snapshot(snapshot_properties=snapshot_properties) - append_method = update_snapshot.merge_append if manifest_merge_enabled else update_snapshot.fast_append - - with append_method() as append_files: + with self._append_snapshot_producer(snapshot_properties) as append_files: # skip writing data files if the dataframe is empty if df.shape[0] > 0: data_files = _dataframe_to_data_files( @@ -415,6 +454,62 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) for data_file in data_files: append_files.append_data_file(data_file) + def dynamic_partition_overwrite(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: + """ + Shorthand for overwriting existing partitions with a PyArrow table. + + The function detects partition values in the provided arrow table using the current + partition spec, and deletes existing partitions matching these values. Finally, the + data in the table is appended to the table. + + Args: + df: The Arrow dataframe that will be used to overwrite the table + snapshot_properties: Custom properties to be added to the snapshot summary + """ + try: + import pyarrow as pa + except ModuleNotFoundError as e: + raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e + + from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, _dataframe_to_data_files + + if not isinstance(df, pa.Table): + raise ValueError(f"Expected PyArrow table, got: {df}") + + if self.table_metadata.spec().is_unpartitioned(): + raise ValueError("Cannot apply dynamic overwrite on an unpartitioned table.") + + for field in self.table_metadata.spec().fields: + if not isinstance(field.transform, IdentityTransform): + raise ValueError( + f"For now dynamic overwrite does not support a table with non-identity-transform field in the latest partition spec: {field}" + ) + + downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False + _check_pyarrow_schema_compatible( + self.table_metadata.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us + ) + + # If dataframe does not have data, there is no need to overwrite + if df.shape[0] == 0: + return + + append_snapshot_commit_uuid = uuid.uuid4() + data_files: List[DataFile] = list( + _dataframe_to_data_files( + table_metadata=self._table.metadata, write_uuid=append_snapshot_commit_uuid, df=df, io=self._table.io + ) + ) + + partitions_to_overwrite = {data_file.partition for data_file in data_files} + delete_filter = self._build_partition_predicate(partition_records=partitions_to_overwrite) + self.delete(delete_filter=delete_filter, snapshot_properties=snapshot_properties) + + with self._append_snapshot_producer(snapshot_properties) as append_files: + append_files.commit_uuid = append_snapshot_commit_uuid + for data_file in data_files: + append_files.append_data_file(data_file) + def overwrite( self, df: pa.Table, @@ -461,14 +556,14 @@ def overwrite( self.delete(delete_filter=overwrite_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties) - with self.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot: + with self._append_snapshot_producer(snapshot_properties) as append_files: # skip writing data files if the dataframe is empty if df.shape[0] > 0: data_files = _dataframe_to_data_files( - table_metadata=self.table_metadata, write_uuid=update_snapshot.commit_uuid, df=df, io=self._table.io + table_metadata=self.table_metadata, write_uuid=append_files.commit_uuid, df=df, io=self._table.io ) for data_file in data_files: - update_snapshot.append_data_file(data_file) + append_files.append_data_file(data_file) def delete( self, @@ -552,9 +647,8 @@ def delete( )) if len(replaced_files) > 0: - with self.update_snapshot(snapshot_properties=snapshot_properties).overwrite( - commit_uuid=commit_uuid - ) as overwrite_snapshot: + with self.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as overwrite_snapshot: + overwrite_snapshot.commit_uuid = commit_uuid for original_data_file, replaced_data_files in replaced_files: overwrite_snapshot.delete_data_file(original_data_file) for replaced_data_file in replaced_data_files: @@ -989,6 +1083,17 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) with self.transaction() as tx: tx.append(df=df, snapshot_properties=snapshot_properties) + def dynamic_partition_overwrite(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: + """Shorthand for dynamic overwriting the table with a PyArrow table. + + Old partitions are auto detected and replaced with data files created for input arrow table. + Args: + df: The Arrow dataframe that will be used to overwrite the table + snapshot_properties: Custom properties to be added to the snapshot summary + """ + with self.transaction() as tx: + tx.dynamic_partition_overwrite(df=df, snapshot_properties=snapshot_properties) + def overwrite( self, df: pa.Table, diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index b199f00210..b92c338931 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -38,6 +38,9 @@ TruncateTransform, YearTransform, ) +from pyiceberg.types import ( + StringType, +) from utils import TABLE_SCHEMA, _create_table @@ -181,6 +184,61 @@ def test_query_filter_appended_null_partitioned( assert len(rows) == 6 +@pytest.mark.integration +@pytest.mark.parametrize( + "part_col", + [ + "int", + "bool", + "string", + "string_long", + "long", + "float", + "double", + "date", + "timestamp", + "binary", + "timestamptz", + ], +) +@pytest.mark.parametrize( + "format_version", + [1, 2], +) +def test_query_filter_dynamic_partition_overwrite_null_partitioned( + session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, part_col: str, format_version: int +) -> None: + # Given + identifier = f"default.arrow_table_v{format_version}_appended_with_null_partitioned_on_col_{part_col}" + nested_field = TABLE_SCHEMA.find_field(part_col) + partition_spec = PartitionSpec( + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=IdentityTransform(), name=part_col) + ) + + # When + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": str(format_version)}, + data=[], + partition_spec=partition_spec, + ) + # Append with arrow_table_1 with lines [A,B,C] and then arrow_table_2 with lines[A,B,C,A,B,C] + tbl.append(arrow_table_with_null) + tbl.append(pa.concat_tables([arrow_table_with_null, arrow_table_with_null])) + tbl.dynamic_partition_overwrite(arrow_table_with_null) + tbl.dynamic_partition_overwrite(arrow_table_with_null.slice(0, 2)) + # Then + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + df = spark.table(identifier) + for col in arrow_table_with_null.column_names: + assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}," + assert df.where(f"{col} is null").count() == 1, f"Expected 1 null rows for {col}," + # expecting 3 files: + rows = spark.sql(f"select partition from {identifier}.files").collect() + assert len(rows) == 3 + + @pytest.mark.integration @pytest.mark.parametrize( "part_col", ["int", "bool", "string", "string_long", "long", "float", "double", "date", "timestamptz", "timestamp", "binary"] @@ -222,6 +280,127 @@ def test_query_filter_v1_v2_append_null( assert df.where(f"{col} is null").count() == 2, f"Expected 2 null rows for {col}" +@pytest.mark.integration +@pytest.mark.parametrize( + "spec", + [ + (PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket"))), + (PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=BucketTransform(2), name="long_bucket"))), + (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=BucketTransform(2), name="date_bucket"))), + (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=BucketTransform(2), name="timestamp_bucket"))), + (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=BucketTransform(2), name="timestamptz_bucket"))), + (PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=BucketTransform(2), name="string_bucket"))), + (PartitionSpec(PartitionField(source_id=12, field_id=1001, transform=BucketTransform(2), name="fixed_bucket"))), + (PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=BucketTransform(2), name="binary_bucket"))), + (PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=TruncateTransform(2), name="int_trunc"))), + (PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=TruncateTransform(2), name="long_trunc"))), + (PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=TruncateTransform(2), name="string_trunc"))), + (PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=TruncateTransform(2), name="binary_trunc"))), + (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=YearTransform(), name="timestamp_year"))), + (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=YearTransform(), name="timestamptz_year"))), + (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=YearTransform(), name="date_year"))), + (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=MonthTransform(), name="timestamp_month"))), + (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=MonthTransform(), name="timestamptz_month"))), + (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=MonthTransform(), name="date_month"))), + (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=DayTransform(), name="timestamp_day"))), + (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=DayTransform(), name="timestamptz_day"))), + (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=DayTransform(), name="date_day"))), + (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=HourTransform(), name="timestamp_hour"))), + (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=HourTransform(), name="timestamptz_hour"))), + ], +) +@pytest.mark.parametrize( + "format_version", + [1, 2], +) +def test_dynamic_partition_overwrite_non_identity_transform( + session_catalog: Catalog, arrow_table_with_null: pa.Table, spec: PartitionSpec, format_version: int +) -> None: + identifier = "default.dynamic_partition_overwrite_non_identity_transform" + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + tbl = session_catalog.create_table( + identifier=identifier, + schema=TABLE_SCHEMA, + properties={"format-version": format_version}, + partition_spec=spec, + ) + with pytest.raises( + ValueError, + match="For now dynamic overwrite does not support a table with non-identity-transform field in the latest partition spec: *", + ): + tbl.dynamic_partition_overwrite(arrow_table_with_null.slice(0, 1)) + + +@pytest.mark.integration +def test_dynamic_partition_overwrite_invalid_on_unpartitioned_table( + session_catalog: Catalog, arrow_table_with_null: pa.Table +) -> None: + identifier = "default.arrow_data_files" + tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, []) + + with pytest.raises(ValueError, match="Cannot apply dynamic overwrite on an unpartitioned table."): + tbl.dynamic_partition_overwrite(arrow_table_with_null) + + +@pytest.mark.integration +@pytest.mark.parametrize( + "part_col", + [ + "int", + "bool", + "string", + "string_long", + "long", + "float", + "double", + "date", + "timestamp", + "binary", + "timestamptz", + ], +) +@pytest.mark.parametrize( + "format_version", + [1, 2], +) +def test_dynamic_partition_overwrite_unpartitioned_evolve_to_identity_transform( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, part_col: str, format_version: int +) -> None: + identifier = f"default.unpartitioned_table_v{format_version}_evolve_into_identity_transformed_partition_field_{part_col}" + tbl = session_catalog.create_table( + identifier=identifier, + schema=TABLE_SCHEMA, + properties={"format-version": format_version}, + ) + tbl.append(arrow_table_with_null) + tbl.update_spec().add_field(part_col, IdentityTransform(), f"{part_col}_identity").commit() + tbl.append(arrow_table_with_null) + # each column should be [a, null, b, a, null, b] + # dynamic overwrite a non-null row a, resulting in [null, b, null, b, a] + tbl.dynamic_partition_overwrite(arrow_table_with_null.slice(0, 1)) + df = spark.table(identifier) + assert df.where(f"{part_col} is not null").count() == 3, f"Expected 3 non-null rows for {part_col}," + assert df.where(f"{part_col} is null").count() == 2, f"Expected 2 null rows for {part_col}," + + # The first 2 appends come from 2 calls of the append API, while the dynamic partition overwrite API + # firstly overwrites of the unpartitioned file from first append, + # then it deletes one of the 3 partition files generated by the second append, + # finally it appends with new data. + expected_operations = ["append", "append", "delete", "overwrite", "append"] + + # For a long string, the lower bound and upper bound is truncated + # e.g. aaaaaaaaaaaaaaaaaaaaaa has lower bound of aaaaaaaaaaaaaaaa and upper bound of aaaaaaaaaaaaaaab + # this makes strict metric evaluator determine the file evaluate as ROWS_MIGHT_NOT_MATCH + # this further causes the partitioned data file to be overwriten rather than deleted + if part_col == "string_long": + expected_operations = ["append", "append", "overwrite", "append"] + assert tbl.inspect.snapshots().to_pydict()["operation"] == expected_operations + + @pytest.mark.integration def test_summaries_with_null(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: identifier = "default.arrow_table_summaries" @@ -239,6 +418,9 @@ def test_summaries_with_null(spark: SparkSession, session_catalog: Catalog, arro tbl.append(arrow_table_with_null) tbl.append(arrow_table_with_null) + tbl.dynamic_partition_overwrite(arrow_table_with_null) + tbl.append(arrow_table_with_null) + tbl.dynamic_partition_overwrite(arrow_table_with_null.slice(0, 2)) rows = spark.sql( f""" @@ -249,10 +431,9 @@ def test_summaries_with_null(spark: SparkSession, session_catalog: Catalog, arro ).collect() operations = [row.operation for row in rows] - assert operations == ["append", "append"] + assert operations == ["append", "append", "delete", "append", "append", "delete", "append"] summaries = [row.summary for row in rows] - file_size = int(summaries[0]["added-files-size"]) assert file_size > 0 @@ -281,12 +462,108 @@ def test_summaries_with_null(spark: SparkSession, session_catalog: Catalog, arro "total-position-deletes": "0", "total-records": "6", } + assert summaries[2] == { + "removed-files-size": str(file_size * 2), + "changed-partition-count": "3", + "total-equality-deletes": "0", + "deleted-data-files": "6", + "total-position-deletes": "0", + "total-delete-files": "0", + "deleted-records": "6", + "total-files-size": "0", + "total-data-files": "0", + "total-records": "0", + } + assert summaries[3] == { + "changed-partition-count": "3", + "added-data-files": "3", + "total-equality-deletes": "0", + "added-records": "3", + "total-position-deletes": "0", + "added-files-size": str(file_size), + "total-delete-files": "0", + "total-files-size": str(file_size), + "total-data-files": "3", + "total-records": "3", + } + assert summaries[4] == { + "changed-partition-count": "3", + "added-data-files": "3", + "total-equality-deletes": "0", + "added-records": "3", + "total-position-deletes": "0", + "added-files-size": str(file_size), + "total-delete-files": "0", + "total-files-size": str(file_size * 2), + "total-data-files": "6", + "total-records": "6", + } + assert summaries[5] == { + "removed-files-size": "15774", + "changed-partition-count": "2", + "total-equality-deletes": "0", + "deleted-data-files": "4", + "total-position-deletes": "0", + "total-delete-files": "0", + "deleted-records": "4", + "total-files-size": "8684", + "total-data-files": "2", + "total-records": "2", + } + assert summaries[6] == { + "changed-partition-count": "2", + "added-data-files": "2", + "total-equality-deletes": "0", + "added-records": "2", + "total-position-deletes": "0", + "added-files-size": "7887", + "total-delete-files": "0", + "total-files-size": "16571", + "total-data-files": "4", + "total-records": "4", + } @pytest.mark.integration def test_data_files_with_table_partitioned_with_null( spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table ) -> None: + # Append : First append has manifestlist file linking to 1 manifest file. + # ML1 = [M1] + # + # Append : Second append's manifestlist links to 2 manifest files. + # ML2 = [M1, M2] + # + # Dynamic Overwrite: Dynamic overwrite on all partitions of the table delete all data and append new data + # it has 2 snapshots of delete and append and thus 2 snapshots + # the first snapshot generates M3 with 6 delete data entries collected from M1 and M2. + # ML3 = [M3] + # + # The second snapshot generates M4 with 3 appended data entries and since M3 (previous manifests) only has delte entries it does not lint to it. + # ML4 = [M4] + + # Append : Append generates M5 with new data entries and links to all previous manifests which is M4 . + # ML5 = [M5, M4] + + # Dynamic Overwrite: Dynamic overwrite on partial partitions of the table delete partial data and append new data + # it has 2 snapshots of delete and append and thus 2 snapshots + # the first snapshot generates M6 with 4 delete data entries collected from M1 and M2, + # then it generates M7 as remaining existing entries from M1 and M8 from M2 + # ML6 = [M6, M7, M8] + # + # The second snapshot generates M9 with 3 appended data entries and it also looks at manifests in ML6 (previous manifests) + # it ignores M6 since it only has delte entries but it links to M7 and M8. + # ML7 = [M9, M7, M8] + + # tldr: + # APPEND ML1 = [M1] + # APPEND ML2 = [M1, M2] + # DYNAMIC_PARTITION_OVERWRITE ML3 = [M3] + # ML4 = [M4] + # APPEND ML5 = [M5, M4] + # DYNAMIC_PARTITION_OVERWRITE ML6 = [M6, M7, M8] + # ML7 = [M9, M7, M8] + identifier = "default.arrow_data_files" try: @@ -296,28 +573,126 @@ def test_data_files_with_table_partitioned_with_null( tbl = session_catalog.create_table( identifier=identifier, schema=TABLE_SCHEMA, - partition_spec=PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=IdentityTransform(), name="int")), + partition_spec=PartitionSpec( + PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="bool"), + PartitionField(source_id=4, field_id=1002, transform=IdentityTransform(), name="int"), + ), properties={"format-version": "1"}, ) tbl.append(arrow_table_with_null) tbl.append(arrow_table_with_null) - - # added_data_files_count, existing_data_files_count, deleted_data_files_count + tbl.dynamic_partition_overwrite(arrow_table_with_null) + tbl.append(arrow_table_with_null) + tbl.dynamic_partition_overwrite(arrow_table_with_null.slice(0, 2)) rows = spark.sql( f""" - SELECT added_data_files_count, existing_data_files_count, deleted_data_files_count + SELECT * FROM {identifier}.all_manifests """ ).collect() - assert [row.added_data_files_count for row in rows] == [3, 3, 3] - assert [row.existing_data_files_count for row in rows] == [ - 0, - 0, - 0, - ] - assert [row.deleted_data_files_count for row in rows] == [0, 0, 0] + assert [row.added_data_files_count for row in rows] == [3, 3, 3, 0, 3, 3, 3, 0, 0, 0, 2, 0, 0] + assert [row.existing_data_files_count for row in rows] == [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1] + assert [row.deleted_data_files_count for row in rows] == [0, 0, 0, 6, 0, 0, 0, 4, 0, 0, 0, 0, 0] + + +@pytest.mark.integration +@pytest.mark.parametrize( + "format_version", + [1, 2], +) +def test_dynamic_partition_overwrite_rename_column(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None: + arrow_table = pa.Table.from_pydict( + { + "place": ["Amsterdam", "Drachten"], + "inhabitants": [921402, 44940], + }, + ) + + identifier = f"default.partitioned_{format_version}_dynamic_partition_overwrite_rename_column" + with pytest.raises(NoSuchTableError): + session_catalog.drop_table(identifier) + + tbl = session_catalog.create_table( + identifier=identifier, + schema=arrow_table.schema, + properties={"format-version": str(format_version)}, + ) + + with tbl.transaction() as tx: + with tx.update_spec() as schema: + schema.add_identity("place") + + tbl.append(arrow_table) + + with tbl.transaction() as tx: + with tx.update_schema() as schema: + schema.rename_column("place", "city") + + arrow_table = pa.Table.from_pydict( + { + "city": ["Drachten"], + "inhabitants": [44941], # A new baby was born! + }, + ) + + tbl.dynamic_partition_overwrite(arrow_table) + result = tbl.scan().to_arrow() + + assert result["city"].to_pylist() == ["Drachten", "Amsterdam"] + assert result["inhabitants"].to_pylist() == [44941, 921402] + + +@pytest.mark.integration +@pytest.mark.parametrize( + "format_version", + [1, 2], +) +@pytest.mark.filterwarnings("ignore") +def test_dynamic_partition_overwrite_evolve_partition(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None: + arrow_table = pa.Table.from_pydict( + { + "place": ["Amsterdam", "Drachten"], + "inhabitants": [921402, 44940], + }, + ) + + identifier = f"default.partitioned_{format_version}_test_dynamic_partition_overwrite_evolve_partition" + with pytest.raises(NoSuchTableError): + session_catalog.drop_table(identifier) + + tbl = session_catalog.create_table( + identifier=identifier, + schema=arrow_table.schema, + properties={"format-version": str(format_version)}, + ) + + with tbl.transaction() as tx: + with tx.update_spec() as schema: + schema.add_identity("place") + + tbl.append(arrow_table) + + with tbl.transaction() as tx: + with tx.update_schema() as schema: + schema.add_column("country", StringType()) + with tx.update_spec() as schema: + schema.add_identity("country") + + arrow_table = pa.Table.from_pydict( + { + "place": ["Groningen"], + "country": ["Netherlands"], + "inhabitants": [238147], + }, + ) + + tbl.dynamic_partition_overwrite(arrow_table) + result = tbl.scan().to_arrow() + + assert result["place"].to_pylist() == ["Groningen", "Amsterdam", "Drachten"] + assert result["inhabitants"].to_pylist() == [238147, 921402, 44940] @pytest.mark.integration