From 0e72d908de3f2e14548f9e2b7e82f7a019251afd Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+sungwy@users.noreply.github.com> Date: Thu, 16 Jan 2025 15:32:58 +0000 Subject: [PATCH] adopt nits --- pyiceberg/transforms.py | 5 +- .../test_writes/test_partitioned_writes.py | 90 +++++++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index 12f25ed7ac..22dcdfe88a 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -198,7 +198,10 @@ def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Arr def _pyiceberg_transform_wrapper( self, transform_func: Callable[["ArrayLike", Any], "ArrayLike"], *args: Any ) -> Callable[["ArrayLike"], "ArrayLike"]: - import pyarrow as pa + try: + import pyarrow as pa + except ModuleNotFoundError as e: + raise ModuleNotFoundError("For bucket/truncate transforms, PyArrow needs to be installed") from e def _transform(array: "ArrayLike") -> "ArrayLike": if isinstance(array, pa.Array): diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index 3c59897b07..1e6ea1b797 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -412,6 +412,12 @@ def test_dynamic_partition_overwrite_unpartitioned_evolve_to_identity_transform( spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, part_col: str, format_version: int ) -> None: identifier = f"default.unpartitioned_table_v{format_version}_evolve_into_identity_transformed_partition_field_{part_col}" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + tbl = session_catalog.create_table( identifier=identifier, schema=TABLE_SCHEMA, @@ -805,6 +811,90 @@ def test_truncate_transform( assert files_df.count() == 3 +@pytest.mark.integration +@pytest.mark.parametrize( + "spec", + [ + # mixed with non-identity is not supported + ( + PartitionSpec( + PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket"), + PartitionField(source_id=1, field_id=1002, transform=IdentityTransform(), name="bool"), + ) + ), + ], +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_identity_and_bucket_transform_spec( + spec: PartitionSpec, + spark: SparkSession, + session_catalog: Catalog, + arrow_table_with_null: pa.Table, + format_version: int, +) -> None: + identifier = "default.identity_and_bucket_transform" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": str(format_version)}, + data=[arrow_table_with_null], + partition_spec=spec, + ) + + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + df = spark.table(identifier) + assert df.count() == 3, f"Expected 3 total rows for {identifier}" + for col in arrow_table_with_null.column_names: + assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" + + assert tbl.inspect.partitions().num_rows == 3 + files_df = spark.sql( + f""" + SELECT * + FROM {identifier}.files + """ + ) + assert files_df.count() == 3 + + +@pytest.mark.integration +@pytest.mark.parametrize( + "spec", + [ + (PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=TruncateTransform(2), name="binary_trunc"))), + ], +) +def test_unsupported_transform( + spec: PartitionSpec, spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table +) -> None: + identifier = "default.unsupported_transform" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + tbl = session_catalog.create_table( + identifier=identifier, + schema=TABLE_SCHEMA, + partition_spec=spec, + properties={"format-version": "1"}, + ) + + with pytest.raises( + ValueError, + match="FeatureUnsupported => Unsupported data type for truncate transform: LargeBinary", + ): + tbl.append(arrow_table_with_null) + + @pytest.mark.integration @pytest.mark.parametrize( "spec, expected_rows",