16
16
// under the License.
17
17
18
18
use crate :: equivalence:: EquivalentClass ;
19
- use crate :: expressions:: { BinaryExpr , Column , UnKnownColumn } ;
19
+ use crate :: expressions:: { BinaryExpr , Column , InListExpr , UnKnownColumn } ;
20
20
use crate :: {
21
21
EquivalenceProperties , PhysicalExpr , PhysicalSortExpr , PhysicalSortRequirement ,
22
22
} ;
@@ -573,8 +573,10 @@ pub fn reassign_predicate_columns(
573
573
schema : & SchemaRef ,
574
574
ignore_not_found : bool ,
575
575
) -> Result < Arc < dyn PhysicalExpr > > {
576
- pred. transform ( & |expr| {
577
- if let Some ( column) = expr. as_any ( ) . downcast_ref :: < Column > ( ) {
576
+ pred. transform_down ( & |expr| {
577
+ let expr_any = expr. as_any ( ) ;
578
+
579
+ if let Some ( column) = expr_any. downcast_ref :: < Column > ( ) {
578
580
let index = match schema. index_of ( column. name ( ) ) {
579
581
Ok ( idx) => idx,
580
582
Err ( _) if ignore_not_found => usize:: MAX ,
@@ -584,6 +586,26 @@ pub fn reassign_predicate_columns(
584
586
column. name ( ) ,
585
587
index,
586
588
) ) ) ) ;
589
+ } else if let Some ( in_list) = expr_any. downcast_ref :: < InListExpr > ( ) {
590
+ // transform child first
591
+ let expr = reassign_predicate_columns (
592
+ in_list. expr ( ) . clone ( ) ,
593
+ schema,
594
+ ignore_not_found,
595
+ ) ?;
596
+ let list = in_list
597
+ . list ( )
598
+ . iter ( )
599
+ . map ( |expr| {
600
+ reassign_predicate_columns ( expr. clone ( ) , schema, ignore_not_found)
601
+ } )
602
+ . collect :: < Result < Vec < _ > > > ( ) ?;
603
+ return Ok ( Transformed :: Yes ( Arc :: new ( InListExpr :: new (
604
+ expr,
605
+ list,
606
+ in_list. negated ( ) ,
607
+ schema. as_ref ( ) ,
608
+ ) ) ) ) ;
587
609
}
588
610
589
611
Ok ( Transformed :: No ( expr) )
@@ -593,7 +615,7 @@ pub fn reassign_predicate_columns(
593
615
#[ cfg( test) ]
594
616
mod tests {
595
617
use super :: * ;
596
- use crate :: expressions:: { binary, cast, col, lit, Column , Literal } ;
618
+ use crate :: expressions:: { binary, cast, col, in_list , lit, Column , Literal } ;
597
619
use crate :: PhysicalSortExpr ;
598
620
use arrow:: compute:: SortOptions ;
599
621
use datafusion_common:: { Result , ScalarValue } ;
@@ -918,4 +940,41 @@ mod tests {
918
940
} ) ) ;
919
941
Ok ( ( ) )
920
942
}
943
+
944
+ #[ test]
945
+ fn test_reassign_predicate_columns_in_list ( ) {
946
+ let int_field = Field :: new ( "should_not_matter" , DataType :: Int64 , true ) ;
947
+ let dict_field = Field :: new (
948
+ "id" ,
949
+ DataType :: Dictionary ( Box :: new ( DataType :: Int32 ) , Box :: new ( DataType :: Utf8 ) ) ,
950
+ true ,
951
+ ) ;
952
+ let schema_small = Arc :: new ( Schema :: new ( vec ! [ dict_field. clone( ) ] ) ) ;
953
+ let schema_big = Arc :: new ( Schema :: new ( vec ! [ int_field, dict_field] ) ) ;
954
+ let pred = in_list (
955
+ Arc :: new ( Column :: new_with_schema ( "id" , & schema_big) . unwrap ( ) ) ,
956
+ vec ! [ lit( ScalarValue :: Dictionary (
957
+ Box :: new( DataType :: Int32 ) ,
958
+ Box :: new( ScalarValue :: from( "2" ) ) ,
959
+ ) ) ] ,
960
+ & false ,
961
+ & schema_big,
962
+ )
963
+ . unwrap ( ) ;
964
+
965
+ let actual = reassign_predicate_columns ( pred, & schema_small, false ) . unwrap ( ) ;
966
+
967
+ let expected = in_list (
968
+ Arc :: new ( Column :: new_with_schema ( "id" , & schema_small) . unwrap ( ) ) ,
969
+ vec ! [ lit( ScalarValue :: Dictionary (
970
+ Box :: new( DataType :: Int32 ) ,
971
+ Box :: new( ScalarValue :: from( "2" ) ) ,
972
+ ) ) ] ,
973
+ & false ,
974
+ & schema_small,
975
+ )
976
+ . unwrap ( ) ;
977
+
978
+ assert_eq ! ( actual. as_ref( ) , expected. as_any( ) ) ;
979
+ }
921
980
}
0 commit comments