From 09193eb018cc67bc2227529bd5d136b54ec8fa09 Mon Sep 17 00:00:00 2001 From: Maksym Shalenyi Date: Mon, 10 Jun 2024 23:48:11 -0700 Subject: [PATCH] adding add_files_overwrite method use delete instead of overwrite check history too --- mkdocs/docs/api.md | 5 + pyiceberg/table/__init__.py | 65 ++++- tests/integration/test_add_files.py | 392 +++++++++++++++++++++++++++- 3 files changed, 449 insertions(+), 13 deletions(-) diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index 70b5fd62eb..80330cee1a 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -672,7 +672,12 @@ file_paths = [ tbl.add_files(file_paths=file_paths) +# or if you want to overwrite + +tbl.add_files_overwrite(file_paths=file_paths) + # A new snapshot is committed to the table with manifests pointing to the existing parquet files + ``` diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index e0214d1bde..99e32bbd3a 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -569,6 +569,27 @@ def add_files(self, file_paths: List[str], snapshot_properties: Dict[str, str] = for data_file in data_files: update_snapshot.append_data_file(data_file) + def add_files_overwrite(self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: + """ + Shorthand API for adding files as data files and overwriting the table. + + Args: + file_paths: The list of full file paths to be added as data files to the table + snapshot_properties: Custom properties to be added to the snapshot summary + + Raises: + FileNotFoundError: If the file does not exist. + """ + if self._table.name_mapping() is None: + self.set_properties(**{TableProperties.DEFAULT_NAME_MAPPING: self._table.schema().name_mapping.model_dump_json()}) + self.delete(delete_filter=ALWAYS_TRUE, snapshot_properties=snapshot_properties) + with self.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot: + data_files = _parquet_files_to_data_files( + table_metadata=self._table.metadata, file_paths=file_paths, io=self._table.io + ) + for data_file in data_files: + update_snapshot.append_data_file(data_file) + def update_spec(self) -> UpdateSpec: """Create a new UpdateSpec to update the partitioning of the table. @@ -1480,6 +1501,20 @@ def add_files(self, file_paths: List[str], snapshot_properties: Dict[str, str] = with self.transaction() as tx: tx.add_files(file_paths=file_paths, snapshot_properties=snapshot_properties) + def add_files_overwrite(self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: + """ + Shorthand API for adding files as data files and overwriting the table. + + Args: + file_paths: The list of full file paths to be added as data files to the table + snapshot_properties: Custom properties to be added to the snapshot summary + + Raises: + FileNotFoundError: If the file does not exist. + """ + with self.transaction() as tx: + tx.add_files_overwrite(file_paths=file_paths, snapshot_properties=snapshot_properties) + def update_spec(self, case_sensitive: bool = True) -> UpdateSpec: return UpdateSpec(Transaction(self, autocommit=True), case_sensitive=case_sensitive) @@ -3273,9 +3308,9 @@ def fast_append(self) -> FastAppendFiles: def overwrite(self, commit_uuid: Optional[uuid.UUID] = None) -> OverwriteFiles: return OverwriteFiles( commit_uuid=commit_uuid, - operation=Operation.OVERWRITE - if self._transaction.table_metadata.current_snapshot() is not None - else Operation.APPEND, + operation=( + Operation.OVERWRITE if self._transaction.table_metadata.current_snapshot() is not None else Operation.APPEND + ), transaction=self._transaction, io=self._io, snapshot_properties=self._snapshot_properties, @@ -3665,12 +3700,16 @@ def _readable_metrics_struct(bound_type: PrimitiveType) -> pa.StructType: "null_value_count": null_value_counts.get(field.field_id), "nan_value_count": nan_value_counts.get(field.field_id), # Makes them readable - "lower_bound": from_bytes(field.field_type, lower_bound) - if (lower_bound := lower_bounds.get(field.field_id)) - else None, - "upper_bound": from_bytes(field.field_type, upper_bound) - if (upper_bound := upper_bounds.get(field.field_id)) - else None, + "lower_bound": ( + from_bytes(field.field_type, lower_bound) + if (lower_bound := lower_bounds.get(field.field_id)) + else None + ), + "upper_bound": ( + from_bytes(field.field_type, upper_bound) + if (upper_bound := upper_bounds.get(field.field_id)) + else None + ), } for field in self.tbl.metadata.schema().fields } @@ -3905,9 +3944,11 @@ def _partition_summaries_to_rows( "added_delete_files_count": manifest.added_files_count if is_delete_file else 0, "existing_delete_files_count": manifest.existing_files_count if is_delete_file else 0, "deleted_delete_files_count": manifest.deleted_files_count if is_delete_file else 0, - "partition_summaries": _partition_summaries_to_rows(specs[manifest.partition_spec_id], manifest.partitions) - if manifest.partitions - else [], + "partition_summaries": ( + _partition_summaries_to_rows(specs[manifest.partition_spec_id], manifest.partitions) + if manifest.partitions + else [] + ), }) return pa.Table.from_pylist( diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index 84729fcca4..fe9a4d8396 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -16,8 +16,10 @@ # under the License. # pylint:disable=redefined-outer-name +import time from datetime import date -from typing import Iterator, Optional +from typing import Any, Dict, Iterator, List, Optional +from uuid import uuid4 import pyarrow as pa import pyarrow.parquet as pq @@ -448,3 +450,391 @@ def test_add_files_snapshot_properties(spark: SparkSession, session_catalog: Cat assert "snapshot_prop_a" in summary assert summary["snapshot_prop_a"] == "test_prop_a" + + +@pytest.mark.integration +def test_add_files_overwrite_to_unpartitioned_table(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None: + identifier = f"default.unpartitioned_table_v{format_version}" + tbl = _create_table(session_catalog, identifier, format_version) + + data = { + "foo": True, + "bar": "bar_string", + "baz": 123, + "qux": date(2024, 3, 7), + } + + # creating initial snapshot + tbl.add_files(file_paths=_create_parquet_files(tbl, [data])) + + # testing overwrite with new data files + file_paths = _create_parquet_files(tbl, [data for _ in range(5)]) + tbl.add_files_overwrite(file_paths=file_paths) + + # NameMapping must have been set to enable reads + assert tbl.name_mapping() is not None + + rows = spark.sql( + f""" + SELECT added_data_files_count, existing_data_files_count, deleted_data_files_count + FROM {identifier}.all_manifests + """ + ).collect() + + assert [row.added_data_files_count for row in rows] == [1, 0, 5] + assert [row.existing_data_files_count for row in rows] == [0, 0, 0] + assert [row.deleted_data_files_count for row in rows] == [0, 1, 0] + + df = spark.table(identifier) + assert df.count() == 5, "Expected 5 rows" + for col in df.columns: + assert df.filter(df[col].isNotNull()).count() == 5, "Expected all 5 rows to be non-null" + + # check that the table can be read by pyiceberg + assert len(tbl.scan().to_arrow()) == 5, "Expected 5 rows" + + # check history + assert len(tbl.scan(snapshot_id=tbl.history()[0].snapshot_id).to_arrow()) == 1, "Expected 1 row" + + +@pytest.mark.integration +def test_add_files_overwrite_to_unpartitioned_table_raises_file_not_found( + spark: SparkSession, session_catalog: Catalog, format_version: int +) -> None: + identifier = f"default.unpartitioned_raises_not_found_v{format_version}" + tbl = _create_table(session_catalog, identifier, format_version) + + file_paths = [f"s3://warehouse/default/unpartitioned_raises_not_found/v{format_version}/test-{i}.parquet" for i in range(5)] + # write parquet files + for file_path in file_paths: + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer: + writer.write_table(ARROW_TABLE) + + # add the parquet files as data files + with pytest.raises(FileNotFoundError): + tbl.add_files_overwrite(file_paths=file_paths + ["s3://warehouse/default/unpartitioned_raises_not_found/unknown.parquet"]) + + +@pytest.mark.integration +def test_add_files_overwrite_to_unpartitioned_table_raises_has_field_ids( + spark: SparkSession, session_catalog: Catalog, format_version: int +) -> None: + identifier = f"default.unpartitioned_raises_field_ids_v{format_version}" + tbl = _create_table(session_catalog, identifier, format_version) + + file_paths = [f"s3://warehouse/default/unpartitioned_raises_field_ids/v{format_version}/test-{i}.parquet" for i in range(5)] + # write parquet files + for file_path in file_paths: + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=ARROW_SCHEMA_WITH_IDS) as writer: + writer.write_table(ARROW_TABLE_WITH_IDS) + + # add the parquet files as data files + with pytest.raises(NotImplementedError): + tbl.add_files_overwrite(file_paths=file_paths) + + +@pytest.mark.integration +def test_add_files_overwrite_to_unpartitioned_table_with_schema_updates( + spark: SparkSession, session_catalog: Catalog, format_version: int +) -> None: + identifier = f"default.unpartitioned_table_schema_updates_v{format_version}" + tbl = _create_table(session_catalog, identifier, format_version) + + file_paths = [f"s3://warehouse/default/unpartitioned_schema_updates/v{format_version}/test-{i}.parquet" for i in range(5)] + # write parquet files + for file_path in file_paths: + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer: + writer.write_table(ARROW_TABLE) + + # add the parquet files as data files + tbl.add_files(file_paths=file_paths) + + # NameMapping must have been set to enable reads + assert tbl.name_mapping() is not None + + with tbl.update_schema() as update: + update.add_column("quux", IntegerType()) + update.delete_column("bar") + + file_path = f"s3://warehouse/default/unpartitioned_schema_updates/v{format_version}/test-6.parquet" + # write parquet files + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=ARROW_SCHEMA_UPDATED) as writer: + writer.write_table(ARROW_TABLE_UPDATED) + + # add the parquet files as data files + tbl.add_files_overwrite(file_paths=[file_path]) + rows = spark.sql( + f""" + SELECT added_data_files_count, existing_data_files_count, deleted_data_files_count + FROM {identifier}.all_manifests + """ + ).collect() + + assert [row.added_data_files_count for row in rows] == [5, 0, 1] + assert [row.existing_data_files_count for row in rows] == [0, 0, 0] + assert [row.deleted_data_files_count for row in rows] == [0, 5, 0] + + df = spark.table(identifier) + assert df.count() == 1, "Expected 1 rows" + assert len(df.columns) == 4, "Expected 4 columns" + + for col in df.columns: + value_count = 1 + assert df.filter(df[col].isNotNull()).count() == value_count, f"Expected {value_count} rows to be non-null" + + # check that the table can be read by pyiceberg + assert len(tbl.scan().to_arrow()) == 1, "Expected 1 rows" + + +def _create_parquet_files(tbl: Table, files_constents: List[Dict[str, Any]]) -> List[str]: + uid = uuid4().hex + time.time_ns() + file_paths = [] + for i, contents in enumerate(files_constents): + file_path = f"s3://warehouse/default/partitioned/test-{time.time_ns()}/test-{uid}-{i}.parquet" + file_paths.append(file_path) + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer: + writer.write_table( + pa.Table.from_pylist( + [contents], + schema=ARROW_SCHEMA, + ) + ) + return file_paths + + +@pytest.mark.integration +def test_add_files_overwrite_to_partitioned_table(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None: + identifier = f"default.partitioned_table_v{format_version}" + + partition_spec = PartitionSpec( + PartitionField(source_id=4, field_id=1000, transform=IdentityTransform(), name="baz"), + PartitionField(source_id=10, field_id=1001, transform=MonthTransform(), name="qux_month"), + spec_id=0, + ) + + tbl = _create_table(session_catalog, identifier, format_version, partition_spec) + + files_constents = [ + { + "foo": True, + "bar": "bar_string", + "baz": 123, + "qux": i, + } + for i in [date(2024, 3, 7), date(2024, 3, 8), date(2024, 3, 16), date(2024, 3, 18), date(2024, 3, 19)] + ] + file_paths = _create_parquet_files(tbl, files_constents) + + # add the parquet files as data files + tbl.add_files(file_paths=file_paths) + + # NameMapping must have been set to enable reads + assert tbl.name_mapping() is not None + + rows = spark.sql( + f""" + SELECT added_data_files_count, existing_data_files_count, deleted_data_files_count + FROM {identifier}.all_manifests + """ + ).collect() + + assert [row.added_data_files_count for row in rows] == [5] + assert [row.existing_data_files_count for row in rows] == [0] + assert [row.deleted_data_files_count for row in rows] == [0] + + df = spark.table(identifier) + assert df.count() == 5, "Expected 5 rows" + for col in df.columns: + assert df.filter(df[col].isNotNull()).count() == 5, "Expected all 5 rows to be non-null" + + partition_rows = spark.sql( + f""" + SELECT partition, record_count, file_count + FROM {identifier}.partitions + """ + ).collect() + assert [row.record_count for row in partition_rows] == [5] + assert [row.file_count for row in partition_rows] == [5] + assert [(row.partition.baz, row.partition.qux_month) for row in partition_rows] == [(123, 650)] + + # check that the table can be read by pyiceberg + assert len(tbl.scan().to_arrow()) == 5, "Expected 5 rows" + + ## + + files_constents = [ + { + "foo": True, + "bar": "bar_string-overwrite", + "baz": 456, + "qux": i, + } + for i in [date(2025, 3, 7), date(2025, 3, 8), date(2025, 3, 16)] + ] + file_paths = _create_parquet_files(tbl, files_constents) + + # add the parquet files as data files + tbl.add_files_overwrite(file_paths=file_paths) + + # NameMapping must have been set to enable reads + assert tbl.name_mapping() is not None + + rows = spark.sql( + f""" + SELECT added_data_files_count, existing_data_files_count, deleted_data_files_count + FROM {identifier}.all_manifests + """ + ).collect() + + assert [row.added_data_files_count for row in rows] == [5, 0, 3] + assert [row.existing_data_files_count for row in rows] == [0, 0, 0] + assert [row.deleted_data_files_count for row in rows] == [0, 5, 0] + + df = spark.table(identifier) + assert df.count() == 3, "Expected 3 rows" + for col in df.columns: + assert df.filter(df[col].isNotNull()).count() == 3, "Expected all 3 rows to be non-null" + + partition_rows = spark.sql( + f""" + SELECT partition, record_count, file_count + FROM {identifier}.partitions + """ + ).collect() + assert [row.record_count for row in partition_rows] == [3] + assert [row.file_count for row in partition_rows] == [3] + assert [(row.partition.baz, row.partition.qux_month) for row in partition_rows] == [(456, 662)] + + # check that the table can be read by pyiceberg + assert len(tbl.scan().to_arrow()) == 3, "Expected 3 rows" + + # check history + assert len(tbl.scan(snapshot_id=tbl.history()[0].snapshot_id).to_arrow()) == 5, "Expected 5 rows" + + +@pytest.mark.integration +def test_add_files_overwrite_to_bucket_partitioned_table_fails( + spark: SparkSession, session_catalog: Catalog, format_version: int +) -> None: + identifier = f"default.partitioned_table_bucket_fails_v{format_version}" + + partition_spec = PartitionSpec( + PartitionField(source_id=4, field_id=1000, transform=BucketTransform(num_buckets=3), name="baz_bucket_3"), + spec_id=0, + ) + + tbl = _create_table(session_catalog, identifier, format_version, partition_spec) + + int_iter = iter(range(5)) + + file_paths = [f"s3://warehouse/default/partitioned_table_bucket_fails/v{format_version}/test-{i}.parquet" for i in range(5)] + # write parquet files + for file_path in file_paths: + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer: + writer.write_table( + pa.Table.from_pylist( + [ + { + "foo": True, + "bar": "bar_string", + "baz": next(int_iter), + "qux": date(2024, 3, 7), + } + ], + schema=ARROW_SCHEMA, + ) + ) + + # add the parquet files as data files + with pytest.raises(ValueError) as exc_info: + tbl.add_files_overwrite(file_paths=file_paths) + assert ( + "Cannot infer partition value from parquet metadata for a non-linear Partition Field: baz_bucket_3 with transform bucket[3]" + in str(exc_info.value) + ) + + +@pytest.mark.integration +def test_add_files_overwrite_to_partitioned_table_fails_with_lower_and_upper_mismatch( + spark: SparkSession, session_catalog: Catalog, format_version: int +) -> None: + identifier = f"default.partitioned_table_mismatch_fails_v{format_version}" + + partition_spec = PartitionSpec( + PartitionField(source_id=4, field_id=1000, transform=IdentityTransform(), name="baz"), + spec_id=0, + ) + + tbl = _create_table(session_catalog, identifier, format_version, partition_spec) + + file_paths = [f"s3://warehouse/default/partitioned_table_mismatch_fails/v{format_version}/test-{i}.parquet" for i in range(5)] + # write parquet files + for file_path in file_paths: + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer: + writer.write_table( + pa.Table.from_pylist( + [ + { + "foo": True, + "bar": "bar_string", + "baz": 123, + "qux": date(2024, 3, 7), + }, + { + "foo": True, + "bar": "bar_string", + "baz": 124, + "qux": date(2024, 3, 7), + }, + ], + schema=ARROW_SCHEMA, + ) + ) + + # add the parquet files as data files + with pytest.raises(ValueError) as exc_info: + tbl.add_files_overwrite(file_paths=file_paths) + assert ( + "Cannot infer partition value from parquet metadata as there are more than one partition values for Partition Field: baz. lower_value=123, upper_value=124" + in str(exc_info.value) + ) + + +@pytest.mark.integration +def test_add_files_overwrite_snapshot_properties(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None: + identifier = f"default.unpartitioned_table_v{format_version}" + tbl = _create_table(session_catalog, identifier, format_version) + + file_paths = [f"s3://warehouse/default/unpartitioned/v{format_version}/test-{i}.parquet" for i in range(5)] + # write parquet files + for file_path in file_paths: + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer: + writer.write_table(ARROW_TABLE) + + # add the parquet files as data files + tbl.add_files_overwrite(file_paths=file_paths, snapshot_properties={"snapshot_prop_a": "test_prop_a"}) + + # NameMapping must have been set to enable reads + assert tbl.name_mapping() is not None + + summary = spark.sql(f"SELECT * FROM {identifier}.snapshots;").collect()[0].summary + + assert "snapshot_prop_a" in summary + assert summary["snapshot_prop_a"] == "test_prop_a"