diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 117b40cfcc..930aeba678 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -123,6 +123,7 @@ ) from pyiceberg.partitioning import PartitionField, PartitionFieldValue, PartitionKey, PartitionSpec, partition_record_value from pyiceberg.schema import ( + Accessor, PartnerAccessor, PreOrderSchemaVisitor, Schema, @@ -1218,16 +1219,24 @@ def _field_id(self, field: pa.Field) -> int: def _get_column_projection_values( - file: DataFile, - projected_schema: Schema, - project_schema_diff: Set[int], - partition_spec: PartitionSpec, -) -> Dict[str, object]: + file: DataFile, projected_schema: Schema, partition_spec: Optional[PartitionSpec], file_project_field_ids: Set[int] +) -> Tuple[bool, Dict[str, Any]]: """Apply Column Projection rules to File Schema.""" + project_schema_diff = projected_schema.field_ids.difference(file_project_field_ids) + should_project_columns = len(project_schema_diff) > 0 projected_missing_fields: Dict[str, Any] = {} - partition_schema = partition_spec.partition_type(projected_schema) - accessors = build_position_accessors(partition_schema) + if not should_project_columns: + return False, {} + + partition_schema: StructType + accessors: Dict[int, Accessor] + + if partition_spec is not None: + partition_schema = partition_spec.partition_type(projected_schema) + accessors = build_position_accessors(partition_schema) + else: + return False, {} for field_id in project_schema_diff: for partition_field in partition_spec.fields_by_source_id(field_id): @@ -1240,7 +1249,7 @@ def _get_column_projection_values( if partition_value := accesor.get(file.partition): projected_missing_fields[partition_field.name] = partition_value - return projected_missing_fields + return True, projected_missing_fields def _task_to_record_batches( @@ -1272,19 +1281,12 @@ def _task_to_record_batches( bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive) pyarrow_filter = expression_to_pyarrow(bound_file_filter) - # Apply column projection rules for missing partitions and default values + # Apply column projection rules # https://iceberg.apache.org/spec/#column-projection file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False) - - project_schema_diff = projected_field_ids.difference(file_project_schema.field_ids) - should_project_columns = len(project_schema_diff) > 0 - - projected_missing_fields = {} - - if should_project_columns and partition_spec is not None: - projected_missing_fields = _get_column_projection_values( - task.file, projected_schema, project_schema_diff, partition_spec - ) + should_project_columns, projected_missing_fields = _get_column_projection_values( + task.file, projected_schema, partition_spec, file_project_schema.field_ids + ) fragment_scanner = ds.Scanner.from_fragment( fragment=fragment, @@ -1335,7 +1337,9 @@ def _task_to_record_batches( # Inject projected column values if available if should_project_columns: for name, value in projected_missing_fields.items(): - result_batch = result_batch.set_column(result_batch.schema.get_field_index(name), name, [value]) + index = result_batch.schema.get_field_index(name) + if index != -1: + result_batch = result_batch.set_column(index, name, [value]) yield result_batch diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 00c5607c67..ec149c0d6b 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -81,7 +81,7 @@ from pyiceberg.table import FileScanTask, TableProperties from pyiceberg.table.metadata import TableMetadataV2 from pyiceberg.table.name_mapping import create_mapping_from_schema -from pyiceberg.transforms import IdentityTransform, VoidTransform +from pyiceberg.transforms import IdentityTransform from pyiceberg.typedef import UTF8, Properties, Record from pyiceberg.types import ( BinaryType, @@ -1127,7 +1127,7 @@ def test_projection_concat_files(schema_int: Schema, file_int: str) -> None: assert repr(result_table.schema) == "id: int32" -def test_projection_single_partition_value(tmp_path: str, catalog: InMemoryCatalog) -> None: +def test_identity_transform_column_projection(tmp_path: str, catalog: InMemoryCatalog) -> None: # Test by adding a non-partitioned data file to a partitioned table, verifying partition value projection from manifest metadata. # TODO: Update to use a data file created by writing data to an unpartitioned table once add_files supports field IDs. # (context: https://github.com/apache/iceberg-python/pull/1443#discussion_r1901374875) @@ -1161,6 +1161,7 @@ def test_projection_single_partition_value(tmp_path: str, catalog: InMemoryCatal content=DataFileContent.DATA, file_path=file_loc, file_format=FileFormat.PARQUET, + # projected value partition=Record(partition_id=1), file_size_in_bytes=os.path.getsize(file_loc), sort_order_id=None, @@ -1185,7 +1186,7 @@ def test_projection_single_partition_value(tmp_path: str, catalog: InMemoryCatal ) -def test_projection_multiple_partition_values(tmp_path: str, catalog: InMemoryCatalog) -> None: +def test_identity_transform_columns_projection(tmp_path: str, catalog: InMemoryCatalog) -> None: # Test by adding a non-partitioned data file to a multi-partitioned table, verifying partition value projection from manifest metadata. # TODO: Update to use a data file created by writing data to an unpartitioned table once add_files supports field IDs. # (context: https://github.com/apache/iceberg-python/pull/1443#discussion_r1901374875) @@ -1194,7 +1195,7 @@ def test_projection_multiple_partition_values(tmp_path: str, catalog: InMemoryCa ) partition_spec = PartitionSpec( - PartitionField(2, 1000, VoidTransform(), "void_partition_id"), + PartitionField(2, 1000, IdentityTransform(), "void_partition_id"), PartitionField(2, 1001, IdentityTransform(), "partition_id"), ) @@ -1219,7 +1220,8 @@ def test_projection_multiple_partition_values(tmp_path: str, catalog: InMemoryCa content=DataFileContent.DATA, file_path=file_loc, file_format=FileFormat.PARQUET, - partition=Record(void_partition_id=None, partition_id=1), + # projected value + partition=Record(void_partition_id=12, partition_id=1), file_size_in_bytes=os.path.getsize(file_loc), sort_order_id=None, spec_id=table.metadata.default_spec_id,