diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index 1056fa525b..e89142d780 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -309,7 +309,27 @@ def __repr__(self) -> str: return f"BucketTransform(num_buckets={self._num_buckets})" def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": - raise NotImplementedError() + import pyarrow as pa + from pyiceberg_core import transform as pyiceberg_core_transform + + ArrayLike = TypeVar("ArrayLike", pa.Array, pa.ChunkedArray) + + def bucket(array: ArrayLike) -> ArrayLike: + if isinstance(array, pa.Array): + return pyiceberg_core_transform.bucket(array, self._num_buckets) + elif isinstance(array, pa.ChunkedArray): + result_chunks = [] + for arr in array.iterchunks(): + result_chunks.append(pyiceberg_core_transform.bucket(arr, self._num_buckets)) + return pa.chunked_array(result_chunks) + else: + raise ValueError(f"PyArrow array can only be of type pa.Array or pa.ChunkedArray, but found {type(array)}") + + return bucket + + @property + def supports_pyarrow_transform(self) -> bool: + return True class TimeResolution(IntEnum): diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 3a9ffd6009..7de44de2c8 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -17,10 +17,11 @@ # pylint: disable=eval-used,protected-access,redefined-outer-name from datetime import date from decimal import Decimal -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import Any, Callable, Optional, Union from uuid import UUID import mmh3 as mmh3 +import pyarrow as pa import pytest from pydantic import ( BeforeValidator, @@ -112,9 +113,6 @@ timestamptz_to_micros, ) -if TYPE_CHECKING: - import pyarrow as pa - @pytest.mark.parametrize( "test_input,test_type,expected", @@ -1840,3 +1838,26 @@ def test_ymd_pyarrow_transforms( else: with pytest.raises(ValueError): transform.pyarrow_transform(DateType())(arrow_table_date_timestamps[source_col]) + + +@pytest.mark.parametrize( + "source_type, input_arr, expected, num_buckets", + [ + (IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 10), + ( + IntegerType(), + pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])]), + pa.chunked_array([pa.array([6, 2], type=pa.int32()), pa.array([5, 0], type=pa.int32())]), + 10, + ), + (IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 10), + ], +) +def test_bucket_pyarrow_transforms( + source_type: PrimitiveType, + input_arr: Union[pa.Array, pa.ChunkedArray], + expected: Union[pa.Array, pa.ChunkedArray], + num_buckets: int, +) -> None: + transform: Transform[Any, Any] = BucketTransform(num_buckets=num_buckets) + assert expected == transform.pyarrow_transform(source_type)(input_arr)