Skip to content

Commit

Permalink
feat: Support Bucket and Truncate transforms on write (#1345)
Browse files Browse the repository at this point in the history
* introduce bucket transform

* include pyiceberg-core

* introduce bucket transform

* include pyiceberg-core

* resolve poetry conflict

* support truncate transforms

* Remove stale comment

* fix poetry hash

* avoid codespell error for truncate transform

* adopt nits
  • Loading branch information
sungwy authored Jan 16, 2025
1 parent 0a3a886 commit 50c33aa
Show file tree
Hide file tree
Showing 5 changed files with 259 additions and 20 deletions.
18 changes: 17 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

39 changes: 37 additions & 2 deletions pyiceberg/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@
if TYPE_CHECKING:
import pyarrow as pa

ArrayLike = TypeVar("ArrayLike", pa.Array, pa.ChunkedArray)

S = TypeVar("S")
T = TypeVar("T")

Expand Down Expand Up @@ -193,6 +195,27 @@ def supports_pyarrow_transform(self) -> bool:
@abstractmethod
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ...

def _pyiceberg_transform_wrapper(
self, transform_func: Callable[["ArrayLike", Any], "ArrayLike"], *args: Any
) -> Callable[["ArrayLike"], "ArrayLike"]:
try:
import pyarrow as pa
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For bucket/truncate transforms, PyArrow needs to be installed") from e

def _transform(array: "ArrayLike") -> "ArrayLike":
if isinstance(array, pa.Array):
return transform_func(array, *args)
elif isinstance(array, pa.ChunkedArray):
result_chunks = []
for arr in array.iterchunks():
result_chunks.append(transform_func(arr, *args))
return pa.chunked_array(result_chunks)
else:
raise ValueError(f"PyArrow array can only be of type pa.Array or pa.ChunkedArray, but found {type(array)}")

return _transform


class BucketTransform(Transform[S, int]):
"""Base Transform class to transform a value into a bucket partition value.
Expand Down Expand Up @@ -309,7 +332,13 @@ def __repr__(self) -> str:
return f"BucketTransform(num_buckets={self._num_buckets})"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
raise NotImplementedError()
from pyiceberg_core import transform as pyiceberg_core_transform

return self._pyiceberg_transform_wrapper(pyiceberg_core_transform.bucket, self._num_buckets)

@property
def supports_pyarrow_transform(self) -> bool:
return True


class TimeResolution(IntEnum):
Expand Down Expand Up @@ -827,7 +856,13 @@ def __repr__(self) -> str:
return f"TruncateTransform(width={self._width})"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
raise NotImplementedError()
from pyiceberg_core import transform as pyiceberg_core_transform

return self._pyiceberg_transform_wrapper(pyiceberg_core_transform.truncate, self._width)

@property
def supports_pyarrow_transform(self) -> bool:
return True


@singledispatch
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ psycopg2-binary = { version = ">=2.9.6", optional = true }
sqlalchemy = { version = "^2.0.18", optional = true }
getdaft = { version = ">=0.2.12", optional = true }
cachetools = "^5.5.0"
pyiceberg-core = { version = "^0.4.0", optional = true }

[tool.poetry.group.dev.dependencies]
pytest = "7.4.4"
Expand Down Expand Up @@ -842,6 +843,10 @@ ignore_missing_imports = true
module = "daft.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "pyiceberg_core.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "pyparsing.*"
ignore_missing_imports = true
Expand Down Expand Up @@ -1206,6 +1211,7 @@ sql-postgres = ["sqlalchemy", "psycopg2-binary"]
sql-sqlite = ["sqlalchemy"]
gcsfs = ["gcsfs"]
rest-sigv4 = ["boto3"]
pyiceberg-core = ["pyiceberg-core"]

[tool.pytest.ini_options]
markers = [
Expand Down
170 changes: 157 additions & 13 deletions tests/integration/test_writes/test_partitioned_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,12 @@ def test_dynamic_partition_overwrite_unpartitioned_evolve_to_identity_transform(
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, part_col: str, format_version: int
) -> None:
identifier = f"default.unpartitioned_table_v{format_version}_evolve_into_identity_transformed_partition_field_{part_col}"

try:
session_catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

tbl = session_catalog.create_table(
identifier=identifier,
schema=TABLE_SCHEMA,
Expand Down Expand Up @@ -756,6 +762,55 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> Non
tbl.append("not a df")


@pytest.mark.integration
@pytest.mark.parametrize(
"spec",
[
(PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=TruncateTransform(2), name="int_trunc"))),
(PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=TruncateTransform(2), name="long_trunc"))),
(PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=TruncateTransform(2), name="string_trunc"))),
],
)
@pytest.mark.parametrize("format_version", [1, 2])
def test_truncate_transform(
spec: PartitionSpec,
spark: SparkSession,
session_catalog: Catalog,
arrow_table_with_null: pa.Table,
format_version: int,
) -> None:
identifier = "default.truncate_transform"

try:
session_catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

tbl = _create_table(
session_catalog=session_catalog,
identifier=identifier,
properties={"format-version": str(format_version)},
data=[arrow_table_with_null],
partition_spec=spec,
)

assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}"
df = spark.table(identifier)
assert df.count() == 3, f"Expected 3 total rows for {identifier}"
for col in arrow_table_with_null.column_names:
assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}"
assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null"

assert tbl.inspect.partitions().num_rows == 3
files_df = spark.sql(
f"""
SELECT *
FROM {identifier}.files
"""
)
assert files_df.count() == 3


@pytest.mark.integration
@pytest.mark.parametrize(
"spec",
Expand All @@ -767,18 +822,52 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> Non
PartitionField(source_id=1, field_id=1002, transform=IdentityTransform(), name="bool"),
)
),
# none of non-identity is supported
(PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket"))),
(PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=BucketTransform(2), name="long_bucket"))),
(PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=BucketTransform(2), name="date_bucket"))),
(PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=BucketTransform(2), name="timestamp_bucket"))),
(PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=BucketTransform(2), name="timestamptz_bucket"))),
(PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=BucketTransform(2), name="string_bucket"))),
(PartitionSpec(PartitionField(source_id=12, field_id=1001, transform=BucketTransform(2), name="fixed_bucket"))),
(PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=BucketTransform(2), name="binary_bucket"))),
(PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=TruncateTransform(2), name="int_trunc"))),
(PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=TruncateTransform(2), name="long_trunc"))),
(PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=TruncateTransform(2), name="string_trunc"))),
],
)
@pytest.mark.parametrize("format_version", [1, 2])
def test_identity_and_bucket_transform_spec(
spec: PartitionSpec,
spark: SparkSession,
session_catalog: Catalog,
arrow_table_with_null: pa.Table,
format_version: int,
) -> None:
identifier = "default.identity_and_bucket_transform"

try:
session_catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

tbl = _create_table(
session_catalog=session_catalog,
identifier=identifier,
properties={"format-version": str(format_version)},
data=[arrow_table_with_null],
partition_spec=spec,
)

assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}"
df = spark.table(identifier)
assert df.count() == 3, f"Expected 3 total rows for {identifier}"
for col in arrow_table_with_null.column_names:
assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}"
assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null"

assert tbl.inspect.partitions().num_rows == 3
files_df = spark.sql(
f"""
SELECT *
FROM {identifier}.files
"""
)
assert files_df.count() == 3


@pytest.mark.integration
@pytest.mark.parametrize(
"spec",
[
(PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=TruncateTransform(2), name="binary_trunc"))),
],
)
Expand All @@ -801,11 +890,66 @@ def test_unsupported_transform(

with pytest.raises(
ValueError,
match="Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: *",
match="FeatureUnsupported => Unsupported data type for truncate transform: LargeBinary",
):
tbl.append(arrow_table_with_null)


@pytest.mark.integration
@pytest.mark.parametrize(
"spec, expected_rows",
[
(PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket")), 3),
(PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=BucketTransform(2), name="long_bucket")), 3),
(PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=BucketTransform(2), name="date_bucket")), 3),
(PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=BucketTransform(2), name="timestamp_bucket")), 3),
(PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=BucketTransform(2), name="timestamptz_bucket")), 3),
(PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=BucketTransform(2), name="string_bucket")), 3),
(PartitionSpec(PartitionField(source_id=12, field_id=1001, transform=BucketTransform(2), name="fixed_bucket")), 2),
(PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=BucketTransform(2), name="binary_bucket")), 2),
],
)
@pytest.mark.parametrize("format_version", [1, 2])
def test_bucket_transform(
spark: SparkSession,
session_catalog: Catalog,
arrow_table_with_null: pa.Table,
spec: PartitionSpec,
expected_rows: int,
format_version: int,
) -> None:
identifier = "default.bucket_transform"

try:
session_catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

tbl = _create_table(
session_catalog=session_catalog,
identifier=identifier,
properties={"format-version": str(format_version)},
data=[arrow_table_with_null],
partition_spec=spec,
)

assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}"
df = spark.table(identifier)
assert df.count() == 3, f"Expected 3 total rows for {identifier}"
for col in arrow_table_with_null.column_names:
assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}"
assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null"

assert tbl.inspect.partitions().num_rows == expected_rows
files_df = spark.sql(
f"""
SELECT *
FROM {identifier}.files
"""
)
assert files_df.count() == expected_rows


@pytest.mark.integration
@pytest.mark.parametrize(
"transform,expected_rows",
Expand Down
46 changes: 42 additions & 4 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
# pylint: disable=eval-used,protected-access,redefined-outer-name
from datetime import date
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import Any, Callable, Optional, Union
from uuid import UUID

import mmh3 as mmh3
import pyarrow as pa
import pytest
from pydantic import (
BeforeValidator,
Expand Down Expand Up @@ -116,9 +117,6 @@
timestamptz_to_micros,
)

if TYPE_CHECKING:
import pyarrow as pa


@pytest.mark.parametrize(
"test_input,test_type,expected",
Expand Down Expand Up @@ -1563,3 +1561,43 @@ def test_ymd_pyarrow_transforms(
else:
with pytest.raises(ValueError):
transform.pyarrow_transform(DateType())(arrow_table_date_timestamps[source_col])


@pytest.mark.parametrize(
"source_type, input_arr, expected, num_buckets",
[
(IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 10),
(
IntegerType(),
pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])]),
pa.chunked_array([pa.array([6, 2], type=pa.int32()), pa.array([5, 0], type=pa.int32())]),
10,
),
(IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 10),
],
)
def test_bucket_pyarrow_transforms(
source_type: PrimitiveType,
input_arr: Union[pa.Array, pa.ChunkedArray],
expected: Union[pa.Array, pa.ChunkedArray],
num_buckets: int,
) -> None:
transform: Transform[Any, Any] = BucketTransform(num_buckets=num_buckets)
assert expected == transform.pyarrow_transform(source_type)(input_arr)


@pytest.mark.parametrize(
"source_type, input_arr, expected, width",
[
(StringType(), pa.array(["developer", "iceberg"]), pa.array(["dev", "ice"]), 3),
(IntegerType(), pa.array([1, -1]), pa.array([0, -10]), 10),
],
)
def test_truncate_pyarrow_transforms(
source_type: PrimitiveType,
input_arr: Union[pa.Array, pa.ChunkedArray],
expected: Union[pa.Array, pa.ChunkedArray],
width: int,
) -> None:
transform: Transform[Any, Any] = TruncateTransform(width=width)
assert expected == transform.pyarrow_transform(source_type)(input_arr)

0 comments on commit 50c33aa

Please sign in to comment.