diff --git a/mkdocs/docs/configuration.md b/mkdocs/docs/configuration.md index 5346e82c25..76e1816c3a 100644 --- a/mkdocs/docs/configuration.md +++ b/mkdocs/docs/configuration.md @@ -288,6 +288,16 @@ catalog: region_name: ``` + + +| Key | Example | Description | +| ----------------- | ------------------------------------ | ------------------------------------------------------------------------------- | +| glue.id | 111111111111 | Configure the 12-digit ID of the Glue Catalog | +| glue.skip-archive | true | Configure whether to skip the archival of older table versions. Default to true | +| glue.endpoint | https://glue.us-east-1.amazonaws.com | Configure an alternative endpoint of the Glue service for GlueCatalog to access | + + + ## DynamoDB Catalog If you want to use AWS DynamoDB as the catalog, you can use the last two ways to configure the pyiceberg and refer diff --git a/mkdocs/docs/how-to-release.md b/mkdocs/docs/how-to-release.md index 99baec25ac..4824cb9994 100644 --- a/mkdocs/docs/how-to-release.md +++ b/mkdocs/docs/how-to-release.md @@ -23,6 +23,21 @@ The guide to release PyIceberg. The first step is to publish a release candidate (RC) and publish it to the public for testing and validation. Once the vote has passed on the RC, the RC turns into the new release. +## Preparing for a release + +Before running the release candidate, we want to remove any APIs that were marked for removal under the @deprecated tag for this release. + +For example, the API with the following deprecation tag should be removed when preparing for the 0.2.0 release. + +```python + +@deprecated( + deprecated_in="0.1.0", + removed_in="0.2.0", + help_message="Please use load_something_else() instead", +) +``` + ## Running a release candidate Make sure that the version is correct in `pyproject.toml` and `pyiceberg/__init__.py`. Correct means that it reflects the version that you want to release. diff --git a/pyiceberg/catalog/glue.py b/pyiceberg/catalog/glue.py index b5ad85768a..26b487f507 100644 --- a/pyiceberg/catalog/glue.py +++ b/pyiceberg/catalog/glue.py @@ -116,6 +116,10 @@ GLUE_SKIP_ARCHIVE = "glue.skip-archive" GLUE_SKIP_ARCHIVE_DEFAULT = True +# Configure an alternative endpoint of the Glue service for GlueCatalog to access. +# This could be used to use GlueCatalog with any glue-compatible metastore service that has a different endpoint +GLUE_CATALOG_ENDPOINT = "glue.endpoint" + ICEBERG_FIELD_ID = "iceberg.field.id" ICEBERG_FIELD_OPTIONAL = "iceberg.field.optional" ICEBERG_FIELD_CURRENT = "iceberg.field.current" @@ -313,7 +317,7 @@ def __init__(self, name: str, **properties: Any): properties, GLUE_SESSION_TOKEN, AWS_SESSION_TOKEN, DEPRECATED_SESSION_TOKEN ), ) - self.glue: GlueClient = session.client("glue") + self.glue: GlueClient = session.client("glue", endpoint_url=properties.get(GLUE_CATALOG_ENDPOINT)) if glue_catalog_id := properties.get(GLUE_ID): _register_glue_catalog_id_with_glue_client(self.glue, glue_catalog_id) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 80d20c0a99..62b887bc47 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -160,6 +160,7 @@ from pyiceberg.utils.concurrent import ExecutorFactory from pyiceberg.utils.config import Config from pyiceberg.utils.datetime import millis_to_datetime +from pyiceberg.utils.deprecated import deprecated from pyiceberg.utils.singleton import Singleton from pyiceberg.utils.truncate import truncate_upper_bound_binary_string, truncate_upper_bound_text_string @@ -178,6 +179,7 @@ MAP_KEY_NAME = "key" MAP_VALUE_NAME = "value" DOC = "doc" +UTC_ALIASES = {"UTC", "+00:00", "Etc/UTC", "Z"} T = TypeVar("T") @@ -943,7 +945,7 @@ def primitive(self, primitive: pa.DataType) -> PrimitiveType: else: raise TypeError(f"Unsupported precision for timestamp type: {primitive.unit}") - if primitive.tz == "UTC" or primitive.tz == "+00:00": + if primitive.tz in UTC_ALIASES: return TimestamptzType() elif primitive.tz is None: return TimestampType() @@ -1079,7 +1081,7 @@ def _task_to_record_batches( arrow_table = pa.Table.from_batches([batch]) arrow_table = arrow_table.filter(pyarrow_filter) batch = arrow_table.to_batches()[0] - yield to_requested_schema(projected_schema, file_project_schema, batch, downcast_ns_timestamp_to_us=True) + yield _to_requested_schema(projected_schema, file_project_schema, batch, downcast_ns_timestamp_to_us=True) current_index += len(batch) @@ -1284,7 +1286,24 @@ def project_batches( total_row_count += len(batch) -def to_requested_schema( +@deprecated( + deprecated_in="0.7.0", + removed_in="0.8.0", + help_message="The public API for 'to_requested_schema' is deprecated and is replaced by '_to_requested_schema'", +) +def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.Table) -> pa.Table: + struct_array = visit_with_partner(requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema)) + + arrays = [] + fields = [] + for pos, field in enumerate(requested_schema.fields): + array = struct_array.field(pos) + arrays.append(array) + fields.append(pa.field(field.name, array.type, field.optional)) + return pa.Table.from_arrays(arrays, schema=pa.schema(fields)) + + +def _to_requested_schema( requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch, @@ -1302,16 +1321,17 @@ def to_requested_schema( class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]): - file_schema: Schema + _file_schema: Schema _include_field_ids: bool + _downcast_ns_timestamp_to_us: bool def __init__(self, file_schema: Schema, downcast_ns_timestamp_to_us: bool = False, include_field_ids: bool = False) -> None: - self.file_schema = file_schema + self._file_schema = file_schema self._include_field_ids = include_field_ids - self.downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us + self._downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array: - file_field = self.file_schema.find_field(field.field_id) + file_field = self._file_schema.find_field(field.field_id) if field.field_type.is_primitive: if field.field_type != file_field.field_type: @@ -1319,14 +1339,31 @@ def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array: schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=self._include_field_ids) ) elif (target_type := schema_to_pyarrow(field.field_type, include_field_ids=self._include_field_ids)) != values.type: - # Downcasting of nanoseconds to microseconds - if ( - pa.types.is_timestamp(target_type) - and target_type.unit == "us" - and pa.types.is_timestamp(values.type) - and values.type.unit == "ns" - ): - return values.cast(target_type, safe=False) + if field.field_type == TimestampType(): + # Downcasting of nanoseconds to microseconds + if ( + pa.types.is_timestamp(target_type) + and not target_type.tz + and pa.types.is_timestamp(values.type) + and not values.type.tz + ): + if target_type.unit == "us" and values.type.unit == "ns" and self._downcast_ns_timestamp_to_us: + return values.cast(target_type, safe=False) + elif target_type.unit == "us" and values.type.unit in {"s", "ms"}: + return values.cast(target_type) + raise ValueError(f"Unsupported schema projection from {values.type} to {target_type}") + elif field.field_type == TimestamptzType(): + if ( + pa.types.is_timestamp(target_type) + and target_type.tz == "UTC" + and pa.types.is_timestamp(values.type) + and values.type.tz in UTC_ALIASES + ): + if target_type.unit == "us" and values.type.unit == "ns" and self._downcast_ns_timestamp_to_us: + return values.cast(target_type, safe=False) + elif target_type.unit == "us" and values.type.unit in {"s", "ms", "us"}: + return values.cast(target_type) + raise ValueError(f"Unsupported schema projection from {values.type} to {target_type}") return values def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field: @@ -1421,6 +1458,8 @@ def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: st if isinstance(partner_struct, pa.StructArray): return partner_struct.field(name) + elif isinstance(partner_struct, pa.Table): + return partner_struct.column(name).combine_chunks() elif isinstance(partner_struct, pa.RecordBatch): return partner_struct.column(name) else: @@ -1882,6 +1921,7 @@ def data_file_statistics_from_parquet_metadata( col_aggs = {} + invalidate_col: Set[int] = set() for r in range(parquet_metadata.num_row_groups): # References: # https://github.com/apache/iceberg/blob/fc381a81a1fdb8f51a0637ca27cd30673bd7aad3/parquet/src/main/java/org/apache/iceberg/parquet/ParquetUtil.java#L232 @@ -1897,8 +1937,6 @@ def data_file_statistics_from_parquet_metadata( else: split_offsets.append(data_offset) - invalidate_col: Set[int] = set() - for pos in range(parquet_metadata.num_columns): column = row_group.column(pos) field_id = parquet_column_mapping[column.path_in_schema] @@ -1977,7 +2015,7 @@ def write_parquet(task: WriteTask) -> DataFile: downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False batches = [ - to_requested_schema( + _to_requested_schema( requested_schema=file_schema, file_schema=table_schema, batch=batch, diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index e0c5ac3670..4164280a24 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -491,10 +491,6 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) _check_schema_compatible( self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us ) - # cast if the two schemas are compatible but not equal - table_arrow_schema = self._table.schema().as_arrow() - if table_arrow_schema != df.schema: - df = df.cast(table_arrow_schema) manifest_merge_enabled = PropertyUtil.property_as_bool( self.table_metadata.properties, @@ -552,10 +548,6 @@ def overwrite( _check_schema_compatible( self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us ) - # cast if the two schemas are compatible but not equal - table_arrow_schema = self._table.schema().as_arrow() - if table_arrow_schema != df.schema: - df = df.cast(table_arrow_schema) self.delete(delete_filter=overwrite_filter, snapshot_properties=snapshot_properties) diff --git a/pyiceberg/types.py b/pyiceberg/types.py index cd662c7387..97ddea0e57 100644 --- a/pyiceberg/types.py +++ b/pyiceberg/types.py @@ -67,7 +67,7 @@ def transform_dict_value_to_str(dict: Dict[str, Any]) -> Dict[str, str]: for key, value in dict.items(): if value is None: raise ValueError(f"None type is not a supported value in properties: {key}") - return {k: str(v) for k, v in dict.items()} + return {k: str(v).lower() if isinstance(v, bool) else str(v) for k, v in dict.items()} def _parse_decimal_type(decimal: Any) -> Tuple[int, int]: diff --git a/tests/catalog/integration_test_glue.py b/tests/catalog/integration_test_glue.py index c69bc86ca8..a5293e38f2 100644 --- a/tests/catalog/integration_test_glue.py +++ b/tests/catalog/integration_test_glue.py @@ -25,7 +25,7 @@ from botocore.exceptions import ClientError from pyiceberg.catalog import Catalog, MetastoreCatalog -from pyiceberg.catalog.glue import GlueCatalog +from pyiceberg.catalog.glue import GLUE_CATALOG_ENDPOINT, GlueCatalog from pyiceberg.exceptions import ( NamespaceAlreadyExistsError, NamespaceNotEmptyError, @@ -36,7 +36,7 @@ from pyiceberg.io.pyarrow import _dataframe_to_data_files, schema_to_pyarrow from pyiceberg.schema import Schema from pyiceberg.types import IntegerType -from tests.conftest import clean_up, get_bucket_name, get_s3_path +from tests.conftest import clean_up, get_bucket_name, get_glue_endpoint, get_s3_path # The number of tables/databases used in list_table/namespace test LIST_TEST_NUMBER = 2 @@ -51,7 +51,9 @@ def fixture_glue_client() -> boto3.client: @pytest.fixture(name="test_catalog", scope="module") def fixture_test_catalog() -> Generator[Catalog, None, None]: """Configure the pre- and post-setting of aws integration test.""" - test_catalog = GlueCatalog(CATALOG_NAME, warehouse=get_s3_path(get_bucket_name())) + test_catalog = GlueCatalog( + CATALOG_NAME, **{"warehouse": get_s3_path(get_bucket_name()), GLUE_CATALOG_ENDPOINT: get_glue_endpoint()} + ) yield test_catalog clean_up(test_catalog) diff --git a/tests/catalog/test_glue.py b/tests/catalog/test_glue.py index d43f5a3866..5ab9966a61 100644 --- a/tests/catalog/test_glue.py +++ b/tests/catalog/test_glue.py @@ -925,3 +925,13 @@ def test_register_table_with_given_location( table = test_catalog.register_table(identifier, location) assert table.identifier == (catalog_name,) + identifier assert test_catalog.table_exists(identifier) is True + + +@mock_aws +def test_glue_endpoint_override(_bucket_initialize: None, moto_endpoint_url: str, database_name: str) -> None: + catalog_name = "glue" + test_endpoint = "https://test-endpoint" + test_catalog = GlueCatalog( + catalog_name, **{"s3.endpoint": moto_endpoint_url, "warehouse": f"s3://{BUCKET_NAME}", "glue.endpoint": test_endpoint} + ) + assert test_catalog.glue.meta.endpoint_url == test_endpoint diff --git a/tests/cli/test_console.py b/tests/cli/test_console.py index 92a7f80c7d..e55ff9a9ad 100644 --- a/tests/cli/test_console.py +++ b/tests/cli/test_console.py @@ -83,7 +83,7 @@ def mock_datetime_now(monkeypatch: pytest.MonkeyPatch) -> None: NestedField(3, "z", LongType(), required=True), ) TEST_TABLE_PARTITION_SPEC = PartitionSpec(PartitionField(name="x", transform=IdentityTransform(), source_id=1, field_id=1000)) -TEST_TABLE_PROPERTIES = {"read.split.target.size": "134217728"} +TEST_TABLE_PROPERTIES = {"read.split.target.size": "134217728", "write.parquet.bloom-filter-enabled.column.x": True} TEST_TABLE_UUID = uuid.UUID("d20125c8-7284-442c-9aea-15fee620737c") TEST_TIMESTAMP = 1602638573874 MOCK_ENVIRONMENT = {"PYICEBERG_CATALOG__PRODUCTION__URI": "test://doesnotexist"} @@ -367,7 +367,10 @@ def test_properties_get_table(catalog: InMemoryCatalog) -> None: runner = CliRunner() result = runner.invoke(run, ["properties", "get", "table", "default.my_table"]) assert result.exit_code == 0 - assert result.output == "read.split.target.size 134217728\n" + assert ( + result.output + == "read.split.target.size 134217728\nwrite.parquet.bloom-filter-enabled.column.x true \n" + ) def test_properties_get_table_specific_property(catalog: InMemoryCatalog) -> None: @@ -763,7 +766,7 @@ def test_json_properties_get_table(catalog: InMemoryCatalog) -> None: runner = CliRunner() result = runner.invoke(run, ["--output=json", "properties", "get", "table", "default.my_table"]) assert result.exit_code == 0 - assert result.output == """{"read.split.target.size": "134217728"}\n""" + assert result.output == """{"read.split.target.size": "134217728", "write.parquet.bloom-filter-enabled.column.x": "true"}\n""" def test_json_properties_get_table_specific_property(catalog: InMemoryCatalog) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index 95e1128af6..91ab8f2e56 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2043,6 +2043,11 @@ def get_bucket_name() -> str: return bucket_name +def get_glue_endpoint() -> Optional[str]: + """Set the optional environment variable AWS_TEST_GLUE_ENDPOINT for a glue endpoint to test.""" + return os.getenv("AWS_TEST_GLUE_ENDPOINT") + + def get_s3_path(bucket_name: str, database_name: Optional[str] = None, table_name: Optional[str] = None) -> str: result_path = f"s3://{bucket_name}" if database_name is not None: @@ -2382,10 +2387,122 @@ def arrow_table_date_timestamps() -> "pa.Table": @pytest.fixture(scope="session") -def arrow_table_date_timestamps_schema() -> Schema: - """Pyarrow table Schema with only date, timestamp and timestamptz values.""" +def table_date_timestamps_schema() -> Schema: + """Iceberg table Schema with only date, timestamp and timestamptz values.""" return Schema( NestedField(field_id=1, name="date", field_type=DateType(), required=False), NestedField(field_id=2, name="timestamp", field_type=TimestampType(), required=False), NestedField(field_id=3, name="timestamptz", field_type=TimestamptzType(), required=False), ) + + +@pytest.fixture(scope="session") +def arrow_table_schema_with_all_timestamp_precisions() -> "pa.Schema": + """Pyarrow Schema with all supported timestamp types.""" + import pyarrow as pa + + return pa.schema([ + ("timestamp_s", pa.timestamp(unit="s")), + ("timestamptz_s", pa.timestamp(unit="s", tz="UTC")), + ("timestamp_ms", pa.timestamp(unit="ms")), + ("timestamptz_ms", pa.timestamp(unit="ms", tz="UTC")), + ("timestamp_us", pa.timestamp(unit="us")), + ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")), + ("timestamp_ns", pa.timestamp(unit="ns")), + ("timestamptz_ns", pa.timestamp(unit="ns", tz="UTC")), + ("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="Etc/UTC")), + ("timestamptz_ns_z", pa.timestamp(unit="ns", tz="Z")), + ("timestamptz_s_0000", pa.timestamp(unit="s", tz="+00:00")), + ]) + + +@pytest.fixture(scope="session") +def arrow_table_with_all_timestamp_precisions(arrow_table_schema_with_all_timestamp_precisions: "pa.Schema") -> "pa.Table": + """Pyarrow table with all supported timestamp types.""" + import pandas as pd + import pyarrow as pa + + test_data = pd.DataFrame({ + "timestamp_s": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], + "timestamptz_s": [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), + ], + "timestamp_ms": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], + "timestamptz_ms": [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), + ], + "timestamp_us": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], + "timestamptz_us": [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), + ], + "timestamp_ns": [ + pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=6), + None, + pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=7), + ], + "timestamptz_ns": [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), + ], + "timestamptz_us_etc_utc": [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), + ], + "timestamptz_ns_z": [ + pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=6, tz="UTC"), + None, + pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=7, tz="UTC"), + ], + "timestamptz_s_0000": [ + datetime(2023, 1, 1, 19, 25, 1, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 1, tzinfo=timezone.utc), + ], + }) + return pa.Table.from_pandas(test_data, schema=arrow_table_schema_with_all_timestamp_precisions) + + +@pytest.fixture(scope="session") +def arrow_table_schema_with_all_microseconds_timestamp_precisions() -> "pa.Schema": + """Pyarrow Schema with all microseconds timestamp.""" + import pyarrow as pa + + return pa.schema([ + ("timestamp_s", pa.timestamp(unit="us")), + ("timestamptz_s", pa.timestamp(unit="us", tz="UTC")), + ("timestamp_ms", pa.timestamp(unit="us")), + ("timestamptz_ms", pa.timestamp(unit="us", tz="UTC")), + ("timestamp_us", pa.timestamp(unit="us")), + ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")), + ("timestamp_ns", pa.timestamp(unit="us")), + ("timestamptz_ns", pa.timestamp(unit="us", tz="UTC")), + ("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="UTC")), + ("timestamptz_ns_z", pa.timestamp(unit="us", tz="UTC")), + ("timestamptz_s_0000", pa.timestamp(unit="us", tz="UTC")), + ]) + + +@pytest.fixture(scope="session") +def table_schema_with_all_microseconds_timestamp_precision() -> Schema: + """Iceberg table Schema with only date, timestamp and timestamptz values.""" + return Schema( + NestedField(field_id=1, name="timestamp_s", field_type=TimestampType(), required=False), + NestedField(field_id=2, name="timestamptz_s", field_type=TimestamptzType(), required=False), + NestedField(field_id=3, name="timestamp_ms", field_type=TimestampType(), required=False), + NestedField(field_id=4, name="timestamptz_ms", field_type=TimestamptzType(), required=False), + NestedField(field_id=5, name="timestamp_us", field_type=TimestampType(), required=False), + NestedField(field_id=6, name="timestamptz_us", field_type=TimestamptzType(), required=False), + NestedField(field_id=7, name="timestamp_ns", field_type=TimestampType(), required=False), + NestedField(field_id=8, name="timestamptz_ns", field_type=TimestamptzType(), required=False), + NestedField(field_id=9, name="timestamptz_us_etc_utc", field_type=TimestamptzType(), required=False), + NestedField(field_id=10, name="timestamptz_ns_z", field_type=TimestamptzType(), required=False), + NestedField(field_id=11, name="timestamptz_s_0000", field_type=TimestamptzType(), required=False), + ) diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index 12da9c928b..b199f00210 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -461,7 +461,7 @@ def test_append_transform_partition_verify_partitions_count( session_catalog: Catalog, spark: SparkSession, arrow_table_date_timestamps: pa.Table, - arrow_table_date_timestamps_schema: Schema, + table_date_timestamps_schema: Schema, transform: Transform[Any, Any], expected_partitions: Set[Any], format_version: int, @@ -469,7 +469,7 @@ def test_append_transform_partition_verify_partitions_count( # Given part_col = "timestamptz" identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_transform_partitioned_on_col_{part_col}" - nested_field = arrow_table_date_timestamps_schema.find_field(part_col) + nested_field = table_date_timestamps_schema.find_field(part_col) partition_spec = PartitionSpec( PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col), ) @@ -481,7 +481,7 @@ def test_append_transform_partition_verify_partitions_count( properties={"format-version": str(format_version)}, data=[arrow_table_date_timestamps], partition_spec=partition_spec, - schema=arrow_table_date_timestamps_schema, + schema=table_date_timestamps_schema, ) # Then @@ -510,20 +510,20 @@ def test_append_multiple_partitions( session_catalog: Catalog, spark: SparkSession, arrow_table_date_timestamps: pa.Table, - arrow_table_date_timestamps_schema: Schema, + table_date_timestamps_schema: Schema, format_version: int, ) -> None: # Given identifier = f"default.arrow_table_v{format_version}_with_multiple_partitions" partition_spec = PartitionSpec( PartitionField( - source_id=arrow_table_date_timestamps_schema.find_field("date").field_id, + source_id=table_date_timestamps_schema.find_field("date").field_id, field_id=1001, transform=YearTransform(), name="date_year", ), PartitionField( - source_id=arrow_table_date_timestamps_schema.find_field("timestamptz").field_id, + source_id=table_date_timestamps_schema.find_field("timestamptz").field_id, field_id=1000, transform=HourTransform(), name="timestamptz_hour", @@ -537,7 +537,7 @@ def test_append_multiple_partitions( properties={"format-version": str(format_version)}, data=[arrow_table_date_timestamps], partition_spec=partition_spec, - schema=arrow_table_date_timestamps_schema, + schema=table_date_timestamps_schema, ) # Then diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index af626718f7..41bc6fb5bf 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -18,11 +18,12 @@ import math import os import time -from datetime import date, datetime, timezone +from datetime import date, datetime from pathlib import Path from typing import Any, Dict from urllib.parse import urlparse +import pandas as pd import pyarrow as pa import pyarrow.parquet as pq import pytest @@ -977,69 +978,43 @@ def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) -def test_write_all_timestamp_precision(mocker: MockerFixture, session_catalog: Catalog, format_version: int) -> None: +def test_write_all_timestamp_precision( + mocker: MockerFixture, + spark: SparkSession, + session_catalog: Catalog, + format_version: int, + arrow_table_schema_with_all_timestamp_precisions: pa.Schema, + arrow_table_with_all_timestamp_precisions: pa.Table, + arrow_table_schema_with_all_microseconds_timestamp_precisions: pa.Schema, +) -> None: identifier = "default.table_all_timestamp_precision" - arrow_table_schema_with_all_timestamp_precisions = pa.schema([ - ("timestamp_s", pa.timestamp(unit="s")), - ("timestamptz_s", pa.timestamp(unit="s", tz="UTC")), - ("timestamp_ms", pa.timestamp(unit="ms")), - ("timestamptz_ms", pa.timestamp(unit="ms", tz="UTC")), - ("timestamp_us", pa.timestamp(unit="us")), - ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")), - ("timestamp_ns", pa.timestamp(unit="ns")), - ("timestamptz_ns", pa.timestamp(unit="ns", tz="UTC")), - ]) - TEST_DATA_WITH_NULL = { - "timestamp_s": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], - "timestamptz_s": [ - datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), - None, - datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), - ], - "timestamp_ms": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], - "timestamptz_ms": [ - datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), - None, - datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), - ], - "timestamp_us": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], - "timestamptz_us": [ - datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), - None, - datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), - ], - "timestamp_ns": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], - "timestamptz_ns": [ - datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), - None, - datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), - ], - } - input_arrow_table = pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=arrow_table_schema_with_all_timestamp_precisions) mocker.patch.dict(os.environ, values={"PYICEBERG_DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE": "True"}) tbl = _create_table( session_catalog, identifier, {"format-version": format_version}, - data=[input_arrow_table], + data=[arrow_table_with_all_timestamp_precisions], schema=arrow_table_schema_with_all_timestamp_precisions, ) - tbl.overwrite(input_arrow_table) + tbl.overwrite(arrow_table_with_all_timestamp_precisions) written_arrow_table = tbl.scan().to_arrow() - expected_schema_in_all_us = pa.schema([ - ("timestamp_s", pa.timestamp(unit="us")), - ("timestamptz_s", pa.timestamp(unit="us", tz="UTC")), - ("timestamp_ms", pa.timestamp(unit="us")), - ("timestamptz_ms", pa.timestamp(unit="us", tz="UTC")), - ("timestamp_us", pa.timestamp(unit="us")), - ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")), - ("timestamp_ns", pa.timestamp(unit="us")), - ("timestamptz_ns", pa.timestamp(unit="us", tz="UTC")), - ]) - assert written_arrow_table.schema == expected_schema_in_all_us - assert written_arrow_table == input_arrow_table.cast(expected_schema_in_all_us) + assert written_arrow_table.schema == arrow_table_schema_with_all_microseconds_timestamp_precisions + assert written_arrow_table == arrow_table_with_all_timestamp_precisions.cast( + arrow_table_schema_with_all_microseconds_timestamp_precisions, safe=False + ) + lhs = spark.table(f"{identifier}").toPandas() + rhs = written_arrow_table.to_pandas() + + for column in written_arrow_table.column_names: + for left, right in zip(lhs[column].to_list(), rhs[column].to_list()): + if pd.isnull(left): + assert pd.isnull(right) + else: + # Check only upto microsecond precision since Spark loaded dtype is timezone unaware + # and supports upto microsecond precision + assert left.timestamp() == right.timestamp(), f"Difference in column {column}: {left} != {right}" @pytest.mark.integration diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 326eeff195..37198b7edb 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -65,6 +65,7 @@ _determine_partitions, _primitive_to_physical, _read_deletes, + _to_requested_schema, bin_pack_arrow_table, expression_to_pyarrow, project_table, @@ -1889,3 +1890,35 @@ def test_identity_partition_on_multi_columns() -> None: ("n_legs", "ascending"), ("animal", "ascending"), ]) == arrow_table.sort_by([("born_year", "ascending"), ("n_legs", "ascending"), ("animal", "ascending")]) + + +def test__to_requested_schema_timestamps( + arrow_table_schema_with_all_timestamp_precisions: pa.Schema, + arrow_table_with_all_timestamp_precisions: pa.Table, + arrow_table_schema_with_all_microseconds_timestamp_precisions: pa.Schema, + table_schema_with_all_microseconds_timestamp_precision: Schema, +) -> None: + requested_schema = table_schema_with_all_microseconds_timestamp_precision + file_schema = requested_schema + batch = arrow_table_with_all_timestamp_precisions.to_batches()[0] + result = _to_requested_schema(requested_schema, file_schema, batch, downcast_ns_timestamp_to_us=True, include_field_ids=False) + + expected = arrow_table_with_all_timestamp_precisions.cast( + arrow_table_schema_with_all_microseconds_timestamp_precisions, safe=False + ).to_batches()[0] + assert result == expected + + +def test__to_requested_schema_timestamps_without_downcast_raises_exception( + arrow_table_schema_with_all_timestamp_precisions: pa.Schema, + arrow_table_with_all_timestamp_precisions: pa.Table, + arrow_table_schema_with_all_microseconds_timestamp_precisions: pa.Schema, + table_schema_with_all_microseconds_timestamp_precision: Schema, +) -> None: + requested_schema = table_schema_with_all_microseconds_timestamp_precision + file_schema = requested_schema + batch = arrow_table_with_all_timestamp_precisions.to_batches()[0] + with pytest.raises(ValueError) as exc_info: + _to_requested_schema(requested_schema, file_schema, batch, downcast_ns_timestamp_to_us=False, include_field_ids=False) + + assert "Unsupported schema projection from timestamp[ns] to timestamp[us]" in str(exc_info.value) diff --git a/tests/test_types.py b/tests/test_types.py index 1e386bb748..52bdce4de8 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -44,6 +44,7 @@ TimeType, UUIDType, strtobool, + transform_dict_value_to_str, ) non_parameterized_types = [ @@ -649,3 +650,14 @@ def test_strtobool() -> None: for val in invalid_values: with pytest.raises(ValueError, match=f"Invalid truth value: {val!r}"): strtobool(val) + + +def test_transform_dict_value_to_str() -> None: + input_dict = {"key1": 1, "key2": 2.0, "key3": "3", "key4: ": True, "key5": False} + expected_dict = {"key1": "1", "key2": "2.0", "key3": "3", "key4: ": "true", "key5": "false"} + # valid values + assert transform_dict_value_to_str(input_dict) == expected_dict + # Null value not allowed, should raise ValueError + input_dict["key6"] = None + with pytest.raises(ValueError, match="None type is not a supported value in properties: key6"): + transform_dict_value_to_str(input_dict)