@@ -109,7 +109,12 @@ impl CaseExpr {
109
109
110
110
macro_rules! if_then_else {
111
111
( $BUILDER_TYPE: ty, $ARRAY_TYPE: ty, $BOOLS: expr, $TRUE: expr, $FALSE: expr) => { {
112
- let true_values = $TRUE
112
+ let true_values = if $TRUE. data_type( ) == & DataType :: Null {
113
+ Arc :: new( <$ARRAY_TYPE>:: from( vec![ None ; $TRUE. len( ) ] ) )
114
+ } else {
115
+ $TRUE
116
+ } ;
117
+ let true_values = true_values
113
118
. as_ref( )
114
119
. as_any( )
115
120
. downcast_ref:: <$ARRAY_TYPE>( )
@@ -118,7 +123,12 @@ macro_rules! if_then_else {
118
123
stringify!( $ARRAY_TYPE)
119
124
) ) ;
120
125
121
- let false_values = $FALSE
126
+ let false_values = if $FALSE. data_type( ) == & DataType :: Null {
127
+ Arc :: new( <$ARRAY_TYPE>:: from( vec![ None ; $FALSE. len( ) ] ) )
128
+ } else {
129
+ $FALSE
130
+ } ;
131
+ let false_values = false_values
122
132
. as_ref( )
123
133
. as_any( )
124
134
. downcast_ref:: <$ARRAY_TYPE>( )
@@ -252,6 +262,23 @@ fn if_then_else(
252
262
}
253
263
254
264
impl CaseExpr {
265
+ /// This function returns the return type of CASE expression.
266
+ ///
267
+ /// The first non-Null THEN expr type is returned; if there are none, ELSE type is returned.
268
+ /// In the abscense of ELSE, Null is returned.
269
+ fn return_type ( & self , schema : & Schema ) -> Result < DataType > {
270
+ for ( _, then) in self . when_then_expr . iter ( ) {
271
+ match then. data_type ( schema) ? {
272
+ DataType :: Null => continue ,
273
+ dt => return Ok ( dt) ,
274
+ } ;
275
+ }
276
+ if let Some ( else_expr) = & self . else_expr {
277
+ return else_expr. data_type ( schema) ;
278
+ }
279
+ Ok ( DataType :: Null )
280
+ }
281
+
255
282
/// This function evaluates the form of CASE that matches an expression to fixed values.
256
283
///
257
284
/// CASE expression
@@ -260,7 +287,7 @@ impl CaseExpr {
260
287
/// [ELSE result]
261
288
/// END
262
289
fn case_when_with_expr ( & self , batch : & RecordBatch ) -> Result < ColumnarValue > {
263
- let return_type = self . when_then_expr [ 0 ] . 1 . data_type ( & batch. schema ( ) ) ?;
290
+ let return_type = self . return_type ( & batch. schema ( ) ) ?;
264
291
let expr = self . expr . as_ref ( ) . unwrap ( ) ;
265
292
let base_value = expr. evaluate ( batch) ?;
266
293
let base_value = base_value. into_array ( batch. num_rows ( ) ) ;
@@ -339,7 +366,7 @@ impl CaseExpr {
339
366
/// [ELSE result]
340
367
/// END
341
368
fn case_when_no_expr ( & self , batch : & RecordBatch ) -> Result < ColumnarValue > {
342
- let return_type = self . when_then_expr [ 0 ] . 1 . data_type ( & batch. schema ( ) ) ?;
369
+ let return_type = self . return_type ( & batch. schema ( ) ) ?;
343
370
344
371
// start with nulls as default output
345
372
let mut current_value = new_null_array ( & return_type, batch. num_rows ( ) ) ;
@@ -392,7 +419,7 @@ impl PhysicalExpr for CaseExpr {
392
419
}
393
420
394
421
fn data_type ( & self , input_schema : & Schema ) -> Result < DataType > {
395
- self . when_then_expr [ 0 ] . 1 . data_type ( input_schema)
422
+ self . return_type ( input_schema)
396
423
}
397
424
398
425
fn nullable ( & self , input_schema : & Schema ) -> Result < bool > {
0 commit comments