Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement column projection #1443

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
54 changes: 50 additions & 4 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
SchemaVisitorPerPrimitiveType,
SchemaWithPartnerVisitor,
_check_schema_compatible,
build_position_accessors,
pre_order_visit,
promote,
prune_columns,
Expand All @@ -138,7 +139,7 @@
)
from pyiceberg.table.metadata import TableMetadata
from pyiceberg.table.name_mapping import NameMapping, apply_name_mapping
from pyiceberg.transforms import TruncateTransform
from pyiceberg.transforms import IdentityTransform, TruncateTransform
from pyiceberg.typedef import EMPTY_DICT, Properties, Record
from pyiceberg.types import (
BinaryType,
Expand Down Expand Up @@ -1216,6 +1217,32 @@ def _field_id(self, field: pa.Field) -> int:
return -1


def _get_column_projection_values(
file: DataFile,
projected_schema: Schema,
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved
project_schema_diff: Set[int],
partition_spec: PartitionSpec,
) -> Dict[str, object]:
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved
"""Apply Column Projection rules to File Schema."""
projected_missing_fields: Dict[str, Any] = {}

partition_schema = partition_spec.partition_type(projected_schema)
accessors = build_position_accessors(partition_schema)

for field_id in project_schema_diff:
for partition_field in partition_spec.fields_by_source_id(field_id):
if isinstance(partition_field.transform, IdentityTransform):
accesor = accessors.get(partition_field.field_id)

if accesor is None:
continue

if partition_value := accesor.get(file.partition):
projected_missing_fields[partition_field.name] = partition_value

return projected_missing_fields


def _task_to_record_batches(
fs: FileSystem,
task: FileScanTask,
Expand All @@ -1226,6 +1253,7 @@ def _task_to_record_batches(
case_sensitive: bool,
name_mapping: Optional[NameMapping] = None,
use_large_types: bool = True,
partition_spec: Optional[PartitionSpec] = None,
) -> Iterator[pa.RecordBatch]:
_, _, path = _parse_location(task.file.file_path)
arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
Expand All @@ -1237,16 +1265,26 @@ def _task_to_record_batches(
# When V3 support is introduced, we will update `downcast_ns_timestamp_to_us` flag based on
# the table format version.
file_schema = pyarrow_to_schema(physical_schema, name_mapping, downcast_ns_timestamp_to_us=True)

pyarrow_filter = None
if bound_row_filter is not AlwaysTrue():
translated_row_filter = translate_column_names(bound_row_filter, file_schema, case_sensitive=case_sensitive)
bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive)
pyarrow_filter = expression_to_pyarrow(bound_file_filter)

# Apply column projection rules for missing partitions and default values
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved
# https://iceberg.apache.org/spec/#column-projection
file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)

if file_schema is None:
raise ValueError(f"Missing Iceberg schema in Metadata for file: {path}")
project_schema_diff = projected_field_ids.difference(file_project_schema.field_ids)
should_project_columns = len(project_schema_diff) > 0

projected_missing_fields = {}

if should_project_columns and partition_spec is not None:
projected_missing_fields = _get_column_projection_values(
task.file, projected_schema, project_schema_diff, partition_spec
)
Comment on lines +1279 to +1287
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: wdyt about structuring the code like this?

Suggested change
project_schema_diff = projected_field_ids.difference(file_project_schema.field_ids)
should_project_columns = len(project_schema_diff) > 0
projected_missing_fields = {}
if should_project_columns and partition_spec is not None:
projected_missing_fields = _get_column_projection_values(
task.file, projected_schema, project_schema_diff, partition_spec
)
should_project_columns, projected_missing_fields = _get_column_projection_values(
task.file, projected_schema, partition_spec
)

and in _get_column_projection_values, move the rest of the logic

def _get_column_projection_values(...):
        project_schema_diff = projected_field_ids.difference(file_project_schema.field_ids)
        should_project_columns = len(project_schema_diff) > 0
        projected_missing_fields = {}
        if not should_project_columns:
                return False, {}         
        ...


fragment_scanner = ds.Scanner.from_fragment(
fragment=fragment,
Expand Down Expand Up @@ -1286,14 +1324,21 @@ def _task_to_record_batches(
continue
output_batches = arrow_table.to_batches()
for output_batch in output_batches:
yield _to_requested_schema(
result_batch = _to_requested_schema(
projected_schema,
file_project_schema,
output_batch,
downcast_ns_timestamp_to_us=True,
use_large_types=use_large_types,
)

# Inject projected column values if available
if should_project_columns:
for name, value in projected_missing_fields.items():
result_batch = result_batch.set_column(result_batch.schema.get_field_index(name), name, [value])

yield result_batch


def _task_to_table(
fs: FileSystem,
Expand Down Expand Up @@ -1517,6 +1562,7 @@ def _record_batches_from_scan_tasks_and_deletes(
self._case_sensitive,
self._table_metadata.name_mapping(),
self._use_large_types,
self._table_metadata.spec(),
)
for batch in batches:
if self._limit is not None:
Expand Down
128 changes: 127 additions & 1 deletion tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,19 @@
_read_deletes,
_to_requested_schema,
bin_pack_arrow_table,
compute_statistics_plan,
data_file_statistics_from_parquet_metadata,
expression_to_pyarrow,
parquet_path_to_id_mapping,
schema_to_pyarrow,
)
from pyiceberg.manifest import DataFile, DataFileContent, FileFormat
from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.schema import Schema, make_compatible_name, visit
from pyiceberg.table import FileScanTask, TableProperties
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.transforms import IdentityTransform
from pyiceberg.table.name_mapping import create_mapping_from_schema
from pyiceberg.transforms import IdentityTransform, VoidTransform
from pyiceberg.typedef import UTF8, Properties, Record
from pyiceberg.types import (
BinaryType,
Expand All @@ -99,6 +103,7 @@
TimestamptzType,
TimeType,
)
from tests.catalog.test_base import InMemoryCatalog
from tests.conftest import UNIFIED_AWS_SESSION_PROPERTIES


Expand Down Expand Up @@ -1122,6 +1127,127 @@ def test_projection_concat_files(schema_int: Schema, file_int: str) -> None:
assert repr(result_table.schema) == "id: int32"


def test_projection_single_partition_value(tmp_path: str, catalog: InMemoryCatalog) -> None:
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved
# Test by adding a non-partitioned data file to a partitioned table, verifying partition value projection from manifest metadata.
# TODO: Update to use a data file created by writing data to an unpartitioned table once add_files supports field IDs.
# (context: https://github.com/apache/iceberg-python/pull/1443#discussion_r1901374875)

schema = Schema(
NestedField(1, "other_field", StringType(), required=False), NestedField(2, "partition_id", IntegerType(), required=False)
)

partition_spec = PartitionSpec(
PartitionField(2, 1000, IdentityTransform(), "partition_id"),
)

table = catalog.create_table(
"default.test_projection_partition",
schema=schema,
partition_spec=partition_spec,
properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json()},
)

file_data = pa.array(["foo"], type=pa.string())
file_loc = f"{tmp_path}/test.parquet"
pq.write_table(pa.table([file_data], names=["other_field"]), file_loc)

statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=pq.read_metadata(file_loc),
stats_columns=compute_statistics_plan(table.schema(), table.metadata.properties),
parquet_column_mapping=parquet_path_to_id_mapping(table.schema()),
)

unpartitioned_file = DataFile(
content=DataFileContent.DATA,
file_path=file_loc,
file_format=FileFormat.PARQUET,
partition=Record(partition_id=1),
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved
file_size_in_bytes=os.path.getsize(file_loc),
sort_order_id=None,
spec_id=table.metadata.default_spec_id,
equality_ids=None,
key_metadata=None,
**statistics.to_serialized_dict(),
)

with table.transaction() as transaction:
with transaction.update_snapshot().overwrite() as update:
update.append_data_file(unpartitioned_file)

assert (
str(table.scan().to_arrow())
== """pyarrow.Table
other_field: large_string
partition_id: int64
----
other_field: [["foo"]]
partition_id: [[1]]"""
)


def test_projection_multiple_partition_values(tmp_path: str, catalog: InMemoryCatalog) -> None:
# Test by adding a non-partitioned data file to a multi-partitioned table, verifying partition value projection from manifest metadata.
# TODO: Update to use a data file created by writing data to an unpartitioned table once add_files supports field IDs.
# (context: https://github.com/apache/iceberg-python/pull/1443#discussion_r1901374875)
schema = Schema(
NestedField(1, "other_field", StringType(), required=False), NestedField(2, "partition_id", IntegerType(), required=False)
)

partition_spec = PartitionSpec(
PartitionField(2, 1000, VoidTransform(), "void_partition_id"),
PartitionField(2, 1001, IdentityTransform(), "partition_id"),
)
Comment on lines +1196 to +1199
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we'd want to test multiple IdentityTransforms here.

im thinking about a case for multiple-level of partitioning in hive-style.

s3://my_table/a=100/b=foo/...parquet

i think _get_column_projection_values might not support this right now


table = catalog.create_table(
"default.test_projection_partitions",
schema=schema,
partition_spec=partition_spec,
properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json()},
)

file_data = pa.array(["foo"], type=pa.string())
file_loc = f"{tmp_path}/test.parquet"
pq.write_table(pa.table([file_data], names=["other_field"]), file_loc)

statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=pq.read_metadata(file_loc),
stats_columns=compute_statistics_plan(table.schema(), table.metadata.properties),
parquet_column_mapping=parquet_path_to_id_mapping(table.schema()),
)

unpartitioned_file = DataFile(
content=DataFileContent.DATA,
file_path=file_loc,
file_format=FileFormat.PARQUET,
partition=Record(void_partition_id=None, partition_id=1),
file_size_in_bytes=os.path.getsize(file_loc),
sort_order_id=None,
spec_id=table.metadata.default_spec_id,
equality_ids=None,
key_metadata=None,
**statistics.to_serialized_dict(),
)

with table.transaction() as transaction:
with transaction.update_snapshot().overwrite() as update:
update.append_data_file(unpartitioned_file)

assert (
str(table.scan().to_arrow())
== """pyarrow.Table
other_field: large_string
partition_id: int64
----
other_field: [["foo"]]
partition_id: [[1]]"""
)


@pytest.fixture
def catalog() -> InMemoryCatalog:
return InMemoryCatalog("test.in_memory.catalog", **{"test.key": "test.value"})


def test_projection_filter(schema_int: Schema, file_int: str) -> None:
result_table = project(schema_int, [file_int], GreaterThan("id", 4))
assert len(result_table.columns[0]) == 0
Expand Down
Loading