@@ -344,7 +344,16 @@ impl CaseExpr {
344
344
fn case_column_or_null ( & self , batch : & RecordBatch ) -> Result < ColumnarValue > {
345
345
let when_expr = & self . when_then_expr [ 0 ] . 0 ;
346
346
let then_expr = & self . when_then_expr [ 0 ] . 1 ;
347
- if let ColumnarValue :: Array ( bit_mask) = when_expr. evaluate ( batch) ? {
347
+
348
+ let when_expr_value = when_expr. evaluate ( batch) ?;
349
+ let when_expr_value = match when_expr_value {
350
+ ColumnarValue :: Scalar ( _) => {
351
+ ColumnarValue :: Array ( when_expr_value. into_array ( batch. num_rows ( ) ) ?)
352
+ }
353
+ other => other,
354
+ } ;
355
+
356
+ if let ColumnarValue :: Array ( bit_mask) = when_expr_value {
348
357
let bit_mask = bit_mask
349
358
. as_any ( )
350
359
. downcast_ref :: < BooleanArray > ( )
@@ -896,6 +905,53 @@ mod tests {
896
905
Ok ( ( ) )
897
906
}
898
907
908
+ #[ test]
909
+ fn case_with_scalar_predicate ( ) -> Result < ( ) > {
910
+ let batch = case_test_batch_nulls ( ) ?;
911
+ let schema = batch. schema ( ) ;
912
+
913
+ // SELECT CASE WHEN TRUE THEN load4 END
914
+ let when = lit ( true ) ;
915
+ let then = col ( "load4" , & schema) ?;
916
+ let expr = generate_case_when_with_type_coercion (
917
+ None ,
918
+ vec ! [ ( when, then) ] ,
919
+ None ,
920
+ schema. as_ref ( ) ,
921
+ ) ?;
922
+
923
+ // many rows
924
+ let result = expr
925
+ . evaluate ( & batch) ?
926
+ . into_array ( batch. num_rows ( ) )
927
+ . expect ( "Failed to convert to array" ) ;
928
+ let result =
929
+ as_float64_array ( & result) . expect ( "failed to downcast to Float64Array" ) ;
930
+ let expected = & Float64Array :: from ( vec ! [
931
+ Some ( 1.77 ) ,
932
+ None ,
933
+ None ,
934
+ Some ( 1.78 ) ,
935
+ None ,
936
+ Some ( 1.77 ) ,
937
+ ] ) ;
938
+ assert_eq ! ( expected, result) ;
939
+
940
+ // one row
941
+ let expected = Float64Array :: from ( vec ! [ Some ( 1.1 ) ] ) ;
942
+ let batch =
943
+ RecordBatch :: try_new ( Arc :: clone ( & schema) , vec ! [ Arc :: new( expected. clone( ) ) ] ) ?;
944
+ let result = expr
945
+ . evaluate ( & batch) ?
946
+ . into_array ( batch. num_rows ( ) )
947
+ . expect ( "Failed to convert to array" ) ;
948
+ let result =
949
+ as_float64_array ( & result) . expect ( "failed to downcast to Float64Array" ) ;
950
+ assert_eq ! ( & expected, result) ;
951
+
952
+ Ok ( ( ) )
953
+ }
954
+
899
955
#[ test]
900
956
fn case_expr_matches_and_nulls ( ) -> Result < ( ) > {
901
957
let batch = case_test_batch_nulls ( ) ?;
0 commit comments