From 50c33aa0119d9e2478b3865d864ec23a7c45b1d7 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+sungwy@users.noreply.github.com> Date: Thu, 16 Jan 2025 10:54:37 -0500 Subject: [PATCH] feat: Support Bucket and Truncate transforms on write (#1345) * introduce bucket transform * include pyiceberg-core * introduce bucket transform * include pyiceberg-core * resolve poetry conflict * support truncate transforms * Remove stale comment * fix poetry hash * avoid codespell error for truncate transform * adopt nits --- poetry.lock | 18 +- pyiceberg/transforms.py | 39 +++- pyproject.toml | 6 + .../test_writes/test_partitioned_writes.py | 170 ++++++++++++++++-- tests/test_transforms.py | 46 ++++- 5 files changed, 259 insertions(+), 20 deletions(-) diff --git a/poetry.lock b/poetry.lock index 1d17ba6b52..1c94a5f29a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3717,6 +3717,21 @@ files = [ [package.extras] windows-terminal = ["colorama (>=0.4.6)"] +[[package]] +name = "pyiceberg-core" +version = "0.4.0" +description = "" +optional = true +python-versions = "*" +files = [ + {file = "pyiceberg_core-0.4.0-cp39-abi3-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:5aec569271c96e18428d542f9b7007117a7232c06017f95cb239d42e952ad3b4"}, + {file = "pyiceberg_core-0.4.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e74773e58efa4df83aba6f6265cdd41e446fa66fa4e343ca86395fed9f209ae"}, + {file = "pyiceberg_core-0.4.0-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7675d21a54bf3753c740d8df78ad7efe33f438096844e479d4f3493f84830925"}, + {file = "pyiceberg_core-0.4.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7058ad935a40b1838e4cdc5febd768878c1a51f83dca005d5a52a7fa280a2489"}, + {file = "pyiceberg_core-0.4.0-cp39-abi3-win_amd64.whl", hash = "sha256:a83eb4c2307ae3dd321a9360828fb043a4add2cc9797bef0bafa20894488fb07"}, + {file = "pyiceberg_core-0.4.0.tar.gz", hash = "sha256:d2e6138707868477b806ed354aee9c476e437913a331cb9ad9ad46b4054cd11f"}, +] + [[package]] name = "pyjwt" version = "2.10.1" @@ -5346,6 +5361,7 @@ glue = ["boto3", "mypy-boto3-glue"] hive = ["thrift"] pandas = ["pandas", "pyarrow"] pyarrow = ["pyarrow"] +pyiceberg-core = ["pyiceberg-core"] ray = ["pandas", "pyarrow", "ray", "ray"] rest-sigv4 = ["boto3"] s3fs = ["s3fs"] @@ -5357,4 +5373,4 @@ zstandard = ["zstandard"] [metadata] lock-version = "2.0" python-versions = "^3.9, !=3.9.7" -content-hash = "306213628bcc69346e14742843c8e6bccf19c2615886943c2e1482a954a388ec" +content-hash = "cc789ef423714710f51e5452de7071642f4512511b1d205f77b952bb1df63a64" diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index 84e1c942d3..22dcdfe88a 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -85,6 +85,8 @@ if TYPE_CHECKING: import pyarrow as pa + ArrayLike = TypeVar("ArrayLike", pa.Array, pa.ChunkedArray) + S = TypeVar("S") T = TypeVar("T") @@ -193,6 +195,27 @@ def supports_pyarrow_transform(self) -> bool: @abstractmethod def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ... + def _pyiceberg_transform_wrapper( + self, transform_func: Callable[["ArrayLike", Any], "ArrayLike"], *args: Any + ) -> Callable[["ArrayLike"], "ArrayLike"]: + try: + import pyarrow as pa + except ModuleNotFoundError as e: + raise ModuleNotFoundError("For bucket/truncate transforms, PyArrow needs to be installed") from e + + def _transform(array: "ArrayLike") -> "ArrayLike": + if isinstance(array, pa.Array): + return transform_func(array, *args) + elif isinstance(array, pa.ChunkedArray): + result_chunks = [] + for arr in array.iterchunks(): + result_chunks.append(transform_func(arr, *args)) + return pa.chunked_array(result_chunks) + else: + raise ValueError(f"PyArrow array can only be of type pa.Array or pa.ChunkedArray, but found {type(array)}") + + return _transform + class BucketTransform(Transform[S, int]): """Base Transform class to transform a value into a bucket partition value. @@ -309,7 +332,13 @@ def __repr__(self) -> str: return f"BucketTransform(num_buckets={self._num_buckets})" def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": - raise NotImplementedError() + from pyiceberg_core import transform as pyiceberg_core_transform + + return self._pyiceberg_transform_wrapper(pyiceberg_core_transform.bucket, self._num_buckets) + + @property + def supports_pyarrow_transform(self) -> bool: + return True class TimeResolution(IntEnum): @@ -827,7 +856,13 @@ def __repr__(self) -> str: return f"TruncateTransform(width={self._width})" def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": - raise NotImplementedError() + from pyiceberg_core import transform as pyiceberg_core_transform + + return self._pyiceberg_transform_wrapper(pyiceberg_core_transform.truncate, self._width) + + @property + def supports_pyarrow_transform(self) -> bool: + return True @singledispatch diff --git a/pyproject.toml b/pyproject.toml index 4b425141b5..5d2808db94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,7 @@ psycopg2-binary = { version = ">=2.9.6", optional = true } sqlalchemy = { version = "^2.0.18", optional = true } getdaft = { version = ">=0.2.12", optional = true } cachetools = "^5.5.0" +pyiceberg-core = { version = "^0.4.0", optional = true } [tool.poetry.group.dev.dependencies] pytest = "7.4.4" @@ -842,6 +843,10 @@ ignore_missing_imports = true module = "daft.*" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "pyiceberg_core.*" +ignore_missing_imports = true + [[tool.mypy.overrides]] module = "pyparsing.*" ignore_missing_imports = true @@ -1206,6 +1211,7 @@ sql-postgres = ["sqlalchemy", "psycopg2-binary"] sql-sqlite = ["sqlalchemy"] gcsfs = ["gcsfs"] rest-sigv4 = ["boto3"] +pyiceberg-core = ["pyiceberg-core"] [tool.pytest.ini_options] markers = [ diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index 9e7632852c..1e6ea1b797 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -412,6 +412,12 @@ def test_dynamic_partition_overwrite_unpartitioned_evolve_to_identity_transform( spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, part_col: str, format_version: int ) -> None: identifier = f"default.unpartitioned_table_v{format_version}_evolve_into_identity_transformed_partition_field_{part_col}" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + tbl = session_catalog.create_table( identifier=identifier, schema=TABLE_SCHEMA, @@ -756,6 +762,55 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> Non tbl.append("not a df") +@pytest.mark.integration +@pytest.mark.parametrize( + "spec", + [ + (PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=TruncateTransform(2), name="int_trunc"))), + (PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=TruncateTransform(2), name="long_trunc"))), + (PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=TruncateTransform(2), name="string_trunc"))), + ], +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_truncate_transform( + spec: PartitionSpec, + spark: SparkSession, + session_catalog: Catalog, + arrow_table_with_null: pa.Table, + format_version: int, +) -> None: + identifier = "default.truncate_transform" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": str(format_version)}, + data=[arrow_table_with_null], + partition_spec=spec, + ) + + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + df = spark.table(identifier) + assert df.count() == 3, f"Expected 3 total rows for {identifier}" + for col in arrow_table_with_null.column_names: + assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" + + assert tbl.inspect.partitions().num_rows == 3 + files_df = spark.sql( + f""" + SELECT * + FROM {identifier}.files + """ + ) + assert files_df.count() == 3 + + @pytest.mark.integration @pytest.mark.parametrize( "spec", @@ -767,18 +822,52 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> Non PartitionField(source_id=1, field_id=1002, transform=IdentityTransform(), name="bool"), ) ), - # none of non-identity is supported - (PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket"))), - (PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=BucketTransform(2), name="long_bucket"))), - (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=BucketTransform(2), name="date_bucket"))), - (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=BucketTransform(2), name="timestamp_bucket"))), - (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=BucketTransform(2), name="timestamptz_bucket"))), - (PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=BucketTransform(2), name="string_bucket"))), - (PartitionSpec(PartitionField(source_id=12, field_id=1001, transform=BucketTransform(2), name="fixed_bucket"))), - (PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=BucketTransform(2), name="binary_bucket"))), - (PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=TruncateTransform(2), name="int_trunc"))), - (PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=TruncateTransform(2), name="long_trunc"))), - (PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=TruncateTransform(2), name="string_trunc"))), + ], +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_identity_and_bucket_transform_spec( + spec: PartitionSpec, + spark: SparkSession, + session_catalog: Catalog, + arrow_table_with_null: pa.Table, + format_version: int, +) -> None: + identifier = "default.identity_and_bucket_transform" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": str(format_version)}, + data=[arrow_table_with_null], + partition_spec=spec, + ) + + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + df = spark.table(identifier) + assert df.count() == 3, f"Expected 3 total rows for {identifier}" + for col in arrow_table_with_null.column_names: + assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" + + assert tbl.inspect.partitions().num_rows == 3 + files_df = spark.sql( + f""" + SELECT * + FROM {identifier}.files + """ + ) + assert files_df.count() == 3 + + +@pytest.mark.integration +@pytest.mark.parametrize( + "spec", + [ (PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=TruncateTransform(2), name="binary_trunc"))), ], ) @@ -801,11 +890,66 @@ def test_unsupported_transform( with pytest.raises( ValueError, - match="Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: *", + match="FeatureUnsupported => Unsupported data type for truncate transform: LargeBinary", ): tbl.append(arrow_table_with_null) +@pytest.mark.integration +@pytest.mark.parametrize( + "spec, expected_rows", + [ + (PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket")), 3), + (PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=BucketTransform(2), name="long_bucket")), 3), + (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=BucketTransform(2), name="date_bucket")), 3), + (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=BucketTransform(2), name="timestamp_bucket")), 3), + (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=BucketTransform(2), name="timestamptz_bucket")), 3), + (PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=BucketTransform(2), name="string_bucket")), 3), + (PartitionSpec(PartitionField(source_id=12, field_id=1001, transform=BucketTransform(2), name="fixed_bucket")), 2), + (PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=BucketTransform(2), name="binary_bucket")), 2), + ], +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_bucket_transform( + spark: SparkSession, + session_catalog: Catalog, + arrow_table_with_null: pa.Table, + spec: PartitionSpec, + expected_rows: int, + format_version: int, +) -> None: + identifier = "default.bucket_transform" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": str(format_version)}, + data=[arrow_table_with_null], + partition_spec=spec, + ) + + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + df = spark.table(identifier) + assert df.count() == 3, f"Expected 3 total rows for {identifier}" + for col in arrow_table_with_null.column_names: + assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" + + assert tbl.inspect.partitions().num_rows == expected_rows + files_df = spark.sql( + f""" + SELECT * + FROM {identifier}.files + """ + ) + assert files_df.count() == expected_rows + + @pytest.mark.integration @pytest.mark.parametrize( "transform,expected_rows", diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 6d04a1e4ce..3088719a06 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -18,10 +18,11 @@ # pylint: disable=eval-used,protected-access,redefined-outer-name from datetime import date from decimal import Decimal -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import Any, Callable, Optional, Union from uuid import UUID import mmh3 as mmh3 +import pyarrow as pa import pytest from pydantic import ( BeforeValidator, @@ -116,9 +117,6 @@ timestamptz_to_micros, ) -if TYPE_CHECKING: - import pyarrow as pa - @pytest.mark.parametrize( "test_input,test_type,expected", @@ -1563,3 +1561,43 @@ def test_ymd_pyarrow_transforms( else: with pytest.raises(ValueError): transform.pyarrow_transform(DateType())(arrow_table_date_timestamps[source_col]) + + +@pytest.mark.parametrize( + "source_type, input_arr, expected, num_buckets", + [ + (IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 10), + ( + IntegerType(), + pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])]), + pa.chunked_array([pa.array([6, 2], type=pa.int32()), pa.array([5, 0], type=pa.int32())]), + 10, + ), + (IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 10), + ], +) +def test_bucket_pyarrow_transforms( + source_type: PrimitiveType, + input_arr: Union[pa.Array, pa.ChunkedArray], + expected: Union[pa.Array, pa.ChunkedArray], + num_buckets: int, +) -> None: + transform: Transform[Any, Any] = BucketTransform(num_buckets=num_buckets) + assert expected == transform.pyarrow_transform(source_type)(input_arr) + + +@pytest.mark.parametrize( + "source_type, input_arr, expected, width", + [ + (StringType(), pa.array(["developer", "iceberg"]), pa.array(["dev", "ice"]), 3), + (IntegerType(), pa.array([1, -1]), pa.array([0, -10]), 10), + ], +) +def test_truncate_pyarrow_transforms( + source_type: PrimitiveType, + input_arr: Union[pa.Array, pa.ChunkedArray], + expected: Union[pa.Array, pa.ChunkedArray], + width: int, +) -> None: + transform: Transform[Any, Any] = TruncateTransform(width=width) + assert expected == transform.pyarrow_transform(source_type)(input_arr)