Skip to content

Commit b447461

Browse files
authored
Add StrictMetricsEvaluator (#518)
* Add StrictMetricsEvaluator This will enable use to delete whole datafiles by evaluating the metrics, and not needing to open the Parquet files. * Split out field check
1 parent 781096e commit b447461

File tree

2 files changed

+847
-34
lines changed

2 files changed

+847
-34
lines changed

pyiceberg/expressions/visitors.py

Lines changed: 319 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
DoubleType,
6868
FloatType,
6969
IcebergType,
70+
NestedField,
7071
PrimitiveType,
7172
StructType,
7273
TimestampType,
@@ -534,7 +535,9 @@ def visit_or(self, left_result: bool, right_result: bool) -> bool:
534535

535536

536537
ROWS_MIGHT_MATCH = True
538+
ROWS_MUST_MATCH = True
537539
ROWS_CANNOT_MATCH = False
540+
ROWS_MIGHT_NOT_MATCH = False
538541
IN_PREDICATE_LIMIT = 200
539542

540543

@@ -1089,16 +1092,52 @@ def expression_to_plain_format(
10891092
return [visit(expression, visitor) for expression in expressions]
10901093

10911094

1092-
class _InclusiveMetricsEvaluator(BoundBooleanExpressionVisitor[bool]):
1093-
struct: StructType
1094-
expr: BooleanExpression
1095-
1095+
class _MetricsEvaluator(BoundBooleanExpressionVisitor[bool], ABC):
10961096
value_counts: Dict[int, int]
10971097
null_counts: Dict[int, int]
10981098
nan_counts: Dict[int, int]
10991099
lower_bounds: Dict[int, bytes]
11001100
upper_bounds: Dict[int, bytes]
11011101

1102+
def visit_true(self) -> bool:
1103+
# all rows match
1104+
return ROWS_MIGHT_MATCH
1105+
1106+
def visit_false(self) -> bool:
1107+
# all rows fail
1108+
return ROWS_CANNOT_MATCH
1109+
1110+
def visit_not(self, child_result: bool) -> bool:
1111+
raise ValueError(f"NOT should be rewritten: {child_result}")
1112+
1113+
def visit_and(self, left_result: bool, right_result: bool) -> bool:
1114+
return left_result and right_result
1115+
1116+
def visit_or(self, left_result: bool, right_result: bool) -> bool:
1117+
return left_result or right_result
1118+
1119+
def _contains_nulls_only(self, field_id: int) -> bool:
1120+
if (value_count := self.value_counts.get(field_id)) and (null_count := self.null_counts.get(field_id)):
1121+
return value_count == null_count
1122+
return False
1123+
1124+
def _contains_nans_only(self, field_id: int) -> bool:
1125+
if (nan_count := self.nan_counts.get(field_id)) and (value_count := self.value_counts.get(field_id)):
1126+
return nan_count == value_count
1127+
return False
1128+
1129+
def _is_nan(self, val: Any) -> bool:
1130+
try:
1131+
return math.isnan(val)
1132+
except TypeError:
1133+
# In the case of None or other non-numeric types
1134+
return False
1135+
1136+
1137+
class _InclusiveMetricsEvaluator(_MetricsEvaluator):
1138+
struct: StructType
1139+
expr: BooleanExpression
1140+
11021141
def __init__(
11031142
self, schema: Schema, expr: BooleanExpression, case_sensitive: bool = True, include_empty_files: bool = False
11041143
) -> None:
@@ -1128,40 +1167,11 @@ def eval(self, file: DataFile) -> bool:
11281167
def _may_contain_null(self, field_id: int) -> bool:
11291168
return self.null_counts is None or (field_id in self.null_counts and self.null_counts.get(field_id) is not None)
11301169

1131-
def _contains_nulls_only(self, field_id: int) -> bool:
1132-
if (value_count := self.value_counts.get(field_id)) and (null_count := self.null_counts.get(field_id)):
1133-
return value_count == null_count
1134-
return False
1135-
11361170
def _contains_nans_only(self, field_id: int) -> bool:
11371171
if (nan_count := self.nan_counts.get(field_id)) and (value_count := self.value_counts.get(field_id)):
11381172
return nan_count == value_count
11391173
return False
11401174

1141-
def _is_nan(self, val: Any) -> bool:
1142-
try:
1143-
return math.isnan(val)
1144-
except TypeError:
1145-
# In the case of None or other non-numeric types
1146-
return False
1147-
1148-
def visit_true(self) -> bool:
1149-
# all rows match
1150-
return ROWS_MIGHT_MATCH
1151-
1152-
def visit_false(self) -> bool:
1153-
# all rows fail
1154-
return ROWS_CANNOT_MATCH
1155-
1156-
def visit_not(self, child_result: bool) -> bool:
1157-
raise ValueError(f"NOT should be rewritten: {child_result}")
1158-
1159-
def visit_and(self, left_result: bool, right_result: bool) -> bool:
1160-
return left_result and right_result
1161-
1162-
def visit_or(self, left_result: bool, right_result: bool) -> bool:
1163-
return left_result or right_result
1164-
11651175
def visit_is_null(self, term: BoundTerm[L]) -> bool:
11661176
field_id = term.ref().field.field_id
11671177

@@ -1421,3 +1431,279 @@ def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool
14211431
return ROWS_CANNOT_MATCH
14221432

14231433
return ROWS_MIGHT_MATCH
1434+
1435+
1436+
class _StrictMetricsEvaluator(_MetricsEvaluator):
1437+
struct: StructType
1438+
expr: BooleanExpression
1439+
1440+
def __init__(
1441+
self, schema: Schema, expr: BooleanExpression, case_sensitive: bool = True, include_empty_files: bool = False
1442+
) -> None:
1443+
self.struct = schema.as_struct()
1444+
self.include_empty_files = include_empty_files
1445+
self.expr = bind(schema, rewrite_not(expr), case_sensitive)
1446+
1447+
def eval(self, file: DataFile) -> bool:
1448+
"""Test whether all records within the file match the expression.
1449+
1450+
Args:
1451+
file: A data file
1452+
1453+
Returns: false if the file may contain any row that doesn't match
1454+
the expression, true otherwise.
1455+
"""
1456+
if file.record_count <= 0:
1457+
# Older version don't correctly implement record count from avro file and thus
1458+
# set record count -1 when importing avro tables to iceberg tables. This should
1459+
# be updated once we implemented and set correct record count.
1460+
return ROWS_MUST_MATCH
1461+
1462+
self.value_counts = file.value_counts or EMPTY_DICT
1463+
self.null_counts = file.null_value_counts or EMPTY_DICT
1464+
self.nan_counts = file.nan_value_counts or EMPTY_DICT
1465+
self.lower_bounds = file.lower_bounds or EMPTY_DICT
1466+
self.upper_bounds = file.upper_bounds or EMPTY_DICT
1467+
1468+
return visit(self.expr, self)
1469+
1470+
def visit_is_null(self, term: BoundTerm[L]) -> bool:
1471+
# no need to check whether the field is required because binding evaluates that case
1472+
# if the column has any non-null values, the expression does not match
1473+
field_id = term.ref().field.field_id
1474+
1475+
if self._contains_nulls_only(field_id):
1476+
return ROWS_MUST_MATCH
1477+
else:
1478+
return ROWS_MIGHT_NOT_MATCH
1479+
1480+
def visit_not_null(self, term: BoundTerm[L]) -> bool:
1481+
# no need to check whether the field is required because binding evaluates that case
1482+
# if the column has any non-null values, the expression does not match
1483+
field_id = term.ref().field.field_id
1484+
1485+
if (null_count := self.null_counts.get(field_id)) is not None and null_count == 0:
1486+
return ROWS_MUST_MATCH
1487+
else:
1488+
return ROWS_MIGHT_NOT_MATCH
1489+
1490+
def visit_is_nan(self, term: BoundTerm[L]) -> bool:
1491+
field_id = term.ref().field.field_id
1492+
1493+
if self._contains_nans_only(field_id):
1494+
return ROWS_MUST_MATCH
1495+
else:
1496+
return ROWS_MIGHT_NOT_MATCH
1497+
1498+
def visit_not_nan(self, term: BoundTerm[L]) -> bool:
1499+
field_id = term.ref().field.field_id
1500+
1501+
if (nan_count := self.nan_counts.get(field_id)) is not None and nan_count == 0:
1502+
return ROWS_MUST_MATCH
1503+
1504+
if self._contains_nulls_only(field_id):
1505+
return ROWS_MUST_MATCH
1506+
1507+
return ROWS_MIGHT_NOT_MATCH
1508+
1509+
def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
1510+
# Rows must match when: <----------Min----Max---X------->
1511+
1512+
field_id = term.ref().field.field_id
1513+
1514+
if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
1515+
return ROWS_MIGHT_NOT_MATCH
1516+
1517+
if upper_bytes := self.upper_bounds.get(field_id):
1518+
field = self._get_field(field_id)
1519+
upper = _from_byte_buffer(field.field_type, upper_bytes)
1520+
1521+
if upper < literal.value:
1522+
return ROWS_MUST_MATCH
1523+
1524+
return ROWS_MIGHT_NOT_MATCH
1525+
1526+
def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
1527+
# Rows must match when: <----------Min----Max---X------->
1528+
1529+
field_id = term.ref().field.field_id
1530+
1531+
if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
1532+
return ROWS_MIGHT_NOT_MATCH
1533+
1534+
if upper_bytes := self.upper_bounds.get(field_id):
1535+
field = self._get_field(field_id)
1536+
upper = _from_byte_buffer(field.field_type, upper_bytes)
1537+
1538+
if upper <= literal.value:
1539+
return ROWS_MUST_MATCH
1540+
1541+
return ROWS_MIGHT_NOT_MATCH
1542+
1543+
def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
1544+
# Rows must match when: <-------X---Min----Max---------->
1545+
1546+
field_id = term.ref().field.field_id
1547+
1548+
if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
1549+
return ROWS_MIGHT_NOT_MATCH
1550+
1551+
if lower_bytes := self.lower_bounds.get(field_id):
1552+
field = self._get_field(field_id)
1553+
lower = _from_byte_buffer(field.field_type, lower_bytes)
1554+
1555+
if self._is_nan(lower):
1556+
# NaN indicates unreliable bounds.
1557+
# See the _StrictMetricsEvaluator docs for more.
1558+
return ROWS_MIGHT_NOT_MATCH
1559+
1560+
if lower > literal.value:
1561+
return ROWS_MUST_MATCH
1562+
1563+
return ROWS_MIGHT_NOT_MATCH
1564+
1565+
def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
1566+
# Rows must match when: <-------X---Min----Max---------->
1567+
field_id = term.ref().field.field_id
1568+
1569+
if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
1570+
return ROWS_MIGHT_NOT_MATCH
1571+
1572+
if lower_bytes := self.lower_bounds.get(field_id):
1573+
field = self._get_field(field_id)
1574+
lower = _from_byte_buffer(field.field_type, lower_bytes)
1575+
1576+
if self._is_nan(lower):
1577+
# NaN indicates unreliable bounds.
1578+
# See the _StrictMetricsEvaluator docs for more.
1579+
return ROWS_MIGHT_NOT_MATCH
1580+
1581+
if lower >= literal.value:
1582+
return ROWS_MUST_MATCH
1583+
1584+
return ROWS_MIGHT_NOT_MATCH
1585+
1586+
def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
1587+
# Rows must match when Min == X == Max
1588+
field_id = term.ref().field.field_id
1589+
1590+
if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
1591+
return ROWS_MIGHT_NOT_MATCH
1592+
1593+
if (lower_bytes := self.lower_bounds.get(field_id)) and (upper_bytes := self.upper_bounds.get(field_id)):
1594+
field = self._get_field(field_id)
1595+
lower = _from_byte_buffer(field.field_type, lower_bytes)
1596+
upper = _from_byte_buffer(field.field_type, upper_bytes)
1597+
1598+
if lower != literal.value or upper != literal.value:
1599+
return ROWS_MIGHT_NOT_MATCH
1600+
else:
1601+
return ROWS_MUST_MATCH
1602+
1603+
return ROWS_MIGHT_NOT_MATCH
1604+
1605+
def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
1606+
# Rows must match when X < Min or Max < X because it is not in the range
1607+
field_id = term.ref().field.field_id
1608+
1609+
if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
1610+
return ROWS_MUST_MATCH
1611+
1612+
field = self._get_field(field_id)
1613+
1614+
if lower_bytes := self.lower_bounds.get(field_id):
1615+
lower = _from_byte_buffer(field.field_type, lower_bytes)
1616+
1617+
if self._is_nan(lower):
1618+
# NaN indicates unreliable bounds.
1619+
# See the _StrictMetricsEvaluator docs for more.
1620+
return ROWS_MIGHT_NOT_MATCH
1621+
1622+
if lower > literal.value:
1623+
return ROWS_MUST_MATCH
1624+
1625+
if upper_bytes := self.upper_bounds.get(field_id):
1626+
upper = _from_byte_buffer(field.field_type, upper_bytes)
1627+
1628+
if upper < literal.value:
1629+
return ROWS_MUST_MATCH
1630+
1631+
return ROWS_MIGHT_NOT_MATCH
1632+
1633+
def visit_in(self, term: BoundTerm[L], literals: Set[L]) -> bool:
1634+
field_id = term.ref().field.field_id
1635+
1636+
if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
1637+
return ROWS_MIGHT_NOT_MATCH
1638+
1639+
field = self._get_field(field_id)
1640+
1641+
if (lower_bytes := self.lower_bounds.get(field_id)) and (upper_bytes := self.upper_bounds.get(field_id)):
1642+
# similar to the implementation in eq, first check if the lower bound is in the set
1643+
lower = _from_byte_buffer(field.field_type, lower_bytes)
1644+
if lower not in literals:
1645+
return ROWS_MIGHT_NOT_MATCH
1646+
1647+
# check if the upper bound is in the set
1648+
upper = _from_byte_buffer(field.field_type, upper_bytes)
1649+
if upper not in literals:
1650+
return ROWS_MIGHT_NOT_MATCH
1651+
1652+
# finally check if the lower bound and the upper bound are equal
1653+
if lower != upper:
1654+
return ROWS_MIGHT_NOT_MATCH
1655+
1656+
# All values must be in the set if the lower bound and the upper bound are
1657+
# in the set and are equal.
1658+
return ROWS_MUST_MATCH
1659+
1660+
return ROWS_MIGHT_NOT_MATCH
1661+
1662+
def visit_not_in(self, term: BoundTerm[L], literals: Set[L]) -> bool:
1663+
field_id = term.ref().field.field_id
1664+
1665+
if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
1666+
return ROWS_MUST_MATCH
1667+
1668+
field = self._get_field(field_id)
1669+
1670+
if lower_bytes := self.lower_bounds.get(field_id):
1671+
lower = _from_byte_buffer(field.field_type, lower_bytes)
1672+
1673+
if self._is_nan(lower):
1674+
# NaN indicates unreliable bounds.
1675+
# See the StrictMetricsEvaluator docs for more.
1676+
return ROWS_MIGHT_NOT_MATCH
1677+
1678+
literals = {val for val in literals if lower <= val}
1679+
if len(literals) == 0:
1680+
return ROWS_MUST_MATCH
1681+
1682+
if upper_bytes := self.upper_bounds.get(field_id):
1683+
upper = _from_byte_buffer(field.field_type, upper_bytes)
1684+
1685+
literals = {val for val in literals if upper >= val}
1686+
1687+
if len(literals) == 0:
1688+
return ROWS_MUST_MATCH
1689+
1690+
return ROWS_MIGHT_NOT_MATCH
1691+
1692+
def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
1693+
return ROWS_MIGHT_NOT_MATCH
1694+
1695+
def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
1696+
return ROWS_MIGHT_NOT_MATCH
1697+
1698+
def _get_field(self, field_id: int) -> NestedField:
1699+
field = self.struct.field(field_id=field_id)
1700+
if field is None:
1701+
raise ValueError(f"Cannot find field, might be nested or missing: {field_id}")
1702+
1703+
return field
1704+
1705+
def _can_contain_nulls(self, field_id: int) -> bool:
1706+
return (null_count := self.null_counts.get(field_id)) is not None and null_count > 0
1707+
1708+
def _can_contain_nans(self, field_id: int) -> bool:
1709+
return (nan_count := self.nan_counts.get(field_id)) is not None and nan_count > 0

0 commit comments

Comments
 (0)