Skip to content

Commit

Permalink
adopt nits
Browse files Browse the repository at this point in the history
  • Loading branch information
sungwy committed Jan 16, 2025
1 parent 1163c2a commit 0e72d90
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 1 deletion.
5 changes: 4 additions & 1 deletion pyiceberg/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,10 @@ def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Arr
def _pyiceberg_transform_wrapper(
self, transform_func: Callable[["ArrayLike", Any], "ArrayLike"], *args: Any
) -> Callable[["ArrayLike"], "ArrayLike"]:
import pyarrow as pa
try:
import pyarrow as pa
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For bucket/truncate transforms, PyArrow needs to be installed") from e

def _transform(array: "ArrayLike") -> "ArrayLike":
if isinstance(array, pa.Array):
Expand Down
90 changes: 90 additions & 0 deletions tests/integration/test_writes/test_partitioned_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,12 @@ def test_dynamic_partition_overwrite_unpartitioned_evolve_to_identity_transform(
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, part_col: str, format_version: int
) -> None:
identifier = f"default.unpartitioned_table_v{format_version}_evolve_into_identity_transformed_partition_field_{part_col}"

try:
session_catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

tbl = session_catalog.create_table(
identifier=identifier,
schema=TABLE_SCHEMA,
Expand Down Expand Up @@ -805,6 +811,90 @@ def test_truncate_transform(
assert files_df.count() == 3


@pytest.mark.integration
@pytest.mark.parametrize(
"spec",
[
# mixed with non-identity is not supported
(
PartitionSpec(
PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket"),
PartitionField(source_id=1, field_id=1002, transform=IdentityTransform(), name="bool"),
)
),
],
)
@pytest.mark.parametrize("format_version", [1, 2])
def test_identity_and_bucket_transform_spec(
spec: PartitionSpec,
spark: SparkSession,
session_catalog: Catalog,
arrow_table_with_null: pa.Table,
format_version: int,
) -> None:
identifier = "default.identity_and_bucket_transform"

try:
session_catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

tbl = _create_table(
session_catalog=session_catalog,
identifier=identifier,
properties={"format-version": str(format_version)},
data=[arrow_table_with_null],
partition_spec=spec,
)

assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}"
df = spark.table(identifier)
assert df.count() == 3, f"Expected 3 total rows for {identifier}"
for col in arrow_table_with_null.column_names:
assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}"
assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null"

assert tbl.inspect.partitions().num_rows == 3
files_df = spark.sql(
f"""
SELECT *
FROM {identifier}.files
"""
)
assert files_df.count() == 3


@pytest.mark.integration
@pytest.mark.parametrize(
"spec",
[
(PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=TruncateTransform(2), name="binary_trunc"))),
],
)
def test_unsupported_transform(
spec: PartitionSpec, spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table
) -> None:
identifier = "default.unsupported_transform"

try:
session_catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

tbl = session_catalog.create_table(
identifier=identifier,
schema=TABLE_SCHEMA,
partition_spec=spec,
properties={"format-version": "1"},
)

with pytest.raises(
ValueError,
match="FeatureUnsupported => Unsupported data type for truncate transform: LargeBinary",
):
tbl.append(arrow_table_with_null)


@pytest.mark.integration
@pytest.mark.parametrize(
"spec, expected_rows",
Expand Down

0 comments on commit 0e72d90

Please sign in to comment.