diff --git a/datafusion/core/tests/tpcds_planning.rs b/datafusion/core/tests/tpcds_planning.rs index 252d76d0f9d9..3ad74962bc2c 100644 --- a/datafusion/core/tests/tpcds_planning.rs +++ b/datafusion/core/tests/tpcds_planning.rs @@ -1052,9 +1052,12 @@ async fn regression_test(query_no: u8, create_physical: bool) -> Result<()> { for sql in &sql { let df = ctx.sql(sql).await?; let (state, plan) = df.into_parts(); - let plan = state.optimize(&plan)?; if create_physical { let _ = state.create_physical_plan(&plan).await?; + } else { + // Run the logical optimizer even if we are not creating the physical plan + // to ensure it will properly succeed + let _ = state.optimize(&plan)?; } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 4666411dd540..98dd5234345f 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -341,6 +341,11 @@ pub fn is_null(expr: Expr) -> Expr { Expr::IsNull(Box::new(expr)) } +/// Create is not null expression +pub fn is_not_null(expr: Expr) -> Expr { + Expr::IsNotNull(Box::new(expr)) +} + /// Create is true expression pub fn is_true(expr: Expr) -> Expr { Expr::IsTrue(Box::new(expr)) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index e803e3534130..aab01ad400af 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. -use super::{Between, Expr, Like}; +use super::{predicate_eval, Between, Expr, Like}; use crate::expr::{ AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, FieldMetadata, InList, InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, WindowFunctionParams, }; +use crate::predicate_eval::TriStateBool; use crate::type_coercion::functions::{ data_types_with_scalar_udf, fields_with_aggregate_udf, fields_with_window_udf, }; @@ -279,13 +280,50 @@ impl ExprSchemable for Expr { Expr::OuterReferenceColumn(field, _) => Ok(field.is_nullable()), Expr::Literal(value, _) => Ok(value.is_null()), Expr::Case(case) => { - // This expression is nullable if any of the input expressions are nullable - let then_nullable = case + // This expression is nullable if any of the then expressions are nullable + let any_nullable_thens = !case .when_then_expr .iter() - .map(|(_, t)| t.nullable(input_schema)) - .collect::>>()?; - if then_nullable.contains(&true) { + .filter_map(|(w, t)| { + match t.nullable(input_schema) { + // Branches with a then expression that is not nullable can be skipped + Ok(false) => None, + // Pass error determining nullability on verbatim + Err(e) => Some(Err(e)), + // For branches with a nullable then expressions try to determine + // using limited const evaluation if the branch will be taken when + // the then expression evaluates to null. + Ok(true) => { + let const_result = predicate_eval::const_eval_predicate( + w, + input_schema, + |expr| { + if expr.eq(t) { + TriStateBool::True + } else { + TriStateBool::Uncertain + } + }, + ); + + match const_result { + // Const evaluation was inconclusive or determined the branch + // would be taken + None | Some(TriStateBool::True) => Some(Ok(())), + // Const evaluation proves the branch will never be taken. + // The most common pattern for this is + // `WHEN x IS NOT NULL THEN x`. + Some(TriStateBool::False) + | Some(TriStateBool::Uncertain) => None, + } + } + } + }) + .collect::>>()? + .is_empty(); + + if any_nullable_thens { + // There is at least one reachable nullable then Ok(true) } else if let Some(e) = &case.else_expr { e.nullable(input_schema) @@ -777,7 +815,7 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result Result<()> { + let nullable_schema = MockExprSchema::new() + .with_data_type(DataType::Int32) + .with_nullable(true); + + let not_nullable_schema = MockExprSchema::new() + .with_data_type(DataType::Int32) + .with_nullable(false); + + // CASE WHEN x IS NOT NULL THEN x ELSE 0 + let e = when(col("x").is_not_null(), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN NOT x IS NULL THEN x ELSE 0 + let e = when(not(col("x").is_null()), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN X = 5 THEN x ELSE 0 + let e = when(col("x").eq(lit(5)), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS NOT NULL AND x = 5 THEN x ELSE 0 + let e = when(and(col("x").is_not_null(), col("x").eq(lit(5))), col("x")) + .otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x = 5 AND x IS NOT NULL THEN x ELSE 0 + let e = when(and(col("x").eq(lit(5)), col("x").is_not_null()), col("x")) + .otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS NOT NULL OR x = 5 THEN x ELSE 0 + let e = when(or(col("x").is_not_null(), col("x").eq(lit(5))), col("x")) + .otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x = 5 OR x IS NOT NULL THEN x ELSE 0 + let e = when(or(col("x").eq(lit(5)), col("x").is_not_null()), col("x")) + .otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN (x = 5 AND x IS NOT NULL) OR (x = bar AND x IS NOT NULL) THEN x ELSE 0 + let e = when( + or( + and(col("x").eq(lit(5)), col("x").is_not_null()), + and(col("x").eq(col("bar")), col("x").is_not_null()), + ), + col("x"), + ) + .otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x = 5 OR x IS NULL THEN x ELSE 0 + let e = when(or(col("x").eq(lit(5)), col("x").is_null()), col("x")) + .otherwise(lit(0))?; + assert_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS TRUE THEN x ELSE 0 + let e = when(col("x").is_true(), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS NOT TRUE THEN x ELSE 0 + let e = when(col("x").is_not_true(), col("x")).otherwise(lit(0))?; + assert_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS FALSE THEN x ELSE 0 + let e = when(col("x").is_false(), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS NOT FALSE THEN x ELSE 0 + let e = when(col("x").is_not_false(), col("x")).otherwise(lit(0))?; + assert_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS UNKNOWN THEN x ELSE 0 + let e = when(col("x").is_unknown(), col("x")).otherwise(lit(0))?; + assert_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS NOT UNKNOWN THEN x ELSE 0 + let e = when(col("x").is_not_unknown(), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x LIKE 'x' THEN x ELSE 0 + let e = when(col("x").like(lit("x")), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN 0 THEN x ELSE 0 + let e = when(lit(0), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN 1 THEN x ELSE 0 + let e = when(lit(1), col("x")).otherwise(lit(0))?; + assert_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + Ok(()) + } + #[test] fn test_inlist_nullability() { let get_schema = |nullable| { diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 346d373ff5b4..2345d662895f 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -69,6 +69,7 @@ pub mod async_udf; pub mod statistics { pub use datafusion_expr_common::statistics::*; } +mod predicate_eval; pub mod ptr_eq; pub mod test; pub mod tree_node; diff --git a/datafusion/expr/src/predicate_eval.rs b/datafusion/expr/src/predicate_eval.rs new file mode 100644 index 000000000000..238e0d711a55 --- /dev/null +++ b/datafusion/expr/src/predicate_eval.rs @@ -0,0 +1,259 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::predicate_eval::TriStateBool::{False, True, Uncertain}; +use crate::{BinaryExpr, Expr, ExprSchemable}; +use arrow::datatypes::DataType; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_common::{DataFusionError, ExprSchema, ScalarValue}; +use datafusion_expr_common::operator::Operator; + +/// Represents the possible values for SQL's three valued logic. +/// `Option` is not used for this since `None` is used by +/// [const_eval_predicate] to represent inconclusive answers. +pub(super) enum TriStateBool { + True, + False, + Uncertain, +} + +impl TryFrom<&ScalarValue> for TriStateBool { + type Error = DataFusionError; + + fn try_from(value: &ScalarValue) -> Result { + match value { + ScalarValue::Null => { + // Literal null is equivalent to boolean uncertain + Ok(Uncertain) + } + ScalarValue::Boolean(b) => Ok(match b { + Some(true) => True, + Some(false) => False, + None => Uncertain, + }), + _ => Self::try_from(&value.cast_to(&DataType::Boolean)?), + } + } +} + +/// Attempts to partially constant-evaluate a predicate under SQL three-valued logic. +/// +/// Semantics of the return value: +/// - `Some(True)` => predicate is provably true +/// - `Some(False)` => predicate is provably false +/// - `Some(Uncertain)` => predicate is provably unknown (i.e., can only be NULL) +/// - `None` => inconclusive with available static information +/// +/// The evaluation is conservative and only uses: +/// - Expression nullability from `input_schema` +/// - Simple type checks (e.g. whether an expression is Boolean) +/// - Syntactic patterns (IS NULL/IS NOT NULL/IS TRUE/IS FALSE/etc.) +/// - Three-valued boolean algebra for AND/OR/NOT +/// +/// It does not evaluate user-defined functions. +pub(super) fn const_eval_predicate( + predicate: &Expr, + input_schema: &dyn ExprSchema, + evaluates_to_null: F, +) -> Option +where + F: Fn(&Expr) -> TriStateBool, +{ + PredicateConstEvaluator { + input_schema, + evaluates_to_null, + } + .const_eval_predicate(predicate) +} + +pub(super) struct PredicateConstEvaluator<'a, F> { + input_schema: &'a dyn ExprSchema, + evaluates_to_null: F, +} + +impl PredicateConstEvaluator<'_, F> +where + F: Fn(&Expr) -> TriStateBool, +{ + fn const_eval_predicate(&self, predicate: &Expr) -> Option { + match predicate { + Expr::Literal(scalar, _) => TriStateBool::try_from(scalar).ok(), + Expr::IsNotNull(e) => { + if let Ok(false) = e.nullable(self.input_schema) { + // If `e` is not nullable, then `e IS NOT NULL` is always true + return Some(True); + } + + match e.get_type(self.input_schema) { + Ok(DataType::Boolean) => match self.const_eval_predicate(e) { + Some(True) | Some(False) => Some(True), + Some(Uncertain) => Some(False), + None => None, + }, + Ok(_) => match self.is_null(e) { + True => Some(False), + False => Some(True), + Uncertain => None, + }, + Err(_) => None, + } + } + Expr::IsNull(e) => { + if let Ok(false) = e.nullable(self.input_schema) { + // If `e` is not nullable, then `e IS NULL` is always false + return Some(False); + } + + match e.get_type(self.input_schema) { + Ok(DataType::Boolean) => match self.const_eval_predicate(e) { + Some(True) | Some(False) => Some(False), + Some(Uncertain) => Some(True), + None => None, + }, + Ok(_) => match self.is_null(e) { + True => Some(True), + False => Some(False), + Uncertain => None, + }, + Err(_) => None, + } + } + Expr::IsTrue(e) => match self.const_eval_predicate(e) { + Some(True) => Some(True), + Some(_) => Some(False), + None => None, + }, + Expr::IsNotTrue(e) => match self.const_eval_predicate(e) { + Some(True) => Some(False), + Some(_) => Some(True), + None => None, + }, + Expr::IsFalse(e) => match self.const_eval_predicate(e) { + Some(False) => Some(True), + Some(_) => Some(False), + None => None, + }, + Expr::IsNotFalse(e) => match self.const_eval_predicate(e) { + Some(False) => Some(False), + Some(_) => Some(True), + None => None, + }, + Expr::IsUnknown(e) => match self.const_eval_predicate(e) { + Some(Uncertain) => Some(True), + Some(_) => Some(False), + None => None, + }, + Expr::IsNotUnknown(e) => match self.const_eval_predicate(e) { + Some(Uncertain) => Some(False), + Some(_) => Some(True), + None => None, + }, + Expr::Not(e) => match self.const_eval_predicate(e) { + Some(True) => Some(False), + Some(False) => Some(True), + Some(Uncertain) => Some(Uncertain), + None => None, + }, + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::And, + right, + }) => { + match ( + self.const_eval_predicate(left), + self.const_eval_predicate(right), + ) { + (Some(False), _) | (_, Some(False)) => Some(False), + (Some(True), Some(True)) => Some(True), + (Some(Uncertain), Some(_)) | (Some(_), Some(Uncertain)) => { + Some(Uncertain) + } + _ => None, + } + } + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Or, + right, + }) => { + match ( + self.const_eval_predicate(left), + self.const_eval_predicate(right), + ) { + (Some(True), _) | (_, Some(True)) => Some(True), + (Some(False), Some(False)) => Some(False), + (Some(Uncertain), Some(_)) | (Some(_), Some(Uncertain)) => { + Some(Uncertain) + } + _ => None, + } + } + e => match self.is_null(e) { + True => Some(Uncertain), + _ => None, + }, + } + } + + /// Determines if the given expression evaluates to `NULL`. + /// + /// This function returns: + /// - `True` if `expr` is provably `NULL` + /// - `False` if `expr` can provably not `NULL` + /// - `Uncertain` if the result is inconclusive + fn is_null(&self, expr: &Expr) -> TriStateBool { + match expr { + Expr::Literal(ScalarValue::Null, _) => { + // Literal null is obviously null + True + } + Expr::Alias(_) + | Expr::Between(_) + | Expr::BinaryExpr(_) + | Expr::Cast(_) + | Expr::Like(_) + | Expr::Negative(_) + | Expr::Not(_) + | Expr::SimilarTo(_) => { + // These expressions are null if any of their direct children is null + // If any child is inconclusive, the result for this expression is also inconclusive + let mut is_null = False; + let _ = expr.apply_children(|child| match self.is_null(child) { + True => { + is_null = True; + Ok(TreeNodeRecursion::Stop) + } + False => Ok(TreeNodeRecursion::Continue), + Uncertain => { + is_null = Uncertain; + Ok(TreeNodeRecursion::Stop) + } + }); + is_null + } + e => { + if let Ok(false) = e.nullable(self.input_schema) { + // If `expr` is not nullable, we can be certain `expr` is not null + False + } else { + // Finally, ask the callback if it knows the nullability of `expr` + (self.evaluates_to_null)(e) + } + } + } + } +} diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index d14146a20d8b..06c177d77eb8 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -15,12 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::expressions::try_cast; +use crate::expressions::{try_cast, BinaryExpr, IsNotNullExpr, IsNullExpr, NotExpr}; use crate::PhysicalExpr; -use std::borrow::Cow; -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - use arrow::array::*; use arrow::compute::kernels::zip::zip; use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter}; @@ -30,8 +26,12 @@ use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; use datafusion_expr::ColumnarValue; +use std::borrow::Cow; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; use super::{Column, Literal}; +use datafusion_expr_common::operator::Operator; use datafusion_physical_expr_common::datum::compare_with_eq; use itertools::Itertools; @@ -488,6 +488,11 @@ impl PhysicalExpr for CaseExpr { let then_nullable = self .when_then_expr .iter() + .filter(|(w, t)| { + // Disregard branches where we can determine statically that the predicate + // is always false when the then expression would evaluate to null + const_result_when_value_is_null(w.as_ref(), t.as_ref()).unwrap_or(true) + }) .map(|(_, t)| t.nullable(input_schema)) .collect::>>()?; if then_nullable.contains(&true) { @@ -595,6 +600,54 @@ impl PhysicalExpr for CaseExpr { } } +/// Determines if the given `predicate` can be const evaluated if `value` were to evaluate to `NULL`. +/// Returns a `Some` value containing the const result if so; otherwise returns `None`. +fn const_result_when_value_is_null( + predicate: &dyn PhysicalExpr, + value: &dyn PhysicalExpr, +) -> Option { + let predicate_any = predicate.as_any(); + if let Some(not_null) = predicate_any.downcast_ref::() { + if not_null.arg().as_ref().dyn_eq(value) { + Some(false) + } else { + None + } + } else if let Some(null) = predicate_any.downcast_ref::() { + if null.arg().as_ref().dyn_eq(value) { + Some(true) + } else { + None + } + } else if let Some(not) = predicate_any.downcast_ref::() { + const_result_when_value_is_null(not.arg().as_ref(), value).map(|b| !b) + } else if let Some(binary) = predicate_any.downcast_ref::() { + match binary.op() { + Operator::And => { + let l = const_result_when_value_is_null(binary.left().as_ref(), value); + let r = const_result_when_value_is_null(binary.right().as_ref(), value); + match (l, r) { + (Some(l), Some(r)) => Some(l && r), + (Some(l), None) => Some(l), + (None, Some(r)) => Some(r), + _ => None, + } + } + Operator::Or => { + let l = const_result_when_value_is_null(binary.left().as_ref(), value); + let r = const_result_when_value_is_null(binary.right().as_ref(), value); + match (l, r) { + (Some(l), Some(r)) => Some(l || r), + _ => None, + } + } + _ => None, + } + } else { + None + } +} + /// Create a CASE expression pub fn case( expr: Option>, @@ -608,7 +661,8 @@ pub fn case( mod tests { use super::*; - use crate::expressions::{binary, cast, col, lit, BinaryExpr}; + use crate::expressions; + use crate::expressions::{binary, cast, col, is_not_null, lit, BinaryExpr}; use arrow::buffer::Buffer; use arrow::datatypes::DataType::Float64; use arrow::datatypes::Field; @@ -616,7 +670,6 @@ mod tests { use datafusion_common::plan_err; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_expr::type_coercion::binary::comparison_coercion; - use datafusion_expr::Operator; use datafusion_physical_expr_common::physical_expr::fmt_sql; #[test] @@ -1442,4 +1495,188 @@ mod tests { Ok(()) } + + fn when_then_else( + when: &Arc, + then: &Arc, + els: &Arc, + ) -> Result> { + let case = CaseExpr::try_new( + None, + vec![(Arc::clone(when), Arc::clone(then))], + Some(Arc::clone(els)), + )?; + Ok(Arc::new(case)) + } + + #[test] + fn test_case_expression_nullability_with_nullable_column() -> Result<()> { + case_expression_nullability(true) + } + + #[test] + fn test_case_expression_nullability_with_not_nullable_column() -> Result<()> { + case_expression_nullability(false) + } + + fn case_expression_nullability(col_is_nullable: bool) -> Result<()> { + let schema = + Schema::new(vec![Field::new("foo", DataType::Int32, col_is_nullable)]); + + let foo = col("foo", &schema)?; + let foo_is_not_null = is_not_null(Arc::clone(&foo))?; + let foo_is_null = expressions::is_null(Arc::clone(&foo))?; + let not_foo_is_null = expressions::not(Arc::clone(&foo_is_null))?; + let zero = lit(0); + let foo_eq_zero = + binary(Arc::clone(&foo), Operator::Eq, Arc::clone(&zero), &schema)?; + + assert_not_nullable(when_then_else(&foo_is_not_null, &foo, &zero)?, &schema); + assert_not_nullable(when_then_else(¬_foo_is_null, &foo, &zero)?, &schema); + assert_nullability( + when_then_else(&foo_eq_zero, &foo, &zero)?, + &schema, + col_is_nullable, + ); + + assert_not_nullable( + when_then_else( + &binary( + Arc::clone(&foo_is_not_null), + Operator::And, + Arc::clone(&foo_eq_zero), + &schema, + )?, + &foo, + &zero, + )?, + &schema, + ); + + assert_not_nullable( + when_then_else( + &binary( + Arc::clone(&foo_eq_zero), + Operator::And, + Arc::clone(&foo_is_not_null), + &schema, + )?, + &foo, + &zero, + )?, + &schema, + ); + + assert_nullability( + when_then_else( + &binary( + Arc::clone(&foo_is_not_null), + Operator::Or, + Arc::clone(&foo_eq_zero), + &schema, + )?, + &foo, + &zero, + )?, + &schema, + col_is_nullable, + ); + + assert_nullability( + when_then_else( + &binary( + Arc::clone(&foo_eq_zero), + Operator::Or, + Arc::clone(&foo_is_not_null), + &schema, + )?, + &foo, + &zero, + )?, + &schema, + col_is_nullable, + ); + + assert_nullability( + when_then_else( + &binary( + Arc::clone(&foo_is_not_null), + Operator::Or, + Arc::clone(&foo_eq_zero), + &schema, + )?, + &foo, + &zero, + )?, + &schema, + col_is_nullable, + ); + + assert_nullability( + when_then_else( + &binary( + binary(Arc::clone(&foo), Operator::Eq, Arc::clone(&zero), &schema)?, + Operator::Or, + Arc::clone(&foo_is_not_null), + &schema, + )?, + &foo, + &zero, + )?, + &schema, + col_is_nullable, + ); + + assert_not_nullable( + when_then_else( + &binary( + binary( + binary( + Arc::clone(&foo), + Operator::Eq, + Arc::clone(&zero), + &schema, + )?, + Operator::And, + Arc::clone(&foo_is_not_null), + &schema, + )?, + Operator::Or, + binary( + binary( + Arc::clone(&foo), + Operator::Eq, + Arc::clone(&foo), + &schema, + )?, + Operator::And, + Arc::clone(&foo_is_not_null), + &schema, + )?, + &schema, + )?, + &foo, + &zero, + )?, + &schema, + ); + + Ok(()) + } + + fn assert_not_nullable(expr: Arc, schema: &Schema) { + assert!(!expr.nullable(schema).unwrap()); + } + + fn assert_nullable(expr: Arc, schema: &Schema) { + assert!(expr.nullable(schema).unwrap()); + } + + fn assert_nullability(expr: Arc, schema: &Schema, nullable: bool) { + if nullable { + assert_nullable(expr, schema); + } else { + assert_not_nullable(expr, schema); + } + } }