Skip to content

Commit

Permalink
Add should_project check, add lookup by accessor, multiple-partition …
Browse files Browse the repository at this point in the history
…test
  • Loading branch information
Gabriel Igliozzi committed Jan 17, 2025
1 parent fee24ab commit 8362803
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 18 deletions.
43 changes: 29 additions & 14 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
SchemaVisitorPerPrimitiveType,
SchemaWithPartnerVisitor,
_check_schema_compatible,
build_position_accessors,
pre_order_visit,
promote,
prune_columns,
Expand Down Expand Up @@ -1219,18 +1220,25 @@ def _field_id(self, field: pa.Field) -> int:
def _get_column_projection_values(
file: DataFile,
projected_schema: Schema,
projected_field_ids: Set[int],
file_project_schema: Schema,
partition_spec: Optional[PartitionSpec] = None,
project_schema_diff: Set[int],
partition_spec: PartitionSpec,
) -> Dict[str, object]:
"""Apply Column Projection rules to File Schema."""
projected_missing_fields = {}
projected_missing_fields: Dict[str, Any] = {}

partition_schema = partition_spec.partition_type(projected_schema)
accessors = build_position_accessors(partition_schema)

for field_id in project_schema_diff:
for partition_field in partition_spec.fields_by_source_id(field_id):
if isinstance(partition_field.transform, IdentityTransform):
accesor = accessors.get(partition_field.field_id)

for field_id in projected_field_ids.difference(file_project_schema.field_ids):
if partition_spec is not None:
for partition_field in partition_spec.fields_by_source_id(field_id):
if isinstance(partition_field.transform, IdentityTransform) and partition_field.name in file.partition.__dict__:
projected_missing_fields[partition_field.name] = file.partition.__dict__[partition_field.name]
if accesor is None:
continue

if partition_value := accesor.get(file.partition):
projected_missing_fields[partition_field.name] = partition_value

return projected_missing_fields

Expand Down Expand Up @@ -1268,9 +1276,15 @@ def _task_to_record_batches(
# https://iceberg.apache.org/spec/#column-projection
file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)

projected_missing_fields = _get_column_projection_values(
task.file, projected_schema, projected_field_ids, file_project_schema, partition_spec
)
project_schema_diff = projected_field_ids.difference(file_project_schema.field_ids)
should_project_columns = len(project_schema_diff) > 0

projected_missing_fields = {}

if should_project_columns and partition_spec is not None:
projected_missing_fields = _get_column_projection_values(
task.file, projected_schema, project_schema_diff, partition_spec
)

fragment_scanner = ds.Scanner.from_fragment(
fragment=fragment,
Expand Down Expand Up @@ -1319,8 +1333,9 @@ def _task_to_record_batches(
)

# Inject projected column values if available
for name, value in projected_missing_fields.items():
result_batch = result_batch.set_column(result_batch.schema.get_field_index(name), name, [value])
if should_project_columns:
for name, value in projected_missing_fields.items():
result_batch = result_batch.set_column(result_batch.schema.get_field_index(name), name, [value])

yield result_batch

Expand Down
72 changes: 68 additions & 4 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
from pyiceberg.table import FileScanTask, TableProperties
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.table.name_mapping import create_mapping_from_schema
from pyiceberg.transforms import IdentityTransform
from pyiceberg.transforms import IdentityTransform, VoidTransform
from pyiceberg.typedef import UTF8, Properties, Record
from pyiceberg.types import (
BinaryType,
Expand Down Expand Up @@ -1127,15 +1127,21 @@ def test_projection_concat_files(schema_int: Schema, file_int: str) -> None:
assert repr(result_table.schema) == "id: int32"


def test_projection_partition_inference(tmp_path: str, catalog: InMemoryCatalog) -> None:
def test_projection_single_partition_value(tmp_path: str, catalog: InMemoryCatalog) -> None:
# Test by adding a non-partitioned data file to a partitioned table, verifying partition value projection from manifest metadata.
# TODO: Update to use a data file created by writing data to an unpartitioned table once add_files supports field IDs.
# (context: https://github.com/apache/iceberg-python/pull/1443#discussion_r1901374875)

schema = Schema(
NestedField(1, "other_field", StringType(), required=False), NestedField(2, "partition_id", IntegerType(), required=False)
)

partition_spec = PartitionSpec(PartitionField(2, 1000, IdentityTransform(), "partition_id"))
partition_spec = PartitionSpec(
PartitionField(2, 1000, IdentityTransform(), "partition_id"),
)

table = catalog.create_table(
"default.test_projection_partition_inference",
"default.test_projection_partition",
schema=schema,
partition_spec=partition_spec,
properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json()},
Expand Down Expand Up @@ -1179,6 +1185,64 @@ def test_projection_partition_inference(tmp_path: str, catalog: InMemoryCatalog)
)


def test_projection_multiple_partition_values(tmp_path: str, catalog: InMemoryCatalog) -> None:
# Test by adding a non-partitioned data file to a multi-partitioned table, verifying partition value projection from manifest metadata.
# TODO: Update to use a data file created by writing data to an unpartitioned table once add_files supports field IDs.
# (context: https://github.com/apache/iceberg-python/pull/1443#discussion_r1901374875)
schema = Schema(
NestedField(1, "other_field", StringType(), required=False), NestedField(2, "partition_id", IntegerType(), required=False)
)

partition_spec = PartitionSpec(
PartitionField(2, 1000, VoidTransform(), "void_partition_id"),
PartitionField(2, 1001, IdentityTransform(), "partition_id"),
)

table = catalog.create_table(
"default.test_projection_partitions",
schema=schema,
partition_spec=partition_spec,
properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json()},
)

file_data = pa.array(["foo"], type=pa.string())
file_loc = f"{tmp_path}/test.parquet"
pq.write_table(pa.table([file_data], names=["other_field"]), file_loc)

statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=pq.read_metadata(file_loc),
stats_columns=compute_statistics_plan(table.schema(), table.metadata.properties),
parquet_column_mapping=parquet_path_to_id_mapping(table.schema()),
)

unpartitioned_file = DataFile(
content=DataFileContent.DATA,
file_path=file_loc,
file_format=FileFormat.PARQUET,
partition=Record(void_partition_id=None, partition_id=1),
file_size_in_bytes=os.path.getsize(file_loc),
sort_order_id=None,
spec_id=table.metadata.default_spec_id,
equality_ids=None,
key_metadata=None,
**statistics.to_serialized_dict(),
)

with table.transaction() as transaction:
with transaction.update_snapshot().overwrite() as update:
update.append_data_file(unpartitioned_file)

assert (
str(table.scan().to_arrow())
== """pyarrow.Table
other_field: large_string
partition_id: int64
----
other_field: [["foo"]]
partition_id: [[1]]"""
)


@pytest.fixture
def catalog() -> InMemoryCatalog:
return InMemoryCatalog("test.in_memory.catalog", **{"test.key": "test.value"})
Expand Down

0 comments on commit 8362803

Please sign in to comment.