diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 3eb74eee1f..cdd132e8cb 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -268,12 +268,10 @@ def _apply(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequ return self - def _scan(self, row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE) -> DataScan: + def _scan(self, row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE, case_sensitive: bool = True) -> DataScan: """Minimal data scan of the table with the current state of the transaction.""" return DataScan( - table_metadata=self.table_metadata, - io=self._table.io, - row_filter=row_filter, + table_metadata=self.table_metadata, io=self._table.io, row_filter=row_filter, case_sensitive=case_sensitive ) def upgrade_table_version(self, format_version: TableVersion) -> Transaction: @@ -421,6 +419,7 @@ def overwrite( self, df: pa.Table, overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE, + case_sensitive: bool = True, snapshot_properties: Dict[str, str] = EMPTY_DICT, ) -> None: """ @@ -436,6 +435,7 @@ def overwrite( df: The Arrow dataframe that will be used to overwrite the table overwrite_filter: ALWAYS_TRUE when you overwrite all the data, or a boolean expression in case of a partial overwrite + case_sensitive: A bool determine if the provided `overwrite_filter` is case-sensitive snapshot_properties: Custom properties to be added to the snapshot summary """ try: @@ -459,7 +459,7 @@ def overwrite( self.table_metadata.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us ) - self.delete(delete_filter=overwrite_filter, snapshot_properties=snapshot_properties) + self.delete(delete_filter=overwrite_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties) with self.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot: # skip writing data files if the dataframe is empty @@ -470,17 +470,23 @@ def overwrite( for data_file in data_files: update_snapshot.append_data_file(data_file) - def delete(self, delete_filter: Union[str, BooleanExpression], snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: + def delete( + self, + delete_filter: Union[str, BooleanExpression], + case_sensitive: bool = True, + snapshot_properties: Dict[str, str] = EMPTY_DICT, + ) -> None: """ Shorthand for deleting record from a table. - An deletee may produce zero or more snapshots based on the operation: + A delete may produce zero or more snapshots based on the operation: - DELETE: In case existing Parquet files can be dropped completely. - REPLACE: In case existing Parquet files need to be rewritten Args: delete_filter: A boolean expression to delete rows from a table + case_sensitive: A bool determine if the provided `delete_filter` is case-sensitive snapshot_properties: Custom properties to be added to the snapshot summary """ from pyiceberg.io.pyarrow import ( @@ -499,14 +505,14 @@ def delete(self, delete_filter: Union[str, BooleanExpression], snapshot_properti delete_filter = _parse_row_filter(delete_filter) with self.update_snapshot(snapshot_properties=snapshot_properties).delete() as delete_snapshot: - delete_snapshot.delete_by_predicate(delete_filter) + delete_snapshot.delete_by_predicate(delete_filter, case_sensitive) # Check if there are any files that require an actual rewrite of a data file if delete_snapshot.rewrites_needed is True: - bound_delete_filter = bind(self.table_metadata.schema(), delete_filter, case_sensitive=True) + bound_delete_filter = bind(self.table_metadata.schema(), delete_filter, case_sensitive) preserve_row_filter = _expression_to_complementary_pyarrow(bound_delete_filter) - files = self._scan(row_filter=delete_filter).plan_files() + files = self._scan(row_filter=delete_filter, case_sensitive=case_sensitive).plan_files() commit_uuid = uuid.uuid4() counter = itertools.count(0) @@ -987,6 +993,7 @@ def overwrite( self, df: pa.Table, overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE, + case_sensitive: bool = True, snapshot_properties: Dict[str, str] = EMPTY_DICT, ) -> None: """ @@ -1002,23 +1009,30 @@ def overwrite( df: The Arrow dataframe that will be used to overwrite the table overwrite_filter: ALWAYS_TRUE when you overwrite all the data, or a boolean expression in case of a partial overwrite + case_sensitive: A bool determine if the provided `overwrite_filter` is case-sensitive snapshot_properties: Custom properties to be added to the snapshot summary """ with self.transaction() as tx: - tx.overwrite(df=df, overwrite_filter=overwrite_filter, snapshot_properties=snapshot_properties) + tx.overwrite( + df=df, overwrite_filter=overwrite_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties + ) def delete( - self, delete_filter: Union[BooleanExpression, str] = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT + self, + delete_filter: Union[BooleanExpression, str] = ALWAYS_TRUE, + case_sensitive: bool = True, + snapshot_properties: Dict[str, str] = EMPTY_DICT, ) -> None: """ Shorthand for deleting rows from the table. Args: delete_filter: The predicate that used to remove rows + case_sensitive: A bool determine if the provided `delete_filter` is case-sensitive snapshot_properties: Custom properties to be added to the snapshot summary """ with self.transaction() as tx: - tx.delete(delete_filter=delete_filter, snapshot_properties=snapshot_properties) + tx.delete(delete_filter=delete_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties) def add_files( self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True @@ -1311,7 +1325,7 @@ def _match_deletes_to_data_file(data_entry: ManifestEntry, positional_delete_ent class DataScan(TableScan): def _build_partition_projection(self, spec_id: int) -> BooleanExpression: - project = inclusive_projection(self.table_metadata.schema(), self.table_metadata.specs()[spec_id]) + project = inclusive_projection(self.table_metadata.schema(), self.table_metadata.specs()[spec_id], self.case_sensitive) return project(self.row_filter) @cached_property diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index 47e5fc55e3..c0d0056e7c 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -318,6 +318,7 @@ class _DeleteFiles(_SnapshotProducer["_DeleteFiles"]): """ _predicate: BooleanExpression + _case_sensitive: bool def __init__( self, @@ -329,6 +330,7 @@ def __init__( ): super().__init__(operation, transaction, io, commit_uuid, snapshot_properties) self._predicate = AlwaysFalse() + self._case_sensitive = True def _commit(self) -> UpdatesAndRequirements: # Only produce a commit when there is something to delete @@ -340,7 +342,7 @@ def _commit(self) -> UpdatesAndRequirements: def _build_partition_projection(self, spec_id: int) -> BooleanExpression: schema = self._transaction.table_metadata.schema() spec = self._transaction.table_metadata.specs()[spec_id] - project = inclusive_projection(schema, spec) + project = inclusive_projection(schema, spec, self._case_sensitive) return project(self._predicate) @cached_property @@ -350,10 +352,11 @@ def partition_filters(self) -> KeyDefaultDict[int, BooleanExpression]: def _build_manifest_evaluator(self, spec_id: int) -> Callable[[ManifestFile], bool]: schema = self._transaction.table_metadata.schema() spec = self._transaction.table_metadata.specs()[spec_id] - return manifest_evaluator(spec, schema, self.partition_filters[spec_id], case_sensitive=True) + return manifest_evaluator(spec, schema, self.partition_filters[spec_id], self._case_sensitive) - def delete_by_predicate(self, predicate: BooleanExpression) -> None: + def delete_by_predicate(self, predicate: BooleanExpression, case_sensitive: bool = True) -> None: self._predicate = Or(self._predicate, predicate) + self._case_sensitive = case_sensitive @cached_property def _compute_deletes(self) -> Tuple[List[ManifestFile], List[ManifestEntry], bool]: @@ -376,8 +379,10 @@ def _copy_with_new_status(entry: ManifestEntry, status: ManifestEntryStatus) -> ) manifest_evaluators: Dict[int, Callable[[ManifestFile], bool]] = KeyDefaultDict(self._build_manifest_evaluator) - strict_metrics_evaluator = _StrictMetricsEvaluator(schema, self._predicate, case_sensitive=True).eval - inclusive_metrics_evaluator = _InclusiveMetricsEvaluator(schema, self._predicate, case_sensitive=True).eval + strict_metrics_evaluator = _StrictMetricsEvaluator(schema, self._predicate, case_sensitive=self._case_sensitive).eval + inclusive_metrics_evaluator = _InclusiveMetricsEvaluator( + schema, self._predicate, case_sensitive=self._case_sensitive + ).eval existing_manifests = [] total_deleted_entries = [] diff --git a/tests/integration/test_deletes.py b/tests/integration/test_deletes.py index 2cdf9916ee..bee9544458 100644 --- a/tests/integration/test_deletes.py +++ b/tests/integration/test_deletes.py @@ -16,7 +16,7 @@ # under the License. # pylint:disable=redefined-outer-name from datetime import datetime -from typing import List +from typing import Generator, List import pyarrow as pa import pytest @@ -28,9 +28,10 @@ from pyiceberg.manifest import ManifestEntryStatus from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema +from pyiceberg.table import Table from pyiceberg.table.snapshots import Operation, Summary from pyiceberg.transforms import IdentityTransform -from pyiceberg.types import FloatType, IntegerType, LongType, NestedField, TimestampType +from pyiceberg.types import FloatType, IntegerType, LongType, NestedField, StringType, TimestampType def run_spark_commands(spark: SparkSession, sqls: List[str]) -> None: @@ -38,6 +39,24 @@ def run_spark_commands(spark: SparkSession, sqls: List[str]) -> None: spark.sql(sql) +@pytest.fixture() +def test_table(session_catalog: RestCatalog) -> Generator[Table, None, None]: + identifier = "default.__test_table" + arrow_table = pa.Table.from_arrays([pa.array([1, 2, 3, 4, 5]), pa.array(["a", "b", "c", "d", "e"])], names=["idx", "value"]) + test_table = session_catalog.create_table( + identifier, + schema=Schema( + NestedField(1, "idx", LongType()), + NestedField(2, "value", StringType()), + ), + ) + test_table.append(arrow_table) + + yield test_table + + session_catalog.drop_table(identifier) + + @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) def test_partitioned_table_delete_full_file(spark: SparkSession, session_catalog: RestCatalog, format_version: int) -> None: @@ -770,3 +789,67 @@ def test_delete_after_partition_evolution_from_partitioned(session_catalog: Rest # Expect 8 records: 10 records - 2 assert len(tbl.scan().to_arrow()) == 8 + + +@pytest.mark.integration +def test_delete_with_filter_case_sensitive(test_table: Table) -> None: + assert {"idx": 2, "value": "b"} in test_table.scan().to_arrow().to_pylist() + + with pytest.raises(ValueError) as e: + test_table.delete("Idx == 2", case_sensitive=True) + assert "Could not find field with name Idx" in str(e.value) + assert {"idx": 2, "value": "b"} in test_table.scan().to_arrow().to_pylist() + + test_table.delete("idx == 2", case_sensitive=True) + assert {"idx": 2, "value": "b"} not in test_table.scan().to_arrow().to_pylist() + + +@pytest.mark.integration +def test_delete_with_filter_case_insensitive(test_table: Table) -> None: + assert {"idx": 2, "value": "b"} in test_table.scan().to_arrow().to_pylist() + + test_table.delete("Idx == 2", case_sensitive=False) + assert {"idx": 2, "value": "b"} not in test_table.scan().to_arrow().to_pylist() + + test_table.delete("idx == 3", case_sensitive=False) + assert {"idx": 3, "value": "c"} not in test_table.scan().to_arrow().to_pylist() + + +@pytest.mark.integration +def test_overwrite_with_filter_case_sensitive(test_table: Table) -> None: + assert {"idx": 2, "value": "b"} in test_table.scan().to_arrow().to_pylist() + + new_table = pa.Table.from_arrays( + [ + pa.array([10]), + pa.array(["x"]), + ], + names=["idx", "value"], + ) + + with pytest.raises(ValueError) as e: + test_table.overwrite(df=new_table, overwrite_filter="Idx == 2", case_sensitive=True) + assert "Could not find field with name Idx" in str(e.value) + assert {"idx": 2, "value": "b"} in test_table.scan().to_arrow().to_pylist() + assert {"idx": 10, "value": "x"} not in test_table.scan().to_arrow().to_pylist() + + test_table.overwrite(df=new_table, overwrite_filter="idx == 2", case_sensitive=True) + assert {"idx": 2, "value": "b"} not in test_table.scan().to_arrow().to_pylist() + assert {"idx": 10, "value": "x"} in test_table.scan().to_arrow().to_pylist() + + +@pytest.mark.integration +def test_overwrite_with_filter_case_insensitive(test_table: Table) -> None: + assert {"idx": 2, "value": "b"} in test_table.scan().to_arrow().to_pylist() + + new_table = pa.Table.from_arrays( + [ + pa.array([10]), + pa.array(["x"]), + ], + names=["idx", "value"], + ) + + test_table.overwrite(df=new_table, overwrite_filter="Idx == 2", case_sensitive=False) + assert {"idx": 2, "value": "b"} not in test_table.scan().to_arrow().to_pylist() + assert {"idx": 10, "value": "x"} in test_table.scan().to_arrow().to_pylist() diff --git a/tests/integration/test_reads.py b/tests/integration/test_reads.py index f8bc57bb8c..dda9dfcf23 100644 --- a/tests/integration/test_reads.py +++ b/tests/integration/test_reads.py @@ -621,6 +621,35 @@ def test_filter_on_new_column(catalog: Catalog) -> None: assert arrow_table["b"].to_pylist() == [None] +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_filter_case_sensitive(catalog: Catalog) -> None: + test_table_add_column = catalog.load_table("default.test_table_add_column") + arrow_table = test_table_add_column.scan().to_arrow() + assert "2" in arrow_table["b"].to_pylist() + + arrow_table = test_table_add_column.scan(row_filter="b == '2'", case_sensitive=True).to_arrow() + assert arrow_table["b"].to_pylist() == ["2"] + + with pytest.raises(ValueError) as e: + _ = test_table_add_column.scan(row_filter="B == '2'", case_sensitive=True).to_arrow() + assert "Could not find field with name B" in str(e.value) + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_filter_case_insensitive(catalog: Catalog) -> None: + test_table_add_column = catalog.load_table("default.test_table_add_column") + arrow_table = test_table_add_column.scan().to_arrow() + assert "2" in arrow_table["b"].to_pylist() + + arrow_table = test_table_add_column.scan(row_filter="b == '2'", case_sensitive=False).to_arrow() + assert arrow_table["b"].to_pylist() == ["2"] + + arrow_table = test_table_add_column.scan(row_filter="B == '2'", case_sensitive=False).to_arrow() + assert arrow_table["b"].to_pylist() == ["2"] + + @pytest.mark.integration @pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) def test_upgrade_table_version(catalog: Catalog) -> None: diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 040c67034b..0cfd77e400 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -310,6 +310,19 @@ def test_table_scan_row_filter(table_v2: Table) -> None: assert scan.filter(EqualTo("x", 10)).filter(In("y", (10, 11))).row_filter == And(EqualTo("x", 10), In("y", (10, 11))) +def test_table_scan_partition_filters_case_sensitive(table_v2: Table) -> None: + scan = table_v2.scan(row_filter=EqualTo("X", 10), case_sensitive=True) + with pytest.raises(ValueError): + for i in range(len(table_v2.metadata.specs())): + _ = scan.partition_filters[i] + + +def test_table_scan_partition_filters_case_insensitive(table_v2: Table) -> None: + scan = table_v2.scan(row_filter=EqualTo("X", 10), case_sensitive=False) + for i in range(len(table_v2.metadata.specs())): + _ = scan.partition_filters[i] + + def test_table_scan_ref(table_v2: Table) -> None: scan = table_v2.scan() assert scan.use_ref("test").snapshot_id == 3051729675574597004 diff --git a/tests/table/test_snapshots.py b/tests/table/test_snapshots.py index b4dde217d4..a9e87b4fd4 100644 --- a/tests/table/test_snapshots.py +++ b/tests/table/test_snapshots.py @@ -17,9 +17,12 @@ # pylint:disable=redefined-outer-name,eval-used import pytest +from pyiceberg.expressions import EqualTo from pyiceberg.manifest import DataFile, DataFileContent, ManifestContent, ManifestFile from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema +from pyiceberg.table import Table +from pyiceberg.table.metadata import TableMetadataV1, TableMetadataV2 from pyiceberg.table.snapshots import Operation, Snapshot, SnapshotSummaryCollector, Summary, update_snapshot_summaries from pyiceberg.transforms import IdentityTransform from pyiceberg.typedef import Record @@ -341,3 +344,20 @@ def test_invalid_type() -> None: ) assert "Could not parse summary property total-data-files to an int: abc" in str(e.value) + + +@pytest.mark.parametrize("case_sensitive", [True, False]) +def test_delete_table_rows_case_sensitive( + case_sensitive: bool, table_v1: Table, table_v2: Table, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setattr(TableMetadataV1, "current_snapshot", lambda _: None) + monkeypatch.setattr(TableMetadataV2, "current_snapshot", lambda _: None) + for table in [table_v1, table_v2]: + delete_file = table.transaction().update_snapshot().delete() + delete_file.delete_by_predicate(predicate=EqualTo("X", 10), case_sensitive=case_sensitive) + if case_sensitive: + with pytest.raises(ValueError) as e: + _ = delete_file._compute_deletes + assert "Could not find field with name X" in str(e.value) + else: + _ = delete_file._compute_deletes