Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Local fix table scan case sensitive #2

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,10 @@ def _apply(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequ

return self

def _scan(self, row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE) -> DataScan:
def _scan(self, row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE, case_sensitive: bool = True) -> DataScan:
"""Minimal data scan of the table with the current state of the transaction."""
return DataScan(
table_metadata=self.table_metadata,
io=self._table.io,
row_filter=row_filter,
table_metadata=self.table_metadata, io=self._table.io, row_filter=row_filter, case_sensitive=case_sensitive
)

def upgrade_table_version(self, format_version: TableVersion) -> Transaction:
Expand Down Expand Up @@ -421,6 +419,7 @@ def overwrite(
self,
df: pa.Table,
overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
case_sensitive: bool = True,
snapshot_properties: Dict[str, str] = EMPTY_DICT,
) -> None:
"""
Expand All @@ -436,6 +435,7 @@ def overwrite(
df: The Arrow dataframe that will be used to overwrite the table
overwrite_filter: ALWAYS_TRUE when you overwrite all the data,
or a boolean expression in case of a partial overwrite
case_sensitive: A bool determine if the provided `overwrite_filter` is case-sensitive
snapshot_properties: Custom properties to be added to the snapshot summary
"""
try:
Expand All @@ -459,7 +459,7 @@ def overwrite(
self.table_metadata.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)

self.delete(delete_filter=overwrite_filter, snapshot_properties=snapshot_properties)
self.delete(delete_filter=overwrite_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties)

with self.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot:
# skip writing data files if the dataframe is empty
Expand All @@ -470,17 +470,23 @@ def overwrite(
for data_file in data_files:
update_snapshot.append_data_file(data_file)

def delete(self, delete_filter: Union[str, BooleanExpression], snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
def delete(
self,
delete_filter: Union[str, BooleanExpression],
case_sensitive: bool = True,
snapshot_properties: Dict[str, str] = EMPTY_DICT,
) -> None:
"""
Shorthand for deleting record from a table.

An deletee may produce zero or more snapshots based on the operation:
A delete may produce zero or more snapshots based on the operation:

- DELETE: In case existing Parquet files can be dropped completely.
- REPLACE: In case existing Parquet files need to be rewritten

Args:
delete_filter: A boolean expression to delete rows from a table
case_sensitive: A bool determine if the provided `delete_filter` is case-sensitive
snapshot_properties: Custom properties to be added to the snapshot summary
"""
from pyiceberg.io.pyarrow import (
Expand All @@ -499,14 +505,14 @@ def delete(self, delete_filter: Union[str, BooleanExpression], snapshot_properti
delete_filter = _parse_row_filter(delete_filter)

with self.update_snapshot(snapshot_properties=snapshot_properties).delete() as delete_snapshot:
delete_snapshot.delete_by_predicate(delete_filter)
delete_snapshot.delete_by_predicate(delete_filter, case_sensitive)

# Check if there are any files that require an actual rewrite of a data file
if delete_snapshot.rewrites_needed is True:
bound_delete_filter = bind(self.table_metadata.schema(), delete_filter, case_sensitive=True)
bound_delete_filter = bind(self.table_metadata.schema(), delete_filter, case_sensitive)
preserve_row_filter = _expression_to_complementary_pyarrow(bound_delete_filter)

files = self._scan(row_filter=delete_filter).plan_files()
files = self._scan(row_filter=delete_filter, case_sensitive=case_sensitive).plan_files()

commit_uuid = uuid.uuid4()
counter = itertools.count(0)
Expand Down Expand Up @@ -987,6 +993,7 @@ def overwrite(
self,
df: pa.Table,
overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
case_sensitive: bool = True,
snapshot_properties: Dict[str, str] = EMPTY_DICT,
) -> None:
"""
Expand All @@ -1002,23 +1009,30 @@ def overwrite(
df: The Arrow dataframe that will be used to overwrite the table
overwrite_filter: ALWAYS_TRUE when you overwrite all the data,
or a boolean expression in case of a partial overwrite
case_sensitive: A bool determine if the provided `overwrite_filter` is case-sensitive
snapshot_properties: Custom properties to be added to the snapshot summary
"""
with self.transaction() as tx:
tx.overwrite(df=df, overwrite_filter=overwrite_filter, snapshot_properties=snapshot_properties)
tx.overwrite(
df=df, overwrite_filter=overwrite_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties
)

def delete(
self, delete_filter: Union[BooleanExpression, str] = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT
self,
delete_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
case_sensitive: bool = True,
snapshot_properties: Dict[str, str] = EMPTY_DICT,
) -> None:
"""
Shorthand for deleting rows from the table.

Args:
delete_filter: The predicate that used to remove rows
case_sensitive: A bool determine if the provided `delete_filter` is case-sensitive
snapshot_properties: Custom properties to be added to the snapshot summary
"""
with self.transaction() as tx:
tx.delete(delete_filter=delete_filter, snapshot_properties=snapshot_properties)
tx.delete(delete_filter=delete_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties)

def add_files(
self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True
Expand Down Expand Up @@ -1311,7 +1325,7 @@ def _match_deletes_to_data_file(data_entry: ManifestEntry, positional_delete_ent

class DataScan(TableScan):
def _build_partition_projection(self, spec_id: int) -> BooleanExpression:
project = inclusive_projection(self.table_metadata.schema(), self.table_metadata.specs()[spec_id])
project = inclusive_projection(self.table_metadata.schema(), self.table_metadata.specs()[spec_id], self.case_sensitive)
return project(self.row_filter)

@cached_property
Expand Down
15 changes: 10 additions & 5 deletions pyiceberg/table/update/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ class _DeleteFiles(_SnapshotProducer["_DeleteFiles"]):
"""

_predicate: BooleanExpression
_case_sensitive: bool

def __init__(
self,
Expand All @@ -329,6 +330,7 @@ def __init__(
):
super().__init__(operation, transaction, io, commit_uuid, snapshot_properties)
self._predicate = AlwaysFalse()
self._case_sensitive = True

def _commit(self) -> UpdatesAndRequirements:
# Only produce a commit when there is something to delete
Expand All @@ -340,7 +342,7 @@ def _commit(self) -> UpdatesAndRequirements:
def _build_partition_projection(self, spec_id: int) -> BooleanExpression:
schema = self._transaction.table_metadata.schema()
spec = self._transaction.table_metadata.specs()[spec_id]
project = inclusive_projection(schema, spec)
project = inclusive_projection(schema, spec, self._case_sensitive)
return project(self._predicate)

@cached_property
Expand All @@ -350,10 +352,11 @@ def partition_filters(self) -> KeyDefaultDict[int, BooleanExpression]:
def _build_manifest_evaluator(self, spec_id: int) -> Callable[[ManifestFile], bool]:
schema = self._transaction.table_metadata.schema()
spec = self._transaction.table_metadata.specs()[spec_id]
return manifest_evaluator(spec, schema, self.partition_filters[spec_id], case_sensitive=True)
return manifest_evaluator(spec, schema, self.partition_filters[spec_id], self._case_sensitive)

def delete_by_predicate(self, predicate: BooleanExpression) -> None:
def delete_by_predicate(self, predicate: BooleanExpression, case_sensitive: bool = True) -> None:
self._predicate = Or(self._predicate, predicate)
self._case_sensitive = case_sensitive

@cached_property
def _compute_deletes(self) -> Tuple[List[ManifestFile], List[ManifestEntry], bool]:
Expand All @@ -376,8 +379,10 @@ def _copy_with_new_status(entry: ManifestEntry, status: ManifestEntryStatus) ->
)

manifest_evaluators: Dict[int, Callable[[ManifestFile], bool]] = KeyDefaultDict(self._build_manifest_evaluator)
strict_metrics_evaluator = _StrictMetricsEvaluator(schema, self._predicate, case_sensitive=True).eval
inclusive_metrics_evaluator = _InclusiveMetricsEvaluator(schema, self._predicate, case_sensitive=True).eval
strict_metrics_evaluator = _StrictMetricsEvaluator(schema, self._predicate, case_sensitive=self._case_sensitive).eval
inclusive_metrics_evaluator = _InclusiveMetricsEvaluator(
schema, self._predicate, case_sensitive=self._case_sensitive
).eval

existing_manifests = []
total_deleted_entries = []
Expand Down
87 changes: 85 additions & 2 deletions tests/integration/test_deletes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
# pylint:disable=redefined-outer-name
from datetime import datetime
from typing import List
from typing import Generator, List

import pyarrow as pa
import pytest
Expand All @@ -28,16 +28,35 @@
from pyiceberg.manifest import ManifestEntryStatus
from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table import Table
from pyiceberg.table.snapshots import Operation, Summary
from pyiceberg.transforms import IdentityTransform
from pyiceberg.types import FloatType, IntegerType, LongType, NestedField, TimestampType
from pyiceberg.types import FloatType, IntegerType, LongType, NestedField, StringType, TimestampType


def run_spark_commands(spark: SparkSession, sqls: List[str]) -> None:
for sql in sqls:
spark.sql(sql)


@pytest.fixture()
def test_table(session_catalog: RestCatalog) -> Generator[Table, None, None]:
identifier = "default.__test_table"
arrow_table = pa.Table.from_arrays([pa.array([1, 2, 3, 4, 5]), pa.array(["a", "b", "c", "d", "e"])], names=["idx", "value"])
test_table = session_catalog.create_table(
identifier,
schema=Schema(
NestedField(1, "idx", LongType()),
NestedField(2, "value", StringType()),
),
)
test_table.append(arrow_table)

yield test_table

session_catalog.drop_table(identifier)


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_partitioned_table_delete_full_file(spark: SparkSession, session_catalog: RestCatalog, format_version: int) -> None:
Expand Down Expand Up @@ -770,3 +789,67 @@ def test_delete_after_partition_evolution_from_partitioned(session_catalog: Rest

# Expect 8 records: 10 records - 2
assert len(tbl.scan().to_arrow()) == 8


@pytest.mark.integration
def test_delete_with_filter_case_sensitive(test_table: Table) -> None:
assert {"idx": 2, "value": "b"} in test_table.scan().to_arrow().to_pylist()

with pytest.raises(ValueError) as e:
test_table.delete("Idx == 2", case_sensitive=True)
assert "Could not find field with name Idx" in str(e.value)
assert {"idx": 2, "value": "b"} in test_table.scan().to_arrow().to_pylist()

test_table.delete("idx == 2", case_sensitive=True)
assert {"idx": 2, "value": "b"} not in test_table.scan().to_arrow().to_pylist()


@pytest.mark.integration
def test_delete_with_filter_case_insensitive(test_table: Table) -> None:
assert {"idx": 2, "value": "b"} in test_table.scan().to_arrow().to_pylist()

test_table.delete("Idx == 2", case_sensitive=False)
assert {"idx": 2, "value": "b"} not in test_table.scan().to_arrow().to_pylist()

test_table.delete("idx == 3", case_sensitive=False)
assert {"idx": 3, "value": "c"} not in test_table.scan().to_arrow().to_pylist()


@pytest.mark.integration
def test_overwrite_with_filter_case_sensitive(test_table: Table) -> None:
assert {"idx": 2, "value": "b"} in test_table.scan().to_arrow().to_pylist()

new_table = pa.Table.from_arrays(
[
pa.array([10]),
pa.array(["x"]),
],
names=["idx", "value"],
)

with pytest.raises(ValueError) as e:
test_table.overwrite(df=new_table, overwrite_filter="Idx == 2", case_sensitive=True)
assert "Could not find field with name Idx" in str(e.value)
assert {"idx": 2, "value": "b"} in test_table.scan().to_arrow().to_pylist()
assert {"idx": 10, "value": "x"} not in test_table.scan().to_arrow().to_pylist()

test_table.overwrite(df=new_table, overwrite_filter="idx == 2", case_sensitive=True)
assert {"idx": 2, "value": "b"} not in test_table.scan().to_arrow().to_pylist()
assert {"idx": 10, "value": "x"} in test_table.scan().to_arrow().to_pylist()


@pytest.mark.integration
def test_overwrite_with_filter_case_insensitive(test_table: Table) -> None:
assert {"idx": 2, "value": "b"} in test_table.scan().to_arrow().to_pylist()

new_table = pa.Table.from_arrays(
[
pa.array([10]),
pa.array(["x"]),
],
names=["idx", "value"],
)

test_table.overwrite(df=new_table, overwrite_filter="Idx == 2", case_sensitive=False)
assert {"idx": 2, "value": "b"} not in test_table.scan().to_arrow().to_pylist()
assert {"idx": 10, "value": "x"} in test_table.scan().to_arrow().to_pylist()
29 changes: 29 additions & 0 deletions tests/integration/test_reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,35 @@ def test_filter_on_new_column(catalog: Catalog) -> None:
assert arrow_table["b"].to_pylist() == [None]


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_filter_case_sensitive(catalog: Catalog) -> None:
test_table_add_column = catalog.load_table("default.test_table_add_column")
arrow_table = test_table_add_column.scan().to_arrow()
assert "2" in arrow_table["b"].to_pylist()

arrow_table = test_table_add_column.scan(row_filter="b == '2'", case_sensitive=True).to_arrow()
assert arrow_table["b"].to_pylist() == ["2"]

with pytest.raises(ValueError) as e:
_ = test_table_add_column.scan(row_filter="B == '2'", case_sensitive=True).to_arrow()
assert "Could not find field with name B" in str(e.value)


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_filter_case_insensitive(catalog: Catalog) -> None:
test_table_add_column = catalog.load_table("default.test_table_add_column")
arrow_table = test_table_add_column.scan().to_arrow()
assert "2" in arrow_table["b"].to_pylist()

arrow_table = test_table_add_column.scan(row_filter="b == '2'", case_sensitive=False).to_arrow()
assert arrow_table["b"].to_pylist() == ["2"]

arrow_table = test_table_add_column.scan(row_filter="B == '2'", case_sensitive=False).to_arrow()
assert arrow_table["b"].to_pylist() == ["2"]


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_upgrade_table_version(catalog: Catalog) -> None:
Expand Down
13 changes: 13 additions & 0 deletions tests/table/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,19 @@ def test_table_scan_row_filter(table_v2: Table) -> None:
assert scan.filter(EqualTo("x", 10)).filter(In("y", (10, 11))).row_filter == And(EqualTo("x", 10), In("y", (10, 11)))


def test_table_scan_partition_filters_case_sensitive(table_v2: Table) -> None:
scan = table_v2.scan(row_filter=EqualTo("X", 10), case_sensitive=True)
with pytest.raises(ValueError):
for i in range(len(table_v2.metadata.specs())):
_ = scan.partition_filters[i]


def test_table_scan_partition_filters_case_insensitive(table_v2: Table) -> None:
scan = table_v2.scan(row_filter=EqualTo("X", 10), case_sensitive=False)
for i in range(len(table_v2.metadata.specs())):
_ = scan.partition_filters[i]


def test_table_scan_ref(table_v2: Table) -> None:
scan = table_v2.scan()
assert scan.use_ref("test").snapshot_id == 3051729675574597004
Expand Down
20 changes: 20 additions & 0 deletions tests/table/test_snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
# pylint:disable=redefined-outer-name,eval-used
import pytest

from pyiceberg.expressions import EqualTo
from pyiceberg.manifest import DataFile, DataFileContent, ManifestContent, ManifestFile
from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table import Table
from pyiceberg.table.metadata import TableMetadataV1, TableMetadataV2
from pyiceberg.table.snapshots import Operation, Snapshot, SnapshotSummaryCollector, Summary, update_snapshot_summaries
from pyiceberg.transforms import IdentityTransform
from pyiceberg.typedef import Record
Expand Down Expand Up @@ -341,3 +344,20 @@ def test_invalid_type() -> None:
)

assert "Could not parse summary property total-data-files to an int: abc" in str(e.value)


@pytest.mark.parametrize("case_sensitive", [True, False])
def test_delete_table_rows_case_sensitive(
case_sensitive: bool, table_v1: Table, table_v2: Table, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setattr(TableMetadataV1, "current_snapshot", lambda _: None)
monkeypatch.setattr(TableMetadataV2, "current_snapshot", lambda _: None)
for table in [table_v1, table_v2]:
delete_file = table.transaction().update_snapshot().delete()
delete_file.delete_by_predicate(predicate=EqualTo("X", 10), case_sensitive=case_sensitive)
if case_sensitive:
with pytest.raises(ValueError) as e:
_ = delete_file._compute_deletes
assert "Could not find field with name X" in str(e.value)
else:
_ = delete_file._compute_deletes