Skip to content

Commit a93b4eb

Browse files
author
Roman Shanin
committed
extract projected columns evaluator into separate visitor
1 parent befa05c commit a93b4eb

File tree

3 files changed

+85
-36
lines changed

3 files changed

+85
-36
lines changed

pyiceberg/expressions/visitors.py

Lines changed: 73 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -860,9 +860,7 @@ class _ColumnNameTranslator(BooleanExpressionVisitor[BooleanExpression]):
860860
861861
Args:
862862
file_schema (Schema): The schema of the file.
863-
projected_schema (Schema): The schema to project onto the data files.
864863
case_sensitive (bool): Whether to consider case when binding a reference to a field in a schema, defaults to True.
865-
projected_missing_fields(dict[str, Any]): Map of fields missing in file_schema, but present as partition values.
866864
867865
Raises:
868866
TypeError: In the case of an UnboundPredicate.
@@ -872,13 +870,9 @@ class _ColumnNameTranslator(BooleanExpressionVisitor[BooleanExpression]):
872870
file_schema: Schema
873871
case_sensitive: bool
874872

875-
def __init__(
876-
self, file_schema: Schema, projected_schema: Schema, case_sensitive: bool, projected_missing_fields: dict[str, Any]
877-
) -> None:
873+
def __init__(self, file_schema: Schema, case_sensitive: bool) -> None:
878874
self.file_schema = file_schema
879-
self.projected_schema = projected_schema
880875
self.case_sensitive = case_sensitive
881-
self.projected_missing_fields = projected_missing_fields
882876

883877
def visit_true(self) -> BooleanExpression:
884878
return AlwaysTrue()
@@ -906,24 +900,6 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi
906900
# in the file schema when reading older data
907901
if isinstance(predicate, BoundIsNull):
908902
return AlwaysTrue()
909-
# Evaluate projected field by value extracted from partition
910-
elif (field_name := predicate.term.ref().field.name) in self.projected_missing_fields:
911-
unbound_predicate: BooleanExpression
912-
if isinstance(predicate, BoundUnaryPredicate):
913-
unbound_predicate = predicate.as_unbound(field_name)
914-
elif isinstance(predicate, BoundLiteralPredicate):
915-
unbound_predicate = predicate.as_unbound(field_name, predicate.literal)
916-
elif isinstance(predicate, BoundSetPredicate):
917-
unbound_predicate = predicate.as_unbound(field_name, predicate.literals)
918-
else:
919-
raise ValueError(f"Unsupported predicate: {predicate}")
920-
field = self.projected_schema.find_field(field_name)
921-
schema = Schema(field)
922-
evaluator = expression_evaluator(schema, unbound_predicate, self.case_sensitive)
923-
if evaluator(Record(self.projected_missing_fields[field_name])):
924-
return AlwaysTrue()
925-
else:
926-
return AlwaysFalse()
927903
else:
928904
return AlwaysFalse()
929905

@@ -937,14 +913,84 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi
937913
raise ValueError(f"Unsupported predicate: {predicate}")
938914

939915

940-
def translate_column_names(
916+
def translate_column_names(expr: BooleanExpression, file_schema: Schema, case_sensitive: bool) -> BooleanExpression:
917+
return visit(expr, _ColumnNameTranslator(file_schema, case_sensitive))
918+
919+
920+
class _ProjectedColumnsEvaluator(BooleanExpressionVisitor[BooleanExpression]):
921+
"""Evaluated predicates which involve projected columns missing from the file.
922+
923+
Args:
924+
file_schema (Schema): The schema of the file.
925+
projected_schema (Schema): The schema to project onto the data files.
926+
case_sensitive (bool): Whether to consider case when binding a reference to a field in a schema, defaults to True.
927+
projected_missing_fields(dict[str, Any]): Map of fields missing in file_schema, but present as partition values.
928+
929+
Raises:
930+
TypeError: In the case of an UnboundPredicate.
931+
"""
932+
933+
file_schema: Schema
934+
case_sensitive: bool
935+
936+
def __init__(
937+
self, file_schema: Schema, projected_schema: Schema, case_sensitive: bool, projected_missing_fields: dict[str, Any]
938+
) -> None:
939+
self.file_schema = file_schema
940+
self.projected_schema = projected_schema
941+
self.case_sensitive = case_sensitive
942+
self.projected_missing_fields = projected_missing_fields
943+
944+
def visit_true(self) -> BooleanExpression:
945+
return AlwaysTrue()
946+
947+
def visit_false(self) -> BooleanExpression:
948+
return AlwaysFalse()
949+
950+
def visit_not(self, child_result: BooleanExpression) -> BooleanExpression:
951+
return Not(child=child_result)
952+
953+
def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression:
954+
return And(left=left_result, right=right_result)
955+
956+
def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression:
957+
return Or(left=left_result, right=right_result)
958+
959+
def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpression:
960+
raise TypeError(f"Expected Bound Predicate, got: {predicate.term}")
961+
962+
def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpression:
963+
file_column_name = self.file_schema.find_column_name(predicate.term.ref().field.field_id)
964+
965+
if file_column_name is None and (field_name := predicate.term.ref().field.name) in self.projected_missing_fields:
966+
unbound_predicate: BooleanExpression
967+
if isinstance(predicate, BoundUnaryPredicate):
968+
unbound_predicate = predicate.as_unbound(field_name)
969+
elif isinstance(predicate, BoundLiteralPredicate):
970+
unbound_predicate = predicate.as_unbound(field_name, predicate.literal)
971+
elif isinstance(predicate, BoundSetPredicate):
972+
unbound_predicate = predicate.as_unbound(field_name, predicate.literals)
973+
else:
974+
raise ValueError(f"Unsupported predicate: {predicate}")
975+
field = self.projected_schema.find_field(field_name)
976+
schema = Schema(field)
977+
evaluator = expression_evaluator(schema, unbound_predicate, self.case_sensitive)
978+
if evaluator(Record(self.projected_missing_fields[field_name])):
979+
return AlwaysTrue()
980+
else:
981+
return AlwaysFalse()
982+
983+
return predicate
984+
985+
986+
def evaluate_projected_columns(
941987
expr: BooleanExpression,
942988
file_schema: Schema,
943989
projected_schema: Schema,
944990
case_sensitive: bool,
945991
projected_missing_fields: dict[str, Any],
946992
) -> BooleanExpression:
947-
return visit(expr, _ColumnNameTranslator(file_schema, projected_schema, case_sensitive, projected_missing_fields))
993+
return visit(expr, _ProjectedColumnsEvaluator(file_schema, projected_schema, case_sensitive, projected_missing_fields))
948994

949995

950996
class _ExpressionFieldIDs(BooleanExpressionVisitor[Set[int]]):

pyiceberg/io/pyarrow.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
from pyiceberg.expressions.visitors import (
7979
BoundBooleanExpressionVisitor,
8080
bind,
81+
evaluate_projected_columns,
8182
extract_field_ids,
8283
translate_column_names,
8384
)
@@ -1466,13 +1467,18 @@ def _task_to_record_batches(
14661467

14671468
pyarrow_filter = None
14681469
if bound_row_filter is not AlwaysTrue():
1469-
translated_row_filter = translate_column_names(
1470+
evaluated_projected_columns_filter = evaluate_projected_columns(
14701471
bound_row_filter,
14711472
file_schema,
14721473
projected_schema,
14731474
case_sensitive=case_sensitive,
14741475
projected_missing_fields=projected_missing_fields,
14751476
)
1477+
translated_row_filter = translate_column_names(
1478+
evaluated_projected_columns_filter,
1479+
file_schema,
1480+
case_sensitive=case_sensitive,
1481+
)
14761482
bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive)
14771483
pyarrow_filter = expression_to_pyarrow(bound_file_filter)
14781484

tests/expressions/test_visitors.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@
6969
BoundBooleanExpressionVisitor,
7070
_ManifestEvalVisitor,
7171
bind,
72+
evaluate_projected_columns,
7273
expression_evaluator,
7374
expression_to_plain_format,
7475
rewrite_not,
7576
rewrite_to_dnf,
76-
translate_column_names,
7777
visit,
7878
visit_bound_predicate,
7979
)
@@ -1652,13 +1652,10 @@ def test_expression_evaluator_null() -> None:
16521652
),
16531653
],
16541654
)
1655-
def test_translate_column_names_eval_projected_fields(
1656-
schema: Schema, before_expression: BooleanExpression, after_expression: BooleanExpression
1657-
) -> None:
1655+
def test_eval_projected_fields(schema: Schema, before_expression: BooleanExpression, after_expression: BooleanExpression) -> None:
16581656
# exclude id from file_schema pretending that it's part of partition values
16591657
file_schema = Schema(*[field for field in schema.columns if field.name != "id"])
16601658
projected_missing_fields = {"id": 1}
1661-
assert (
1662-
translate_column_names(bind(schema, before_expression, True), file_schema, schema, True, projected_missing_fields)
1663-
== after_expression
1664-
)
1659+
assert evaluate_projected_columns(
1660+
bind(schema, before_expression, True), file_schema, schema, True, projected_missing_fields
1661+
) == bind(schema, after_expression, True)

0 commit comments

Comments
 (0)