Skip to content

Commit 88a911b

Browse files
committed
#17801 Improve nullability reporting of case expressions
1 parent 5bbdb7e commit 88a911b

File tree

3 files changed

+75
-6
lines changed

3 files changed

+75
-6
lines changed

datafusion/core/tests/tpcds_planning.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1052,9 +1052,10 @@ async fn regression_test(query_no: u8, create_physical: bool) -> Result<()> {
10521052
for sql in &sql {
10531053
let df = ctx.sql(sql).await?;
10541054
let (state, plan) = df.into_parts();
1055-
let plan = state.optimize(&plan)?;
10561055
if create_physical {
10571056
let _ = state.create_physical_plan(&plan).await?;
1057+
} else {
1058+
let _ = state.optimize(&plan)?;
10581059
}
10591060
}
10601061

datafusion/expr/src/expr_schema.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ use datafusion_common::{
3232
not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema,
3333
Result, Spans, TableReference,
3434
};
35+
use datafusion_expr_common::operator::Operator;
3536
use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer;
3637
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
3738
use std::sync::Arc;
@@ -283,6 +284,9 @@ impl ExprSchemable for Expr {
283284
let then_nullable = case
284285
.when_then_expr
285286
.iter()
287+
.filter(|(w, t)| {
288+
!always_false_when_value_is_null(w, t).unwrap_or(false)
289+
})
286290
.map(|(_, t)| t.nullable(input_schema))
287291
.collect::<Result<Vec<_>>>()?;
288292
if then_nullable.contains(&true) {
@@ -647,6 +651,33 @@ impl ExprSchemable for Expr {
647651
}
648652
}
649653

654+
fn always_false_when_value_is_null(predicate: &Expr, value: &Expr) -> Option<bool> {
655+
match predicate {
656+
Expr::IsNotNull(e) => Some(e.as_ref().eq(value)),
657+
Expr::Not(e) => always_false_when_value_is_null(e, value).map(|b| !b),
658+
Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
659+
Operator::And => {
660+
let l = always_false_when_value_is_null(left, value);
661+
let r = always_false_when_value_is_null(right, value);
662+
match (l, r) {
663+
(Some(l), Some(r)) => Some(l || r),
664+
_ => None,
665+
}
666+
}
667+
Operator::Or => {
668+
let l = always_false_when_value_is_null(left, value);
669+
let r = always_false_when_value_is_null(right, value);
670+
match (l, r) {
671+
(Some(l), Some(r)) => Some(l && r),
672+
_ => None,
673+
}
674+
}
675+
_ => None,
676+
},
677+
_ => None,
678+
}
679+
}
680+
650681
impl Expr {
651682
/// Common method for window functions that applies type coercion
652683
/// to all arguments of the window function to check if it matches

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,8 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::expressions::try_cast;
18+
use crate::expressions::{try_cast, BinaryExpr, IsNotNullExpr, NotExpr};
1919
use crate::PhysicalExpr;
20-
use std::borrow::Cow;
21-
use std::hash::Hash;
22-
use std::{any::Any, sync::Arc};
23-
2420
use arrow::array::*;
2521
use arrow::compute::kernels::zip::zip;
2622
use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter};
@@ -30,8 +26,12 @@ use datafusion_common::{
3026
exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
3127
};
3228
use datafusion_expr::ColumnarValue;
29+
use std::borrow::Cow;
30+
use std::hash::Hash;
31+
use std::{any::Any, sync::Arc};
3332

3433
use super::{Column, Literal};
34+
use datafusion_expr_common::operator::Operator;
3535
use datafusion_physical_expr_common::datum::compare_with_eq;
3636
use itertools::Itertools;
3737

@@ -481,6 +481,9 @@ impl PhysicalExpr for CaseExpr {
481481
let then_nullable = self
482482
.when_then_expr
483483
.iter()
484+
.filter(|(w, t)| {
485+
!always_false_when_value_is_null(w.as_ref(), t.as_ref()).unwrap_or(false)
486+
})
484487
.map(|(_, t)| t.nullable(input_schema))
485488
.collect::<Result<Vec<_>>>()?;
486489
if then_nullable.contains(&true) {
@@ -588,6 +591,40 @@ impl PhysicalExpr for CaseExpr {
588591
}
589592
}
590593

594+
fn always_false_when_value_is_null(
595+
predicate: &dyn PhysicalExpr,
596+
value: &dyn PhysicalExpr,
597+
) -> Option<bool> {
598+
let predicate_any = predicate.as_any();
599+
if let Some(not_null) = predicate_any.downcast_ref::<IsNotNullExpr>() {
600+
Some(not_null.arg().as_ref().dyn_eq(value))
601+
} else if let Some(not) = predicate_any.downcast_ref::<NotExpr>() {
602+
always_false_when_value_is_null(not.arg().as_ref(), value).map(|b| !b)
603+
} else if let Some(binary) = predicate_any.downcast_ref::<BinaryExpr>() {
604+
match binary.op() {
605+
Operator::And => {
606+
let l = always_false_when_value_is_null(binary.left().as_ref(), value);
607+
let r = always_false_when_value_is_null(binary.right().as_ref(), value);
608+
match (l, r) {
609+
(Some(l), Some(r)) => Some(l || r),
610+
_ => None,
611+
}
612+
}
613+
Operator::Or => {
614+
let l = always_false_when_value_is_null(binary.left().as_ref(), value);
615+
let r = always_false_when_value_is_null(binary.right().as_ref(), value);
616+
match (l, r) {
617+
(Some(l), Some(r)) => Some(l && r),
618+
_ => None,
619+
}
620+
}
621+
_ => None,
622+
}
623+
} else {
624+
None
625+
}
626+
}
627+
591628
/// Create a CASE expression
592629
pub fn case(
593630
expr: Option<Arc<dyn PhysicalExpr>>,

0 commit comments

Comments
 (0)