Skip to content

Commit 65a03d2

Browse files
authored
Support Appends with TimeTransform Partitions (#784)
* checkpoint * checkpoint2 * todo: sort with pyarrow_transform vals * checkpoint * checkpoint * fix * tests * more tests * adopt review feedback * comment * checkpoint * checkpoint2 * todo: sort with pyarrow_transform vals * checkpoint * checkpoint * fix * tests * more tests * adopt review feedback * comment * rebase
1 parent 20f6afd commit 65a03d2

File tree

6 files changed

+392
-54
lines changed

6 files changed

+392
-54
lines changed

pyiceberg/partitioning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def partition(self) -> Record: # partition key transformed with iceberg interna
387387
for raw_partition_field_value in self.raw_partition_field_values:
388388
partition_fields = self.partition_spec.source_id_to_fields_map[raw_partition_field_value.field.source_id]
389389
if len(partition_fields) != 1:
390-
raise ValueError("partition_fields must contain exactly one field.")
390+
raise ValueError(f"Cannot have redundant partitions: {partition_fields}")
391391
partition_field = partition_fields[0]
392392
iceberg_typed_key_values[partition_field.name] = partition_record_value(
393393
partition_field=partition_field,

pyiceberg/table/__init__.py

Lines changed: 29 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -392,10 +392,11 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
392392
if not isinstance(df, pa.Table):
393393
raise ValueError(f"Expected PyArrow table, got: {df}")
394394

395-
supported_transforms = {IdentityTransform}
396-
if not all(type(field.transform) in supported_transforms for field in self.table_metadata.spec().fields):
395+
if unsupported_partitions := [
396+
field for field in self.table_metadata.spec().fields if not field.transform.supports_pyarrow_transform
397+
]:
397398
raise ValueError(
398-
f"All transforms are not supported, expected: {supported_transforms}, but get: {[str(field) for field in self.table_metadata.spec().fields if field.transform not in supported_transforms]}."
399+
f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}."
399400
)
400401

401402
_check_schema_compatible(self._table.schema(), other_schema=df.schema)
@@ -3643,33 +3644,6 @@ class TablePartition:
36433644
arrow_table_partition: pa.Table
36443645

36453646

3646-
def _get_partition_sort_order(partition_columns: list[str], reverse: bool = False) -> dict[str, Any]:
3647-
order = "ascending" if not reverse else "descending"
3648-
null_placement = "at_start" if reverse else "at_end"
3649-
return {"sort_keys": [(column_name, order) for column_name in partition_columns], "null_placement": null_placement}
3650-
3651-
3652-
def group_by_partition_scheme(arrow_table: pa.Table, partition_columns: list[str]) -> pa.Table:
3653-
"""Given a table, sort it by current partition scheme."""
3654-
# only works for identity for now
3655-
sort_options = _get_partition_sort_order(partition_columns, reverse=False)
3656-
sorted_arrow_table = arrow_table.sort_by(sorting=sort_options["sort_keys"], null_placement=sort_options["null_placement"])
3657-
return sorted_arrow_table
3658-
3659-
3660-
def get_partition_columns(
3661-
spec: PartitionSpec,
3662-
schema: Schema,
3663-
) -> list[str]:
3664-
partition_cols = []
3665-
for partition_field in spec.fields:
3666-
column_name = schema.find_column_name(partition_field.source_id)
3667-
if not column_name:
3668-
raise ValueError(f"{partition_field=} could not be found in {schema}.")
3669-
partition_cols.append(column_name)
3670-
return partition_cols
3671-
3672-
36733647
def _get_table_partitions(
36743648
arrow_table: pa.Table,
36753649
partition_spec: PartitionSpec,
@@ -3724,13 +3698,30 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T
37243698
"""
37253699
import pyarrow as pa
37263700

3727-
partition_columns = get_partition_columns(spec=spec, schema=schema)
3728-
arrow_table = group_by_partition_scheme(arrow_table, partition_columns)
3729-
3730-
reversing_sort_order_options = _get_partition_sort_order(partition_columns, reverse=True)
3731-
reversed_indices = pa.compute.sort_indices(arrow_table, **reversing_sort_order_options).to_pylist()
3732-
3733-
slice_instructions: list[dict[str, Any]] = []
3701+
partition_columns: List[Tuple[PartitionField, NestedField]] = [
3702+
(partition_field, schema.find_field(partition_field.source_id)) for partition_field in spec.fields
3703+
]
3704+
partition_values_table = pa.table({
3705+
str(partition.field_id): partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name])
3706+
for partition, field in partition_columns
3707+
})
3708+
3709+
# Sort by partitions
3710+
sort_indices = pa.compute.sort_indices(
3711+
partition_values_table,
3712+
sort_keys=[(col, "ascending") for col in partition_values_table.column_names],
3713+
null_placement="at_end",
3714+
).to_pylist()
3715+
arrow_table = arrow_table.take(sort_indices)
3716+
3717+
# Get slice_instructions to group by partitions
3718+
partition_values_table = partition_values_table.take(sort_indices)
3719+
reversed_indices = pa.compute.sort_indices(
3720+
partition_values_table,
3721+
sort_keys=[(col, "descending") for col in partition_values_table.column_names],
3722+
null_placement="at_start",
3723+
).to_pylist()
3724+
slice_instructions: List[Dict[str, Any]] = []
37343725
last = len(reversed_indices)
37353726
reversed_indices_size = len(reversed_indices)
37363727
ptr = 0
@@ -3741,6 +3732,6 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T
37413732
last = reversed_indices[ptr]
37423733
ptr = ptr + group_size
37433734

3744-
table_partitions: list[TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions)
3735+
table_partitions: List[TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions)
37453736

37463737
return table_partitions

pyiceberg/transforms.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from abc import ABC, abstractmethod
2121
from enum import IntEnum
2222
from functools import singledispatch
23-
from typing import Any, Callable, Generic, Optional, TypeVar
23+
from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar
2424
from typing import Literal as LiteralType
2525
from uuid import UUID
2626

@@ -82,6 +82,9 @@
8282
from pyiceberg.utils.parsing import ParseNumberFromBrackets
8383
from pyiceberg.utils.singleton import Singleton
8484

85+
if TYPE_CHECKING:
86+
import pyarrow as pa
87+
8588
S = TypeVar("S")
8689
T = TypeVar("T")
8790

@@ -175,6 +178,13 @@ def __eq__(self, other: Any) -> bool:
175178
return self.root == other.root
176179
return False
177180

181+
@property
182+
def supports_pyarrow_transform(self) -> bool:
183+
return False
184+
185+
@abstractmethod
186+
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ...
187+
178188

179189
class BucketTransform(Transform[S, int]):
180190
"""Base Transform class to transform a value into a bucket partition value.
@@ -290,6 +300,9 @@ def __repr__(self) -> str:
290300
"""Return the string representation of the BucketTransform class."""
291301
return f"BucketTransform(num_buckets={self._num_buckets})"
292302

303+
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
304+
raise NotImplementedError()
305+
293306

294307
class TimeResolution(IntEnum):
295308
YEAR = 6
@@ -349,6 +362,10 @@ def dedup_name(self) -> str:
349362
def preserves_order(self) -> bool:
350363
return True
351364

365+
@property
366+
def supports_pyarrow_transform(self) -> bool:
367+
return True
368+
352369

353370
class YearTransform(TimeTransform[S]):
354371
"""Transforms a datetime value into a year value.
@@ -391,6 +408,21 @@ def __repr__(self) -> str:
391408
"""Return the string representation of the YearTransform class."""
392409
return "YearTransform()"
393410

411+
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
412+
import pyarrow as pa
413+
import pyarrow.compute as pc
414+
415+
if isinstance(source, DateType):
416+
epoch = datetime.EPOCH_DATE
417+
elif isinstance(source, TimestampType):
418+
epoch = datetime.EPOCH_TIMESTAMP
419+
elif isinstance(source, TimestamptzType):
420+
epoch = datetime.EPOCH_TIMESTAMPTZ
421+
else:
422+
raise ValueError(f"Cannot apply year transform for type: {source}")
423+
424+
return lambda v: pc.years_between(pa.scalar(epoch), v) if v is not None else None
425+
394426

395427
class MonthTransform(TimeTransform[S]):
396428
"""Transforms a datetime value into a month value.
@@ -433,6 +465,27 @@ def __repr__(self) -> str:
433465
"""Return the string representation of the MonthTransform class."""
434466
return "MonthTransform()"
435467

468+
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
469+
import pyarrow as pa
470+
import pyarrow.compute as pc
471+
472+
if isinstance(source, DateType):
473+
epoch = datetime.EPOCH_DATE
474+
elif isinstance(source, TimestampType):
475+
epoch = datetime.EPOCH_TIMESTAMP
476+
elif isinstance(source, TimestamptzType):
477+
epoch = datetime.EPOCH_TIMESTAMPTZ
478+
else:
479+
raise ValueError(f"Cannot apply month transform for type: {source}")
480+
481+
def month_func(v: pa.Array) -> pa.Array:
482+
return pc.add(
483+
pc.multiply(pc.years_between(pa.scalar(epoch), v), pa.scalar(12)),
484+
pc.add(pc.month(v), pa.scalar(-1)),
485+
)
486+
487+
return lambda v: month_func(v) if v is not None else None
488+
436489

437490
class DayTransform(TimeTransform[S]):
438491
"""Transforms a datetime value into a day value.
@@ -478,6 +531,21 @@ def __repr__(self) -> str:
478531
"""Return the string representation of the DayTransform class."""
479532
return "DayTransform()"
480533

534+
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
535+
import pyarrow as pa
536+
import pyarrow.compute as pc
537+
538+
if isinstance(source, DateType):
539+
epoch = datetime.EPOCH_DATE
540+
elif isinstance(source, TimestampType):
541+
epoch = datetime.EPOCH_TIMESTAMP
542+
elif isinstance(source, TimestamptzType):
543+
epoch = datetime.EPOCH_TIMESTAMPTZ
544+
else:
545+
raise ValueError(f"Cannot apply day transform for type: {source}")
546+
547+
return lambda v: pc.days_between(pa.scalar(epoch), v) if v is not None else None
548+
481549

482550
class HourTransform(TimeTransform[S]):
483551
"""Transforms a datetime value into a hour value.
@@ -515,6 +583,19 @@ def __repr__(self) -> str:
515583
"""Return the string representation of the HourTransform class."""
516584
return "HourTransform()"
517585

586+
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
587+
import pyarrow as pa
588+
import pyarrow.compute as pc
589+
590+
if isinstance(source, TimestampType):
591+
epoch = datetime.EPOCH_TIMESTAMP
592+
elif isinstance(source, TimestamptzType):
593+
epoch = datetime.EPOCH_TIMESTAMPTZ
594+
else:
595+
raise ValueError(f"Cannot apply hour transform for type: {source}")
596+
597+
return lambda v: pc.hours_between(pa.scalar(epoch), v) if v is not None else None
598+
518599

519600
def _base64encode(buffer: bytes) -> str:
520601
"""Convert bytes to base64 string."""
@@ -585,6 +666,13 @@ def __repr__(self) -> str:
585666
"""Return the string representation of the IdentityTransform class."""
586667
return "IdentityTransform()"
587668

669+
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
670+
return lambda v: v
671+
672+
@property
673+
def supports_pyarrow_transform(self) -> bool:
674+
return True
675+
588676

589677
class TruncateTransform(Transform[S, S]):
590678
"""A transform for truncating a value to a specified width.
@@ -725,6 +813,9 @@ def __repr__(self) -> str:
725813
"""Return the string representation of the TruncateTransform class."""
726814
return f"TruncateTransform(width={self._width})"
727815

816+
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
817+
raise NotImplementedError()
818+
728819

729820
@singledispatch
730821
def _human_string(value: Any, _type: IcebergType) -> str:
@@ -807,6 +898,9 @@ def __repr__(self) -> str:
807898
"""Return the string representation of the UnknownTransform class."""
808899
return f"UnknownTransform(transform={repr(self._transform)})"
809900

901+
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
902+
raise NotImplementedError()
903+
810904

811905
class VoidTransform(Transform[S, None], Singleton):
812906
"""A transform that always returns None."""
@@ -835,6 +929,9 @@ def __repr__(self) -> str:
835929
"""Return the string representation of the VoidTransform class."""
836930
return "VoidTransform()"
837931

932+
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
933+
raise NotImplementedError()
934+
838935

839936
def _truncate_number(
840937
name: str, pred: BoundLiteralPredicate[L], transform: Callable[[Optional[L]], Optional[L]]

tests/conftest.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2158,3 +2158,46 @@ def arrow_table_with_only_nulls(pa_schema: "pa.Schema") -> "pa.Table":
21582158
import pyarrow as pa
21592159

21602160
return pa.Table.from_pylist([{}, {}], schema=pa_schema)
2161+
2162+
2163+
@pytest.fixture(scope="session")
2164+
def arrow_table_date_timestamps() -> "pa.Table":
2165+
"""Pyarrow table with only date, timestamp and timestamptz values."""
2166+
import pyarrow as pa
2167+
2168+
return pa.Table.from_pydict(
2169+
{
2170+
"date": [date(2023, 12, 31), date(2024, 1, 1), date(2024, 1, 31), date(2024, 2, 1), date(2024, 2, 1), None],
2171+
"timestamp": [
2172+
datetime(2023, 12, 31, 0, 0, 0),
2173+
datetime(2024, 1, 1, 0, 0, 0),
2174+
datetime(2024, 1, 31, 0, 0, 0),
2175+
datetime(2024, 2, 1, 0, 0, 0),
2176+
datetime(2024, 2, 1, 6, 0, 0),
2177+
None,
2178+
],
2179+
"timestamptz": [
2180+
datetime(2023, 12, 31, 0, 0, 0, tzinfo=timezone.utc),
2181+
datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
2182+
datetime(2024, 1, 31, 0, 0, 0, tzinfo=timezone.utc),
2183+
datetime(2024, 2, 1, 0, 0, 0, tzinfo=timezone.utc),
2184+
datetime(2024, 2, 1, 6, 0, 0, tzinfo=timezone.utc),
2185+
None,
2186+
],
2187+
},
2188+
schema=pa.schema([
2189+
("date", pa.date32()),
2190+
("timestamp", pa.timestamp(unit="us")),
2191+
("timestamptz", pa.timestamp(unit="us", tz="UTC")),
2192+
]),
2193+
)
2194+
2195+
2196+
@pytest.fixture(scope="session")
2197+
def arrow_table_date_timestamps_schema() -> Schema:
2198+
"""Pyarrow table Schema with only date, timestamp and timestamptz values."""
2199+
return Schema(
2200+
NestedField(field_id=1, name="date", field_type=DateType(), required=False),
2201+
NestedField(field_id=2, name="timestamp", field_type=TimestampType(), required=False),
2202+
NestedField(field_id=3, name="timestamptz", field_type=TimestamptzType(), required=False),
2203+
)

0 commit comments

Comments
 (0)