Skip to content

Commit 0dc81b4

Browse files
committed
fix: handle scalar predicates in CASE expressions to prevent internal errors for InfallibleExprOrNull eval method
1 parent e8ad5e5 commit 0dc81b4

File tree

1 file changed

+57
-1
lines changed
  • datafusion/physical-expr/src/expressions

1 file changed

+57
-1
lines changed

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,16 @@ impl CaseExpr {
344344
fn case_column_or_null(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
345345
let when_expr = &self.when_then_expr[0].0;
346346
let then_expr = &self.when_then_expr[0].1;
347-
if let ColumnarValue::Array(bit_mask) = when_expr.evaluate(batch)? {
347+
348+
let when_expr_value = when_expr.evaluate(batch)?;
349+
let when_expr_value = match when_expr_value {
350+
ColumnarValue::Scalar(_) => {
351+
ColumnarValue::Array(when_expr_value.into_array(batch.num_rows())?)
352+
}
353+
other => other,
354+
};
355+
356+
if let ColumnarValue::Array(bit_mask) = when_expr_value {
348357
let bit_mask = bit_mask
349358
.as_any()
350359
.downcast_ref::<BooleanArray>()
@@ -896,6 +905,53 @@ mod tests {
896905
Ok(())
897906
}
898907

908+
#[test]
909+
fn case_with_scalar_predicate() -> Result<()> {
910+
let batch = case_test_batch_nulls()?;
911+
let schema = batch.schema();
912+
913+
// SELECT CASE WHEN TRUE THEN load4 END
914+
let when = lit(true);
915+
let then = col("load4", &schema)?;
916+
let expr = generate_case_when_with_type_coercion(
917+
None,
918+
vec![(when, then)],
919+
None,
920+
schema.as_ref(),
921+
)?;
922+
923+
// many rows
924+
let result = expr
925+
.evaluate(&batch)?
926+
.into_array(batch.num_rows())
927+
.expect("Failed to convert to array");
928+
let result =
929+
as_float64_array(&result).expect("failed to downcast to Float64Array");
930+
let expected = &Float64Array::from(vec![
931+
Some(1.77),
932+
None,
933+
None,
934+
Some(1.78),
935+
None,
936+
Some(1.77),
937+
]);
938+
assert_eq!(expected, result);
939+
940+
// one row
941+
let expected = Float64Array::from(vec![Some(1.1)]);
942+
let batch =
943+
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(expected.clone())])?;
944+
let result = expr
945+
.evaluate(&batch)?
946+
.into_array(batch.num_rows())
947+
.expect("Failed to convert to array");
948+
let result =
949+
as_float64_array(&result).expect("failed to downcast to Float64Array");
950+
assert_eq!(&expected, result);
951+
952+
Ok(())
953+
}
954+
899955
#[test]
900956
fn case_expr_matches_and_nulls() -> Result<()> {
901957
let batch = case_test_batch_nulls()?;

0 commit comments

Comments
 (0)