Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement column projection #1443

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
39 changes: 35 additions & 4 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@
)
from pyiceberg.table.metadata import TableMetadata
from pyiceberg.table.name_mapping import NameMapping, apply_name_mapping
from pyiceberg.transforms import TruncateTransform
from pyiceberg.transforms import IdentityTransform, TruncateTransform
from pyiceberg.typedef import EMPTY_DICT, Properties, Record
from pyiceberg.types import (
BinaryType,
Expand Down Expand Up @@ -1216,6 +1216,25 @@ def _field_id(self, field: pa.Field) -> int:
return -1


def _get_column_projection_values(
file: DataFile,
projected_schema: Schema,
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved
projected_field_ids: Set[int],
file_project_schema: Schema,
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved
partition_spec: Optional[PartitionSpec] = None,
) -> Dict[str, object]:
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved
"""Apply Column Projection rules to File Schema."""
projected_missing_fields = {}

for field_id in projected_field_ids.difference(file_project_schema.field_ids):
if partition_spec is not None:
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved
for partition_field in partition_spec.fields_by_source_id(field_id):
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(partition_field.transform, IdentityTransform) and partition_field.name in file.partition.__dict__:
projected_missing_fields[partition_field.name] = file.partition.__dict__[partition_field.name]
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved

return projected_missing_fields


def _task_to_record_batches(
fs: FileSystem,
task: FileScanTask,
Expand All @@ -1226,6 +1245,7 @@ def _task_to_record_batches(
case_sensitive: bool,
name_mapping: Optional[NameMapping] = None,
use_large_types: bool = True,
partition_spec: Optional[PartitionSpec] = None,
) -> Iterator[pa.RecordBatch]:
_, _, path = _parse_location(task.file.file_path)
arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
Expand All @@ -1237,16 +1257,20 @@ def _task_to_record_batches(
# When V3 support is introduced, we will update `downcast_ns_timestamp_to_us` flag based on
# the table format version.
file_schema = pyarrow_to_schema(physical_schema, name_mapping, downcast_ns_timestamp_to_us=True)

pyarrow_filter = None
if bound_row_filter is not AlwaysTrue():
translated_row_filter = translate_column_names(bound_row_filter, file_schema, case_sensitive=case_sensitive)
bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive)
pyarrow_filter = expression_to_pyarrow(bound_file_filter)

# Apply column projection rules for missing partitions and default values
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved
# https://iceberg.apache.org/spec/#column-projection
file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)

if file_schema is None:
raise ValueError(f"Missing Iceberg schema in Metadata for file: {path}")
projected_missing_fields = _get_column_projection_values(
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved
task.file, projected_schema, projected_field_ids, file_project_schema, partition_spec
)
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved

fragment_scanner = ds.Scanner.from_fragment(
fragment=fragment,
Expand Down Expand Up @@ -1286,14 +1310,20 @@ def _task_to_record_batches(
continue
output_batches = arrow_table.to_batches()
for output_batch in output_batches:
yield _to_requested_schema(
result_batch = _to_requested_schema(
projected_schema,
file_project_schema,
output_batch,
downcast_ns_timestamp_to_us=True,
use_large_types=use_large_types,
)

# Inject projected column values if available
for name, value in projected_missing_fields.items():
result_batch = result_batch.set_column(result_batch.schema.get_field_index(name), name, [value])
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved

yield result_batch


def _task_to_table(
fs: FileSystem,
Expand Down Expand Up @@ -1517,6 +1547,7 @@ def _record_batches_from_scan_tasks_and_deletes(
self._case_sensitive,
self._table_metadata.name_mapping(),
self._use_large_types,
self._table_metadata.spec(),
)
for batch in batches:
if self._limit is not None:
Expand Down
62 changes: 62 additions & 0 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,18 @@
_read_deletes,
_to_requested_schema,
bin_pack_arrow_table,
compute_statistics_plan,
data_file_statistics_from_parquet_metadata,
expression_to_pyarrow,
parquet_path_to_id_mapping,
schema_to_pyarrow,
)
from pyiceberg.manifest import DataFile, DataFileContent, FileFormat
from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.schema import Schema, make_compatible_name, visit
from pyiceberg.table import FileScanTask, TableProperties
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.table.name_mapping import create_mapping_from_schema
from pyiceberg.transforms import IdentityTransform
from pyiceberg.typedef import UTF8, Properties, Record
from pyiceberg.types import (
Expand All @@ -99,6 +103,7 @@
TimestamptzType,
TimeType,
)
from tests.catalog.test_base import InMemoryCatalog
from tests.conftest import UNIFIED_AWS_SESSION_PROPERTIES


Expand Down Expand Up @@ -1122,6 +1127,63 @@ def test_projection_concat_files(schema_int: Schema, file_int: str) -> None:
assert repr(result_table.schema) == "id: int32"


def test_projection_partition_inference(tmp_path: str, catalog: InMemoryCatalog) -> None:
schema = Schema(
NestedField(1, "other_field", StringType(), required=False), NestedField(2, "partition_id", IntegerType(), required=False)
)

partition_spec = PartitionSpec(PartitionField(2, 1000, IdentityTransform(), "partition_id"))

table = catalog.create_table(
"default.test_projection_partition_inference",
schema=schema,
partition_spec=partition_spec,
properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json()},
)

file_data = pa.array(["foo"], type=pa.string())
file_loc = f"{tmp_path}/test.parquet"
pq.write_table(pa.table([file_data], names=["other_field"]), file_loc)

statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=pq.read_metadata(file_loc),
stats_columns=compute_statistics_plan(table.schema(), table.metadata.properties),
parquet_column_mapping=parquet_path_to_id_mapping(table.schema()),
)

unpartitioned_file = DataFile(
content=DataFileContent.DATA,
file_path=file_loc,
file_format=FileFormat.PARQUET,
partition=Record(partition_id=1),
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved
file_size_in_bytes=os.path.getsize(file_loc),
sort_order_id=None,
spec_id=table.metadata.default_spec_id,
equality_ids=None,
key_metadata=None,
**statistics.to_serialized_dict(),
)

with table.transaction() as transaction:
with transaction.update_snapshot().overwrite() as update:
update.append_data_file(unpartitioned_file)
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved

assert (
str(table.scan().to_arrow())
== """pyarrow.Table
other_field: large_string
partition_id: int64
----
other_field: [["foo"]]
partition_id: [[1]]"""
)


@pytest.fixture
def catalog() -> InMemoryCatalog:
return InMemoryCatalog("test.in_memory.catalog", **{"test.key": "test.value"})


def test_projection_filter(schema_int: Schema, file_int: str) -> None:
result_table = project(schema_int, [file_int], GreaterThan("id", 4))
assert len(result_table.columns[0]) == 0
Expand Down