Skip to content

Commit

Permalink
Check for name before injecting, fix lint issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Gabriel Igliozzi committed Jan 25, 2025
1 parent 8362803 commit 4333dc0
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 25 deletions.
44 changes: 24 additions & 20 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@
)
from pyiceberg.partitioning import PartitionField, PartitionFieldValue, PartitionKey, PartitionSpec, partition_record_value
from pyiceberg.schema import (
Accessor,
PartnerAccessor,
PreOrderSchemaVisitor,
Schema,
Expand Down Expand Up @@ -1218,16 +1219,24 @@ def _field_id(self, field: pa.Field) -> int:


def _get_column_projection_values(
file: DataFile,
projected_schema: Schema,
project_schema_diff: Set[int],
partition_spec: PartitionSpec,
) -> Dict[str, object]:
file: DataFile, projected_schema: Schema, partition_spec: Optional[PartitionSpec], file_project_field_ids: Set[int]
) -> Tuple[bool, Dict[str, Any]]:
"""Apply Column Projection rules to File Schema."""
project_schema_diff = projected_schema.field_ids.difference(file_project_field_ids)
should_project_columns = len(project_schema_diff) > 0
projected_missing_fields: Dict[str, Any] = {}

partition_schema = partition_spec.partition_type(projected_schema)
accessors = build_position_accessors(partition_schema)
if not should_project_columns:
return False, {}

partition_schema: StructType
accessors: Dict[int, Accessor]

if partition_spec is not None:
partition_schema = partition_spec.partition_type(projected_schema)
accessors = build_position_accessors(partition_schema)
else:
return False, {}

for field_id in project_schema_diff:
for partition_field in partition_spec.fields_by_source_id(field_id):
Expand All @@ -1240,7 +1249,7 @@ def _get_column_projection_values(
if partition_value := accesor.get(file.partition):
projected_missing_fields[partition_field.name] = partition_value

return projected_missing_fields
return True, projected_missing_fields


def _task_to_record_batches(
Expand Down Expand Up @@ -1272,19 +1281,12 @@ def _task_to_record_batches(
bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive)
pyarrow_filter = expression_to_pyarrow(bound_file_filter)

# Apply column projection rules for missing partitions and default values
# Apply column projection rules
# https://iceberg.apache.org/spec/#column-projection
file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)

project_schema_diff = projected_field_ids.difference(file_project_schema.field_ids)
should_project_columns = len(project_schema_diff) > 0

projected_missing_fields = {}

if should_project_columns and partition_spec is not None:
projected_missing_fields = _get_column_projection_values(
task.file, projected_schema, project_schema_diff, partition_spec
)
should_project_columns, projected_missing_fields = _get_column_projection_values(
task.file, projected_schema, partition_spec, file_project_schema.field_ids
)

fragment_scanner = ds.Scanner.from_fragment(
fragment=fragment,
Expand Down Expand Up @@ -1335,7 +1337,9 @@ def _task_to_record_batches(
# Inject projected column values if available
if should_project_columns:
for name, value in projected_missing_fields.items():
result_batch = result_batch.set_column(result_batch.schema.get_field_index(name), name, [value])
index = result_batch.schema.get_field_index(name)
if index != -1:
result_batch = result_batch.set_column(index, name, [value])

yield result_batch

Expand Down
12 changes: 7 additions & 5 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
from pyiceberg.table import FileScanTask, TableProperties
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.table.name_mapping import create_mapping_from_schema
from pyiceberg.transforms import IdentityTransform, VoidTransform
from pyiceberg.transforms import IdentityTransform
from pyiceberg.typedef import UTF8, Properties, Record
from pyiceberg.types import (
BinaryType,
Expand Down Expand Up @@ -1127,7 +1127,7 @@ def test_projection_concat_files(schema_int: Schema, file_int: str) -> None:
assert repr(result_table.schema) == "id: int32"


def test_projection_single_partition_value(tmp_path: str, catalog: InMemoryCatalog) -> None:
def test_identity_transform_column_projection(tmp_path: str, catalog: InMemoryCatalog) -> None:
# Test by adding a non-partitioned data file to a partitioned table, verifying partition value projection from manifest metadata.
# TODO: Update to use a data file created by writing data to an unpartitioned table once add_files supports field IDs.
# (context: https://github.com/apache/iceberg-python/pull/1443#discussion_r1901374875)
Expand Down Expand Up @@ -1161,6 +1161,7 @@ def test_projection_single_partition_value(tmp_path: str, catalog: InMemoryCatal
content=DataFileContent.DATA,
file_path=file_loc,
file_format=FileFormat.PARQUET,
# projected value
partition=Record(partition_id=1),
file_size_in_bytes=os.path.getsize(file_loc),
sort_order_id=None,
Expand All @@ -1185,7 +1186,7 @@ def test_projection_single_partition_value(tmp_path: str, catalog: InMemoryCatal
)


def test_projection_multiple_partition_values(tmp_path: str, catalog: InMemoryCatalog) -> None:
def test_identity_transform_columns_projection(tmp_path: str, catalog: InMemoryCatalog) -> None:
# Test by adding a non-partitioned data file to a multi-partitioned table, verifying partition value projection from manifest metadata.
# TODO: Update to use a data file created by writing data to an unpartitioned table once add_files supports field IDs.
# (context: https://github.com/apache/iceberg-python/pull/1443#discussion_r1901374875)
Expand All @@ -1194,7 +1195,7 @@ def test_projection_multiple_partition_values(tmp_path: str, catalog: InMemoryCa
)

partition_spec = PartitionSpec(
PartitionField(2, 1000, VoidTransform(), "void_partition_id"),
PartitionField(2, 1000, IdentityTransform(), "void_partition_id"),
PartitionField(2, 1001, IdentityTransform(), "partition_id"),
)

Expand All @@ -1219,7 +1220,8 @@ def test_projection_multiple_partition_values(tmp_path: str, catalog: InMemoryCa
content=DataFileContent.DATA,
file_path=file_loc,
file_format=FileFormat.PARQUET,
partition=Record(void_partition_id=None, partition_id=1),
# projected value
partition=Record(void_partition_id=12, partition_id=1),
file_size_in_bytes=os.path.getsize(file_loc),
sort_order_id=None,
spec_id=table.metadata.default_spec_id,
Expand Down

0 comments on commit 4333dc0

Please sign in to comment.