From 361ee468a129c6169a22cc9f05cf0b6fbf20bc66 Mon Sep 17 00:00:00 2001 From: Fokko Date: Wed, 22 Jan 2025 21:05:56 +0100 Subject: [PATCH 1/2] Refactor `truncate` transform types --- pyiceberg/transforms.py | 27 ++++++++++++++++++++- tests/table/test_partitioning.py | 40 ++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index 22dcdfe88a..40a78b2811 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -16,6 +16,7 @@ # under the License. import base64 +import datetime as py_datetime import struct from abc import ABC, abstractmethod from enum import IntEnum @@ -298,7 +299,31 @@ def can_transform(self, source: IcebergType) -> bool: ) def transform(self, source: IcebergType, bucket: bool = True) -> Callable[[Optional[Any]], Optional[int]]: - if isinstance(source, (IntegerType, LongType, DateType, TimeType, TimestampType, TimestamptzType)): + if isinstance(source, TimeType): + + def hash_func(v: Any) -> int: + if isinstance(v, py_datetime.time): + v = datetime.time_to_micros(v) + + return mmh3.hash(struct.pack(" int: + if isinstance(v, py_datetime.date): + v = datetime.date_to_days(v) + + return mmh3.hash(struct.pack(" int: + if isinstance(v, py_datetime.datetime): + v = datetime.datetime_to_micros(v) + + return mmh3.hash(struct.pack(" int: return mmh3.hash(struct.pack(" None: NestedField(field_id=1000, name="str_truncate", field_type=StringType(), required=False), NestedField(field_id=1001, name="int_bucket", field_type=IntegerType(), required=True), ) + + +@pytest.mark.parametrize( + "source_type, value", + [ + (IntegerType(), 22), + (LongType(), 22), + (DecimalType(5, 9), Decimal(19.25)), + (DateType(), datetime.date(1925, 5, 22)), + (TimeType(), datetime.time(19, 25, 00)), + (TimestampType(), datetime.datetime(19, 5, 1, 22, 1, 1)), + (TimestamptzType(), datetime.datetime(19, 5, 1, 22, 1, 1, tzinfo=datetime.timezone.utc)), + (StringType(), "abc"), + (UUIDType(), UUID("12345678-1234-5678-1234-567812345678").bytes), + (FixedType(5), 'b"\x8e\xd1\x87\x01"'), + (BinaryType(), b"\x8e\xd1\x87\x01"), + ], +) +def test_bucketing_function(source_type: PrimitiveType, value: Any) -> None: + bucket = BucketTransform(2) # type: ignore + import pyarrow as pa + + assert bucket.transform(source_type)(value) == bucket.pyarrow_transform(source_type)(pa.array([value])).to_pylist()[0] From 92baa2b10d4e04443c7175f86dcb3fb96eacc1ea Mon Sep 17 00:00:00 2001 From: Fokko Date: Wed, 22 Jan 2025 21:07:13 +0100 Subject: [PATCH 2/2] lint --- tests/table/test_partitioning.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/table/test_partitioning.py b/tests/table/test_partitioning.py index d0e2514665..bdd68ea7a2 100644 --- a/tests/table/test_partitioning.py +++ b/tests/table/test_partitioning.py @@ -191,6 +191,8 @@ def test_bucketing_function(source_type: PrimitiveType, value: Any) -> None: import pyarrow as pa assert bucket.transform(source_type)(value) == bucket.pyarrow_transform(source_type)(pa.array([value])).to_pylist()[0] + + def test_deserialize_partition_field_v2() -> None: json_partition_spec = """{"source-id": 1, "field-id": 1000, "transform": "truncate[19]", "name": "str_truncate"}"""