diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index ad7e4f4f85..1ce0842844 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -136,6 +136,7 @@ visit, visit_with_partner, ) +from pyiceberg.table.locations import load_location_provider from pyiceberg.table.metadata import TableMetadata from pyiceberg.table.name_mapping import NameMapping, apply_name_mapping from pyiceberg.transforms import TruncateTransform @@ -2305,6 +2306,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT property_name=TableProperties.PARQUET_ROW_GROUP_LIMIT, default=TableProperties.PARQUET_ROW_GROUP_LIMIT_DEFAULT, ) + location_provider = load_location_provider(table_location=table_metadata.location, table_properties=table_metadata.properties) def write_parquet(task: WriteTask) -> DataFile: table_schema = table_metadata.schema() @@ -2327,7 +2329,10 @@ def write_parquet(task: WriteTask) -> DataFile: for batch in task.record_batches ] arrow_table = pa.Table.from_batches(batches) - file_path = f"{table_metadata.location}/data/{task.generate_data_file_path('parquet')}" + file_path = location_provider.new_data_location( + data_file_name=task.generate_data_file_filename("parquet"), + partition_key=task.partition_key, + ) fo = io.new_output(file_path) with fo.create(overwrite=True) as fos: with pq.ParquetWriter(fos, schema=arrow_table.schema, **parquet_writer_kwargs) as writer: diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 7bc3fe838b..0c8c848c43 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -187,6 +187,14 @@ class TableProperties: WRITE_PARTITION_SUMMARY_LIMIT = "write.summary.partition-limit" WRITE_PARTITION_SUMMARY_LIMIT_DEFAULT = 0 + WRITE_PY_LOCATION_PROVIDER_IMPL = "write.py-location-provider.impl" + + OBJECT_STORE_ENABLED = "write.object-storage.enabled" + OBJECT_STORE_ENABLED_DEFAULT = False + + WRITE_OBJECT_STORE_PARTITIONED_PATHS = "write.object-storage.partitioned-paths" + WRITE_OBJECT_STORE_PARTITIONED_PATHS_DEFAULT = True + DELETE_MODE = "write.delete.mode" DELETE_MODE_COPY_ON_WRITE = "copy-on-write" DELETE_MODE_MERGE_ON_READ = "merge-on-read" @@ -1613,13 +1621,6 @@ def generate_data_file_filename(self, extension: str) -> str: # https://github.com/apache/iceberg/blob/a582968975dd30ff4917fbbe999f1be903efac02/core/src/main/java/org/apache/iceberg/io/OutputFileFactory.java#L92-L101 return f"00000-{self.task_id}-{self.write_uuid}.{extension}" - def generate_data_file_path(self, extension: str) -> str: - if self.partition_key: - file_path = f"{self.partition_key.to_path()}/{self.generate_data_file_filename(extension)}" - return file_path - else: - return self.generate_data_file_filename(extension) - @dataclass(frozen=True) class AddFileTask: diff --git a/pyiceberg/table/locations.py b/pyiceberg/table/locations.py new file mode 100644 index 0000000000..046ee32527 --- /dev/null +++ b/pyiceberg/table/locations.py @@ -0,0 +1,145 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import importlib +import logging +from abc import ABC, abstractmethod +from typing import Optional + +import mmh3 + +from pyiceberg.partitioning import PartitionKey +from pyiceberg.table import TableProperties +from pyiceberg.typedef import Properties +from pyiceberg.utils.properties import property_as_bool + +logger = logging.getLogger(__name__) + + +class LocationProvider(ABC): + """A base class for location providers, that provide data file locations for write tasks.""" + + table_location: str + table_properties: Properties + + def __init__(self, table_location: str, table_properties: Properties): + self.table_location = table_location + self.table_properties = table_properties + + @abstractmethod + def new_data_location(self, data_file_name: str, partition_key: Optional[PartitionKey] = None) -> str: + """Return a fully-qualified data file location for the given filename. + + Args: + data_file_name (str): The name of the data file. + partition_key (Optional[PartitionKey]): The data file's partition key. If None, the data is not partitioned. + + Returns: + str: A fully-qualified location URI for the data file. + """ + + +class SimpleLocationProvider(LocationProvider): + def __init__(self, table_location: str, table_properties: Properties): + super().__init__(table_location, table_properties) + + def new_data_location(self, data_file_name: str, partition_key: Optional[PartitionKey] = None) -> str: + prefix = f"{self.table_location}/data" + return f"{prefix}/{partition_key.to_path()}/{data_file_name}" if partition_key else f"{prefix}/{data_file_name}" + + +class ObjectStoreLocationProvider(LocationProvider): + HASH_BINARY_STRING_BITS = 20 + ENTROPY_DIR_LENGTH = 4 + ENTROPY_DIR_DEPTH = 3 + + _include_partition_paths: bool + + def __init__(self, table_location: str, table_properties: Properties): + super().__init__(table_location, table_properties) + self._include_partition_paths = property_as_bool( + self.table_properties, + TableProperties.WRITE_OBJECT_STORE_PARTITIONED_PATHS, + TableProperties.WRITE_OBJECT_STORE_PARTITIONED_PATHS_DEFAULT, + ) + + def new_data_location(self, data_file_name: str, partition_key: Optional[PartitionKey] = None) -> str: + if self._include_partition_paths and partition_key: + return self.new_data_location(f"{partition_key.to_path()}/{data_file_name}") + + prefix = f"{self.table_location}/data" + hashed_path = self._compute_hash(data_file_name) + + return ( + f"{prefix}/{hashed_path}/{data_file_name}" + if self._include_partition_paths + else f"{prefix}/{hashed_path}-{data_file_name}" + ) + + @staticmethod + def _compute_hash(data_file_name: str) -> str: + # Bitwise AND to combat sign-extension; bitwise OR to preserve leading zeroes that `bin` would otherwise strip. + top_mask = 1 << ObjectStoreLocationProvider.HASH_BINARY_STRING_BITS + hash_code = mmh3.hash(data_file_name) & (top_mask - 1) | top_mask + return ObjectStoreLocationProvider._dirs_from_hash(bin(hash_code)[-ObjectStoreLocationProvider.HASH_BINARY_STRING_BITS :]) + + @staticmethod + def _dirs_from_hash(file_hash: str) -> str: + """Divides hash into directories for optimized orphan removal operation using ENTROPY_DIR_DEPTH and ENTROPY_DIR_LENGTH.""" + total_entropy_length = ObjectStoreLocationProvider.ENTROPY_DIR_DEPTH * ObjectStoreLocationProvider.ENTROPY_DIR_LENGTH + + hash_with_dirs = [] + for i in range(0, total_entropy_length, ObjectStoreLocationProvider.ENTROPY_DIR_LENGTH): + hash_with_dirs.append(file_hash[i : i + ObjectStoreLocationProvider.ENTROPY_DIR_LENGTH]) + + if len(file_hash) > total_entropy_length: + hash_with_dirs.append(file_hash[total_entropy_length:]) + + return "/".join(hash_with_dirs) + + +def _import_location_provider( + location_provider_impl: str, table_location: str, table_properties: Properties +) -> Optional[LocationProvider]: + try: + path_parts = location_provider_impl.split(".") + if len(path_parts) < 2: + raise ValueError( + f"{TableProperties.WRITE_PY_LOCATION_PROVIDER_IMPL} should be full path (module.CustomLocationProvider), got: {location_provider_impl}" + ) + module_name, class_name = ".".join(path_parts[:-1]), path_parts[-1] + module = importlib.import_module(module_name) + class_ = getattr(module, class_name) + return class_(table_location, table_properties) + except ModuleNotFoundError: + logger.warning("Could not initialize LocationProvider: %s", location_provider_impl) + return None + + +def load_location_provider(table_location: str, table_properties: Properties) -> LocationProvider: + table_location = table_location.rstrip("/") + + if location_provider_impl := table_properties.get(TableProperties.WRITE_PY_LOCATION_PROVIDER_IMPL): + if location_provider := _import_location_provider(location_provider_impl, table_location, table_properties): + logger.info("Loaded LocationProvider: %s", location_provider_impl) + return location_provider + else: + raise ValueError(f"Could not initialize LocationProvider: {location_provider_impl}") + + if property_as_bool(table_properties, TableProperties.OBJECT_STORE_ENABLED, TableProperties.OBJECT_STORE_ENABLED_DEFAULT): + return ObjectStoreLocationProvider(table_location, table_properties) + else: + return SimpleLocationProvider(table_location, table_properties) diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index 8a3a5c9acc..50a1bc8c38 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -28,6 +28,7 @@ from pyiceberg.exceptions import NoSuchTableError from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema +from pyiceberg.table import TableProperties from pyiceberg.transforms import ( BucketTransform, DayTransform, @@ -280,6 +281,44 @@ def test_query_filter_v1_v2_append_null( assert df.where(f"{col} is null").count() == 2, f"Expected 2 null rows for {col}" +@pytest.mark.integration +@pytest.mark.parametrize( + "part_col", ["int", "bool", "string", "string_long", "long", "float", "double", "date", "timestamp", "timestamptz", "binary"] +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_object_storage_location_provider_excludes_partition_path( + session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, part_col: str, format_version: int +) -> None: + nested_field = TABLE_SCHEMA.find_field(part_col) + partition_spec = PartitionSpec( + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=IdentityTransform(), name=part_col) + ) + + tbl = _create_table( + session_catalog=session_catalog, + identifier=f"default.arrow_table_v{format_version}_with_null_partitioned_on_col_{part_col}", + # write.object-storage.partitioned-paths defaults to True + properties={"format-version": str(format_version), TableProperties.OBJECT_STORE_ENABLED: True}, + data=[arrow_table_with_null], + partition_spec=partition_spec, + ) + + original_paths = tbl.inspect.data_files().to_pydict()["file_path"] + assert len(original_paths) == 3 + + # Update props to exclude partitioned paths and append data + with tbl.transaction() as tx: + tx.set_properties({TableProperties.WRITE_OBJECT_STORE_PARTITIONED_PATHS: False}) + tbl.append(arrow_table_with_null) + + added_paths = set(tbl.inspect.data_files().to_pydict()["file_path"]) - set(original_paths) + assert len(added_paths) == 3 + + # All paths before the props update should contain the partition, while all paths after should not + assert all(f"{part_col}=" in path for path in original_paths) + assert all(f"{part_col}=" not in path for path in added_paths) + + @pytest.mark.integration @pytest.mark.parametrize( "spec", diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index c23e836554..fff48b9373 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -285,6 +285,33 @@ def test_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_w assert [row.deleted_data_files_count for row in rows] == [0, 1, 0, 0, 0] +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_object_storage_data_files( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + tbl = _create_table( + session_catalog=session_catalog, + identifier="default.object_stored", + properties={"format-version": format_version, TableProperties.OBJECT_STORE_ENABLED: True}, + data=[arrow_table_with_null], + ) + tbl.append(arrow_table_with_null) + + paths = tbl.inspect.data_files().to_pydict()["file_path"] + assert len(paths) == 2 + + for location in paths: + assert location.startswith("s3://warehouse/default/object_stored/data/") + parts = location.split("/") + assert len(parts) == 11 + + # Entropy binary directories should have been injected + for dir_name in parts[6:10]: + assert dir_name + assert all(c in "01" for c in dir_name) + + @pytest.mark.integration def test_python_writes_with_spark_snapshot_reads( spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table diff --git a/tests/table/test_locations.py b/tests/table/test_locations.py new file mode 100644 index 0000000000..bda2442aca --- /dev/null +++ b/tests/table/test_locations.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Optional + +import pytest + +from pyiceberg.partitioning import PartitionField, PartitionFieldValue, PartitionKey, PartitionSpec +from pyiceberg.schema import Schema +from pyiceberg.table.locations import LocationProvider, load_location_provider +from pyiceberg.transforms import IdentityTransform +from pyiceberg.typedef import EMPTY_DICT +from pyiceberg.types import NestedField, StringType + +PARTITION_FIELD = PartitionField(source_id=1, field_id=1002, transform=IdentityTransform(), name="string_field") +PARTITION_KEY = PartitionKey( + raw_partition_field_values=[PartitionFieldValue(PARTITION_FIELD, "example_string")], + partition_spec=PartitionSpec(PARTITION_FIELD), + schema=Schema(NestedField(field_id=1, name="string_field", field_type=StringType(), required=False)), +) + + +class CustomLocationProvider(LocationProvider): + def new_data_location(self, data_file_name: str, partition_key: Optional[PartitionKey] = None) -> str: + return f"custom_location_provider/{data_file_name}" + + +def test_default_location_provider() -> None: + provider = load_location_provider(table_location="table_location", table_properties=EMPTY_DICT) + + assert provider.new_data_location("my_file") == "table_location/data/my_file" + + +def test_custom_location_provider() -> None: + qualified_name = CustomLocationProvider.__module__ + "." + CustomLocationProvider.__name__ + provider = load_location_provider( + table_location="table_location", table_properties={"write.py-location-provider.impl": qualified_name} + ) + + assert provider.new_data_location("my_file") == "custom_location_provider/my_file" + + +def test_custom_location_provider_single_path() -> None: + with pytest.raises(ValueError, match=r"write\.py-location-provider\.impl should be full path"): + load_location_provider(table_location="table_location", table_properties={"write.py-location-provider.impl": "not_found"}) + + +def test_custom_location_provider_not_found() -> None: + with pytest.raises(ValueError, match=r"Could not initialize LocationProvider"): + load_location_provider( + table_location="table_location", table_properties={"write.py-location-provider.impl": "module.not_found"} + ) + + +def test_object_storage_injects_entropy() -> None: + provider = load_location_provider(table_location="table_location", table_properties={"write.object-storage.enabled": "true"}) + + location = provider.new_data_location("test.parquet") + parts = location.split("/") + + assert len(parts) == 7 + assert parts[0] == "table_location" + assert parts[1] == "data" + assert parts[-1] == "test.parquet" + + # Entropy directories in the middle + for dir_name in parts[2:-1]: + assert dir_name + assert all(c in "01" for c in dir_name) + + +@pytest.mark.parametrize("object_storage", [True, False]) +def test_partition_value_in_path(object_storage: bool) -> None: + provider = load_location_provider( + table_location="table_location", + table_properties={ + "write.object-storage.enabled": str(object_storage), + }, + ) + + location = provider.new_data_location("test.parquet", PARTITION_KEY) + partition_segment = location.split("/")[-2] + + assert partition_segment == "string_field=example_string" + + +# NB: We test here with None partition key too because disabling partitioned paths still replaces final / with - even in +# paths of un-partitioned files. This matches the behaviour of the Java implementation. +@pytest.mark.parametrize("partition_key", [PARTITION_KEY, None]) +def test_object_storage_partitioned_paths_disabled(partition_key: Optional[PartitionKey]) -> None: + provider = load_location_provider( + table_location="table_location", + table_properties={ + "write.object-storage.enabled": "true", + "write.object-storage.partitioned-paths": "false", + }, + ) + + location = provider.new_data_location("test.parquet", partition_key) + + # No partition values included in the path and last part of entropy is separated with "-" + assert location == "table_location/data/0110/1010/0011/11101000-test.parquet" + + +@pytest.mark.parametrize( + ["data_file_name", "expected_hash"], + [ + ("a", "0101/0110/1001/10110010"), + ("b", "1110/0111/1110/00000011"), + ("c", "0010/1101/0110/01011111"), + ("d", "1001/0001/0100/01110011"), + ], +) +def test_hash_injection(data_file_name: str, expected_hash: str) -> None: + provider = load_location_provider(table_location="table_location", table_properties={"write.object-storage.enabled": "true"}) + + assert provider.new_data_location(data_file_name) == f"table_location/data/{expected_hash}/{data_file_name}"