Skip to content

Commit

Permalink
PyArrow: Avoid buffer-overflow by avoid doing a sort (#1555)
Browse files Browse the repository at this point in the history
Second attempt of #1539

This was already being discussed back here:
#208 (comment)

This PR changes from doing a sort, and then a single pass over the table
to the approach where we determine the unique partition tuples filter on
them individually.

Fixes #1491

Because the sort caused buffers to be joined where it would overflow in
Arrow. I think this is an issue on the Arrow side, and it should
automatically break up into smaller buffers. The `combine_chunks` method
does this correctly.

Now:

```
0.42877754200890195
Run 1 took: 0.2507691659993725
Run 2 took: 0.24833179199777078
Run 3 took: 0.24401691700040828
Run 4 took: 0.2419595829996979
Average runtime of 0.28 seconds
```

Before:

```
Run 0 took: 1.0768639159941813
Run 1 took: 0.8784021250030492
Run 2 took: 0.8486490420036716
Run 3 took: 0.8614017910003895
Run 4 took: 0.8497851670108503
Average runtime of 0.9 seconds
```

So it comes with a nice speedup as well :)

---------

Co-authored-by: Kevin Liu <[email protected]>
  • Loading branch information
Fokko and kevinjqliu authored Jan 23, 2025
1 parent 872a445 commit 36d383d
Show file tree
Hide file tree
Showing 7 changed files with 805 additions and 743 deletions.
129 changes: 50 additions & 79 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@

import concurrent.futures
import fnmatch
import functools
import itertools
import logging
import operator
import os
import re
import uuid
Expand Down Expand Up @@ -2174,7 +2176,10 @@ def _partition_value(self, partition_field: PartitionField, schema: Schema) -> A
raise ValueError(
f"Cannot infer partition value from parquet metadata as there are more than one partition values for Partition Field: {partition_field.name}. {lower_value=}, {upper_value=}"
)
return lower_value

source_field = schema.find_field(partition_field.source_id)
transform = partition_field.transform.transform(source_field.field_type)
return transform(lower_value)

def partition(self, partition_spec: PartitionSpec, schema: Schema) -> Record:
return Record(**{field.name: self._partition_value(field, schema) for field in partition_spec.fields})
Expand Down Expand Up @@ -2558,38 +2563,8 @@ class _TablePartition:
arrow_table_partition: pa.Table


def _get_table_partitions(
arrow_table: pa.Table,
partition_spec: PartitionSpec,
schema: Schema,
slice_instructions: list[dict[str, Any]],
) -> list[_TablePartition]:
sorted_slice_instructions = sorted(slice_instructions, key=lambda x: x["offset"])

partition_fields = partition_spec.fields

offsets = [inst["offset"] for inst in sorted_slice_instructions]
projected_and_filtered = {
partition_field.source_id: arrow_table[schema.find_field(name_or_id=partition_field.source_id).name]
.take(offsets)
.to_pylist()
for partition_field in partition_fields
}

table_partitions = []
for idx, inst in enumerate(sorted_slice_instructions):
partition_slice = arrow_table.slice(**inst)
fieldvalues = [
PartitionFieldValue(partition_field, projected_and_filtered[partition_field.source_id][idx])
for partition_field in partition_fields
]
partition_key = PartitionKey(raw_partition_field_values=fieldvalues, partition_spec=partition_spec, schema=schema)
table_partitions.append(_TablePartition(partition_key=partition_key, arrow_table_partition=partition_slice))
return table_partitions


def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> List[_TablePartition]:
"""Based on the iceberg table partition spec, slice the arrow table into partitions with their keys.
"""Based on the iceberg table partition spec, filter the arrow table into partitions with their keys.
Example:
Input:
Expand All @@ -2598,54 +2573,50 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T
'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100],
'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", "Horse","Brittle stars", "Centipede"]}.
The algorithm:
Firstly we group the rows into partitions by sorting with sort order [('n_legs', 'descending'), ('year', 'descending')]
and null_placement of "at_end".
This gives the same table as raw input.
Then we sort_indices using reverse order of [('n_legs', 'descending'), ('year', 'descending')]
and null_placement : "at_start".
This gives:
[8, 7, 4, 5, 6, 3, 1, 2, 0]
Based on this we get partition groups of indices:
[{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4, 'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2}, {'offset': 0, 'length': 1}]
We then retrieve the partition keys by offsets.
And slice the arrow table by offsets and lengths of each partition.
- We determine the set of unique partition keys
- Then we produce a set of partitions by filtering on each of the combinations
- We combine the chunks to create a copy to avoid GIL congestion on the original table
"""
partition_columns: List[Tuple[PartitionField, NestedField]] = [
(partition_field, schema.find_field(partition_field.source_id)) for partition_field in spec.fields
]
partition_values_table = pa.table(
{
str(partition.field_id): partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name])
for partition, field in partition_columns
}
)
# Assign unique names to columns where the partition transform has been applied
# to avoid conflicts
partition_fields = [f"_partition_{field.name}" for field in spec.fields]

for partition, name in zip(spec.fields, partition_fields):
source_field = schema.find_field(partition.source_id)
arrow_table = arrow_table.append_column(
name, partition.transform.pyarrow_transform(source_field.field_type)(arrow_table[source_field.name])
)

unique_partition_fields = arrow_table.select(partition_fields).group_by(partition_fields).aggregate([])

table_partitions = []
# TODO: As a next step, we could also play around with yielding instead of materializing the full list
for unique_partition in unique_partition_fields.to_pylist():
partition_key = PartitionKey(
field_values=[
PartitionFieldValue(field=field, value=unique_partition[name])
for field, name in zip(spec.fields, partition_fields)
],
partition_spec=spec,
schema=schema,
)
filtered_table = arrow_table.filter(
functools.reduce(
operator.and_,
[
pc.field(partition_field_name) == unique_partition[partition_field_name]
if unique_partition[partition_field_name] is not None
else pc.field(partition_field_name).is_null()
for field, partition_field_name in zip(spec.fields, partition_fields)
],
)
)
filtered_table = filtered_table.drop_columns(partition_fields)

# Sort by partitions
sort_indices = pa.compute.sort_indices(
partition_values_table,
sort_keys=[(col, "ascending") for col in partition_values_table.column_names],
null_placement="at_end",
).to_pylist()
arrow_table = arrow_table.take(sort_indices)

# Get slice_instructions to group by partitions
partition_values_table = partition_values_table.take(sort_indices)
reversed_indices = pa.compute.sort_indices(
partition_values_table,
sort_keys=[(col, "descending") for col in partition_values_table.column_names],
null_placement="at_start",
).to_pylist()
slice_instructions: List[Dict[str, Any]] = []
last = len(reversed_indices)
reversed_indices_size = len(reversed_indices)
ptr = 0
while ptr < reversed_indices_size:
group_size = last - reversed_indices[ptr]
offset = reversed_indices[ptr]
slice_instructions.append({"offset": offset, "length": group_size})
last = reversed_indices[ptr]
ptr = ptr + group_size

table_partitions: List[_TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions)
# The combine_chunks seems to be counter-intuitive to do, but it actually returns
# fresh buffers that don't interfere with each other when it is written out to file
table_partitions.append(
_TablePartition(partition_key=partition_key, arrow_table_partition=filtered_table.combine_chunks())
)

return table_partitions
39 changes: 30 additions & 9 deletions pyiceberg/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Optional,
Tuple,
TypeVar,
Union,
)
from urllib.parse import quote_plus

Expand Down Expand Up @@ -393,14 +394,14 @@ class PartitionFieldValue:

@dataclass(frozen=True)
class PartitionKey:
raw_partition_field_values: List[PartitionFieldValue]
field_values: List[PartitionFieldValue]
partition_spec: PartitionSpec
schema: Schema

@cached_property
def partition(self) -> Record: # partition key transformed with iceberg internal representation as input
iceberg_typed_key_values = {}
for raw_partition_field_value in self.raw_partition_field_values:
for raw_partition_field_value in self.field_values:
partition_fields = self.partition_spec.source_id_to_fields_map[raw_partition_field_value.field.source_id]
if len(partition_fields) != 1:
raise ValueError(f"Cannot have redundant partitions: {partition_fields}")
Expand All @@ -427,25 +428,45 @@ def partition_record_value(partition_field: PartitionField, value: Any, schema:
the final partition record value.
"""
iceberg_type = schema.find_field(name_or_id=partition_field.source_id).field_type
iceberg_typed_value = _to_partition_representation(iceberg_type, value)
transformed_value = partition_field.transform.transform(iceberg_type)(iceberg_typed_value)
return transformed_value
return _to_partition_representation(iceberg_type, value)


@singledispatch
def _to_partition_representation(type: IcebergType, value: Any) -> Any:
"""Strip the logical type into the physical type.
It can be that the value is already transformed into its physical type,
in this case it will return the original value. Keep in mind that the
bucket transform always will return an int, but an identity transform
can return date that still needs to be transformed into an int (days
since epoch).
"""
return TypeError(f"Unsupported partition field type: {type}")


@_to_partition_representation.register(TimestampType)
@_to_partition_representation.register(TimestamptzType)
def _(type: IcebergType, value: Optional[datetime]) -> Optional[int]:
return datetime_to_micros(value) if value is not None else None
def _(type: IcebergType, value: Optional[Union[int, datetime]]) -> Optional[int]:
if value is None:
return None
elif isinstance(value, int):
return value
elif isinstance(value, datetime):
return datetime_to_micros(value)
else:
raise ValueError(f"Unknown type: {value}")


@_to_partition_representation.register(DateType)
def _(type: IcebergType, value: Optional[date]) -> Optional[int]:
return date_to_days(value) if value is not None else None
def _(type: IcebergType, value: Optional[Union[int, date]]) -> Optional[int]:
if value is None:
return None
elif isinstance(value, int):
return value
elif isinstance(value, date):
return date_to_days(value)
else:
raise ValueError(f"Unknown type: {value}")


@_to_partition_representation.register(TimeType)
Expand Down
6 changes: 4 additions & 2 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,10 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
with self._append_snapshot_producer(snapshot_properties) as append_files:
# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
data_files = _dataframe_to_data_files(
table_metadata=self.table_metadata, write_uuid=append_files.commit_uuid, df=df, io=self._table.io
data_files = list(
_dataframe_to_data_files(
table_metadata=self.table_metadata, write_uuid=append_files.commit_uuid, df=df, io=self._table.io
)
)
for data_file in data_files:
append_files.append_data_file(data_file)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -1220,6 +1220,7 @@ markers = [
"adls: marks a test as requiring access to adls compliant storage (use with --adls.account-name, --adls.account-key, and --adls.endpoint args)",
"integration: marks integration tests against Apache Spark",
"gcs: marks a test as requiring access to gcs compliant storage (use with --gs.token, --gs.project, and --gs.endpoint)",
"benchmark: collection of tests to validate read/write performance before and after a change"
]

# Turns a warning into an error
Expand Down
72 changes: 72 additions & 0 deletions tests/benchmark/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import statistics
import timeit
import urllib

import pyarrow as pa
import pyarrow.parquet as pq
import pytest

from pyiceberg.transforms import DayTransform


@pytest.fixture(scope="session")
def taxi_dataset(tmp_path_factory: pytest.TempPathFactory) -> pa.Table:
"""Reads the Taxi dataset to disk"""
taxi_dataset = "https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2022-01.parquet"
taxi_dataset_dest = tmp_path_factory.mktemp("taxi_dataset") / "yellow_tripdata_2022-01.parquet"
urllib.request.urlretrieve(taxi_dataset, taxi_dataset_dest)

return pq.read_table(taxi_dataset_dest)


@pytest.mark.benchmark
def test_partitioned_write(tmp_path_factory: pytest.TempPathFactory, taxi_dataset: pa.Table) -> None:
"""Tests writing to a partitioned table with something that would be close a production-like situation"""
from pyiceberg.catalog.sql import SqlCatalog

warehouse_path = str(tmp_path_factory.mktemp("warehouse"))
catalog = SqlCatalog(
"default",
uri=f"sqlite:///{warehouse_path}/pyiceberg_catalog.db",
warehouse=f"file://{warehouse_path}",
)

catalog.create_namespace("default")

tbl = catalog.create_table("default.taxi_partitioned", schema=taxi_dataset.schema)

with tbl.update_spec() as spec:
spec.add_field("tpep_pickup_datetime", DayTransform())

# Profiling can sometimes be handy as well
# with cProfile.Profile() as pr:
# tbl.append(taxi_dataset)
#
# pr.print_stats(sort=True)

runs = []
for run in range(5):
start_time = timeit.default_timer()
tbl.append(taxi_dataset)
elapsed = timeit.default_timer() - start_time

print(f"Run {run} took: {elapsed}")
runs.append(elapsed)

print(f"Average runtime of {round(statistics.mean(runs), 2)} seconds")
Loading

0 comments on commit 36d383d

Please sign in to comment.