diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index 44d32c9449..12f25ed7ac 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -85,6 +85,8 @@ if TYPE_CHECKING: import pyarrow as pa + ArrayLike = TypeVar("ArrayLike", pa.Array, pa.ChunkedArray) + S = TypeVar("S") T = TypeVar("T") @@ -193,6 +195,24 @@ def supports_pyarrow_transform(self) -> bool: @abstractmethod def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ... + def _pyiceberg_transform_wrapper( + self, transform_func: Callable[["ArrayLike", Any], "ArrayLike"], *args: Any + ) -> Callable[["ArrayLike"], "ArrayLike"]: + import pyarrow as pa + + def _transform(array: "ArrayLike") -> "ArrayLike": + if isinstance(array, pa.Array): + return transform_func(array, *args) + elif isinstance(array, pa.ChunkedArray): + result_chunks = [] + for arr in array.iterchunks(): + result_chunks.append(transform_func(arr, *args)) + return pa.chunked_array(result_chunks) + else: + raise ValueError(f"PyArrow array can only be of type pa.Array or pa.ChunkedArray, but found {type(array)}") + + return _transform + class BucketTransform(Transform[S, int]): """Base Transform class to transform a value into a bucket partition value. @@ -309,23 +329,9 @@ def __repr__(self) -> str: return f"BucketTransform(num_buckets={self._num_buckets})" def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": - import pyarrow as pa from pyiceberg_core import transform as pyiceberg_core_transform - ArrayLike = TypeVar("ArrayLike", pa.Array, pa.ChunkedArray) - - def bucket(array: ArrayLike) -> ArrayLike: - if isinstance(array, pa.Array): - return pyiceberg_core_transform.bucket(array, self._num_buckets) - elif isinstance(array, pa.ChunkedArray): - result_chunks = [] - for arr in array.iterchunks(): - result_chunks.append(pyiceberg_core_transform.bucket(arr, self._num_buckets)) - return pa.chunked_array(result_chunks) - else: - raise ValueError(f"PyArrow array can only be of type pa.Array or pa.ChunkedArray, but found {type(array)}") - - return bucket + return self._pyiceberg_transform_wrapper(pyiceberg_core_transform.bucket, self._num_buckets) @property def supports_pyarrow_transform(self) -> bool: @@ -847,7 +853,13 @@ def __repr__(self) -> str: return f"TruncateTransform(width={self._width})" def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": - raise NotImplementedError() + from pyiceberg_core import transform as pyiceberg_core_transform + + return self._pyiceberg_transform_wrapper(pyiceberg_core_transform.truncate, self._width) + + @property + def supports_pyarrow_transform(self) -> bool: + return True @singledispatch diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index b92c338931..16b668fd85 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -719,50 +719,105 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> Non @pytest.mark.parametrize( "spec", [ - # mixed with non-identity is not supported - ( - PartitionSpec( - PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket"), - PartitionField(source_id=1, field_id=1002, transform=IdentityTransform(), name="bool"), - ) - ), - # none of non-identity is supported - (PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket"))), - (PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=BucketTransform(2), name="long_bucket"))), - (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=BucketTransform(2), name="date_bucket"))), - (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=BucketTransform(2), name="timestamp_bucket"))), - (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=BucketTransform(2), name="timestamptz_bucket"))), - (PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=BucketTransform(2), name="string_bucket"))), - (PartitionSpec(PartitionField(source_id=12, field_id=1001, transform=BucketTransform(2), name="fixed_bucket"))), - (PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=BucketTransform(2), name="binary_bucket"))), (PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=TruncateTransform(2), name="int_trunc"))), (PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=TruncateTransform(2), name="long_trunc"))), (PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=TruncateTransform(2), name="string_trunc"))), - (PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=TruncateTransform(2), name="binary_trunc"))), ], ) -def test_unsupported_transform( - spec: PartitionSpec, spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table +@pytest.mark.parametrize("format_version", [1, 2]) +def test_truncate_transform( + spec: PartitionSpec, + spark: SparkSession, + session_catalog: Catalog, + arrow_table_with_null: pa.Table, + format_version: int, ) -> None: - identifier = "default.unsupported_transform" + identifier = "default.truncate_transform" try: session_catalog.drop_table(identifier=identifier) except NoSuchTableError: pass - tbl = session_catalog.create_table( + tbl = _create_table( + session_catalog=session_catalog, identifier=identifier, - schema=TABLE_SCHEMA, + properties={"format-version": str(format_version)}, + data=[arrow_table_with_null], partition_spec=spec, - properties={"format-version": "1"}, ) - with pytest.raises( - ValueError, - match="Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: *", - ): - tbl.append(arrow_table_with_null) + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + df = spark.table(identifier) + assert df.count() == 3, f"Expected 3 total rows for {identifier}" + for col in arrow_table_with_null.column_names: + assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" + + assert tbl.inspect.partitions().num_rows == 3 + files_df = spark.sql( + f""" + SELECT * + FROM {identifier}.files + """ + ) + assert files_df.count() == 3 + + +@pytest.mark.integration +@pytest.mark.parametrize( + "spec, expected_rows", + [ + # none of non-identity is supported + (PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket")), 3), + (PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=BucketTransform(2), name="long_bucket")), 3), + (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=BucketTransform(2), name="date_bucket")), 3), + (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=BucketTransform(2), name="timestamp_bucket")), 3), + (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=BucketTransform(2), name="timestamptz_bucket")), 3), + (PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=BucketTransform(2), name="string_bucket")), 3), + (PartitionSpec(PartitionField(source_id=12, field_id=1001, transform=BucketTransform(2), name="fixed_bucket")), 2), + (PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=BucketTransform(2), name="binary_bucket")), 2), + ], +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_bucket_transform( + spark: SparkSession, + session_catalog: Catalog, + arrow_table_with_null: pa.Table, + spec: PartitionSpec, + expected_rows: int, + format_version: int, +) -> None: + identifier = "default.bucket_transform" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": str(format_version)}, + data=[arrow_table_with_null], + partition_spec=spec, + ) + + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + df = spark.table(identifier) + assert df.count() == 3, f"Expected 3 total rows for {identifier}" + for col in arrow_table_with_null.column_names: + assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" + + assert tbl.inspect.partitions().num_rows == expected_rows + files_df = spark.sql( + f""" + SELECT * + FROM {identifier}.files + """ + ) + assert files_df.count() == expected_rows @pytest.mark.integration diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 50ed775272..2fa459527e 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1584,3 +1584,20 @@ def test_bucket_pyarrow_transforms( ) -> None: transform: Transform[Any, Any] = BucketTransform(num_buckets=num_buckets) assert expected == transform.pyarrow_transform(source_type)(input_arr) + + +@pytest.mark.parametrize( + "source_type, input_arr, expected, width", + [ + (StringType(), pa.array(["hello", "iceberg"]), pa.array(["hel", "ice"]), 3), + (IntegerType(), pa.array([1, -1]), pa.array([0, -10]), 10), + ], +) +def test_truncate_pyarrow_transforms( + source_type: PrimitiveType, + input_arr: Union[pa.Array, pa.ChunkedArray], + expected: Union[pa.Array, pa.ChunkedArray], + width: int, +) -> None: + transform: Transform[Any, Any] = TruncateTransform(width=width) + assert expected == transform.pyarrow_transform(source_type)(input_arr)