@@ -1351,15 +1351,20 @@ mod tests {
1351
1351
use datafusion_common:: { ColumnStatistics , Result , Statistics } ;
1352
1352
use datafusion_expr:: type_coercion:: binary:: get_input_types;
1353
1353
1354
- // Create a binary expression without coercion. Used here when we do not want to coerce the expressions
1355
- // to valid types. Usage can result in an execution (after plan) error.
1356
- fn binary_simple (
1357
- l : Arc < dyn PhysicalExpr > ,
1354
+ /// Performs a binary operation, applying any type coercion necessary
1355
+ fn binary_op (
1356
+ left : Arc < dyn PhysicalExpr > ,
1358
1357
op : Operator ,
1359
- r : Arc < dyn PhysicalExpr > ,
1360
- input_schema : & Schema ,
1361
- ) -> Arc < dyn PhysicalExpr > {
1362
- binary ( l, op, r, input_schema) . unwrap ( )
1358
+ right : Arc < dyn PhysicalExpr > ,
1359
+ schema : & Schema ,
1360
+ ) -> Result < Arc < dyn PhysicalExpr > > {
1361
+ let left_type = left. data_type ( schema) ?;
1362
+ let right_type = right. data_type ( schema) ?;
1363
+ let ( lhs, rhs) = get_input_types ( & left_type, & op, & right_type) ?;
1364
+
1365
+ let left_expr = try_cast ( left, schema, lhs) ?;
1366
+ let right_expr = try_cast ( right, schema, rhs) ?;
1367
+ binary ( left_expr, op, right_expr, schema)
1363
1368
}
1364
1369
1365
1370
#[ test]
@@ -1372,12 +1377,12 @@ mod tests {
1372
1377
let b = Int32Array :: from ( vec ! [ 1 , 2 , 4 , 8 , 16 ] ) ;
1373
1378
1374
1379
// expression: "a < b"
1375
- let lt = binary_simple (
1380
+ let lt = binary (
1376
1381
col ( "a" , & schema) ?,
1377
1382
Operator :: Lt ,
1378
1383
col ( "b" , & schema) ?,
1379
1384
& schema,
1380
- ) ;
1385
+ ) ? ;
1381
1386
let batch =
1382
1387
RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ Arc :: new( a) , Arc :: new( b) ] ) ?;
1383
1388
@@ -1404,22 +1409,22 @@ mod tests {
1404
1409
let b = Int32Array :: from ( vec ! [ 2 , 5 , 4 , 8 , 8 ] ) ;
1405
1410
1406
1411
// expression: "a < b OR a == b"
1407
- let expr = binary_simple (
1408
- binary_simple (
1412
+ let expr = binary (
1413
+ binary (
1409
1414
col ( "a" , & schema) ?,
1410
1415
Operator :: Lt ,
1411
1416
col ( "b" , & schema) ?,
1412
1417
& schema,
1413
- ) ,
1418
+ ) ? ,
1414
1419
Operator :: Or ,
1415
- binary_simple (
1420
+ binary (
1416
1421
col ( "a" , & schema) ?,
1417
1422
Operator :: Eq ,
1418
1423
col ( "b" , & schema) ?,
1419
1424
& schema,
1420
- ) ,
1425
+ ) ? ,
1421
1426
& schema,
1422
- ) ;
1427
+ ) ? ;
1423
1428
let batch =
1424
1429
RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ Arc :: new( a) , Arc :: new( b) ] ) ?;
1425
1430
@@ -1492,7 +1497,7 @@ mod tests {
1492
1497
}
1493
1498
1494
1499
#[ test]
1495
- fn test_type_coersion ( ) -> Result < ( ) > {
1500
+ fn test_type_coercion ( ) -> Result < ( ) > {
1496
1501
test_coercion ! (
1497
1502
Int32Array ,
1498
1503
DataType :: Int32 ,
@@ -1814,8 +1819,7 @@ mod tests {
1814
1819
// is no way at the time of this writing to create a dictionary
1815
1820
// array using the `From` trait
1816
1821
#[ test]
1817
- #[ cfg( feature = "dictionary_expressions" ) ]
1818
- fn test_dictionary_type_to_array_coersion ( ) -> Result < ( ) > {
1822
+ fn test_dictionary_type_to_array_coercion ( ) -> Result < ( ) > {
1819
1823
// Test string a string dictionary
1820
1824
let dict_type =
1821
1825
DataType :: Dictionary ( Box :: new ( DataType :: Int32 ) , Box :: new ( DataType :: Utf8 ) ) ;
@@ -1878,7 +1882,6 @@ mod tests {
1878
1882
}
1879
1883
1880
1884
#[ test]
1881
- #[ cfg( feature = "dictionary_expressions" ) ]
1882
1885
fn plus_op_dict ( ) -> Result < ( ) > {
1883
1886
let schema = Schema :: new ( vec ! [
1884
1887
Field :: new(
@@ -1912,7 +1915,6 @@ mod tests {
1912
1915
}
1913
1916
1914
1917
#[ test]
1915
- #[ cfg( feature = "dictionary_expressions" ) ]
1916
1918
fn plus_op_dict_decimal ( ) -> Result < ( ) > {
1917
1919
let schema = Schema :: new ( vec ! [
1918
1920
Field :: new(
@@ -2096,7 +2098,6 @@ mod tests {
2096
2098
}
2097
2099
2098
2100
#[ test]
2099
- #[ cfg( feature = "dictionary_expressions" ) ]
2100
2101
fn minus_op_dict ( ) -> Result < ( ) > {
2101
2102
let schema = Schema :: new ( vec ! [
2102
2103
Field :: new(
@@ -2130,7 +2131,6 @@ mod tests {
2130
2131
}
2131
2132
2132
2133
#[ test]
2133
- #[ cfg( feature = "dictionary_expressions" ) ]
2134
2134
fn minus_op_dict_decimal ( ) -> Result < ( ) > {
2135
2135
let schema = Schema :: new ( vec ! [
2136
2136
Field :: new(
@@ -2306,7 +2306,6 @@ mod tests {
2306
2306
}
2307
2307
2308
2308
#[ test]
2309
- #[ cfg( feature = "dictionary_expressions" ) ]
2310
2309
fn multiply_op_dict ( ) -> Result < ( ) > {
2311
2310
let schema = Schema :: new ( vec ! [
2312
2311
Field :: new(
@@ -2340,7 +2339,6 @@ mod tests {
2340
2339
}
2341
2340
2342
2341
#[ test]
2343
- #[ cfg( feature = "dictionary_expressions" ) ]
2344
2342
fn multiply_op_dict_decimal ( ) -> Result < ( ) > {
2345
2343
let schema = Schema :: new ( vec ! [
2346
2344
Field :: new(
@@ -2514,7 +2512,6 @@ mod tests {
2514
2512
}
2515
2513
2516
2514
#[ test]
2517
- #[ cfg( feature = "dictionary_expressions" ) ]
2518
2515
fn divide_op_dict ( ) -> Result < ( ) > {
2519
2516
let schema = Schema :: new ( vec ! [
2520
2517
Field :: new(
@@ -2554,7 +2551,6 @@ mod tests {
2554
2551
}
2555
2552
2556
2553
#[ test]
2557
- #[ cfg( feature = "dictionary_expressions" ) ]
2558
2554
fn divide_op_dict_decimal ( ) -> Result < ( ) > {
2559
2555
let schema = Schema :: new ( vec ! [
2560
2556
Field :: new(
@@ -2740,7 +2736,6 @@ mod tests {
2740
2736
}
2741
2737
2742
2738
#[ test]
2743
- #[ cfg( feature = "dictionary_expressions" ) ]
2744
2739
fn modulus_op_dict ( ) -> Result < ( ) > {
2745
2740
let schema = Schema :: new ( vec ! [
2746
2741
Field :: new(
@@ -2780,7 +2775,6 @@ mod tests {
2780
2775
}
2781
2776
2782
2777
#[ test]
2783
- #[ cfg( feature = "dictionary_expressions" ) ]
2784
2778
fn modulus_op_dict_decimal ( ) -> Result < ( ) > {
2785
2779
let schema = Schema :: new ( vec ! [
2786
2780
Field :: new(
@@ -2937,7 +2931,7 @@ mod tests {
2937
2931
expected : PrimitiveArray < T > ,
2938
2932
) -> Result < ( ) > {
2939
2933
let arithmetic_op =
2940
- binary_simple ( col ( "a" , & schema) ?, op, col ( "b" , & schema) ?, & schema) ;
2934
+ binary_op ( col ( "a" , & schema) ?, op, col ( "b" , & schema) ?, & schema) ? ;
2941
2935
let batch = RecordBatch :: try_new ( schema, data) ?;
2942
2936
let result = arithmetic_op. evaluate ( & batch) ?. into_array ( batch. num_rows ( ) ) ;
2943
2937
@@ -2953,7 +2947,7 @@ mod tests {
2953
2947
expected : ArrayRef ,
2954
2948
) -> Result < ( ) > {
2955
2949
let lit = Arc :: new ( Literal :: new ( literal) ) ;
2956
- let arithmetic_op = binary_simple ( col ( "a" , & schema) ?, op, lit, & schema) ;
2950
+ let arithmetic_op = binary_op ( col ( "a" , & schema) ?, op, lit, & schema) ? ;
2957
2951
let batch = RecordBatch :: try_new ( schema, data) ?;
2958
2952
let result = arithmetic_op. evaluate ( & batch) ?. into_array ( batch. num_rows ( ) ) ;
2959
2953
@@ -2968,16 +2962,10 @@ mod tests {
2968
2962
op : Operator ,
2969
2963
expected : BooleanArray ,
2970
2964
) -> Result < ( ) > {
2971
- let left_type = left. data_type ( ) ;
2972
- let right_type = right. data_type ( ) ;
2973
- let ( lhs, rhs) = get_input_types ( left_type, & op, right_type) ?;
2974
-
2975
- let left_expr = try_cast ( col ( "a" , schema) ?, schema, lhs) ?;
2976
- let right_expr = try_cast ( col ( "b" , schema) ?, schema, rhs) ?;
2977
- let arithmetic_op = binary_simple ( left_expr, op, right_expr, schema) ;
2965
+ let op = binary_op ( col ( "a" , schema) ?, op, col ( "b" , schema) ?, schema) ?;
2978
2966
let data: Vec < ArrayRef > = vec ! [ left. clone( ) , right. clone( ) ] ;
2979
2967
let batch = RecordBatch :: try_new ( schema. clone ( ) , data) ?;
2980
- let result = arithmetic_op . evaluate ( & batch) ?. into_array ( batch. num_rows ( ) ) ;
2968
+ let result = op . evaluate ( & batch) ?. into_array ( batch. num_rows ( ) ) ;
2981
2969
2982
2970
assert_eq ! ( result. as_ref( ) , & expected) ;
2983
2971
Ok ( ( ) )
@@ -2992,14 +2980,9 @@ mod tests {
2992
2980
expected : & BooleanArray ,
2993
2981
) -> Result < ( ) > {
2994
2982
let scalar = lit ( scalar. clone ( ) ) ;
2995
- let ( lhs, rhs) =
2996
- get_input_types ( & scalar. data_type ( schema) ?, & op, arr. data_type ( ) ) ?;
2997
- let left_expr = try_cast ( scalar, schema, lhs) ?;
2998
- let right_expr = try_cast ( col ( "a" , schema) ?, schema, rhs) ?;
2999
-
3000
- let arithmetic_op = binary_simple ( left_expr, op, right_expr, schema) ;
2983
+ let op = binary_op ( scalar, op, col ( "a" , schema) ?, schema) ?;
3001
2984
let batch = RecordBatch :: try_new ( Arc :: clone ( schema) , vec ! [ Arc :: clone( arr) ] ) ?;
3002
- let result = arithmetic_op . evaluate ( & batch) ?. into_array ( batch. num_rows ( ) ) ;
2985
+ let result = op . evaluate ( & batch) ?. into_array ( batch. num_rows ( ) ) ;
3003
2986
assert_eq ! ( result. as_ref( ) , expected) ;
3004
2987
3005
2988
Ok ( ( ) )
@@ -3014,14 +2997,9 @@ mod tests {
3014
2997
expected : & BooleanArray ,
3015
2998
) -> Result < ( ) > {
3016
2999
let scalar = lit ( scalar. clone ( ) ) ;
3017
- let ( lhs, rhs) =
3018
- get_input_types ( arr. data_type ( ) , & op, & scalar. data_type ( schema) ?) ?;
3019
- let left_expr = try_cast ( col ( "a" , schema) ?, schema, lhs) ?;
3020
- let right_expr = try_cast ( scalar, schema, rhs) ?;
3021
-
3022
- let arithmetic_op = binary_simple ( left_expr, op, right_expr, schema) ;
3000
+ let op = binary_op ( col ( "a" , schema) ?, op, scalar, schema) ?;
3023
3001
let batch = RecordBatch :: try_new ( Arc :: clone ( schema) , vec ! [ Arc :: clone( arr) ] ) ?;
3024
- let result = arithmetic_op . evaluate ( & batch) ?. into_array ( batch. num_rows ( ) ) ;
3002
+ let result = op . evaluate ( & batch) ?. into_array ( batch. num_rows ( ) ) ;
3025
3003
assert_eq ! ( result. as_ref( ) , expected) ;
3026
3004
3027
3005
Ok ( ( ) )
@@ -3587,7 +3565,7 @@ mod tests {
3587
3565
let tree_depth: i32 = 100 ;
3588
3566
let expr = ( 0 ..tree_depth)
3589
3567
. map ( |_| col ( "a" , schema. as_ref ( ) ) . unwrap ( ) )
3590
- . reduce ( |l, r| binary_simple ( l, Operator :: Plus , r, & schema) )
3568
+ . reduce ( |l, r| binary ( l, Operator :: Plus , r, & schema) . unwrap ( ) )
3591
3569
. unwrap ( ) ;
3592
3570
3593
3571
let result = expr
@@ -4069,26 +4047,7 @@ mod tests {
4069
4047
op : Operator ,
4070
4048
expected : ArrayRef ,
4071
4049
) -> Result < ( ) > {
4072
- let ( lhs_type, rhs_type) =
4073
- get_input_types ( left. data_type ( ) , & op, right. data_type ( ) ) . unwrap ( ) ;
4074
-
4075
- let left_expr = try_cast ( col ( "a" , schema) ?, schema, lhs_type. clone ( ) ) ?;
4076
- let right_expr = try_cast ( col ( "b" , schema) ?, schema, rhs_type. clone ( ) ) ?;
4077
-
4078
- let coerced_schema = Schema :: new ( vec ! [
4079
- Field :: new(
4080
- schema. field( 0 ) . name( ) ,
4081
- lhs_type,
4082
- schema. field( 0 ) . is_nullable( ) ,
4083
- ) ,
4084
- Field :: new(
4085
- schema. field( 1 ) . name( ) ,
4086
- rhs_type,
4087
- schema. field( 1 ) . is_nullable( ) ,
4088
- ) ,
4089
- ] ) ;
4090
-
4091
- let arithmetic_op = binary_simple ( left_expr, op, right_expr, & coerced_schema) ;
4050
+ let arithmetic_op = binary_op ( col ( "a" , schema) ?, op, col ( "b" , schema) ?, schema) ?;
4092
4051
let data: Vec < ArrayRef > = vec ! [ left. clone( ) , right. clone( ) ] ;
4093
4052
let batch = RecordBatch :: try_new ( schema. clone ( ) , data) ?;
4094
4053
let result = arithmetic_op. evaluate ( & batch) ?. into_array ( batch. num_rows ( ) ) ;
@@ -4761,12 +4720,12 @@ mod tests {
4761
4720
4762
4721
// expression: "a >= 25"
4763
4722
let a = col ( "a" , & schema) . unwrap ( ) ;
4764
- let gt = binary_simple (
4723
+ let gt = binary (
4765
4724
a. clone ( ) ,
4766
4725
Operator :: GtEq ,
4767
4726
lit ( ScalarValue :: from ( 25 ) ) ,
4768
4727
& schema,
4769
- ) ;
4728
+ ) ? ;
4770
4729
4771
4730
let context = AnalysisContext :: from_statistics ( & schema, & statistics) ;
4772
4731
let predicate_boundaries = gt
@@ -4790,12 +4749,12 @@ mod tests {
4790
4749
4791
4750
// expression: "50 >= a"
4792
4751
let a = col ( "a" , & schema) . unwrap ( ) ;
4793
- let gt = binary_simple (
4752
+ let gt = binary (
4794
4753
lit ( ScalarValue :: from ( 50 ) ) ,
4795
4754
Operator :: GtEq ,
4796
4755
a. clone ( ) ,
4797
4756
& schema,
4798
- ) ;
4757
+ ) ? ;
4799
4758
4800
4759
let context = AnalysisContext :: from_statistics ( & schema, & statistics) ;
4801
4760
let predicate_boundaries = gt
0 commit comments