Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement column projection #1443

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
29 changes: 25 additions & 4 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@
)
from pyiceberg.table.metadata import TableMetadata
from pyiceberg.table.name_mapping import NameMapping, apply_name_mapping
from pyiceberg.transforms import TruncateTransform
from pyiceberg.transforms import IdentityTransform, TruncateTransform
from pyiceberg.typedef import EMPTY_DICT, Properties, Record
from pyiceberg.types import (
BinaryType,
Expand Down Expand Up @@ -1226,6 +1226,7 @@ def _task_to_record_batches(
case_sensitive: bool,
name_mapping: Optional[NameMapping] = None,
use_large_types: bool = True,
partition_spec: Optional[PartitionSpec] = None,
) -> Iterator[pa.RecordBatch]:
_, _, path = _parse_location(task.file.file_path)
arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
Expand All @@ -1237,16 +1238,29 @@ def _task_to_record_batches(
# When V3 support is introduced, we will update `downcast_ns_timestamp_to_us` flag based on
# the table format version.
file_schema = pyarrow_to_schema(physical_schema, name_mapping, downcast_ns_timestamp_to_us=True)

if file_schema is None:
Fokko marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"Missing Iceberg schema in Metadata for file: {path}")

pyarrow_filter = None
if bound_row_filter is not AlwaysTrue():
translated_row_filter = translate_column_names(bound_row_filter, file_schema, case_sensitive=case_sensitive)
bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive)
pyarrow_filter = expression_to_pyarrow(bound_file_filter)

# Apply column projection rules for missing partitions and default values
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved
# https://iceberg.apache.org/spec/#column-projection
file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)
projected_missing_fields = {}

if file_schema is None:
raise ValueError(f"Missing Iceberg schema in Metadata for file: {path}")
for field_id in projected_field_ids.difference(file_project_schema.field_ids):
if nested_field := projected_schema.find_field(field_id):
if nested_field.initial_default is not None:
projected_missing_fields[nested_field.name] = nested_field.initial_default
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved
if partition_spec is not None:
for partition_field in partition_spec.fields_by_source_id(field_id):
if isinstance(partition_field.transform, IdentityTransform) and task.file.partition is not None:
projected_missing_fields[partition_field.name] = task.file.partition[0]
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved

fragment_scanner = ds.Scanner.from_fragment(
fragment=fragment,
Expand Down Expand Up @@ -1286,14 +1300,20 @@ def _task_to_record_batches(
continue
output_batches = arrow_table.to_batches()
for output_batch in output_batches:
yield _to_requested_schema(
result_batch = _to_requested_schema(
projected_schema,
file_project_schema,
output_batch,
downcast_ns_timestamp_to_us=True,
use_large_types=use_large_types,
)

# Inject projected column values if available
for name, value in projected_missing_fields.items():
result_batch = result_batch.set_column(result_batch.schema.get_field_index(name), name, [value])
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved

yield result_batch


def _task_to_table(
fs: FileSystem,
Expand Down Expand Up @@ -1517,6 +1537,7 @@ def _record_batches_from_scan_tasks_and_deletes(
self._case_sensitive,
self._table_metadata.name_mapping(),
self._use_large_types,
self._table_metadata.spec(),
)
for batch in batches:
if self._limit is not None:
Expand Down
105 changes: 105 additions & 0 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
from pyiceberg.schema import Schema, make_compatible_name, visit
from pyiceberg.table import FileScanTask, TableProperties
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.table.name_mapping import create_mapping_from_schema
from pyiceberg.transforms import IdentityTransform
from pyiceberg.typedef import UTF8, Properties, Record
from pyiceberg.types import (
Expand Down Expand Up @@ -1122,6 +1123,110 @@ def test_projection_concat_files(schema_int: Schema, file_int: str) -> None:
assert repr(result_table.schema) == "id: int32"


def test_projection_partition_inference(tmp_path: str) -> None:
Fokko marked this conversation as resolved.
Show resolved Hide resolved
schema = Schema(
NestedField(1, "partition_field", IntegerType(), required=False),
NestedField(2, "other_field", StringType(), required=False),
)

partition_spec = PartitionSpec(PartitionField(1, 1000, IdentityTransform(), "partition_field"))

table = TableMetadataV2(
location="file://a/b/c.json",
last_column_id=2,
format_version=2,
current_schema_id=0,
schemas=[schema],
partition_specs=[partition_spec],
properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json()},
)

pa_schema = pa.schema([pa.field("other_field", pa.string())])
pa_table = pa.table({"other_field": ["x"]}, schema=pa_schema)
pq.write_table(pa_table, f"{tmp_path}/datafile.parquet")

data_file = DataFile(
content=DataFileContent.DATA,
file_path=f"{tmp_path}/datafile.parquet",
file_format=FileFormat.PARQUET,
partition=Record(partition_id=123456),
file_size_in_bytes=os.path.getsize(f"{tmp_path}/datafile.parquet"),
sort_order_id=None,
spec_id=0,
equality_ids=None,
key_metadata=None,
)

table_result_scan = ArrowScan(
table_metadata=table,
io=load_file_io(),
projected_schema=schema,
row_filter=AlwaysTrue(),
).to_table(tasks=[FileScanTask(data_file=data_file)])

assert (
str(table_result_scan)
== """pyarrow.Table
partition_field: int64
other_field: large_string
----
partition_field: [[123456]]
other_field: [["x"]]"""
)


def test_projection_initial_default_inference(tmp_path: str) -> None:
schema = Schema(
NestedField(1, "other_field", StringType(), required=False),
NestedField(2, "other_field_1", StringType(), required=False, initial_default="foo"),
)

table = TableMetadataV2(
location="file://a/b/c.json",
last_column_id=2,
format_version=2,
current_schema_id=0,
schemas=[schema],
partition_specs=[PartitionSpec()],
properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json()},
)

pa_schema = pa.schema([pa.field("other_field", pa.string())])
pa_table = pa.table({"other_field": ["x"]}, schema=pa_schema)
pq.write_table(pa_table, f"{tmp_path}/datafile.parquet")

data_file = DataFile(
content=DataFileContent.DATA,
file_path=f"{tmp_path}/datafile.parquet",
file_format=FileFormat.PARQUET,
partition=Record(),
file_size_in_bytes=os.path.getsize(f"{tmp_path}/datafile.parquet"),
sort_order_id=None,
spec_id=0,
equality_ids=None,
key_metadata=None,
)

table_result_scan = ArrowScan(
table_metadata=table,
io=load_file_io(),
projected_schema=schema,
row_filter=AlwaysTrue(),
).to_table(tasks=[FileScanTask(data_file=data_file)])

print(str(table_result_scan))

assert (
str(table_result_scan)
== """pyarrow.Table
other_field: large_string
other_field_1: string
----
other_field: [["x"]]
other_field_1: [["foo"]]"""
)


def test_projection_filter(schema_int: Schema, file_int: str) -> None:
result_table = project(schema_int, [file_int], GreaterThan("id", 4))
assert len(result_table.columns[0]) == 0
Expand Down