Skip to content

Commit 0e72d90

Browse files
committed
adopt nits
1 parent 1163c2a commit 0e72d90

File tree

2 files changed

+94
-1
lines changed

2 files changed

+94
-1
lines changed

pyiceberg/transforms.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,10 @@ def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Arr
198198
def _pyiceberg_transform_wrapper(
199199
self, transform_func: Callable[["ArrayLike", Any], "ArrayLike"], *args: Any
200200
) -> Callable[["ArrayLike"], "ArrayLike"]:
201-
import pyarrow as pa
201+
try:
202+
import pyarrow as pa
203+
except ModuleNotFoundError as e:
204+
raise ModuleNotFoundError("For bucket/truncate transforms, PyArrow needs to be installed") from e
202205

203206
def _transform(array: "ArrayLike") -> "ArrayLike":
204207
if isinstance(array, pa.Array):

tests/integration/test_writes/test_partitioned_writes.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,12 @@ def test_dynamic_partition_overwrite_unpartitioned_evolve_to_identity_transform(
412412
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, part_col: str, format_version: int
413413
) -> None:
414414
identifier = f"default.unpartitioned_table_v{format_version}_evolve_into_identity_transformed_partition_field_{part_col}"
415+
416+
try:
417+
session_catalog.drop_table(identifier=identifier)
418+
except NoSuchTableError:
419+
pass
420+
415421
tbl = session_catalog.create_table(
416422
identifier=identifier,
417423
schema=TABLE_SCHEMA,
@@ -805,6 +811,90 @@ def test_truncate_transform(
805811
assert files_df.count() == 3
806812

807813

814+
@pytest.mark.integration
815+
@pytest.mark.parametrize(
816+
"spec",
817+
[
818+
# mixed with non-identity is not supported
819+
(
820+
PartitionSpec(
821+
PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket"),
822+
PartitionField(source_id=1, field_id=1002, transform=IdentityTransform(), name="bool"),
823+
)
824+
),
825+
],
826+
)
827+
@pytest.mark.parametrize("format_version", [1, 2])
828+
def test_identity_and_bucket_transform_spec(
829+
spec: PartitionSpec,
830+
spark: SparkSession,
831+
session_catalog: Catalog,
832+
arrow_table_with_null: pa.Table,
833+
format_version: int,
834+
) -> None:
835+
identifier = "default.identity_and_bucket_transform"
836+
837+
try:
838+
session_catalog.drop_table(identifier=identifier)
839+
except NoSuchTableError:
840+
pass
841+
842+
tbl = _create_table(
843+
session_catalog=session_catalog,
844+
identifier=identifier,
845+
properties={"format-version": str(format_version)},
846+
data=[arrow_table_with_null],
847+
partition_spec=spec,
848+
)
849+
850+
assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}"
851+
df = spark.table(identifier)
852+
assert df.count() == 3, f"Expected 3 total rows for {identifier}"
853+
for col in arrow_table_with_null.column_names:
854+
assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}"
855+
assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null"
856+
857+
assert tbl.inspect.partitions().num_rows == 3
858+
files_df = spark.sql(
859+
f"""
860+
SELECT *
861+
FROM {identifier}.files
862+
"""
863+
)
864+
assert files_df.count() == 3
865+
866+
867+
@pytest.mark.integration
868+
@pytest.mark.parametrize(
869+
"spec",
870+
[
871+
(PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=TruncateTransform(2), name="binary_trunc"))),
872+
],
873+
)
874+
def test_unsupported_transform(
875+
spec: PartitionSpec, spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table
876+
) -> None:
877+
identifier = "default.unsupported_transform"
878+
879+
try:
880+
session_catalog.drop_table(identifier=identifier)
881+
except NoSuchTableError:
882+
pass
883+
884+
tbl = session_catalog.create_table(
885+
identifier=identifier,
886+
schema=TABLE_SCHEMA,
887+
partition_spec=spec,
888+
properties={"format-version": "1"},
889+
)
890+
891+
with pytest.raises(
892+
ValueError,
893+
match="FeatureUnsupported => Unsupported data type for truncate transform: LargeBinary",
894+
):
895+
tbl.append(arrow_table_with_null)
896+
897+
808898
@pytest.mark.integration
809899
@pytest.mark.parametrize(
810900
"spec, expected_rows",

0 commit comments

Comments
 (0)