diff --git a/pyiceberg/expressions/visitors.py b/pyiceberg/expressions/visitors.py index 244e4d2e87..9329b9940d 100644 --- a/pyiceberg/expressions/visitors.py +++ b/pyiceberg/expressions/visitors.py @@ -67,6 +67,7 @@ DoubleType, FloatType, IcebergType, + NestedField, PrimitiveType, StructType, TimestampType, @@ -1470,9 +1471,6 @@ def visit_is_null(self, term: BoundTerm[L]) -> bool: # no need to check whether the field is required because binding evaluates that case # if the column has any non-null values, the expression does not match field_id = term.ref().field.field_id - field = self.struct.field(field_id=field_id) - if field is None: - raise ValueError(f"Cannot find field, might be nested or missing: {field_id}") if self._contains_nulls_only(field_id): return ROWS_MUST_MATCH @@ -1483,9 +1481,6 @@ def visit_not_null(self, term: BoundTerm[L]) -> bool: # no need to check whether the field is required because binding evaluates that case # if the column has any non-null values, the expression does not match field_id = term.ref().field.field_id - field = self.struct.field(field_id=field_id) - if field is None: - raise ValueError(f"Cannot find field, might be nested or missing: {field_id}") if (null_count := self.null_counts.get(field_id)) is not None and null_count == 0: return ROWS_MUST_MATCH @@ -1520,10 +1515,7 @@ def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_NOT_MATCH if upper_bytes := self.upper_bounds.get(field_id): - field = self.struct.field(field_id=field_id) - if field is None: - raise ValueError(f"Cannot find field, might be nested or missing: {field_id}") - + field = self._get_field(field_id) upper = _from_byte_buffer(field.field_type, upper_bytes) if upper < literal.value: @@ -1540,10 +1532,7 @@ def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> b return ROWS_MIGHT_NOT_MATCH if upper_bytes := self.upper_bounds.get(field_id): - field = self.struct.field(field_id=field_id) - if field is None: - raise ValueError(f"Cannot find field, might be nested or missing: {field_id}") - + field = self._get_field(field_id) upper = _from_byte_buffer(field.field_type, upper_bytes) if upper <= literal.value: @@ -1560,10 +1549,7 @@ def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_NOT_MATCH if lower_bytes := self.lower_bounds.get(field_id): - field = self.struct.field(field_id=field_id) - if field is None: - raise ValueError(f"Cannot find field, might be nested or missing: {field_id}") - + field = self._get_field(field_id) lower = _from_byte_buffer(field.field_type, lower_bytes) if self._is_nan(lower): @@ -1584,10 +1570,7 @@ def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) - return ROWS_MIGHT_NOT_MATCH if lower_bytes := self.lower_bounds.get(field_id): - field = self.struct.field(field_id=field_id) - if field is None: - raise ValueError(f"Cannot find field, might be nested or missing: {field_id}") - + field = self._get_field(field_id) lower = _from_byte_buffer(field.field_type, lower_bytes) if self._is_nan(lower): @@ -1608,10 +1591,7 @@ def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_NOT_MATCH if (lower_bytes := self.lower_bounds.get(field_id)) and (upper_bytes := self.upper_bounds.get(field_id)): - field = self.struct.field(field_id=field_id) - if field is None: - raise ValueError(f"Cannot find field, might be nested or missing: {field_id}") - + field = self._get_field(field_id) lower = _from_byte_buffer(field.field_type, lower_bytes) upper = _from_byte_buffer(field.field_type, upper_bytes) @@ -1629,9 +1609,7 @@ def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id): return ROWS_MUST_MATCH - field = self.struct.field(field_id=field_id) - if field is None: - raise ValueError(f"Cannot find field, might be nested or missing: {field_id}") + field = self._get_field(field_id) if lower_bytes := self.lower_bounds.get(field_id): lower = _from_byte_buffer(field.field_type, lower_bytes) @@ -1658,9 +1636,7 @@ def visit_in(self, term: BoundTerm[L], literals: Set[L]) -> bool: if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id): return ROWS_MIGHT_NOT_MATCH - field = self.struct.field(field_id=field_id) - if field is None: - raise ValueError(f"Cannot find field, might be nested or missing: {field_id}") + field = self._get_field(field_id) if (lower_bytes := self.lower_bounds.get(field_id)) and (upper_bytes := self.upper_bounds.get(field_id)): # similar to the implementation in eq, first check if the lower bound is in the set @@ -1689,9 +1665,7 @@ def visit_not_in(self, term: BoundTerm[L], literals: Set[L]) -> bool: if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id): return ROWS_MUST_MATCH - field = self.struct.field(field_id=field_id) - if field is None: - raise ValueError(f"Cannot find field, might be nested or missing: {field_id}") + field = self._get_field(field_id) if lower_bytes := self.lower_bounds.get(field_id): lower = _from_byte_buffer(field.field_type, lower_bytes) @@ -1721,6 +1695,13 @@ def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: return ROWS_MIGHT_NOT_MATCH + def _get_field(self, field_id: int) -> NestedField: + field = self.struct.field(field_id=field_id) + if field is None: + raise ValueError(f"Cannot find field, might be nested or missing: {field_id}") + + return field + def _can_contain_nulls(self, field_id: int) -> bool: return (null_count := self.null_counts.get(field_id)) is not None and null_count > 0