Skip to content

Rewrite array operator to function in parser #11101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ async fn test_parameter_invalid_types() -> Result<()> {
.await;
assert_eq!(
results.unwrap_err().strip_backtrace(),
"Arrow error: Invalid argument error: Invalid comparison operation: List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) == List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })"
"type_coercion\ncaused by\nError during planning: Cannot infer common argument type for comparison operation List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) = Int32"
);
Ok(())
}
24 changes: 9 additions & 15 deletions datafusion/expr/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -889,21 +889,18 @@ fn dictionary_coercion(
/// 2. Data type of the other side should be able to cast to string type
fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
string_coercion(lhs_type, rhs_type)
.or_else(|| list_coercion(lhs_type, rhs_type))
.or(match (lhs_type, rhs_type) {
(Utf8, from_type) | (from_type, Utf8) => {
string_concat_internal_coercion(from_type, &Utf8)
}
(LargeUtf8, from_type) | (from_type, LargeUtf8) => {
string_concat_internal_coercion(from_type, &LargeUtf8)
}
_ => None,
})
string_coercion(lhs_type, rhs_type).or(match (lhs_type, rhs_type) {
(Utf8, from_type) | (from_type, Utf8) => {
string_concat_internal_coercion(from_type, &Utf8)
}
(LargeUtf8, from_type) | (from_type, LargeUtf8) => {
string_concat_internal_coercion(from_type, &LargeUtf8)
}
_ => None,
})
}

fn array_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
// TODO: cast between array elements (#6558)
if lhs_type.equals_datatype(rhs_type) {
Some(lhs_type.to_owned())
} else {
Expand Down Expand Up @@ -952,10 +949,7 @@ fn numeric_string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
// TODO: cast between array elements (#6558)
(List(_), List(_)) => Some(lhs_type.clone()),
(List(_), _) => Some(lhs_type.clone()),
(_, List(_)) => Some(rhs_type.clone()),
_ => None,
}
}
Expand Down
109 changes: 2 additions & 107 deletions datafusion/functions-array/src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@
//! Rewrites for using Array Functions

use crate::array_has::array_has_all;
use crate::concat::{array_append, array_concat, array_prepend};
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::Transformed;
use datafusion_common::utils::list_ndims;
use datafusion_common::DFSchema;
use datafusion_common::Result;
use datafusion_common::{Column, DFSchema};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::expr_rewriter::FunctionRewrite;
use datafusion_expr::{BinaryExpr, Expr, Operator};
Expand All @@ -39,7 +37,7 @@ impl FunctionRewrite for ArrayFunctionRewriter {
fn rewrite(
&self,
expr: Expr,
schema: &DFSchema,
_schema: &DFSchema,
_config: &ConfigOptions,
) -> Result<Transformed<Expr>> {
let transformed = match expr {
Expand All @@ -61,91 +59,6 @@ impl FunctionRewrite for ArrayFunctionRewriter {
Transformed::yes(array_has_all(*right, *left))
}

// Column cases:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do the same thing for the other operators (AtArrow)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes!

// 1) array_prepend/append/concat || column
Expr::BinaryExpr(BinaryExpr { left, op, right })
if op == Operator::StringConcat
&& is_one_of_func(
&left,
&["array_append", "array_prepend", "array_concat"],
)
&& as_col(&right).is_some() =>
{
let c = as_col(&right).unwrap();
let d = schema.field_from_column(c)?.data_type();
let ndim = list_ndims(d);
match ndim {
0 => Transformed::yes(array_append(*left, *right)),
_ => Transformed::yes(array_concat(vec![*left, *right])),
}
}
// 2) select column1 || column2
Expr::BinaryExpr(BinaryExpr { left, op, right })
if op == Operator::StringConcat
&& as_col(&left).is_some()
&& as_col(&right).is_some() =>
{
let c1 = as_col(&left).unwrap();
let c2 = as_col(&right).unwrap();
let d1 = schema.field_from_column(c1)?.data_type();
let d2 = schema.field_from_column(c2)?.data_type();
let ndim1 = list_ndims(d1);
let ndim2 = list_ndims(d2);
match (ndim1, ndim2) {
(0, _) => Transformed::yes(array_prepend(*left, *right)),
(_, 0) => Transformed::yes(array_append(*left, *right)),
_ => Transformed::yes(array_concat(vec![*left, *right])),
}
}

// Chain concat operator (a || b) || array,
// (array_concat, array_append, array_prepend) || array -> array concat
Expr::BinaryExpr(BinaryExpr { left, op, right })
if op == Operator::StringConcat
&& is_one_of_func(
&left,
&["array_append", "array_prepend", "array_concat"],
)
&& is_func(&right, "make_array") =>
{
Transformed::yes(array_concat(vec![*left, *right]))
}

// Chain concat operator (a || b) || scalar,
// (array_concat, array_append, array_prepend) || scalar -> array append
Expr::BinaryExpr(BinaryExpr { left, op, right })
if op == Operator::StringConcat
&& is_one_of_func(
&left,
&["array_append", "array_prepend", "array_concat"],
) =>
{
Transformed::yes(array_append(*left, *right))
}

// array || array -> array concat
Expr::BinaryExpr(BinaryExpr { left, op, right })
if op == Operator::StringConcat
&& is_func(&left, "make_array")
&& is_func(&right, "make_array") =>
{
Transformed::yes(array_concat(vec![*left, *right]))
}

// array || scalar -> array append
Expr::BinaryExpr(BinaryExpr { left, op, right })
if op == Operator::StringConcat && is_func(&left, "make_array") =>
{
Transformed::yes(array_append(*left, *right))
}

// scalar || array -> array prepend
Expr::BinaryExpr(BinaryExpr { left, op, right })
if op == Operator::StringConcat && is_func(&right, "make_array") =>
{
Transformed::yes(array_prepend(*left, *right))
}

_ => Transformed::no(expr),
};
Ok(transformed)
Expand All @@ -161,21 +74,3 @@ fn is_func(expr: &Expr, func_name: &str) -> bool {

func.name() == func_name
}

/// Returns true if expr is a function call with one of the specified names
fn is_one_of_func(expr: &Expr, func_names: &[&str]) -> bool {
let Expr::ScalarFunction(ScalarFunction { func, args: _ }) = expr else {
return false;
};

func_names.contains(&func.name())
}

/// returns Some(col) if this is Expr::Column
fn as_col(expr: &Expr) -> Option<&Column> {
if let Expr::Column(c) = expr {
Some(c)
} else {
None
}
}
72 changes: 65 additions & 7 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

use arrow_schema::DataType;
use arrow_schema::TimeUnit;
use datafusion_common::utils::list_ndims;
use sqlparser::ast::{CastKind, Expr as SQLExpr, Subscript, TrimWhereField, Value};

use datafusion_common::{
Expand Down Expand Up @@ -86,13 +87,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
StackEntry::Operator(op) => {
let right = eval_stack.pop().unwrap();
let left = eval_stack.pop().unwrap();

let expr = Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
op,
Box::new(right),
));

let expr = self.build_logical_expr(op, left, right, schema)?;
eval_stack.push(expr);
}
}
Expand All @@ -103,6 +98,69 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Ok(expr)
}

fn build_logical_expr(
&self,
op: Operator,
left: Expr,
right: Expr,
schema: &DFSchema,
) -> Result<Expr> {
// Rewrite string concat operator to function based on types
// if we get list || list then we rewrite it to array_concat()
// if we get list || non-list then we rewrite it to array_append()
// if we get non-list || list then we rewrite it to array_prepend()
// if we get string || string then we rewrite it to concat()
if op == Operator::StringConcat {
let left_type = left.get_type(schema)?;
let right_type = right.get_type(schema)?;
let left_list_ndims = list_ndims(&left_type);
let right_list_ndims = list_ndims(&right_type);

// We determine the target function to rewrite based on the list n-dimension, the check is not exact but sufficient.
// The exact validity check is handled in the actual function, so even if there is 3d list appended with 1d list, it is also fine to rewrite.
if left_list_ndims + right_list_ndims == 0 {
// TODO: concat function ignore null, but string concat takes null into consideration
// we can rewrite it to concat if we can configure the behaviour of concat function to the one like `string concat operator`
} else if left_list_ndims == right_list_ndims {
if let Some(udf) = self.context_provider.get_function_meta("array_concat")
{
return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
udf,
vec![left, right],
)));
} else {
return internal_err!("array_concat not found");
}
} else if left_list_ndims > right_list_ndims {
if let Some(udf) = self.context_provider.get_function_meta("array_append")
{
return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
udf,
vec![left, right],
)));
} else {
return internal_err!("array_append not found");
}
} else if left_list_ndims < right_list_ndims {
if let Some(udf) =
self.context_provider.get_function_meta("array_prepend")
{
return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
udf,
vec![left, right],
)));
} else {
return internal_err!("array_append not found");
}
}
}
Ok(Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
op,
Box::new(right),
)))
}

/// Generate a relational expression from a SQL expression
pub fn sql_to_expr(
&self,
Expand Down