Skip to content

Commit

Permalink
Split out field check
Browse files Browse the repository at this point in the history
  • Loading branch information
Fokko committed Mar 15, 2024
1 parent c36e6dc commit b804bb9
Showing 1 changed file with 16 additions and 35 deletions.
51 changes: 16 additions & 35 deletions pyiceberg/expressions/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
DoubleType,
FloatType,
IcebergType,
NestedField,
PrimitiveType,
StructType,
TimestampType,
Expand Down Expand Up @@ -1470,9 +1471,6 @@ def visit_is_null(self, term: BoundTerm[L]) -> bool:
# no need to check whether the field is required because binding evaluates that case
# if the column has any non-null values, the expression does not match
field_id = term.ref().field.field_id
field = self.struct.field(field_id=field_id)
if field is None:
raise ValueError(f"Cannot find field, might be nested or missing: {field_id}")

if self._contains_nulls_only(field_id):
return ROWS_MUST_MATCH
Expand All @@ -1483,9 +1481,6 @@ def visit_not_null(self, term: BoundTerm[L]) -> bool:
# no need to check whether the field is required because binding evaluates that case
# if the column has any non-null values, the expression does not match
field_id = term.ref().field.field_id
field = self.struct.field(field_id=field_id)
if field is None:
raise ValueError(f"Cannot find field, might be nested or missing: {field_id}")

if (null_count := self.null_counts.get(field_id)) is not None and null_count == 0:
return ROWS_MUST_MATCH
Expand Down Expand Up @@ -1520,10 +1515,7 @@ def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
return ROWS_MIGHT_NOT_MATCH

if upper_bytes := self.upper_bounds.get(field_id):
field = self.struct.field(field_id=field_id)
if field is None:
raise ValueError(f"Cannot find field, might be nested or missing: {field_id}")

field = self._get_field(field_id)
upper = _from_byte_buffer(field.field_type, upper_bytes)

if upper < literal.value:
Expand All @@ -1540,10 +1532,7 @@ def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> b
return ROWS_MIGHT_NOT_MATCH

if upper_bytes := self.upper_bounds.get(field_id):
field = self.struct.field(field_id=field_id)
if field is None:
raise ValueError(f"Cannot find field, might be nested or missing: {field_id}")

field = self._get_field(field_id)
upper = _from_byte_buffer(field.field_type, upper_bytes)

if upper <= literal.value:
Expand All @@ -1560,10 +1549,7 @@ def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
return ROWS_MIGHT_NOT_MATCH

if lower_bytes := self.lower_bounds.get(field_id):
field = self.struct.field(field_id=field_id)
if field is None:
raise ValueError(f"Cannot find field, might be nested or missing: {field_id}")

field = self._get_field(field_id)
lower = _from_byte_buffer(field.field_type, lower_bytes)

if self._is_nan(lower):
Expand All @@ -1584,10 +1570,7 @@ def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -
return ROWS_MIGHT_NOT_MATCH

if lower_bytes := self.lower_bounds.get(field_id):
field = self.struct.field(field_id=field_id)
if field is None:
raise ValueError(f"Cannot find field, might be nested or missing: {field_id}")

field = self._get_field(field_id)
lower = _from_byte_buffer(field.field_type, lower_bytes)

if self._is_nan(lower):
Expand All @@ -1608,10 +1591,7 @@ def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
return ROWS_MIGHT_NOT_MATCH

if (lower_bytes := self.lower_bounds.get(field_id)) and (upper_bytes := self.upper_bounds.get(field_id)):
field = self.struct.field(field_id=field_id)
if field is None:
raise ValueError(f"Cannot find field, might be nested or missing: {field_id}")

field = self._get_field(field_id)
lower = _from_byte_buffer(field.field_type, lower_bytes)
upper = _from_byte_buffer(field.field_type, upper_bytes)

Expand All @@ -1629,9 +1609,7 @@ def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
return ROWS_MUST_MATCH

field = self.struct.field(field_id=field_id)
if field is None:
raise ValueError(f"Cannot find field, might be nested or missing: {field_id}")
field = self._get_field(field_id)

if lower_bytes := self.lower_bounds.get(field_id):
lower = _from_byte_buffer(field.field_type, lower_bytes)
Expand All @@ -1658,9 +1636,7 @@ def visit_in(self, term: BoundTerm[L], literals: Set[L]) -> bool:
if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
return ROWS_MIGHT_NOT_MATCH

field = self.struct.field(field_id=field_id)
if field is None:
raise ValueError(f"Cannot find field, might be nested or missing: {field_id}")
field = self._get_field(field_id)

if (lower_bytes := self.lower_bounds.get(field_id)) and (upper_bytes := self.upper_bounds.get(field_id)):
# similar to the implementation in eq, first check if the lower bound is in the set
Expand Down Expand Up @@ -1689,9 +1665,7 @@ def visit_not_in(self, term: BoundTerm[L], literals: Set[L]) -> bool:
if self._can_contain_nulls(field_id) or self._can_contain_nans(field_id):
return ROWS_MUST_MATCH

field = self.struct.field(field_id=field_id)
if field is None:
raise ValueError(f"Cannot find field, might be nested or missing: {field_id}")
field = self._get_field(field_id)

if lower_bytes := self.lower_bounds.get(field_id):
lower = _from_byte_buffer(field.field_type, lower_bytes)
Expand Down Expand Up @@ -1721,6 +1695,13 @@ def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool:
return ROWS_MIGHT_NOT_MATCH

def _get_field(self, field_id: int) -> NestedField:
field = self.struct.field(field_id=field_id)
if field is None:
raise ValueError(f"Cannot find field, might be nested or missing: {field_id}")

return field

def _can_contain_nulls(self, field_id: int) -> bool:
return (null_count := self.null_counts.get(field_id)) is not None and null_count > 0

Expand Down

0 comments on commit b804bb9

Please sign in to comment.