diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index fcfd5b4904..ba336e0e61 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -124,7 +124,7 @@ visit, visit_with_partner, ) -from pyiceberg.table import PropertyUtil, TableProperties, WriteTask +from pyiceberg.table import AddFileTask, PropertyUtil, TableProperties, WriteTask from pyiceberg.table.metadata import TableMetadata from pyiceberg.table.name_mapping import NameMapping from pyiceberg.transforms import TruncateTransform @@ -1772,31 +1772,32 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT return iter([data_file]) -def parquet_file_to_data_file(io: FileIO, table_metadata: TableMetadata, file_path: str) -> DataFile: - input_file = io.new_input(file_path) - with input_file.open() as input_stream: - parquet_metadata = pq.read_metadata(input_stream) - - schema = table_metadata.schema() - data_file = DataFile( - content=DataFileContent.DATA, - file_path=file_path, - file_format=FileFormat.PARQUET, - partition=Record(), - record_count=parquet_metadata.num_rows, - file_size_in_bytes=len(input_file), - sort_order_id=None, - spec_id=table_metadata.default_spec_id, - equality_ids=None, - key_metadata=None, - ) - fill_parquet_file_metadata( - data_file=data_file, - parquet_metadata=parquet_metadata, - stats_columns=compute_statistics_plan(schema, table_metadata.properties), - parquet_column_mapping=parquet_path_to_id_mapping(schema), - ) - return data_file +def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[AddFileTask]) -> Iterator[DataFile]: + for task in tasks: + input_file = io.new_input(task.file_path) + with input_file.open() as input_stream: + parquet_metadata = pq.read_metadata(input_stream) + + schema = table_metadata.schema() + data_file = DataFile( + content=DataFileContent.DATA, + file_path=task.file_path, + file_format=FileFormat.PARQUET, + partition=task.partition_field_value, + record_count=parquet_metadata.num_rows, + file_size_in_bytes=len(input_file), + sort_order_id=None, + spec_id=table_metadata.default_spec_id, + equality_ids=None, + key_metadata=None, + ) + fill_parquet_file_metadata( + data_file=data_file, + parquet_metadata=parquet_metadata, + stats_columns=compute_statistics_plan(schema, table_metadata.properties), + parquet_column_mapping=parquet_path_to_id_mapping(schema), + ) + yield data_file ICEBERG_UNCOMPRESSED_CODEC = "uncompressed" diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 6d63fc9105..d8d05740ed 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -33,6 +33,7 @@ Dict, Generic, Iterable, + Iterator, List, Literal, Optional, @@ -56,6 +57,7 @@ parser, visitors, ) +from pyiceberg.expressions.literals import StringLiteral from pyiceberg.expressions.visitors import _InclusiveMetricsEvaluator, inclusive_projection from pyiceberg.io import FileIO, load_file_io from pyiceberg.manifest import ( @@ -117,6 +119,7 @@ Identifier, KeyDefaultDict, Properties, + Record, ) from pyiceberg.types import ( IcebergType, @@ -1140,7 +1143,7 @@ def add_files(self, file_paths: List[str]) -> None: Shorthand API for adding files as data files to the table. Args: - files: The list of full file paths to be added as data files to the table + file_paths: The list of full file paths to be added as data files to the table """ if self.name_mapping() is None: with self.transaction() as tx: @@ -2449,17 +2452,8 @@ def generate_data_file_filename(self, extension: str) -> str: @dataclass(frozen=True) class AddFileTask: - write_uuid: uuid.UUID - task_id: int - df: pa.Table - sort_order_id: Optional[int] = None - - # Later to be extended with partition information - - def generate_data_file_filename(self, extension: str) -> str: - # Mimics the behavior in the Java API: - # https://github.com/apache/iceberg/blob/a582968975dd30ff4917fbbe999f1be903efac02/core/src/main/java/org/apache/iceberg/io/OutputFileFactory.java#L92-L101 - return f"00000-{self.task_id}-{self.write_uuid}.{extension}" + file_path: str + partition_field_value: Record def _new_manifest_path(location: str, num: int, commit_uuid: uuid.UUID) -> str: @@ -2493,16 +2487,34 @@ def _dataframe_to_data_files( yield from write_file(io=io, table_metadata=table_metadata, tasks=iter([WriteTask(write_uuid, next(counter), df)])) +def add_file_tasks_from_file_paths(file_paths: List[str], table_metadata: TableMetadata) -> Iterator[AddFileTask]: + partition_spec = table_metadata.spec() + partition_struct = partition_spec.partition_type(table_metadata.schema()) + + for file_path in file_paths: + # file_path = 's3://warehouse/default/part1=2024-03-04/part2=ABCD' + # ['part1=2024-03-04', 'part2=ABCD'] + parts = [part for part in file_path.split("/") if "=" in part] + + partition_field_values = {} + for part in parts: + partition_name, string_value = part.split("=") + if partition_field := partition_struct.field_by_name(partition_name): + partition_field_values[partition_name] = StringLiteral(string_value).to(partition_field.field_type).value + + yield AddFileTask(file_path=file_path, partition_field_value=Record(**partition_field_values)) + + def _parquet_files_to_data_files(table_metadata: TableMetadata, file_paths: List[str], io: FileIO) -> Iterable[DataFile]: """Convert a list files into DataFiles. Returns: An iterable that supplies DataFiles that describe the parquet files. """ - from pyiceberg.io.pyarrow import parquet_file_to_data_file + from pyiceberg.io.pyarrow import parquet_files_to_data_files - for file_path in file_paths: - yield parquet_file_to_data_file(io=io, table_metadata=table_metadata, file_path=file_path) + tasks = add_file_tasks_from_file_paths(file_paths, table_metadata) + yield from parquet_files_to_data_files(io=io, table_metadata=table_metadata, tasks=tasks) class _MergingSnapshotProducer(UpdateTableMetadata["_MergingSnapshotProducer"]): diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index d76a5bf2bf..2aa33342cd 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -15,75 +15,92 @@ # specific language governing permissions and limitations # under the License. # pylint:disable=redefined-outer-name + +from datetime import date +from typing import Optional + import pyarrow as pa import pyarrow.parquet as pq import pytest -from pathlib import Path +from pyspark.sql import SparkSession -from pyiceberg.catalog import Catalog, Properties, Table -from pyiceberg.io.pyarrow import schema_to_pyarrow +from pyiceberg.catalog import Catalog, Table +from pyiceberg.exceptions import NoSuchTableError +from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema +from pyiceberg.transforms import IdentityTransform, MonthTransform from pyiceberg.types import ( BooleanType, + DateType, IntegerType, NestedField, StringType, ) -from pyiceberg.exceptions import NoSuchTableError -from pyspark.sql import SparkSession TABLE_SCHEMA = Schema( NestedField(field_id=1, name="foo", field_type=BooleanType(), required=False), NestedField(field_id=2, name="bar", field_type=StringType(), required=False), NestedField(field_id=4, name="baz", field_type=IntegerType(), required=False), + NestedField(field_id=10, name="qux", field_type=DateType(), required=False), ) +ARROW_SCHEMA = pa.schema([ + ("foo", pa.bool_()), + ("bar", pa.string()), + ("baz", pa.int32()), + ("qux", pa.date32()), +]) + ARROW_TABLE = pa.Table.from_pylist( - [ - { - "foo": True, - "bar": "bar_string", - "baz": 123, - } - ], - schema=schema_to_pyarrow(TABLE_SCHEMA), - ) - -def _create_table(session_catalog: Catalog, identifier: str) -> Table: + [ + { + "foo": True, + "bar": "bar_string", + "baz": 123, + "qux": date(2024, 3, 7), + } + ], + schema=ARROW_SCHEMA, +) + +PARTITION_SPEC = PartitionSpec( + PartitionField(source_id=4, field_id=1000, transform=IdentityTransform(), name="baz"), + PartitionField(source_id=10, field_id=1001, transform=IdentityTransform(), name="qux"), + spec_id=0, +) + + +def _create_table(session_catalog: Catalog, identifier: str, partition_spec: Optional[PartitionSpec] = None) -> Table: try: session_catalog.drop_table(identifier=identifier) except NoSuchTableError: pass - tbl = session_catalog.create_table(identifier=identifier, schema=TABLE_SCHEMA) + tbl = session_catalog.create_table( + identifier=identifier, schema=TABLE_SCHEMA, partition_spec=partition_spec if partition_spec else PartitionSpec() + ) return tbl + @pytest.mark.integration -def test_add_files_to_unpartitioned_table(spark: SparkSession, session_catalog: Catalog, warehouse: Path) -> None: +def test_add_files_to_unpartitioned_table(spark: SparkSession, session_catalog: Catalog) -> None: identifier = "default.unpartitioned_table" tbl = _create_table(session_catalog, identifier) - # rows = spark.sql( - # f""" - # SELECT added_data_files_count, existing_data_files_count, deleted_data_files_count - # FROM {identifier}.all_manifests - # """ - # ).collect() - - # assert [row.added_data_files_count for row in rows] == [] - # assert [row.existing_data_files_count for row in rows] == [] - # assert [row.deleted_data_files_count for row in rows] == [] - - file_paths = [f"/{warehouse}/test-{i}.parquet" for i in range(5)] + + file_paths = [f"s3://warehouse/default/unpartitioned/test-{i}.parquet" for i in range(5)] # write parquet files for file_path in file_paths: fo = tbl.io.new_output(file_path) with fo.create(overwrite=True) as fos: - with pq.ParquetWriter(fos, schema=ARROW_TABLE.schema) as writer: + with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer: writer.write_table(ARROW_TABLE) # add the parquet files as data files - tbl.add_files(file_paths) + tbl.add_files(file_paths=file_paths) + + # NameMapping must have been set to enable reads + assert tbl.name_mapping() is not None rows = spark.sql( f""" @@ -94,4 +111,45 @@ def test_add_files_to_unpartitioned_table(spark: SparkSession, session_catalog: assert [row.added_data_files_count for row in rows] == [5] assert [row.existing_data_files_count for row in rows] == [0] - assert [row.deleted_data_files_count for row in rows] == [0] \ No newline at end of file + assert [row.deleted_data_files_count for row in rows] == [0] + + df = spark.table(identifier) + assert df.count() == 5, "Expected 5 rows" + for col in df.columns: + assert df.filter(df[col].isNotNull()).count() == 5, "Expected all 5 rows to be non-null" + + +@pytest.mark.integration +def test_add_files_to_partitioned_table(spark: SparkSession, session_catalog: Catalog) -> None: + identifier = "default.partitioned_table" + tbl = _create_table(session_catalog, identifier, PARTITION_SPEC) + + file_paths = [f"s3://warehouse/default/baz=123/qux=2024-03-07/test-{i}.parquet" for i in range(5)] + # write parquet files + for file_path in file_paths: + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer: + writer.write_table(ARROW_TABLE) + + # add the parquet files as data files + tbl.add_files(file_paths=file_paths) + + # NameMapping must have been set to enable reads + assert tbl.name_mapping() is not None + + rows = spark.sql( + f""" + SELECT added_data_files_count, existing_data_files_count, deleted_data_files_count + FROM {identifier}.all_manifests + """ + ).collect() + + assert [row.added_data_files_count for row in rows] == [5] + assert [row.existing_data_files_count for row in rows] == [0] + assert [row.deleted_data_files_count for row in rows] == [0] + + df = spark.table(identifier) + assert df.count() == 5, "Expected 5 rows" + for col in df.columns: + assert df.filter(df[col].isNotNull()).count() == 5, "Expected all 5 rows to be non-null"