@@ -344,7 +344,16 @@ impl CaseExpr {
344
344
fn case_column_or_null ( & self , batch : & RecordBatch ) -> Result < ColumnarValue > {
345
345
let when_expr = & self . when_then_expr [ 0 ] . 0 ;
346
346
let then_expr = & self . when_then_expr [ 0 ] . 1 ;
347
- if let ColumnarValue :: Array ( bit_mask) = when_expr. evaluate ( batch) ? {
347
+
348
+ let when_expr_value = when_expr. evaluate ( batch) ?;
349
+ let when_expr_value = match when_expr_value {
350
+ ColumnarValue :: Scalar ( _) => {
351
+ ColumnarValue :: Array ( when_expr_value. into_array ( batch. num_rows ( ) ) ?)
352
+ }
353
+ other => other,
354
+ } ;
355
+
356
+ if let ColumnarValue :: Array ( bit_mask) = when_expr_value {
348
357
let bit_mask = bit_mask
349
358
. as_any ( )
350
359
. downcast_ref :: < BooleanArray > ( )
@@ -896,6 +905,47 @@ mod tests {
896
905
Ok ( ( ) )
897
906
}
898
907
908
+ #[ test]
909
+ fn case_with_scalar_predicate ( ) -> Result < ( ) > {
910
+ let batch = case_test_batch_nulls ( ) ?;
911
+ let schema = batch. schema ( ) ;
912
+
913
+ // SELECT CASE WHEN TRUE THEN load4 END
914
+ let when = lit ( true ) ;
915
+ let then = col ( "load4" , & schema) ?;
916
+ let expr = generate_case_when_with_type_coercion (
917
+ None ,
918
+ vec ! [ ( when, then) ] ,
919
+ None ,
920
+ schema. as_ref ( ) ,
921
+ ) ?;
922
+
923
+ // many rows
924
+ let result = expr
925
+ . evaluate ( & batch) ?
926
+ . into_array ( batch. num_rows ( ) )
927
+ . expect ( "Failed to convert to array" ) ;
928
+ let result =
929
+ as_float64_array ( & result) . expect ( "failed to downcast to Float64Array" ) ;
930
+ let expected =
931
+ & Float64Array :: from ( vec ! [ Some ( 1.77 ) , None , None , Some ( 1.78 ) , None , Some ( 1.77 ) ] ) ;
932
+ assert_eq ! ( expected, result) ;
933
+
934
+ // one row
935
+ let expected = Float64Array :: from ( vec ! [ Some ( 1.1 ) ] ) ;
936
+ let batch = RecordBatch :: try_new (
937
+ Arc :: clone ( & schema) , vec ! [ Arc :: new( expected. clone( ) ) ] ) ?;
938
+ let result = expr
939
+ . evaluate ( & batch) ?
940
+ . into_array ( batch. num_rows ( ) )
941
+ . expect ( "Failed to convert to array" ) ;
942
+ let result =
943
+ as_float64_array ( & result) . expect ( "failed to downcast to Float64Array" ) ;
944
+ assert_eq ! ( & expected, result) ;
945
+
946
+ Ok ( ( ) )
947
+ }
948
+
899
949
#[ test]
900
950
fn case_expr_matches_and_nulls ( ) -> Result < ( ) > {
901
951
let batch = case_test_batch_nulls ( ) ?;
0 commit comments