Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion datafusion/core/tests/tpcds_planning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1052,9 +1052,12 @@ async fn regression_test(query_no: u8, create_physical: bool) -> Result<()> {
for sql in &sql {
let df = ctx.sql(sql).await?;
let (state, plan) = df.into_parts();
let plan = state.optimize(&plan)?;
if create_physical {
let _ = state.create_physical_plan(&plan).await?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still dont' get this part. Previously physical plan was created on top of optimized logical, why this step is skipped? 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_physical_plan calls optimize itself already. Due to the fact that the logical plan was being optimised twice the logical vs physical schema mismatch was being hidden.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To see what I'm talking about, comment out this line on main and run tpcds_physical_q75. You'll get

Error: Internal("Physical input schema should be the same as the one converted from logical input schema. Differences: \n\t- field nullability at index 5 [sales_cnt]: (physical) true vs (logical) false\n\t- field nullability at index 6 [sales_amt]: (physical) true vs (logical) false")

} else {
// Run the logical optimizer even if we are not creating the physical plan
// to ensure it will properly succeed
let _ = state.optimize(&plan)?;
}
}

Expand Down
5 changes: 5 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,11 @@ pub fn is_null(expr: Expr) -> Expr {
Expr::IsNull(Box::new(expr))
}

/// Create is not null expression
pub fn is_not_null(expr: Expr) -> Expr {
Expr::IsNotNull(Box::new(expr))
}

/// Create is true expression
pub fn is_true(expr: Expr) -> Expr {
Expr::IsTrue(Box::new(expr))
Expand Down
183 changes: 176 additions & 7 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
// specific language governing permissions and limitations
// under the License.

use super::{Between, Expr, Like};
use super::{predicate_eval, Between, Expr, Like};
use crate::expr::{
AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, FieldMetadata,
InList, InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction,
WindowFunctionParams,
};
use crate::predicate_eval::TriStateBool;
use crate::type_coercion::functions::{
data_types_with_scalar_udf, fields_with_aggregate_udf, fields_with_window_udf,
};
Expand Down Expand Up @@ -279,13 +280,50 @@ impl ExprSchemable for Expr {
Expr::OuterReferenceColumn(field, _) => Ok(field.is_nullable()),
Expr::Literal(value, _) => Ok(value.is_null()),
Expr::Case(case) => {
// This expression is nullable if any of the input expressions are nullable
let then_nullable = case
// This expression is nullable if any of the then expressions are nullable
let any_nullable_thens = !case
.when_then_expr
.iter()
.map(|(_, t)| t.nullable(input_schema))
.collect::<Result<Vec<_>>>()?;
if then_nullable.contains(&true) {
.filter_map(|(w, t)| {
match t.nullable(input_schema) {
// Branches with a then expression that is not nullable can be skipped
Ok(false) => None,
// Pass error determining nullability on verbatim
Err(e) => Some(Err(e)),
// For branches with a nullable then expressions try to determine
// using limited const evaluation if the branch will be taken when
// the then expression evaluates to null.
Ok(true) => {
let const_result = predicate_eval::const_eval_predicate(
w,
input_schema,
|expr| {
if expr.eq(t) {
TriStateBool::True
} else {
TriStateBool::Uncertain
}
},
);

match const_result {
// Const evaluation was inconclusive or determined the branch
// would be taken
None | Some(TriStateBool::True) => Some(Ok(())),
// Const evaluation proves the branch will never be taken.
// The most common pattern for this is
// `WHEN x IS NOT NULL THEN x`.
Some(TriStateBool::False)
| Some(TriStateBool::Uncertain) => None,
}
}
}
})
.collect::<Result<Vec<_>>>()?
.is_empty();

if any_nullable_thens {
// There is at least one reachable nullable then
Ok(true)
} else if let Some(e) = &case.else_expr {
e.nullable(input_schema)
Expand Down Expand Up @@ -777,7 +815,7 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subq
#[cfg(test)]
mod tests {
use super::*;
use crate::{col, lit, out_ref_col_with_metadata};
use crate::{and, col, lit, not, or, out_ref_col_with_metadata, when};

use datafusion_common::{internal_err, DFSchema, HashMap, ScalarValue};

Expand Down Expand Up @@ -830,6 +868,137 @@ mod tests {
assert!(expr.nullable(&get_schema(false)).unwrap());
}

fn assert_nullability(expr: &Expr, schema: &dyn ExprSchema, expected: bool) {
assert_eq!(
expr.nullable(schema).unwrap(),
expected,
"Nullability of '{expr}' should be {expected}"
);
}

fn assert_not_nullable(expr: &Expr, schema: &dyn ExprSchema) {
assert_nullability(expr, schema, false);
}

fn assert_nullable(expr: &Expr, schema: &dyn ExprSchema) {
assert_nullability(expr, schema, true);
}

#[test]
fn test_case_expression_nullability() -> Result<()> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is an impressive list of tests

let nullable_schema = MockExprSchema::new()
.with_data_type(DataType::Int32)
.with_nullable(true);

let not_nullable_schema = MockExprSchema::new()
.with_data_type(DataType::Int32)
.with_nullable(false);

// CASE WHEN x IS NOT NULL THEN x ELSE 0
let e = when(col("x").is_not_null(), col("x")).otherwise(lit(0))?;
assert_not_nullable(&e, &nullable_schema);
assert_not_nullable(&e, &not_nullable_schema);

// CASE WHEN NOT x IS NULL THEN x ELSE 0
let e = when(not(col("x").is_null()), col("x")).otherwise(lit(0))?;
assert_not_nullable(&e, &nullable_schema);
assert_not_nullable(&e, &not_nullable_schema);

// CASE WHEN X = 5 THEN x ELSE 0
let e = when(col("x").eq(lit(5)), col("x")).otherwise(lit(0))?;
assert_not_nullable(&e, &nullable_schema);
assert_not_nullable(&e, &not_nullable_schema);

// CASE WHEN x IS NOT NULL AND x = 5 THEN x ELSE 0
let e = when(and(col("x").is_not_null(), col("x").eq(lit(5))), col("x"))
.otherwise(lit(0))?;
assert_not_nullable(&e, &nullable_schema);
assert_not_nullable(&e, &not_nullable_schema);

// CASE WHEN x = 5 AND x IS NOT NULL THEN x ELSE 0
let e = when(and(col("x").eq(lit(5)), col("x").is_not_null()), col("x"))
.otherwise(lit(0))?;
assert_not_nullable(&e, &nullable_schema);
assert_not_nullable(&e, &not_nullable_schema);

// CASE WHEN x IS NOT NULL OR x = 5 THEN x ELSE 0
let e = when(or(col("x").is_not_null(), col("x").eq(lit(5))), col("x"))
.otherwise(lit(0))?;
assert_not_nullable(&e, &nullable_schema);
assert_not_nullable(&e, &not_nullable_schema);

// CASE WHEN x = 5 OR x IS NOT NULL THEN x ELSE 0
let e = when(or(col("x").eq(lit(5)), col("x").is_not_null()), col("x"))
.otherwise(lit(0))?;
assert_not_nullable(&e, &nullable_schema);
assert_not_nullable(&e, &not_nullable_schema);

// CASE WHEN (x = 5 AND x IS NOT NULL) OR (x = bar AND x IS NOT NULL) THEN x ELSE 0
let e = when(
or(
and(col("x").eq(lit(5)), col("x").is_not_null()),
and(col("x").eq(col("bar")), col("x").is_not_null()),
),
col("x"),
)
.otherwise(lit(0))?;
assert_not_nullable(&e, &nullable_schema);
assert_not_nullable(&e, &not_nullable_schema);

// CASE WHEN x = 5 OR x IS NULL THEN x ELSE 0
let e = when(or(col("x").eq(lit(5)), col("x").is_null()), col("x"))
.otherwise(lit(0))?;
assert_nullable(&e, &nullable_schema);
assert_not_nullable(&e, &not_nullable_schema);

// CASE WHEN x IS TRUE THEN x ELSE 0
let e = when(col("x").is_true(), col("x")).otherwise(lit(0))?;
assert_not_nullable(&e, &nullable_schema);
assert_not_nullable(&e, &not_nullable_schema);

// CASE WHEN x IS NOT TRUE THEN x ELSE 0
let e = when(col("x").is_not_true(), col("x")).otherwise(lit(0))?;
assert_nullable(&e, &nullable_schema);
assert_not_nullable(&e, &not_nullable_schema);

// CASE WHEN x IS FALSE THEN x ELSE 0
let e = when(col("x").is_false(), col("x")).otherwise(lit(0))?;
assert_not_nullable(&e, &nullable_schema);
assert_not_nullable(&e, &not_nullable_schema);

// CASE WHEN x IS NOT FALSE THEN x ELSE 0
let e = when(col("x").is_not_false(), col("x")).otherwise(lit(0))?;
assert_nullable(&e, &nullable_schema);
assert_not_nullable(&e, &not_nullable_schema);

// CASE WHEN x IS UNKNOWN THEN x ELSE 0
let e = when(col("x").is_unknown(), col("x")).otherwise(lit(0))?;
assert_nullable(&e, &nullable_schema);
assert_not_nullable(&e, &not_nullable_schema);

// CASE WHEN x IS NOT UNKNOWN THEN x ELSE 0
let e = when(col("x").is_not_unknown(), col("x")).otherwise(lit(0))?;
assert_not_nullable(&e, &nullable_schema);
assert_not_nullable(&e, &not_nullable_schema);

// CASE WHEN x LIKE 'x' THEN x ELSE 0
let e = when(col("x").like(lit("x")), col("x")).otherwise(lit(0))?;
assert_not_nullable(&e, &nullable_schema);
assert_not_nullable(&e, &not_nullable_schema);

// CASE WHEN 0 THEN x ELSE 0
let e = when(lit(0), col("x")).otherwise(lit(0))?;
assert_not_nullable(&e, &nullable_schema);
assert_not_nullable(&e, &not_nullable_schema);

// CASE WHEN 1 THEN x ELSE 0
let e = when(lit(1), col("x")).otherwise(lit(0))?;
assert_nullable(&e, &nullable_schema);
assert_not_nullable(&e, &not_nullable_schema);

Ok(())
}

#[test]
fn test_inlist_nullability() {
let get_schema = |nullable| {
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ pub mod async_udf;
pub mod statistics {
pub use datafusion_expr_common::statistics::*;
}
mod predicate_eval;
pub mod ptr_eq;
pub mod test;
pub mod tree_node;
Expand Down
Loading