|
67 | 67 | DoubleType,
|
68 | 68 | FloatType,
|
69 | 69 | IcebergType,
|
| 70 | + NestedField, |
70 | 71 | PrimitiveType,
|
71 | 72 | StructType,
|
72 | 73 | TimestampType,
|
@@ -534,7 +535,9 @@ def visit_or(self, left_result: bool, right_result: bool) -> bool:
|
534 | 535 |
|
535 | 536 |
|
536 | 537 | ROWS_MIGHT_MATCH = True
|
| 538 | +ROWS_MUST_MATCH = True |
537 | 539 | ROWS_CANNOT_MATCH = False
|
| 540 | +ROWS_MIGHT_NOT_MATCH = False |
538 | 541 | IN_PREDICATE_LIMIT = 200
|
539 | 542 |
|
540 | 543 |
|
@@ -1089,16 +1092,52 @@ def expression_to_plain_format(
|
1089 | 1092 | return [visit(expression, visitor) for expression in expressions]
|
1090 | 1093 |
|
1091 | 1094 |
|
1092 |
| -class _InclusiveMetricsEvaluator(BoundBooleanExpressionVisitor[bool]): |
1093 |
| - struct: StructType |
1094 |
| - expr: BooleanExpression |
1095 |
| - |
| 1095 | +class _MetricsEvaluator(BoundBooleanExpressionVisitor[bool], ABC): |
1096 | 1096 | value_counts: Dict[int, int]
|
1097 | 1097 | null_counts: Dict[int, int]
|
1098 | 1098 | nan_counts: Dict[int, int]
|
1099 | 1099 | lower_bounds: Dict[int, bytes]
|
1100 | 1100 | upper_bounds: Dict[int, bytes]
|
1101 | 1101 |
|
| 1102 | + def visit_true(self) -> bool: |
| 1103 | + # all rows match |
| 1104 | + return ROWS_MIGHT_MATCH |
| 1105 | + |
| 1106 | + def visit_false(self) -> bool: |
| 1107 | + # all rows fail |
| 1108 | + return ROWS_CANNOT_MATCH |
| 1109 | + |
| 1110 | + def visit_not(self, child_result: bool) -> bool: |
| 1111 | + raise ValueError(f"NOT should be rewritten: {child_result}") |
| 1112 | + |
| 1113 | + def visit_and(self, left_result: bool, right_result: bool) -> bool: |
| 1114 | + return left_result and right_result |
| 1115 | + |
| 1116 | + def visit_or(self, left_result: bool, right_result: bool) -> bool: |
| 1117 | + return left_result or right_result |
| 1118 | + |
| 1119 | + def _contains_nulls_only(self, field_id: int) -> bool: |
| 1120 | + if (value_count := self.value_counts.get(field_id)) and (null_count := self.null_counts.get(field_id)): |
| 1121 | + return value_count == null_count |
| 1122 | + return False |
| 1123 | + |
| 1124 | + def _contains_nans_only(self, field_id: int) -> bool: |
| 1125 | + if (nan_count := self.nan_counts.get(field_id)) and (value_count := self.value_counts.get(field_id)): |
| 1126 | + return nan_count == value_count |
| 1127 | + return False |
| 1128 | + |
| 1129 | + def _is_nan(self, val: Any) -> bool: |
| 1130 | + try: |
| 1131 | + return math.isnan(val) |
| 1132 | + except TypeError: |
| 1133 | + # In the case of None or other non-numeric types |
| 1134 | + return False |
| 1135 | + |
| 1136 | + |
| 1137 | +class _InclusiveMetricsEvaluator(_MetricsEvaluator): |
| 1138 | + struct: StructType |
| 1139 | + expr: BooleanExpression |
| 1140 | + |
1102 | 1141 | def __init__(
|
1103 | 1142 | self, schema: Schema, expr: BooleanExpression, case_sensitive: bool = True, include_empty_files: bool = False
|
1104 | 1143 | ) -> None:
|
@@ -1128,40 +1167,11 @@ def eval(self, file: DataFile) -> bool:
|
1128 | 1167 | def _may_contain_null(self, field_id: int) -> bool:
|
1129 | 1168 | return self.null_counts is None or (field_id in self.null_counts and self.null_counts.get(field_id) is not None)
|
1130 | 1169 |
|
1131 |
| - def _contains_nulls_only(self, field_id: int) -> bool: |
1132 |
| - if (value_count := self.value_counts.get(field_id)) and (null_count := self.null_counts.get(field_id)): |
1133 |
| - return value_count == null_count |
1134 |
| - return False |
1135 |
| - |
1136 | 1170 | def _contains_nans_only(self, field_id: int) -> bool:
|
1137 | 1171 | if (nan_count := self.nan_counts.get(field_id)) and (value_count := self.value_counts.get(field_id)):
|
1138 | 1172 | return nan_count == value_count
|
1139 | 1173 | return False
|
1140 | 1174 |
|
1141 |
| - def _is_nan(self, val: Any) -> bool: |
1142 |
| - try: |
1143 |
| - return math.isnan(val) |
1144 |
| - except TypeError: |
1145 |
| - # In the case of None or other non-numeric types |
1146 |
| - return False |
1147 |
| - |
1148 |
| - def visit_true(self) -> bool: |
1149 |
| - # all rows match |
1150 |
| - return ROWS_MIGHT_MATCH |
1151 |
| - |
1152 |
| - def visit_false(self) -> bool: |
1153 |
| - # all rows fail |
1154 |
| - return ROWS_CANNOT_MATCH |
1155 |
| - |
1156 |
| - def visit_not(self, child_result: bool) -> bool: |
1157 |
| - raise ValueError(f"NOT should be rewritten: {child_result}") |
1158 |
| - |
1159 |
| - def visit_and(self, left_result: bool, right_result: bool) -> bool: |
1160 |
| - return left_result and right_result |
1161 |
| - |
1162 |
| - def visit_or(self, left_result: bool, right_result: bool) -> bool: |
1163 |
| - return left_result or right_result |
1164 |
| - |
1165 | 1175 | def visit_is_null(self, term: BoundTerm[L]) -> bool:
|
1166 | 1176 | field_id = term.ref().field.field_id
|
1167 | 1177 |
|
@@ -1421,3 +1431,279 @@ def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool
|
1421 | 1431 | return ROWS_CANNOT_MATCH
|
1422 | 1432 |
|
1423 | 1433 | return ROWS_MIGHT_MATCH
|
| 1434 | + |
| 1435 | + |
| 1436 | +class _StrictMetricsEvaluator(_MetricsEvaluator): |
| 1437 | + struct: StructType |
| 1438 | + expr: BooleanExpression |
| 1439 | + |
| 1440 | + def __init__( |
| 1441 | + self, schema: Schema, expr: BooleanExpression, case_sensitive: bool = True, include_empty_files: bool = False |
| 1442 | + ) -> None: |
| 1443 | + self.struct = schema.as_struct() |
| 1444 | + self.include_empty_files = include_empty_files |
| 1445 | + self.expr = bind(schema, rewrite_not(expr), case_sensitive) |
| 1446 | + |
| 1447 | + def eval(self, file: DataFile) -> bool: |
| 1448 | + """Test whether all records within the file match the expression. |
| 1449 | +
|
| 1450 | + Args: |
| 1451 | + file: A data file |
| 1452 | +
|
| 1453 | + Returns: false if the file may contain any row that doesn't match |
| 1454 | + the expression, true otherwise. |
| 1455 | + """ |
| 1456 | + if file.record_count <= 0: |
| 1457 | + # Older version don't correctly implement record count from avro file and thus |
| 1458 | + # set record count -1 when importing avro tables to iceberg tables. This should |
| 1459 | + # be updated once we implemented and set correct record count. |
| 1460 | + return ROWS_MUST_MATCH |
| 1461 | + |
| 1462 | + self.value_counts = file.value_counts or EMPTY_DICT |
| 1463 | + self.null_counts = file.null_value_counts or EMPTY_DICT |
| 1464 | + self.nan_counts = file.nan_value_counts or EMPTY_DICT |
| 1465 | + self.lower_bounds = file.lower_bounds or EMPTY_DICT |
| 1466 | + self.upper_bounds = file.upper_bounds or EMPTY_DICT |
| 1467 | + |
| 1468 | + return visit(self.expr, self) |
| 1469 | + |
| 1470 | + def visit_is_null(self, term: BoundTerm[L]) -> bool: |
| 1471 | + # no need to check whether the field is required because binding evaluates that case |
| 1472 | + # if the column has any non-null values, the expression does not match |
| 1473 | + field_id = term.ref().field.field_id |
| 1474 | + |
| 1475 | + if self._contains_nulls_only(field_id): |
| 1476 | + return ROWS_MUST_MATCH |
| 1477 | + else: |
| 1478 | + return ROWS_MIGHT_NOT_MATCH |
| 1479 | + |
| 1480 | + def visit_not_null(self, term: BoundTerm[L]) -> bool: |
| 1481 | + # no need to check whether the field is required because binding evaluates that case |
| 1482 | + # if the column has any non-null values, the expression does not match |
| 1483 | + field_id = term.ref().field.field_id |
| 1484 | + |
| 1485 | + if (null_count := self.null_counts.get(field_id)) is not None and null_count == 0: |
| 1486 | + return ROWS_MUST_MATCH |
| 1487 | + else: |
| 1488 | + return ROWS_MIGHT_NOT_MATCH |
| 1489 | + |
| 1490 | + def visit_is_nan(self, term: BoundTerm[L]) -> bool: |
| 1491 | + field_id = term.ref().field.field_id |
| 1492 | + |
| 1493 | + if self._contains_nans_only(field_id): |
| 1494 | + return ROWS_MUST_MATCH |
| 1495 | + else: |
| 1496 | + return ROWS_MIGHT_NOT_MATCH |
| 1497 | + |
| 1498 | + def visit_not_nan(self, term: BoundTerm[L]) -> bool: |
| 1499 | + field_id = term.ref().field.field_id |
| 1500 | + |
| 1501 | + if (nan_count := self.nan_counts.get(field_id)) is not None and nan_count == 0: |
| 1502 | + return ROWS_MUST_MATCH |
| 1503 | + |
| 1504 | + if self._contains_nulls_only(field_id): |
| 1505 | + return ROWS_MUST_MATCH |
| 1506 | + |
| 1507 | + return ROWS_MIGHT_NOT_MATCH |
| 1508 | + |
| 1509 | + def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: |
| 1510 | + # Rows must match when: <----------Min----Max---X-------> |
| 1511 | + |
| 1512 | + field_id = term.ref().field.field_id |
| 1513 | + |
| 1514 | + if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id): |
| 1515 | + return ROWS_MIGHT_NOT_MATCH |
| 1516 | + |
| 1517 | + if upper_bytes := self.upper_bounds.get(field_id): |
| 1518 | + field = self._get_field(field_id) |
| 1519 | + upper = _from_byte_buffer(field.field_type, upper_bytes) |
| 1520 | + |
| 1521 | + if upper < literal.value: |
| 1522 | + return ROWS_MUST_MATCH |
| 1523 | + |
| 1524 | + return ROWS_MIGHT_NOT_MATCH |
| 1525 | + |
| 1526 | + def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: |
| 1527 | + # Rows must match when: <----------Min----Max---X-------> |
| 1528 | + |
| 1529 | + field_id = term.ref().field.field_id |
| 1530 | + |
| 1531 | + if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id): |
| 1532 | + return ROWS_MIGHT_NOT_MATCH |
| 1533 | + |
| 1534 | + if upper_bytes := self.upper_bounds.get(field_id): |
| 1535 | + field = self._get_field(field_id) |
| 1536 | + upper = _from_byte_buffer(field.field_type, upper_bytes) |
| 1537 | + |
| 1538 | + if upper <= literal.value: |
| 1539 | + return ROWS_MUST_MATCH |
| 1540 | + |
| 1541 | + return ROWS_MIGHT_NOT_MATCH |
| 1542 | + |
| 1543 | + def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: |
| 1544 | + # Rows must match when: <-------X---Min----Max----------> |
| 1545 | + |
| 1546 | + field_id = term.ref().field.field_id |
| 1547 | + |
| 1548 | + if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id): |
| 1549 | + return ROWS_MIGHT_NOT_MATCH |
| 1550 | + |
| 1551 | + if lower_bytes := self.lower_bounds.get(field_id): |
| 1552 | + field = self._get_field(field_id) |
| 1553 | + lower = _from_byte_buffer(field.field_type, lower_bytes) |
| 1554 | + |
| 1555 | + if self._is_nan(lower): |
| 1556 | + # NaN indicates unreliable bounds. |
| 1557 | + # See the _StrictMetricsEvaluator docs for more. |
| 1558 | + return ROWS_MIGHT_NOT_MATCH |
| 1559 | + |
| 1560 | + if lower > literal.value: |
| 1561 | + return ROWS_MUST_MATCH |
| 1562 | + |
| 1563 | + return ROWS_MIGHT_NOT_MATCH |
| 1564 | + |
| 1565 | + def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: |
| 1566 | + # Rows must match when: <-------X---Min----Max----------> |
| 1567 | + field_id = term.ref().field.field_id |
| 1568 | + |
| 1569 | + if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id): |
| 1570 | + return ROWS_MIGHT_NOT_MATCH |
| 1571 | + |
| 1572 | + if lower_bytes := self.lower_bounds.get(field_id): |
| 1573 | + field = self._get_field(field_id) |
| 1574 | + lower = _from_byte_buffer(field.field_type, lower_bytes) |
| 1575 | + |
| 1576 | + if self._is_nan(lower): |
| 1577 | + # NaN indicates unreliable bounds. |
| 1578 | + # See the _StrictMetricsEvaluator docs for more. |
| 1579 | + return ROWS_MIGHT_NOT_MATCH |
| 1580 | + |
| 1581 | + if lower >= literal.value: |
| 1582 | + return ROWS_MUST_MATCH |
| 1583 | + |
| 1584 | + return ROWS_MIGHT_NOT_MATCH |
| 1585 | + |
| 1586 | + def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: |
| 1587 | + # Rows must match when Min == X == Max |
| 1588 | + field_id = term.ref().field.field_id |
| 1589 | + |
| 1590 | + if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id): |
| 1591 | + return ROWS_MIGHT_NOT_MATCH |
| 1592 | + |
| 1593 | + if (lower_bytes := self.lower_bounds.get(field_id)) and (upper_bytes := self.upper_bounds.get(field_id)): |
| 1594 | + field = self._get_field(field_id) |
| 1595 | + lower = _from_byte_buffer(field.field_type, lower_bytes) |
| 1596 | + upper = _from_byte_buffer(field.field_type, upper_bytes) |
| 1597 | + |
| 1598 | + if lower != literal.value or upper != literal.value: |
| 1599 | + return ROWS_MIGHT_NOT_MATCH |
| 1600 | + else: |
| 1601 | + return ROWS_MUST_MATCH |
| 1602 | + |
| 1603 | + return ROWS_MIGHT_NOT_MATCH |
| 1604 | + |
| 1605 | + def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: |
| 1606 | + # Rows must match when X < Min or Max < X because it is not in the range |
| 1607 | + field_id = term.ref().field.field_id |
| 1608 | + |
| 1609 | + if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id): |
| 1610 | + return ROWS_MUST_MATCH |
| 1611 | + |
| 1612 | + field = self._get_field(field_id) |
| 1613 | + |
| 1614 | + if lower_bytes := self.lower_bounds.get(field_id): |
| 1615 | + lower = _from_byte_buffer(field.field_type, lower_bytes) |
| 1616 | + |
| 1617 | + if self._is_nan(lower): |
| 1618 | + # NaN indicates unreliable bounds. |
| 1619 | + # See the _StrictMetricsEvaluator docs for more. |
| 1620 | + return ROWS_MIGHT_NOT_MATCH |
| 1621 | + |
| 1622 | + if lower > literal.value: |
| 1623 | + return ROWS_MUST_MATCH |
| 1624 | + |
| 1625 | + if upper_bytes := self.upper_bounds.get(field_id): |
| 1626 | + upper = _from_byte_buffer(field.field_type, upper_bytes) |
| 1627 | + |
| 1628 | + if upper < literal.value: |
| 1629 | + return ROWS_MUST_MATCH |
| 1630 | + |
| 1631 | + return ROWS_MIGHT_NOT_MATCH |
| 1632 | + |
| 1633 | + def visit_in(self, term: BoundTerm[L], literals: Set[L]) -> bool: |
| 1634 | + field_id = term.ref().field.field_id |
| 1635 | + |
| 1636 | + if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id): |
| 1637 | + return ROWS_MIGHT_NOT_MATCH |
| 1638 | + |
| 1639 | + field = self._get_field(field_id) |
| 1640 | + |
| 1641 | + if (lower_bytes := self.lower_bounds.get(field_id)) and (upper_bytes := self.upper_bounds.get(field_id)): |
| 1642 | + # similar to the implementation in eq, first check if the lower bound is in the set |
| 1643 | + lower = _from_byte_buffer(field.field_type, lower_bytes) |
| 1644 | + if lower not in literals: |
| 1645 | + return ROWS_MIGHT_NOT_MATCH |
| 1646 | + |
| 1647 | + # check if the upper bound is in the set |
| 1648 | + upper = _from_byte_buffer(field.field_type, upper_bytes) |
| 1649 | + if upper not in literals: |
| 1650 | + return ROWS_MIGHT_NOT_MATCH |
| 1651 | + |
| 1652 | + # finally check if the lower bound and the upper bound are equal |
| 1653 | + if lower != upper: |
| 1654 | + return ROWS_MIGHT_NOT_MATCH |
| 1655 | + |
| 1656 | + # All values must be in the set if the lower bound and the upper bound are |
| 1657 | + # in the set and are equal. |
| 1658 | + return ROWS_MUST_MATCH |
| 1659 | + |
| 1660 | + return ROWS_MIGHT_NOT_MATCH |
| 1661 | + |
| 1662 | + def visit_not_in(self, term: BoundTerm[L], literals: Set[L]) -> bool: |
| 1663 | + field_id = term.ref().field.field_id |
| 1664 | + |
| 1665 | + if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id): |
| 1666 | + return ROWS_MUST_MATCH |
| 1667 | + |
| 1668 | + field = self._get_field(field_id) |
| 1669 | + |
| 1670 | + if lower_bytes := self.lower_bounds.get(field_id): |
| 1671 | + lower = _from_byte_buffer(field.field_type, lower_bytes) |
| 1672 | + |
| 1673 | + if self._is_nan(lower): |
| 1674 | + # NaN indicates unreliable bounds. |
| 1675 | + # See the StrictMetricsEvaluator docs for more. |
| 1676 | + return ROWS_MIGHT_NOT_MATCH |
| 1677 | + |
| 1678 | + literals = {val for val in literals if lower <= val} |
| 1679 | + if len(literals) == 0: |
| 1680 | + return ROWS_MUST_MATCH |
| 1681 | + |
| 1682 | + if upper_bytes := self.upper_bounds.get(field_id): |
| 1683 | + upper = _from_byte_buffer(field.field_type, upper_bytes) |
| 1684 | + |
| 1685 | + literals = {val for val in literals if upper >= val} |
| 1686 | + |
| 1687 | + if len(literals) == 0: |
| 1688 | + return ROWS_MUST_MATCH |
| 1689 | + |
| 1690 | + return ROWS_MIGHT_NOT_MATCH |
| 1691 | + |
| 1692 | + def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: |
| 1693 | + return ROWS_MIGHT_NOT_MATCH |
| 1694 | + |
| 1695 | + def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: |
| 1696 | + return ROWS_MIGHT_NOT_MATCH |
| 1697 | + |
| 1698 | + def _get_field(self, field_id: int) -> NestedField: |
| 1699 | + field = self.struct.field(field_id=field_id) |
| 1700 | + if field is None: |
| 1701 | + raise ValueError(f"Cannot find field, might be nested or missing: {field_id}") |
| 1702 | + |
| 1703 | + return field |
| 1704 | + |
| 1705 | + def _can_contain_nulls(self, field_id: int) -> bool: |
| 1706 | + return (null_count := self.null_counts.get(field_id)) is not None and null_count > 0 |
| 1707 | + |
| 1708 | + def _can_contain_nans(self, field_id: int) -> bool: |
| 1709 | + return (nan_count := self.nan_counts.get(field_id)) is not None and nan_count > 0 |
0 commit comments