Skip to content

Commit d5ea583

Browse files
committed
fix: handle scalar predicates in CASE expressions to prevent internal errors for InfallibleExprOrNull eval method
1 parent e8ad5e5 commit d5ea583

File tree

1 file changed

+51
-1
lines changed
  • datafusion/physical-expr/src/expressions

1 file changed

+51
-1
lines changed

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,16 @@ impl CaseExpr {
344344
fn case_column_or_null(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
345345
let when_expr = &self.when_then_expr[0].0;
346346
let then_expr = &self.when_then_expr[0].1;
347-
if let ColumnarValue::Array(bit_mask) = when_expr.evaluate(batch)? {
347+
348+
let when_expr_value = when_expr.evaluate(batch)?;
349+
let when_expr_value = match when_expr_value {
350+
ColumnarValue::Scalar(_) => {
351+
ColumnarValue::Array(when_expr_value.into_array(batch.num_rows())?)
352+
}
353+
other => other,
354+
};
355+
356+
if let ColumnarValue::Array(bit_mask) = when_expr_value {
348357
let bit_mask = bit_mask
349358
.as_any()
350359
.downcast_ref::<BooleanArray>()
@@ -896,6 +905,47 @@ mod tests {
896905
Ok(())
897906
}
898907

908+
#[test]
909+
fn case_with_scalar_predicate() -> Result<()> {
910+
let batch = case_test_batch_nulls()?;
911+
let schema = batch.schema();
912+
913+
// SELECT CASE WHEN TRUE THEN load4 END
914+
let when = lit(true);
915+
let then = col("load4", &schema)?;
916+
let expr = generate_case_when_with_type_coercion(
917+
None,
918+
vec![(when, then)],
919+
None,
920+
schema.as_ref(),
921+
)?;
922+
923+
// many rows
924+
let result = expr
925+
.evaluate(&batch)?
926+
.into_array(batch.num_rows())
927+
.expect("Failed to convert to array");
928+
let result =
929+
as_float64_array(&result).expect("failed to downcast to Float64Array");
930+
let expected =
931+
&Float64Array::from(vec![Some(1.77), None, None, Some(1.78), None, Some(1.77)]);
932+
assert_eq!(expected, result);
933+
934+
// one row
935+
let expected = Float64Array::from(vec![Some(1.1)]);
936+
let batch = RecordBatch::try_new(
937+
Arc::clone(&schema), vec![Arc::new(expected.clone())])?;
938+
let result = expr
939+
.evaluate(&batch)?
940+
.into_array(batch.num_rows())
941+
.expect("Failed to convert to array");
942+
let result =
943+
as_float64_array(&result).expect("failed to downcast to Float64Array");
944+
assert_eq!(&expected, result);
945+
946+
Ok(())
947+
}
948+
899949
#[test]
900950
fn case_expr_matches_and_nulls() -> Result<()> {
901951
let batch = case_test_batch_nulls()?;

0 commit comments

Comments
 (0)