@@ -860,9 +860,7 @@ class _ColumnNameTranslator(BooleanExpressionVisitor[BooleanExpression]):
860
860
861
861
Args:
862
862
file_schema (Schema): The schema of the file.
863
- projected_schema (Schema): The schema to project onto the data files.
864
863
case_sensitive (bool): Whether to consider case when binding a reference to a field in a schema, defaults to True.
865
- projected_missing_fields(dict[str, Any]): Map of fields missing in file_schema, but present as partition values.
866
864
867
865
Raises:
868
866
TypeError: In the case of an UnboundPredicate.
@@ -872,13 +870,9 @@ class _ColumnNameTranslator(BooleanExpressionVisitor[BooleanExpression]):
872
870
file_schema : Schema
873
871
case_sensitive : bool
874
872
875
- def __init__ (
876
- self , file_schema : Schema , projected_schema : Schema , case_sensitive : bool , projected_missing_fields : dict [str , Any ]
877
- ) -> None :
873
+ def __init__ (self , file_schema : Schema , case_sensitive : bool ) -> None :
878
874
self .file_schema = file_schema
879
- self .projected_schema = projected_schema
880
875
self .case_sensitive = case_sensitive
881
- self .projected_missing_fields = projected_missing_fields
882
876
883
877
def visit_true (self ) -> BooleanExpression :
884
878
return AlwaysTrue ()
@@ -906,24 +900,6 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi
906
900
# in the file schema when reading older data
907
901
if isinstance (predicate , BoundIsNull ):
908
902
return AlwaysTrue ()
909
- # Evaluate projected field by value extracted from partition
910
- elif (field_name := predicate .term .ref ().field .name ) in self .projected_missing_fields :
911
- unbound_predicate : BooleanExpression
912
- if isinstance (predicate , BoundUnaryPredicate ):
913
- unbound_predicate = predicate .as_unbound (field_name )
914
- elif isinstance (predicate , BoundLiteralPredicate ):
915
- unbound_predicate = predicate .as_unbound (field_name , predicate .literal )
916
- elif isinstance (predicate , BoundSetPredicate ):
917
- unbound_predicate = predicate .as_unbound (field_name , predicate .literals )
918
- else :
919
- raise ValueError (f"Unsupported predicate: { predicate } " )
920
- field = self .projected_schema .find_field (field_name )
921
- schema = Schema (field )
922
- evaluator = expression_evaluator (schema , unbound_predicate , self .case_sensitive )
923
- if evaluator (Record (self .projected_missing_fields [field_name ])):
924
- return AlwaysTrue ()
925
- else :
926
- return AlwaysFalse ()
927
903
else :
928
904
return AlwaysFalse ()
929
905
@@ -937,14 +913,84 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi
937
913
raise ValueError (f"Unsupported predicate: { predicate } " )
938
914
939
915
940
- def translate_column_names (
916
+ def translate_column_names (expr : BooleanExpression , file_schema : Schema , case_sensitive : bool ) -> BooleanExpression :
917
+ return visit (expr , _ColumnNameTranslator (file_schema , case_sensitive ))
918
+
919
+
920
+ class _ProjectedColumnsEvaluator (BooleanExpressionVisitor [BooleanExpression ]):
921
+ """Evaluated predicates which involve projected columns missing from the file.
922
+
923
+ Args:
924
+ file_schema (Schema): The schema of the file.
925
+ projected_schema (Schema): The schema to project onto the data files.
926
+ case_sensitive (bool): Whether to consider case when binding a reference to a field in a schema, defaults to True.
927
+ projected_missing_fields(dict[str, Any]): Map of fields missing in file_schema, but present as partition values.
928
+
929
+ Raises:
930
+ TypeError: In the case of an UnboundPredicate.
931
+ """
932
+
933
+ file_schema : Schema
934
+ case_sensitive : bool
935
+
936
+ def __init__ (
937
+ self , file_schema : Schema , projected_schema : Schema , case_sensitive : bool , projected_missing_fields : dict [str , Any ]
938
+ ) -> None :
939
+ self .file_schema = file_schema
940
+ self .projected_schema = projected_schema
941
+ self .case_sensitive = case_sensitive
942
+ self .projected_missing_fields = projected_missing_fields
943
+
944
+ def visit_true (self ) -> BooleanExpression :
945
+ return AlwaysTrue ()
946
+
947
+ def visit_false (self ) -> BooleanExpression :
948
+ return AlwaysFalse ()
949
+
950
+ def visit_not (self , child_result : BooleanExpression ) -> BooleanExpression :
951
+ return Not (child = child_result )
952
+
953
+ def visit_and (self , left_result : BooleanExpression , right_result : BooleanExpression ) -> BooleanExpression :
954
+ return And (left = left_result , right = right_result )
955
+
956
+ def visit_or (self , left_result : BooleanExpression , right_result : BooleanExpression ) -> BooleanExpression :
957
+ return Or (left = left_result , right = right_result )
958
+
959
+ def visit_unbound_predicate (self , predicate : UnboundPredicate [L ]) -> BooleanExpression :
960
+ raise TypeError (f"Expected Bound Predicate, got: { predicate .term } " )
961
+
962
+ def visit_bound_predicate (self , predicate : BoundPredicate [L ]) -> BooleanExpression :
963
+ file_column_name = self .file_schema .find_column_name (predicate .term .ref ().field .field_id )
964
+
965
+ if file_column_name is None and (field_name := predicate .term .ref ().field .name ) in self .projected_missing_fields :
966
+ unbound_predicate : BooleanExpression
967
+ if isinstance (predicate , BoundUnaryPredicate ):
968
+ unbound_predicate = predicate .as_unbound (field_name )
969
+ elif isinstance (predicate , BoundLiteralPredicate ):
970
+ unbound_predicate = predicate .as_unbound (field_name , predicate .literal )
971
+ elif isinstance (predicate , BoundSetPredicate ):
972
+ unbound_predicate = predicate .as_unbound (field_name , predicate .literals )
973
+ else :
974
+ raise ValueError (f"Unsupported predicate: { predicate } " )
975
+ field = self .projected_schema .find_field (field_name )
976
+ schema = Schema (field )
977
+ evaluator = expression_evaluator (schema , unbound_predicate , self .case_sensitive )
978
+ if evaluator (Record (self .projected_missing_fields [field_name ])):
979
+ return AlwaysTrue ()
980
+ else :
981
+ return AlwaysFalse ()
982
+
983
+ return predicate
984
+
985
+
986
+ def evaluate_projected_columns (
941
987
expr : BooleanExpression ,
942
988
file_schema : Schema ,
943
989
projected_schema : Schema ,
944
990
case_sensitive : bool ,
945
991
projected_missing_fields : dict [str , Any ],
946
992
) -> BooleanExpression :
947
- return visit (expr , _ColumnNameTranslator (file_schema , projected_schema , case_sensitive , projected_missing_fields ))
993
+ return visit (expr , _ProjectedColumnsEvaluator (file_schema , projected_schema , case_sensitive , projected_missing_fields ))
948
994
949
995
950
996
class _ExpressionFieldIDs (BooleanExpressionVisitor [Set [int ]]):
0 commit comments