|
20 | 20 | from abc import ABC, abstractmethod
|
21 | 21 | from enum import IntEnum
|
22 | 22 | from functools import singledispatch
|
23 |
| -from typing import Any, Callable, Generic, Optional, TypeVar |
| 23 | +from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar |
24 | 24 | from typing import Literal as LiteralType
|
25 | 25 | from uuid import UUID
|
26 | 26 |
|
|
82 | 82 | from pyiceberg.utils.parsing import ParseNumberFromBrackets
|
83 | 83 | from pyiceberg.utils.singleton import Singleton
|
84 | 84 |
|
| 85 | +if TYPE_CHECKING: |
| 86 | + import pyarrow as pa |
| 87 | + |
85 | 88 | S = TypeVar("S")
|
86 | 89 | T = TypeVar("T")
|
87 | 90 |
|
@@ -175,6 +178,13 @@ def __eq__(self, other: Any) -> bool:
|
175 | 178 | return self.root == other.root
|
176 | 179 | return False
|
177 | 180 |
|
| 181 | + @property |
| 182 | + def supports_pyarrow_transform(self) -> bool: |
| 183 | + return False |
| 184 | + |
| 185 | + @abstractmethod |
| 186 | + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ... |
| 187 | + |
178 | 188 |
|
179 | 189 | class BucketTransform(Transform[S, int]):
|
180 | 190 | """Base Transform class to transform a value into a bucket partition value.
|
@@ -290,6 +300,9 @@ def __repr__(self) -> str:
|
290 | 300 | """Return the string representation of the BucketTransform class."""
|
291 | 301 | return f"BucketTransform(num_buckets={self._num_buckets})"
|
292 | 302 |
|
| 303 | + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": |
| 304 | + raise NotImplementedError() |
| 305 | + |
293 | 306 |
|
294 | 307 | class TimeResolution(IntEnum):
|
295 | 308 | YEAR = 6
|
@@ -349,6 +362,10 @@ def dedup_name(self) -> str:
|
349 | 362 | def preserves_order(self) -> bool:
|
350 | 363 | return True
|
351 | 364 |
|
| 365 | + @property |
| 366 | + def supports_pyarrow_transform(self) -> bool: |
| 367 | + return True |
| 368 | + |
352 | 369 |
|
353 | 370 | class YearTransform(TimeTransform[S]):
|
354 | 371 | """Transforms a datetime value into a year value.
|
@@ -391,6 +408,21 @@ def __repr__(self) -> str:
|
391 | 408 | """Return the string representation of the YearTransform class."""
|
392 | 409 | return "YearTransform()"
|
393 | 410 |
|
| 411 | + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": |
| 412 | + import pyarrow as pa |
| 413 | + import pyarrow.compute as pc |
| 414 | + |
| 415 | + if isinstance(source, DateType): |
| 416 | + epoch = datetime.EPOCH_DATE |
| 417 | + elif isinstance(source, TimestampType): |
| 418 | + epoch = datetime.EPOCH_TIMESTAMP |
| 419 | + elif isinstance(source, TimestamptzType): |
| 420 | + epoch = datetime.EPOCH_TIMESTAMPTZ |
| 421 | + else: |
| 422 | + raise ValueError(f"Cannot apply year transform for type: {source}") |
| 423 | + |
| 424 | + return lambda v: pc.years_between(pa.scalar(epoch), v) if v is not None else None |
| 425 | + |
394 | 426 |
|
395 | 427 | class MonthTransform(TimeTransform[S]):
|
396 | 428 | """Transforms a datetime value into a month value.
|
@@ -433,6 +465,27 @@ def __repr__(self) -> str:
|
433 | 465 | """Return the string representation of the MonthTransform class."""
|
434 | 466 | return "MonthTransform()"
|
435 | 467 |
|
| 468 | + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": |
| 469 | + import pyarrow as pa |
| 470 | + import pyarrow.compute as pc |
| 471 | + |
| 472 | + if isinstance(source, DateType): |
| 473 | + epoch = datetime.EPOCH_DATE |
| 474 | + elif isinstance(source, TimestampType): |
| 475 | + epoch = datetime.EPOCH_TIMESTAMP |
| 476 | + elif isinstance(source, TimestamptzType): |
| 477 | + epoch = datetime.EPOCH_TIMESTAMPTZ |
| 478 | + else: |
| 479 | + raise ValueError(f"Cannot apply month transform for type: {source}") |
| 480 | + |
| 481 | + def month_func(v: pa.Array) -> pa.Array: |
| 482 | + return pc.add( |
| 483 | + pc.multiply(pc.years_between(pa.scalar(epoch), v), pa.scalar(12)), |
| 484 | + pc.add(pc.month(v), pa.scalar(-1)), |
| 485 | + ) |
| 486 | + |
| 487 | + return lambda v: month_func(v) if v is not None else None |
| 488 | + |
436 | 489 |
|
437 | 490 | class DayTransform(TimeTransform[S]):
|
438 | 491 | """Transforms a datetime value into a day value.
|
@@ -478,6 +531,21 @@ def __repr__(self) -> str:
|
478 | 531 | """Return the string representation of the DayTransform class."""
|
479 | 532 | return "DayTransform()"
|
480 | 533 |
|
| 534 | + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": |
| 535 | + import pyarrow as pa |
| 536 | + import pyarrow.compute as pc |
| 537 | + |
| 538 | + if isinstance(source, DateType): |
| 539 | + epoch = datetime.EPOCH_DATE |
| 540 | + elif isinstance(source, TimestampType): |
| 541 | + epoch = datetime.EPOCH_TIMESTAMP |
| 542 | + elif isinstance(source, TimestamptzType): |
| 543 | + epoch = datetime.EPOCH_TIMESTAMPTZ |
| 544 | + else: |
| 545 | + raise ValueError(f"Cannot apply day transform for type: {source}") |
| 546 | + |
| 547 | + return lambda v: pc.days_between(pa.scalar(epoch), v) if v is not None else None |
| 548 | + |
481 | 549 |
|
482 | 550 | class HourTransform(TimeTransform[S]):
|
483 | 551 | """Transforms a datetime value into a hour value.
|
@@ -515,6 +583,19 @@ def __repr__(self) -> str:
|
515 | 583 | """Return the string representation of the HourTransform class."""
|
516 | 584 | return "HourTransform()"
|
517 | 585 |
|
| 586 | + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": |
| 587 | + import pyarrow as pa |
| 588 | + import pyarrow.compute as pc |
| 589 | + |
| 590 | + if isinstance(source, TimestampType): |
| 591 | + epoch = datetime.EPOCH_TIMESTAMP |
| 592 | + elif isinstance(source, TimestamptzType): |
| 593 | + epoch = datetime.EPOCH_TIMESTAMPTZ |
| 594 | + else: |
| 595 | + raise ValueError(f"Cannot apply hour transform for type: {source}") |
| 596 | + |
| 597 | + return lambda v: pc.hours_between(pa.scalar(epoch), v) if v is not None else None |
| 598 | + |
518 | 599 |
|
519 | 600 | def _base64encode(buffer: bytes) -> str:
|
520 | 601 | """Convert bytes to base64 string."""
|
@@ -585,6 +666,13 @@ def __repr__(self) -> str:
|
585 | 666 | """Return the string representation of the IdentityTransform class."""
|
586 | 667 | return "IdentityTransform()"
|
587 | 668 |
|
| 669 | + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": |
| 670 | + return lambda v: v |
| 671 | + |
| 672 | + @property |
| 673 | + def supports_pyarrow_transform(self) -> bool: |
| 674 | + return True |
| 675 | + |
588 | 676 |
|
589 | 677 | class TruncateTransform(Transform[S, S]):
|
590 | 678 | """A transform for truncating a value to a specified width.
|
@@ -725,6 +813,9 @@ def __repr__(self) -> str:
|
725 | 813 | """Return the string representation of the TruncateTransform class."""
|
726 | 814 | return f"TruncateTransform(width={self._width})"
|
727 | 815 |
|
| 816 | + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": |
| 817 | + raise NotImplementedError() |
| 818 | + |
728 | 819 |
|
729 | 820 | @singledispatch
|
730 | 821 | def _human_string(value: Any, _type: IcebergType) -> str:
|
@@ -807,6 +898,9 @@ def __repr__(self) -> str:
|
807 | 898 | """Return the string representation of the UnknownTransform class."""
|
808 | 899 | return f"UnknownTransform(transform={repr(self._transform)})"
|
809 | 900 |
|
| 901 | + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": |
| 902 | + raise NotImplementedError() |
| 903 | + |
810 | 904 |
|
811 | 905 | class VoidTransform(Transform[S, None], Singleton):
|
812 | 906 | """A transform that always returns None."""
|
@@ -835,6 +929,9 @@ def __repr__(self) -> str:
|
835 | 929 | """Return the string representation of the VoidTransform class."""
|
836 | 930 | return "VoidTransform()"
|
837 | 931 |
|
| 932 | + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": |
| 933 | + raise NotImplementedError() |
| 934 | + |
838 | 935 |
|
839 | 936 | def _truncate_number(
|
840 | 937 | name: str, pred: BoundLiteralPredicate[L], transform: Callable[[Optional[L]], Optional[L]]
|
|
0 commit comments