Skip to content

Commit f187706

Browse files
committed
fix: Make CASE expr use first non-Null type
1 parent b5de6f5 commit f187706

File tree

2 files changed

+37
-10
lines changed

2 files changed

+37
-10
lines changed

datafusion/core/tests/sql/expr.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,11 @@ async fn case_when_else_null() -> Result<()> {
7171
let sql = "SELECT CASE 'test' WHEN 'int4' THEN NULL ELSE 100 END";
7272
let actual = execute_to_batches(&ctx, sql).await;
7373
let expected = vec![
74-
"+-------------------------------------------------------------------------+",
75-
"| CASE Utf8(\"test\") WHEN Utf8(\"int4\") THEN Utf8(NULL) ELSE Int64(100) END |",
76-
"+-------------------------------------------------------------------------+",
77-
"| 100 |",
78-
"+-------------------------------------------------------------------------+",
74+
"+-------------------------------------------------------------------+",
75+
"| CASE Utf8(\"test\") WHEN Utf8(\"int4\") THEN NULL ELSE Int64(100) END |",
76+
"+-------------------------------------------------------------------+",
77+
"| 100 |",
78+
"+-------------------------------------------------------------------+",
7979
];
8080
assert_batches_eq!(expected, &actual);
8181
Ok(())

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,12 @@ impl CaseExpr {
109109

110110
macro_rules! if_then_else {
111111
($BUILDER_TYPE:ty, $ARRAY_TYPE:ty, $BOOLS:expr, $TRUE:expr, $FALSE:expr) => {{
112-
let true_values = $TRUE
112+
let true_values = if $TRUE.data_type() == &DataType::Null {
113+
Arc::new(<$ARRAY_TYPE>::from(vec![None; $TRUE.len()]))
114+
} else {
115+
$TRUE
116+
};
117+
let true_values = true_values
113118
.as_ref()
114119
.as_any()
115120
.downcast_ref::<$ARRAY_TYPE>()
@@ -118,7 +123,12 @@ macro_rules! if_then_else {
118123
stringify!($ARRAY_TYPE)
119124
));
120125

121-
let false_values = $FALSE
126+
let false_values = if $FALSE.data_type() == &DataType::Null {
127+
Arc::new(<$ARRAY_TYPE>::from(vec![None; $FALSE.len()]))
128+
} else {
129+
$FALSE
130+
};
131+
let false_values = false_values
122132
.as_ref()
123133
.as_any()
124134
.downcast_ref::<$ARRAY_TYPE>()
@@ -252,6 +262,23 @@ fn if_then_else(
252262
}
253263

254264
impl CaseExpr {
265+
/// This function returns the return type of CASE expression.
266+
///
267+
/// The first non-Null THEN expr type is returned; if there are none, ELSE type is returned.
268+
/// In the abscense of ELSE, Null is returned.
269+
fn return_type(&self, schema: &Schema) -> Result<DataType> {
270+
for (_, then) in self.when_then_expr.iter() {
271+
match then.data_type(schema)? {
272+
DataType::Null => continue,
273+
dt => return Ok(dt),
274+
};
275+
}
276+
if let Some(else_expr) = &self.else_expr {
277+
return else_expr.data_type(schema);
278+
}
279+
Ok(DataType::Null)
280+
}
281+
255282
/// This function evaluates the form of CASE that matches an expression to fixed values.
256283
///
257284
/// CASE expression
@@ -260,7 +287,7 @@ impl CaseExpr {
260287
/// [ELSE result]
261288
/// END
262289
fn case_when_with_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
263-
let return_type = self.when_then_expr[0].1.data_type(&batch.schema())?;
290+
let return_type = self.return_type(&batch.schema())?;
264291
let expr = self.expr.as_ref().unwrap();
265292
let base_value = expr.evaluate(batch)?;
266293
let base_value = base_value.into_array(batch.num_rows());
@@ -339,7 +366,7 @@ impl CaseExpr {
339366
/// [ELSE result]
340367
/// END
341368
fn case_when_no_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
342-
let return_type = self.when_then_expr[0].1.data_type(&batch.schema())?;
369+
let return_type = self.return_type(&batch.schema())?;
343370

344371
// start with nulls as default output
345372
let mut current_value = new_null_array(&return_type, batch.num_rows());
@@ -392,7 +419,7 @@ impl PhysicalExpr for CaseExpr {
392419
}
393420

394421
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
395-
self.when_then_expr[0].1.data_type(input_schema)
422+
self.return_type(input_schema)
396423
}
397424

398425
fn nullable(&self, input_schema: &Schema) -> Result<bool> {

0 commit comments

Comments
 (0)