diff --git a/pyiceberg/avro/writer.py b/pyiceberg/avro/writer.py index 80e96b04ad..6fa485f21a 100644 --- a/pyiceberg/avro/writer.py +++ b/pyiceberg/avro/writer.py @@ -32,6 +32,7 @@ List, Optional, Tuple, + Union, ) from uuid import UUID @@ -121,8 +122,11 @@ def write(self, encoder: BinaryEncoder, val: Any) -> None: @dataclass(frozen=True) class UUIDWriter(Writer): - def write(self, encoder: BinaryEncoder, val: UUID) -> None: - encoder.write(val.bytes) + def write(self, encoder: BinaryEncoder, val: Union[UUID, bytes]) -> None: + if isinstance(val, UUID): + encoder.write(val.bytes) + else: + encoder.write(val) @dataclass(frozen=True) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 1aaab32dbe..bbac075dc2 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -684,7 +684,7 @@ def visit_string(self, _: StringType) -> pa.DataType: return pa.large_string() def visit_uuid(self, _: UUIDType) -> pa.DataType: - return pa.binary(16) + return pa.uuid() def visit_unknown(self, _: UnknownType) -> pa.DataType: return pa.null() @@ -1252,6 +1252,8 @@ def primitive(self, primitive: pa.DataType) -> PrimitiveType: return FixedType(primitive.byte_width) elif pa.types.is_null(primitive): return UnknownType() + elif isinstance(primitive, pa.UuidType): + return UUIDType() raise TypeError(f"Unsupported type: {primitive}") diff --git a/pyiceberg/partitioning.py b/pyiceberg/partitioning.py index df07f94342..dd707cea14 100644 --- a/pyiceberg/partitioning.py +++ b/pyiceberg/partitioning.py @@ -467,8 +467,17 @@ def _(type: IcebergType, value: Optional[time]) -> Optional[int]: @_to_partition_representation.register(UUIDType) -def _(type: IcebergType, value: Optional[uuid.UUID]) -> Optional[str]: - return str(value) if value is not None else None +def _(type: IcebergType, value: Optional[Union[uuid.UUID, int, bytes]]) -> Optional[Union[bytes, int]]: + if value is None: + return None + elif isinstance(value, bytes): + return value # IdentityTransform + elif isinstance(value, uuid.UUID): + return value.bytes # IdentityTransform + elif isinstance(value, int): + return value # BucketTransform + else: + raise ValueError(f"Type not recognized: {value}") @_to_partition_representation.register(PrimitiveType) diff --git a/tests/conftest.py b/tests/conftest.py index 729e29cb0c..74c803863d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2788,7 +2788,7 @@ def pyarrow_schema_with_promoted_types() -> "pa.Schema": pa.field("list", pa.list_(pa.int32()), nullable=False), # can support upcasting integer to long pa.field("map", pa.map_(pa.string(), pa.int32()), nullable=False), # can support upcasting integer to long pa.field("double", pa.float32(), nullable=True), # can support upcasting float to double - pa.field("uuid", pa.binary(length=16), nullable=True), # can support upcasting float to double + pa.field("uuid", pa.binary(length=16), nullable=True), # can support upcasting fixed to uuid ) ) @@ -2804,7 +2804,10 @@ def pyarrow_table_with_promoted_types(pyarrow_schema_with_promoted_types: "pa.Sc "list": [[1, 1], [2, 2]], "map": [{"a": 1}, {"b": 2}], "double": [1.1, 9.2], - "uuid": [b"qZx\xefNS@\x89\x9b\xf9:\xd0\xee\x9b\xf5E", b"\x97]\x87T^JDJ\x96\x97\xf4v\xe4\x03\x0c\xde"], + "uuid": [ + uuid.UUID("00000000-0000-0000-0000-000000000000").bytes, + uuid.UUID("11111111-1111-1111-1111-111111111111").bytes, + ], }, schema=pyarrow_schema_with_promoted_types, ) diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index 3d36ffcf31..47e56be1f3 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -737,7 +737,7 @@ def test_add_files_with_valid_upcast( with pq.ParquetWriter(fos, schema=pyarrow_schema_with_promoted_types) as writer: writer.write_table(pyarrow_table_with_promoted_types) - tbl.add_files(file_paths=[file_path]) + tbl.add_files(file_paths=[file_path], check_duplicate_files=False) # table's long field should cast to long on read written_arrow_table = tbl.scan().to_arrow() assert written_arrow_table == pyarrow_table_with_promoted_types.cast( @@ -747,7 +747,7 @@ def test_add_files_with_valid_upcast( pa.field("list", pa.list_(pa.int64()), nullable=False), pa.field("map", pa.map_(pa.string(), pa.int64()), nullable=False), pa.field("double", pa.float64(), nullable=True), - pa.field("uuid", pa.binary(length=16), nullable=True), # can UUID is read as fixed length binary of length 16 + pa.field("uuid", pa.uuid(), nullable=True), ) ) ) diff --git a/tests/integration/test_partitioning_key.py b/tests/integration/test_partitioning_key.py index f9bdd4eead..1908ec16f3 100644 --- a/tests/integration/test_partitioning_key.py +++ b/tests/integration/test_partitioning_key.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. # pylint:disable=redefined-outer-name -import uuid from datetime import date, datetime, timedelta, timezone from decimal import Decimal from typing import Any, List @@ -308,25 +307,6 @@ (CAST('2023-01-01' AS DATE), 'Associated string value for date 2023-01-01') """, ), - ( - [PartitionField(source_id=14, field_id=1001, transform=IdentityTransform(), name="uuid_field")], - [uuid.UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479")], - Record("f47ac10b-58cc-4372-a567-0e02b2c3d479"), - "uuid_field=f47ac10b-58cc-4372-a567-0e02b2c3d479", - f"""CREATE TABLE {identifier} ( - uuid_field string, - string_field string - ) - USING iceberg - PARTITIONED BY ( - identity(uuid_field) - ) - """, - f"""INSERT INTO {identifier} - VALUES - ('f47ac10b-58cc-4372-a567-0e02b2c3d479', 'Associated string value for UUID f47ac10b-58cc-4372-a567-0e02b2c3d479') - """, - ), ( [PartitionField(source_id=11, field_id=1001, transform=IdentityTransform(), name="binary_field")], [b"example"], diff --git a/tests/integration/test_reads.py b/tests/integration/test_reads.py index b417a43616..a9fb9246a9 100644 --- a/tests/integration/test_reads.py +++ b/tests/integration/test_reads.py @@ -588,15 +588,15 @@ def test_partitioned_tables(catalog: Catalog) -> None: def test_unpartitioned_uuid_table(catalog: Catalog) -> None: unpartitioned_uuid = catalog.load_table("default.test_uuid_and_fixed_unpartitioned") arrow_table_eq = unpartitioned_uuid.scan(row_filter="uuid_col == '102cb62f-e6f8-4eb0-9973-d9b012ff0967'").to_arrow() - assert arrow_table_eq["uuid_col"].to_pylist() == [uuid.UUID("102cb62f-e6f8-4eb0-9973-d9b012ff0967").bytes] + assert arrow_table_eq["uuid_col"].to_pylist() == [uuid.UUID("102cb62f-e6f8-4eb0-9973-d9b012ff0967")] arrow_table_neq = unpartitioned_uuid.scan( row_filter="uuid_col != '102cb62f-e6f8-4eb0-9973-d9b012ff0967' and uuid_col != '639cccce-c9d2-494a-a78c-278ab234f024'" ).to_arrow() assert arrow_table_neq["uuid_col"].to_pylist() == [ - uuid.UUID("ec33e4b2-a834-4cc3-8c4a-a1d3bfc2f226").bytes, - uuid.UUID("c1b0d8e0-0b0e-4b1e-9b0a-0e0b0d0c0a0b").bytes, - uuid.UUID("923dae77-83d6-47cd-b4b0-d383e64ee57e").bytes, + uuid.UUID("ec33e4b2-a834-4cc3-8c4a-a1d3bfc2f226"), + uuid.UUID("c1b0d8e0-0b0e-4b1e-9b0a-0e0b0d0c0a0b"), + uuid.UUID("923dae77-83d6-47cd-b4b0-d383e64ee57e"), ] diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 493b163b95..d84450e173 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -19,6 +19,7 @@ import os import random import time +import uuid from datetime import date, datetime, timedelta from decimal import Decimal from pathlib import Path @@ -49,7 +50,7 @@ from pyiceberg.schema import Schema from pyiceberg.table import TableProperties from pyiceberg.table.sorting import SortDirection, SortField, SortOrder -from pyiceberg.transforms import DayTransform, HourTransform, IdentityTransform +from pyiceberg.transforms import DayTransform, HourTransform, IdentityTransform, Transform from pyiceberg.types import ( DateType, DecimalType, @@ -59,6 +60,7 @@ LongType, NestedField, StringType, + UUIDType, ) from utils import _create_table @@ -1272,7 +1274,7 @@ def test_table_write_schema_with_valid_upcast( pa.field("list", pa.list_(pa.int64()), nullable=False), pa.field("map", pa.map_(pa.string(), pa.int64()), nullable=False), pa.field("double", pa.float64(), nullable=True), # can support upcasting float to double - pa.field("uuid", pa.binary(length=16), nullable=True), # can UUID is read as fixed length binary of length 16 + pa.field("uuid", pa.uuid(), nullable=True), ) ) ) @@ -1844,6 +1846,59 @@ def test_read_write_decimals(session_catalog: Catalog) -> None: assert tbl.scan().to_arrow() == arrow_table +@pytest.mark.integration +@pytest.mark.parametrize( + "transform", + [ + IdentityTransform(), + # Bucket is disabled because of an issue in Iceberg Java: + # https://github.com/apache/iceberg/pull/13324 + # BucketTransform(32) + ], +) +def test_uuid_partitioning(session_catalog: Catalog, spark: SparkSession, transform: Transform) -> None: # type: ignore + identifier = f"default.test_uuid_partitioning_{str(transform).replace('[32]', '')}" + + schema = Schema(NestedField(field_id=1, name="uuid", field_type=UUIDType(), required=True)) + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + partition_spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=transform, name="uuid_identity")) + + import pyarrow as pa + + arr_table = pa.Table.from_pydict( + { + "uuid": [ + uuid.UUID("00000000-0000-0000-0000-000000000000").bytes, + uuid.UUID("11111111-1111-1111-1111-111111111111").bytes, + ], + }, + schema=pa.schema( + [ + # Uuid not yet supported, so we have to stick with `binary(16)` + # https://github.com/apache/arrow/issues/46468 + pa.field("uuid", pa.binary(16), nullable=False), + ] + ), + ) + + tbl = session_catalog.create_table( + identifier=identifier, + schema=schema, + partition_spec=partition_spec, + ) + + tbl.append(arr_table) + + lhs = [r[0] for r in spark.table(identifier).collect()] + rhs = [str(u.as_py()) for u in tbl.scan().to_arrow()["uuid"].combine_chunks()] + assert lhs == rhs + + @pytest.mark.integration def test_avro_compression_codecs(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: identifier = "default.test_avro_compression_codecs"