Skip to content

Commit

Permalink
PyArrow: Avoid buffer-overflow by avoid doing a sort
Browse files Browse the repository at this point in the history
This was already being discussed back here:

#208 (comment)

This PR changes from doing a sort, and then a single pass over the
table to the the approach where we determine the unique partition tuples
then filter on them one by one.

Fixes #1491

Because the sort caused buffers to be joined where it would overflow
in Arrow. I think this is an issue on the Arrow side, and it should
automatically break up into smaller buffers. The `combine_chunks`
method does this correctly.

Now:

```
0.42877754200890195
Run 1 took: 0.2507691659993725
Run 2 took: 0.24833179199777078
Run 3 took: 0.24401691700040828
Run 4 took: 0.2419595829996979
Average runtime of 0.28 seconds
```

Before:

```
Run 0 took: 1.0768639159941813
Run 1 took: 0.8784021250030492
Run 2 took: 0.8486490420036716
Run 3 took: 0.8614017910003895
Run 4 took: 0.8497851670108503
Average runtime of 0.9 seconds
```

So it comes with a nice speedup as well :)
  • Loading branch information
Fokko committed Jan 20, 2025
1 parent 818cd15 commit 0043889
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 68 deletions.
108 changes: 42 additions & 66 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@

import concurrent.futures
import fnmatch
import functools
import itertools
import logging
import operator
import os
import re
import uuid
Expand Down Expand Up @@ -2542,36 +2544,6 @@ class _TablePartition:
arrow_table_partition: pa.Table


def _get_table_partitions(
arrow_table: pa.Table,
partition_spec: PartitionSpec,
schema: Schema,
slice_instructions: list[dict[str, Any]],
) -> list[_TablePartition]:
sorted_slice_instructions = sorted(slice_instructions, key=lambda x: x["offset"])

partition_fields = partition_spec.fields

offsets = [inst["offset"] for inst in sorted_slice_instructions]
projected_and_filtered = {
partition_field.source_id: arrow_table[schema.find_field(name_or_id=partition_field.source_id).name]
.take(offsets)
.to_pylist()
for partition_field in partition_fields
}

table_partitions = []
for idx, inst in enumerate(sorted_slice_instructions):
partition_slice = arrow_table.slice(**inst)
fieldvalues = [
PartitionFieldValue(partition_field, projected_and_filtered[partition_field.source_id][idx])
for partition_field in partition_fields
]
partition_key = PartitionKey(raw_partition_field_values=fieldvalues, partition_spec=partition_spec, schema=schema)
table_partitions.append(_TablePartition(partition_key=partition_key, arrow_table_partition=partition_slice))
return table_partitions


def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> List[_TablePartition]:
"""Based on the iceberg table partition spec, slice the arrow table into partitions with their keys.
Expand All @@ -2594,42 +2566,46 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T
We then retrieve the partition keys by offsets.
And slice the arrow table by offsets and lengths of each partition.
"""
partition_columns: List[Tuple[PartitionField, NestedField]] = [
(partition_field, schema.find_field(partition_field.source_id)) for partition_field in spec.fields
]
partition_values_table = pa.table(
{
str(partition.field_id): partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name])
for partition, field in partition_columns
}
)
# Assign unique names to columns where the partition transform has been applied
# to avoid conflicts
partition_fields = [f"_partition_{field.name}" for field in spec.fields]

for partition, name in zip(spec.fields, partition_fields):
source_field = schema.find_field(partition.source_id)
arrow_table = arrow_table.append_column(
name, partition.transform.pyarrow_transform(source_field.field_type)(arrow_table[source_field.name])
)

unique_partition_fields = arrow_table.select(partition_fields).group_by(partition_fields).aggregate([])

table_partitions = []
# TODO: As a next step, we could also play around with yielding instead of materializing the full list
for unique_partition in unique_partition_fields.to_pylist():
partition_key = PartitionKey(
raw_partition_field_values=[
PartitionFieldValue(field=field, value=unique_partition[name])
for field, name in zip(spec.fields, partition_fields)
],
partition_spec=spec,
schema=schema,
)
filtered_table = arrow_table.filter(
functools.reduce(
operator.and_,
[
pc.field(partition_field_name) == unique_partition[partition_field_name]
if unique_partition[partition_field_name] is not None
else pc.field(partition_field_name).is_null()
for field, partition_field_name in zip(spec.fields, partition_fields)
],
)
)
filtered_table = filtered_table.drop_columns(partition_fields)

# Sort by partitions
sort_indices = pa.compute.sort_indices(
partition_values_table,
sort_keys=[(col, "ascending") for col in partition_values_table.column_names],
null_placement="at_end",
).to_pylist()
arrow_table = arrow_table.take(sort_indices)

# Get slice_instructions to group by partitions
partition_values_table = partition_values_table.take(sort_indices)
reversed_indices = pa.compute.sort_indices(
partition_values_table,
sort_keys=[(col, "descending") for col in partition_values_table.column_names],
null_placement="at_start",
).to_pylist()
slice_instructions: List[Dict[str, Any]] = []
last = len(reversed_indices)
reversed_indices_size = len(reversed_indices)
ptr = 0
while ptr < reversed_indices_size:
group_size = last - reversed_indices[ptr]
offset = reversed_indices[ptr]
slice_instructions.append({"offset": offset, "length": group_size})
last = reversed_indices[ptr]
ptr = ptr + group_size

table_partitions: List[_TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions)
# The combine_chunks seems to be counter-intuitive to do, but it actually returns
# fresh buffers that don't interfere with each other when it is written out to file
table_partitions.append(
_TablePartition(partition_key=partition_key, arrow_table_partition=filtered_table.combine_chunks())
)

return table_partitions
10 changes: 8 additions & 2 deletions pyiceberg/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Optional,
Tuple,
TypeVar,
Union,
)
from urllib.parse import quote_plus

Expand Down Expand Up @@ -425,8 +426,13 @@ def _to_partition_representation(type: IcebergType, value: Any) -> Any:

@_to_partition_representation.register(TimestampType)
@_to_partition_representation.register(TimestamptzType)
def _(type: IcebergType, value: Optional[datetime]) -> Optional[int]:
return datetime_to_micros(value) if value is not None else None
def _(type: IcebergType, value: Optional[Union[datetime, int]]) -> Optional[int]:
if value is None:
return None
elif isinstance(value, int):
return value
else:
return datetime_to_micros(value)


@_to_partition_representation.register(DateType)
Expand Down
71 changes: 71 additions & 0 deletions tests/benchmark/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import statistics
import timeit
import urllib

import pyarrow as pa
import pyarrow.parquet as pq
import pytest

from pyiceberg.transforms import DayTransform


@pytest.fixture(scope="session")
def taxi_dataset(tmp_path_factory: pytest.TempPathFactory) -> pa.Table:
"""Reads the Taxi dataset to disk"""
taxi_dataset = "https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2022-01.parquet"
taxi_dataset_dest = tmp_path_factory.mktemp("taxi_dataset") / "yellow_tripdata_2022-01.parquet"
urllib.request.urlretrieve(taxi_dataset, taxi_dataset_dest)

return pq.read_table(taxi_dataset_dest)


def test_partitioned_write(tmp_path_factory: pytest.TempPathFactory, taxi_dataset: pa.Table) -> None:
"""Tests writing to a partitioned table with something that would be close a production-like situation"""
from pyiceberg.catalog.sql import SqlCatalog

warehouse_path = str(tmp_path_factory.mktemp("warehouse"))
catalog = SqlCatalog(
"default",
uri=f"sqlite:///{warehouse_path}/pyiceberg_catalog.db",
warehouse=f"file://{warehouse_path}",
)

catalog.create_namespace("default")

tbl = catalog.create_table("default.taxi_partitioned", schema=taxi_dataset.schema)

with tbl.update_spec() as spec:
spec.add_field("tpep_pickup_datetime", DayTransform())

# Profiling can sometimes be handy as well
# with cProfile.Profile() as pr:
# tbl.append(taxi_dataset)
#
# pr.print_stats(sort=True)

runs = []
for run in range(5):
start_time = timeit.default_timer()
tbl.append(taxi_dataset)
elapsed = timeit.default_timer() - start_time

print(f"Run {run} took: {elapsed}")
runs.append(elapsed)

print(f"Average runtime of {round(statistics.mean(runs), 2)} seconds")
22 changes: 22 additions & 0 deletions tests/integration/test_writes/test_partitioned_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint:disable=redefined-outer-name


import random
from datetime import date
from typing import Any, Set

Expand Down Expand Up @@ -1126,3 +1127,24 @@ def test_append_multiple_partitions(
"""
)
assert files_df.count() == 6


def test_pyarrow_overflow(session_catalog: Catalog) -> None:
"""Test what happens when the offset is beyond 32 bits"""
identifier = "default.arrow_table_overflow"
try:
session_catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

x = pa.array([random.randint(0, 999) for _ in range(30_000)])
ta = pa.chunked_array([x] * 10_000)
y = ["fixed_string"] * 30_000
tb = pa.chunked_array([y] * 10_000)
# Create pa.table
arrow_table = pa.table({"a": ta, "b": tb})

table = session_catalog.create_table(identifier, arrow_table.schema)
with table.update_spec() as update_spec:
update_spec.add_field("b", IdentityTransform(), "pb")
table.append(arrow_table)

0 comments on commit 0043889

Please sign in to comment.