diff --git a/datafusion/core/src/execution/session_state_defaults.rs b/datafusion/core/src/execution/session_state_defaults.rs index a241738bd3a42..a1a4e19cd2afe 100644 --- a/datafusion/core/src/execution/session_state_defaults.rs +++ b/datafusion/core/src/execution/session_state_defaults.rs @@ -81,6 +81,7 @@ impl SessionStateDefaults { default_catalog } + /// NOTE! /// returns the list of default [`ExprPlanner`]s pub fn default_expr_planners() -> Vec> { let expr_planners: Vec> = vec![ diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 9264a2940dd1b..fec3e25c6a665 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -178,124 +178,129 @@ impl<'a> BinaryTypeCoercer<'a> { use arrow::datatypes::DataType::*; use Operator::*; let result = match self.op { - Eq | - NotEq | - Lt | - LtEq | - Gt | - GtEq | - IsDistinctFrom | - IsNotDistinctFrom => { - comparison_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { - plan_datafusion_err!( - "Cannot infer common argument type for comparison operation {} {} {}", - self.lhs, - self.op, - self.rhs - ) - }) - } - And | Or => if matches!((lhs, rhs), (Boolean | Null, Boolean | Null)) { - // Logical binary boolean operators can only be evaluated for - // boolean or null arguments. - Ok(Signature::uniform(Boolean)) - } else { - plan_err!( - "Cannot infer common argument type for logical boolean operation {} {} {}", self.lhs, self.op, self.rhs - ) - } - RegexMatch | RegexIMatch | RegexNotMatch | RegexNotIMatch => { - regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { - plan_datafusion_err!( - "Cannot infer common argument type for regex operation {} {} {}", self.lhs, self.op, self.rhs - ) - }) - } - LikeMatch | ILikeMatch | NotLikeMatch | NotILikeMatch => { - regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { - plan_datafusion_err!( - "Cannot infer common argument type for regex operation {} {} {}", self.lhs, self.op, self.rhs - ) - }) - } - BitwiseAnd | BitwiseOr | BitwiseXor | BitwiseShiftRight | BitwiseShiftLeft => { - bitwise_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| { - plan_datafusion_err!( - "Cannot infer common type for bitwise operation {} {} {}", self.lhs, self.op, self.rhs - ) - }) - } - StringConcat => { - string_concat_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| { - plan_datafusion_err!( - "Cannot infer common string type for string concat operation {} {} {}", self.lhs, self.op, self.rhs - ) - }) - } - AtArrow | ArrowAt => { - // Array contains or search (similar to LIKE) operation - array_coercion(lhs, rhs) - .or_else(|| like_coercion(lhs, rhs)).map(Signature::comparison).ok_or_else(|| { + Eq | NotEq | Lt | LtEq | Gt | GtEq | IsDistinctFrom | IsNotDistinctFrom => { + comparison_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { plan_datafusion_err!( - "Cannot infer common argument type for operation {} {} {}", self.lhs, self.op, self.rhs + "Cannot infer common argument type for comparison operation {} {} {}", + self.lhs, + self.op, + self.rhs ) }) - } - AtAt => { - // text search has similar signature to LIKE - like_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { - plan_datafusion_err!( - "Cannot infer common argument type for AtAt operation {} {} {}", self.lhs, self.op, self.rhs + } + And | Or => if matches!((lhs, rhs), (Boolean | Null, Boolean | Null)) { + // Logical binary boolean operators can only be evaluated for + // boolean or null arguments. + Ok(Signature::uniform(Boolean)) + } else { + plan_err!( + "Cannot infer common argument type for logical boolean operation {} {} {}", self.lhs, self.op, self.rhs ) - }) - } - Plus | Minus | Multiply | Divide | Modulo => { - if let Ok(ret) = self.get_result(lhs, rhs) { - // Temporal arithmetic, e.g. Date32 + Interval - Ok(Signature{ - lhs: lhs.clone(), - rhs: rhs.clone(), - ret, + } + RegexMatch | RegexIMatch | RegexNotMatch | RegexNotIMatch => { + regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { + plan_datafusion_err!( + "Cannot infer common argument type for regex operation {} {} {}", self.lhs, self.op, self.rhs + ) }) - } else if let Some(coerced) = temporal_coercion_strict_timezone(lhs, rhs) { - // Temporal arithmetic by first coercing to a common time representation - // e.g. Date32 - Timestamp - let ret = self.get_result(&coerced, &coerced).map_err(|e| { + } + LikeMatch | ILikeMatch | NotLikeMatch | NotILikeMatch => { + regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { plan_datafusion_err!( - "Cannot get result type for temporal operation {coerced} {} {coerced}: {e}", self.op + "Cannot infer common argument type for regex operation {} {} {}", self.lhs, self.op, self.rhs ) - })?; - Ok(Signature{ - lhs: coerced.clone(), - rhs: coerced, - ret, }) - } else if let Some((lhs, rhs)) = math_decimal_coercion(lhs, rhs) { - // Decimal arithmetic, e.g. Decimal(10, 2) + Decimal(10, 0) - let ret = self.get_result(&lhs, &rhs).map_err(|e| { + } + BitwiseAnd | BitwiseOr | BitwiseXor | BitwiseShiftRight | BitwiseShiftLeft => { + bitwise_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| { plan_datafusion_err!( - "Cannot get result type for decimal operation {} {} {}: {e}", self.lhs, self.op, self.rhs + "Cannot infer common type for bitwise operation {} {} {}", self.lhs, self.op, self.rhs ) - })?; - Ok(Signature{ - lhs, - rhs, - ret, }) - } else if let Some(numeric) = mathematics_numerical_coercion(lhs, rhs) { - // Numeric arithmetic, e.g. Int32 + Int32 - Ok(Signature::uniform(numeric)) - } else { - plan_err!( - "Cannot coerce arithmetic expression {} {} {} to valid types", self.lhs, self.op, self.rhs - ) } - }, - IntegerDivide | Arrow | LongArrow | HashArrow | HashLongArrow - | HashMinus | AtQuestion | Question | QuestionAnd | QuestionPipe => { - not_impl_err!("Operator {} is not yet supported", self.op) - } - }; + StringConcat => { + string_concat_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| { + plan_datafusion_err!( + "Cannot infer common string type for string concat operation {} {} {}", self.lhs, self.op, self.rhs + ) + }) + } + AtArrow | ArrowAt => { + // Array contains or search (similar to LIKE) operation + array_coercion(lhs, rhs) + .or_else(|| like_coercion(lhs, rhs)).map(Signature::comparison).ok_or_else(|| { + plan_datafusion_err!( + "Cannot infer common argument type for operation {} {} {}", self.lhs, self.op, self.rhs + ) + }) + } + AtAt => { + // text search has similar signature to LIKE + like_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { + plan_datafusion_err!( + "Cannot infer common argument type for AtAt operation {} {} {}", self.lhs, self.op, self.rhs + ) + }) + } + Plus | Minus | Multiply | Divide | Modulo => { + if let Ok(ret) = self.get_result(lhs, rhs) { + // Temporal arithmetic, e.g. Date32 + Interval + Ok(Signature{ + lhs: lhs.clone(), + rhs: rhs.clone(), + ret, + }) + } else if let Some(coerced) = temporal_coercion_strict_timezone(lhs, rhs) { + // Temporal arithmetic by first coercing to a common time representation + // e.g. Date32 - Timestamp + let ret = self.get_result(&coerced, &coerced).map_err(|e| { + plan_datafusion_err!( + "Cannot get result type for temporal operation {coerced} {} {coerced}: {e}", self.op + ) + })?; + Ok(Signature{ + lhs: coerced.clone(), + rhs: coerced, + ret, + }) + } else if let Some((lhs, rhs)) = temporal_coercion_resolve_ints_to_intervals(lhs, rhs) { + // e.g. Date32 + Int32 + let ret = self.get_result(&lhs, &rhs).map_err(|e| { + plan_datafusion_err!( + "Cannot get result type for temporal operation {} {} {}: {e}", self.lhs, self.op, self.rhs + ) + })?; + Ok(Signature{ + lhs: lhs.clone(), + rhs: rhs.clone(), + ret, + }) + } else if let Some((lhs, rhs)) = math_decimal_coercion(lhs, rhs) { + // decimal arithmetic, e.g. Decimal(10, 2) + Decimal(10, 0) + let ret = self.get_result(&lhs, &rhs).map_err(|e| { + plan_datafusion_err!( + "Cannot get result type for decimal operation {} {} {}: {e}", self.lhs, self.op, self.rhs + ) + })?; + Ok(Signature{ + lhs, + rhs, + ret, + }) + } else if let Some(numeric) = mathematics_numerical_coercion(lhs, rhs) { + // Numeric arithmetic, e.g. Int32 + Int32 + Ok(Signature::uniform(numeric)) + } else { + plan_err!( + "Cannot coerce arithmetic expression {} {} {} to valid types", self.lhs, self.op, self.rhs + ) + } + }, + IntegerDivide | Arrow | LongArrow | HashArrow | HashLongArrow | HashMinus | AtQuestion | Question | QuestionAnd | QuestionPipe => { + not_impl_err!("Operator {} is not yet supported", self.op) + } + }; + result.map_err(|err| { let diagnostic = Diagnostic::new_error("expressions have incompatible types", self.span()) @@ -1433,6 +1438,18 @@ fn temporal_coercion_nonstrict_timezone( } } +fn temporal_coercion_resolve_ints_to_intervals( + lhs: &DataType, + rhs: &DataType, +) -> Option<(DataType, DataType)> { + use arrow::datatypes::DataType::{Date32, Int32, Int64, Interval}; + use arrow::datatypes::IntervalUnit::DayTime; + match (lhs, rhs) { + (Date32, Int32 | Int64) => Some((lhs.clone(), Interval(DayTime))), + _ => None, + } +} + /// Strict Timezone coercion is useful in scenarios where we cannot guarantee a stable relationship /// between two timestamps with different timezones or do not want implicit coercion between them. /// diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs index 369eaecb1905f..cd99d54efdb95 100644 --- a/datafusion/functions-nested/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -30,6 +30,7 @@ use datafusion_expr::{ use datafusion_functions::core::get_field as get_field_inner; use datafusion_functions::expr_fn::get_field; use datafusion_functions_aggregate::nth_value::nth_value_udaf; +use sqlparser::ast::BinaryOperator; use std::sync::Arc; use crate::map::map_udf; @@ -51,7 +52,7 @@ impl ExprPlanner for NestedFunctionPlanner { ) -> Result> { let RawBinaryExpr { op, left, right } = expr; - if op == sqlparser::ast::BinaryOperator::StringConcat { + if op == BinaryOperator::StringConcat { let left_type = left.get_type(schema)?; let right_type = right.get_type(schema)?; let left_list_ndims = list_ndims(&left_type); @@ -75,18 +76,14 @@ impl ExprPlanner for NestedFunctionPlanner { } else if left_list_ndims < right_list_ndims { return Ok(PlannerResult::Planned(array_prepend(left, right))); } - } else if matches!( - op, - sqlparser::ast::BinaryOperator::AtArrow - | sqlparser::ast::BinaryOperator::ArrowAt - ) { + } else if matches!(op, BinaryOperator::AtArrow | BinaryOperator::ArrowAt) { let left_type = left.get_type(schema)?; let right_type = right.get_type(schema)?; let left_list_ndims = list_ndims(&left_type); let right_list_ndims = list_ndims(&right_type); // if both are list if left_list_ndims > 0 && right_list_ndims > 0 { - if op == sqlparser::ast::BinaryOperator::AtArrow { + if op == BinaryOperator::AtArrow { // array1 @> array2 -> array_has_all(array1, array2) return Ok(PlannerResult::Planned(array_has_all(left, right))); } else { @@ -94,6 +91,46 @@ impl ExprPlanner for NestedFunctionPlanner { return Ok(PlannerResult::Planned(array_has_all(right, left))); } } + // } else if matches!(op, BinaryOperator::Plus | BinaryOperator::Minus) + // && matches!(left.get_type(schema)?, DataType::Date32) + // && matches!(right.get_type(schema)?, DataType::Int32 | DataType::Int64) + // { + // use arrow::datatypes::IntervalDayTime; + // use datafusion_common::ScalarValue; + // use datafusion_expr::BinaryExpr; + // use datafusion_expr::Operator; + // use sqlparser::ast::BinaryOperator; + // + // let op: Operator = match op { + // BinaryOperator::Plus => Operator::Plus, + // BinaryOperator::Minus => Operator::Minus, + // _ => unreachable!(), + // }; + // + // let new_right: Expr = match right { + // Expr::Literal(ScalarValue::Int32(Some(i)), meta) => Expr::Literal( + // ScalarValue::IntervalDayTime(Some(IntervalDayTime { + // days: i, + // milliseconds: 0, + // })), + // meta, + // ), + // Expr::Literal(ScalarValue::Int64(Some(i)), meta) => Expr::Literal( + // ScalarValue::IntervalDayTime(Some(IntervalDayTime { + // days: i as i32, + // milliseconds: 0, + // })), + // meta, + // ), + // _ => unreachable!(), + // }; + // + // let planned = Expr::BinaryExpr(BinaryExpr { + // left: Box::new(left.clone()), + // right: Box::new(new_right), + // op, + // }); + // return Ok(PlannerResult::Planned(planned)); } Ok(PlannerResult::Original(RawBinaryExpr { op, left, right })) @@ -123,7 +160,7 @@ impl ExprPlanner for NestedFunctionPlanner { } fn plan_any(&self, expr: RawBinaryExpr) -> Result> { - if expr.op == sqlparser::ast::BinaryOperator::Eq { + if expr.op == BinaryOperator::Eq { Ok(PlannerResult::Planned(Expr::ScalarFunction( ScalarFunction::new_udf( array_has_udf(), diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 9a3a8bcd23a7f..9385d19814972 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -45,6 +45,7 @@ use datafusion_expr::{simplify::ExprSimplifyResult, Cast, TryCast}; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; use super::inlist_simplifier::ShortenInListSimplifier; +use super::make_interval_simplifier::MakeIntervalSimplifier; use super::utils::*; use crate::analyzer::type_coercion::TypeCoercionRewriter; use crate::simplify_expressions::guarantees::GuaranteeRewriter; @@ -55,6 +56,7 @@ use crate::simplify_expressions::unwrap_cast::{ unwrap_cast_in_comparison_for_binary, }; use crate::simplify_expressions::SimplifyInfo; +use arrow::datatypes::IntervalDayTime; use datafusion_expr::expr::FieldMetadata; use datafusion_expr_common::casts::try_cast_literal_to_type; use indexmap::IndexSet; @@ -256,6 +258,10 @@ impl ExprSimplifier { } // shorten inlist should be started after other inlist rules are applied expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?; + + let mut make_int_interval = MakeIntervalSimplifier::new(); + expr = expr.rewrite(&mut make_int_interval).data()?; + Ok(( Transformed::new_transformed(expr, has_transformed), num_cycles, @@ -767,8 +773,8 @@ impl TreeNodeRewriter for Simplifier<'_, S> { fn f_up(&mut self, expr: Expr) -> Result> { use datafusion_expr::Operator::{ And, BitwiseAnd, BitwiseOr, BitwiseShiftLeft, BitwiseShiftRight, BitwiseXor, - Divide, Eq, Modulo, Multiply, NotEq, Or, RegexIMatch, RegexMatch, - RegexNotIMatch, RegexNotMatch, + Divide, Eq, Minus, Modulo, Multiply, NotEq, Or, Plus, RegexIMatch, + RegexMatch, RegexNotIMatch, RegexNotMatch, }; let info = self.info; diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index bbb023cfbad9f..5ccc05dd251bb 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -26,6 +26,11 @@ use datafusion_common::{DataFusionError, Result}; use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr}; +use datafusion_common::ScalarValue; +use datafusion_expr_common::operator::Operator; + +use arrow::datatypes::IntervalDayTime; + /// Rewrite expressions to incorporate guarantees. /// /// Guarantees are a mapping from an expression (which currently is always a @@ -108,6 +113,38 @@ impl TreeNodeRewriter for GuaranteeRewriter<'_> { } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + if *op == Operator::Plus || *op == Operator::Minus { + if let Expr::Literal(ScalarValue::Date32(_), _) = &**left { + let interval_right = match &**right { + Expr::Literal(ScalarValue::Int32(Some(i)), meta) => { + Expr::Literal( + ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: *i, + milliseconds: 0, + })), + meta.clone(), + ) + } + Expr::Literal(ScalarValue::Int64(Some(i)), meta) => { + Expr::Literal( + ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: *i as i32, + milliseconds: 0, + })), + meta.clone(), + ) + } + _ => *right.clone(), + }; + + return Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { + left: left.clone(), + right: Box::new(interval_right), + op: *op, + }))); + } + } + // The left or right side of expression might either have a guarantee // or be a literal. Either way, we can resolve them to a NullableInterval. let left_interval = self diff --git a/datafusion/optimizer/src/simplify_expressions/make_interval_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/make_interval_simplifier.rs new file mode 100644 index 0000000000000..d42402992d72f --- /dev/null +++ b/datafusion/optimizer/src/simplify_expressions/make_interval_simplifier.rs @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module implements a rule that simplifies the values for `InList`s + +use arrow::datatypes::IntervalDayTime; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; +use datafusion_common::Result; +use datafusion_common::ScalarValue; +use datafusion_expr::BinaryExpr; +use datafusion_expr::Expr; +use datafusion_expr::Operator::{Minus, Plus}; + +pub(super) struct MakeIntervalSimplifier {} + +impl MakeIntervalSimplifier { + pub(super) fn new() -> Self { + Self {} + } +} + +impl TreeNodeRewriter for MakeIntervalSimplifier { + type Node = Expr; + + fn f_up(&mut self, expr: Expr) -> Result> { + if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = &expr { + if matches!( + &**left, + Expr::Literal(ScalarValue::Date32(_), _) + | Expr::Literal(ScalarValue::Date64(_), _) + ) && matches!(op, Plus | Minus) + && (matches!(&**right, Expr::Literal(ScalarValue::Int32(_), _)) + || matches!(&**right, Expr::Literal(ScalarValue::Int64(_), _))) + { + let new_right: Expr = match &**right { + Expr::Literal(ScalarValue::Int32(Some(i)), meta) => Expr::Literal( + ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: *i, + milliseconds: 0, + })), + meta.clone(), + ), + Expr::Literal(ScalarValue::Int64(Some(i)), meta) => Expr::Literal( + ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: *i as i32, + milliseconds: 0, + })), + meta.clone(), + ), + _ => unreachable!(), + }; + + return Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { + left: Box::new(*left.clone()), + right: Box::new(new_right), + op: *op, + }))); + } + } + + Ok(Transformed::no(expr)) + } +} diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index 7ae38eec9a3ad..b23a75b72c0c8 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -21,6 +21,7 @@ pub mod expr_simplifier; mod guarantees; mod inlist_simplifier; +mod make_interval_simplifier; mod regex; pub mod simplify_exprs; mod simplify_predicates; diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index e92869873731f..42ba4fb5d2bc9 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -102,10 +102,10 @@ impl SqlToRel<'_, S> { } } } - StackEntry::Operator(op) => { + StackEntry::Operator(binaryOp) => { let right = eval_stack.pop().unwrap(); let left = eval_stack.pop().unwrap(); - let expr = self.build_logical_expr(op, left, right, schema)?; + let expr = self.build_logical_expr(binaryOp, left, right, schema)?; eval_stack.push(expr); } } @@ -137,6 +137,7 @@ impl SqlToRel<'_, S> { } let RawBinaryExpr { op, left, right } = binary_expr; + Ok(Expr::BinaryExpr(BinaryExpr::new( Box::new(left), self.parse_sql_binary_op(op)?, diff --git a/parquet-testing b/parquet-testing index 107b36603e051..f4d7ed772a62a 160000 --- a/parquet-testing +++ b/parquet-testing @@ -1 +1 @@ -Subproject commit 107b36603e051aee26bd93e04b871034f6c756c0 +Subproject commit f4d7ed772a62a95111db50fbcad2460833e8c882