Skip to content

Commit 666a926

Browse files
authored
Refactor bucket transform types (#1562)
I think this aligns closer to the spec, and is also more friendly to the end-user: ![image](https://github.com/user-attachments/assets/1ae955a6-635f-4988-b964-fee471ebdad9)
1 parent 2cd4e78 commit 666a926

File tree

2 files changed

+66
-1
lines changed

2 files changed

+66
-1
lines changed

pyiceberg/transforms.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717

1818
import base64
19+
import datetime as py_datetime
1920
import struct
2021
from abc import ABC, abstractmethod
2122
from enum import IntEnum
@@ -298,7 +299,31 @@ def can_transform(self, source: IcebergType) -> bool:
298299
)
299300

300301
def transform(self, source: IcebergType, bucket: bool = True) -> Callable[[Optional[Any]], Optional[int]]:
301-
if isinstance(source, (IntegerType, LongType, DateType, TimeType, TimestampType, TimestamptzType)):
302+
if isinstance(source, TimeType):
303+
304+
def hash_func(v: Any) -> int:
305+
if isinstance(v, py_datetime.time):
306+
v = datetime.time_to_micros(v)
307+
308+
return mmh3.hash(struct.pack("<q", v))
309+
310+
elif isinstance(source, DateType):
311+
312+
def hash_func(v: Any) -> int:
313+
if isinstance(v, py_datetime.date):
314+
v = datetime.date_to_days(v)
315+
316+
return mmh3.hash(struct.pack("<q", v))
317+
318+
elif isinstance(source, (TimestampType, TimestamptzType)):
319+
320+
def hash_func(v: Any) -> int:
321+
if isinstance(v, py_datetime.datetime):
322+
v = datetime.datetime_to_micros(v)
323+
324+
return mmh3.hash(struct.pack("<q", v))
325+
326+
elif isinstance(source, (IntegerType, LongType)):
302327

303328
def hash_func(v: Any) -> int:
304329
return mmh3.hash(struct.pack("<q", v))

tests/table/test_partitioning.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,32 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
import datetime
18+
from decimal import Decimal
19+
from typing import Any
20+
from uuid import UUID
21+
22+
import pytest
23+
1724
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec
1825
from pyiceberg.schema import Schema
1926
from pyiceberg.transforms import BucketTransform, IdentityTransform, TruncateTransform
2027
from pyiceberg.typedef import Record
2128
from pyiceberg.types import (
29+
BinaryType,
30+
DateType,
31+
DecimalType,
32+
FixedType,
2233
IntegerType,
34+
LongType,
2335
NestedField,
36+
PrimitiveType,
2437
StringType,
2538
StructType,
39+
TimestampType,
40+
TimestamptzType,
41+
TimeType,
42+
UUIDType,
2643
)
2744

2845

@@ -153,6 +170,29 @@ def test_partition_type(table_schema_simple: Schema) -> None:
153170
)
154171

155172

173+
@pytest.mark.parametrize(
174+
"source_type, value",
175+
[
176+
(IntegerType(), 22),
177+
(LongType(), 22),
178+
(DecimalType(5, 9), Decimal(19.25)),
179+
(DateType(), datetime.date(1925, 5, 22)),
180+
(TimeType(), datetime.time(19, 25, 00)),
181+
(TimestampType(), datetime.datetime(19, 5, 1, 22, 1, 1)),
182+
(TimestamptzType(), datetime.datetime(19, 5, 1, 22, 1, 1, tzinfo=datetime.timezone.utc)),
183+
(StringType(), "abc"),
184+
(UUIDType(), UUID("12345678-1234-5678-1234-567812345678").bytes),
185+
(FixedType(5), 'b"\x8e\xd1\x87\x01"'),
186+
(BinaryType(), b"\x8e\xd1\x87\x01"),
187+
],
188+
)
189+
def test_bucketing_function(source_type: PrimitiveType, value: Any) -> None:
190+
bucket = BucketTransform(2) # type: ignore
191+
import pyarrow as pa
192+
193+
assert bucket.transform(source_type)(value) == bucket.pyarrow_transform(source_type)(pa.array([value])).to_pylist()[0]
194+
195+
156196
def test_deserialize_partition_field_v2() -> None:
157197
json_partition_spec = """{"source-id": 1, "field-id": 1000, "transform": "truncate[19]", "name": "str_truncate"}"""
158198

0 commit comments

Comments
 (0)