Skip to content

Commit 31daf25

Browse files
authored
Rewrite array operator to function in parser (#11101)
* rewrite func Signed-off-by: jayzhan211 <[email protected]> * remove rule in analyzer Signed-off-by: jayzhan211 <[email protected]> * fix test Signed-off-by: jayzhan211 <[email protected]> --------- Signed-off-by: jayzhan211 <[email protected]>
1 parent e266018 commit 31daf25

File tree

4 files changed

+77
-130
lines changed

4 files changed

+77
-130
lines changed

datafusion/core/tests/sql/select.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ async fn test_parameter_invalid_types() -> Result<()> {
246246
.await;
247247
assert_eq!(
248248
results.unwrap_err().strip_backtrace(),
249-
"Arrow error: Invalid argument error: Invalid comparison operation: List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) == List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })"
249+
"type_coercion\ncaused by\nError during planning: Cannot infer common argument type for comparison operation List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) = Int32"
250250
);
251251
Ok(())
252252
}

datafusion/expr/src/type_coercion/binary.rs

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -889,21 +889,18 @@ fn dictionary_coercion(
889889
/// 2. Data type of the other side should be able to cast to string type
890890
fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
891891
use arrow::datatypes::DataType::*;
892-
string_coercion(lhs_type, rhs_type)
893-
.or_else(|| list_coercion(lhs_type, rhs_type))
894-
.or(match (lhs_type, rhs_type) {
895-
(Utf8, from_type) | (from_type, Utf8) => {
896-
string_concat_internal_coercion(from_type, &Utf8)
897-
}
898-
(LargeUtf8, from_type) | (from_type, LargeUtf8) => {
899-
string_concat_internal_coercion(from_type, &LargeUtf8)
900-
}
901-
_ => None,
902-
})
892+
string_coercion(lhs_type, rhs_type).or(match (lhs_type, rhs_type) {
893+
(Utf8, from_type) | (from_type, Utf8) => {
894+
string_concat_internal_coercion(from_type, &Utf8)
895+
}
896+
(LargeUtf8, from_type) | (from_type, LargeUtf8) => {
897+
string_concat_internal_coercion(from_type, &LargeUtf8)
898+
}
899+
_ => None,
900+
})
903901
}
904902

905903
fn array_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
906-
// TODO: cast between array elements (#6558)
907904
if lhs_type.equals_datatype(rhs_type) {
908905
Some(lhs_type.to_owned())
909906
} else {
@@ -952,10 +949,7 @@ fn numeric_string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
952949
fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
953950
use arrow::datatypes::DataType::*;
954951
match (lhs_type, rhs_type) {
955-
// TODO: cast between array elements (#6558)
956952
(List(_), List(_)) => Some(lhs_type.clone()),
957-
(List(_), _) => Some(lhs_type.clone()),
958-
(_, List(_)) => Some(rhs_type.clone()),
959953
_ => None,
960954
}
961955
}

datafusion/functions-array/src/rewrite.rs

Lines changed: 2 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,10 @@
1818
//! Rewrites for using Array Functions
1919
2020
use crate::array_has::array_has_all;
21-
use crate::concat::{array_append, array_concat, array_prepend};
2221
use datafusion_common::config::ConfigOptions;
2322
use datafusion_common::tree_node::Transformed;
24-
use datafusion_common::utils::list_ndims;
23+
use datafusion_common::DFSchema;
2524
use datafusion_common::Result;
26-
use datafusion_common::{Column, DFSchema};
2725
use datafusion_expr::expr::ScalarFunction;
2826
use datafusion_expr::expr_rewriter::FunctionRewrite;
2927
use datafusion_expr::{BinaryExpr, Expr, Operator};
@@ -39,7 +37,7 @@ impl FunctionRewrite for ArrayFunctionRewriter {
3937
fn rewrite(
4038
&self,
4139
expr: Expr,
42-
schema: &DFSchema,
40+
_schema: &DFSchema,
4341
_config: &ConfigOptions,
4442
) -> Result<Transformed<Expr>> {
4543
let transformed = match expr {
@@ -61,91 +59,6 @@ impl FunctionRewrite for ArrayFunctionRewriter {
6159
Transformed::yes(array_has_all(*right, *left))
6260
}
6361

64-
// Column cases:
65-
// 1) array_prepend/append/concat || column
66-
Expr::BinaryExpr(BinaryExpr { left, op, right })
67-
if op == Operator::StringConcat
68-
&& is_one_of_func(
69-
&left,
70-
&["array_append", "array_prepend", "array_concat"],
71-
)
72-
&& as_col(&right).is_some() =>
73-
{
74-
let c = as_col(&right).unwrap();
75-
let d = schema.field_from_column(c)?.data_type();
76-
let ndim = list_ndims(d);
77-
match ndim {
78-
0 => Transformed::yes(array_append(*left, *right)),
79-
_ => Transformed::yes(array_concat(vec![*left, *right])),
80-
}
81-
}
82-
// 2) select column1 || column2
83-
Expr::BinaryExpr(BinaryExpr { left, op, right })
84-
if op == Operator::StringConcat
85-
&& as_col(&left).is_some()
86-
&& as_col(&right).is_some() =>
87-
{
88-
let c1 = as_col(&left).unwrap();
89-
let c2 = as_col(&right).unwrap();
90-
let d1 = schema.field_from_column(c1)?.data_type();
91-
let d2 = schema.field_from_column(c2)?.data_type();
92-
let ndim1 = list_ndims(d1);
93-
let ndim2 = list_ndims(d2);
94-
match (ndim1, ndim2) {
95-
(0, _) => Transformed::yes(array_prepend(*left, *right)),
96-
(_, 0) => Transformed::yes(array_append(*left, *right)),
97-
_ => Transformed::yes(array_concat(vec![*left, *right])),
98-
}
99-
}
100-
101-
// Chain concat operator (a || b) || array,
102-
// (array_concat, array_append, array_prepend) || array -> array concat
103-
Expr::BinaryExpr(BinaryExpr { left, op, right })
104-
if op == Operator::StringConcat
105-
&& is_one_of_func(
106-
&left,
107-
&["array_append", "array_prepend", "array_concat"],
108-
)
109-
&& is_func(&right, "make_array") =>
110-
{
111-
Transformed::yes(array_concat(vec![*left, *right]))
112-
}
113-
114-
// Chain concat operator (a || b) || scalar,
115-
// (array_concat, array_append, array_prepend) || scalar -> array append
116-
Expr::BinaryExpr(BinaryExpr { left, op, right })
117-
if op == Operator::StringConcat
118-
&& is_one_of_func(
119-
&left,
120-
&["array_append", "array_prepend", "array_concat"],
121-
) =>
122-
{
123-
Transformed::yes(array_append(*left, *right))
124-
}
125-
126-
// array || array -> array concat
127-
Expr::BinaryExpr(BinaryExpr { left, op, right })
128-
if op == Operator::StringConcat
129-
&& is_func(&left, "make_array")
130-
&& is_func(&right, "make_array") =>
131-
{
132-
Transformed::yes(array_concat(vec![*left, *right]))
133-
}
134-
135-
// array || scalar -> array append
136-
Expr::BinaryExpr(BinaryExpr { left, op, right })
137-
if op == Operator::StringConcat && is_func(&left, "make_array") =>
138-
{
139-
Transformed::yes(array_append(*left, *right))
140-
}
141-
142-
// scalar || array -> array prepend
143-
Expr::BinaryExpr(BinaryExpr { left, op, right })
144-
if op == Operator::StringConcat && is_func(&right, "make_array") =>
145-
{
146-
Transformed::yes(array_prepend(*left, *right))
147-
}
148-
14962
_ => Transformed::no(expr),
15063
};
15164
Ok(transformed)
@@ -161,21 +74,3 @@ fn is_func(expr: &Expr, func_name: &str) -> bool {
16174

16275
func.name() == func_name
16376
}
164-
165-
/// Returns true if expr is a function call with one of the specified names
166-
fn is_one_of_func(expr: &Expr, func_names: &[&str]) -> bool {
167-
let Expr::ScalarFunction(ScalarFunction { func, args: _ }) = expr else {
168-
return false;
169-
};
170-
171-
func_names.contains(&func.name())
172-
}
173-
174-
/// returns Some(col) if this is Expr::Column
175-
fn as_col(expr: &Expr) -> Option<&Column> {
176-
if let Expr::Column(c) = expr {
177-
Some(c)
178-
} else {
179-
None
180-
}
181-
}

datafusion/sql/src/expr/mod.rs

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
use arrow_schema::DataType;
1919
use arrow_schema::TimeUnit;
20+
use datafusion_common::utils::list_ndims;
2021
use sqlparser::ast::{CastKind, Expr as SQLExpr, Subscript, TrimWhereField, Value};
2122

2223
use datafusion_common::{
@@ -86,13 +87,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
8687
StackEntry::Operator(op) => {
8788
let right = eval_stack.pop().unwrap();
8889
let left = eval_stack.pop().unwrap();
89-
90-
let expr = Expr::BinaryExpr(BinaryExpr::new(
91-
Box::new(left),
92-
op,
93-
Box::new(right),
94-
));
95-
90+
let expr = self.build_logical_expr(op, left, right, schema)?;
9691
eval_stack.push(expr);
9792
}
9893
}
@@ -103,6 +98,69 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
10398
Ok(expr)
10499
}
105100

101+
fn build_logical_expr(
102+
&self,
103+
op: Operator,
104+
left: Expr,
105+
right: Expr,
106+
schema: &DFSchema,
107+
) -> Result<Expr> {
108+
// Rewrite string concat operator to function based on types
109+
// if we get list || list then we rewrite it to array_concat()
110+
// if we get list || non-list then we rewrite it to array_append()
111+
// if we get non-list || list then we rewrite it to array_prepend()
112+
// if we get string || string then we rewrite it to concat()
113+
if op == Operator::StringConcat {
114+
let left_type = left.get_type(schema)?;
115+
let right_type = right.get_type(schema)?;
116+
let left_list_ndims = list_ndims(&left_type);
117+
let right_list_ndims = list_ndims(&right_type);
118+
119+
// We determine the target function to rewrite based on the list n-dimension, the check is not exact but sufficient.
120+
// The exact validity check is handled in the actual function, so even if there is 3d list appended with 1d list, it is also fine to rewrite.
121+
if left_list_ndims + right_list_ndims == 0 {
122+
// TODO: concat function ignore null, but string concat takes null into consideration
123+
// we can rewrite it to concat if we can configure the behaviour of concat function to the one like `string concat operator`
124+
} else if left_list_ndims == right_list_ndims {
125+
if let Some(udf) = self.context_provider.get_function_meta("array_concat")
126+
{
127+
return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
128+
udf,
129+
vec![left, right],
130+
)));
131+
} else {
132+
return internal_err!("array_concat not found");
133+
}
134+
} else if left_list_ndims > right_list_ndims {
135+
if let Some(udf) = self.context_provider.get_function_meta("array_append")
136+
{
137+
return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
138+
udf,
139+
vec![left, right],
140+
)));
141+
} else {
142+
return internal_err!("array_append not found");
143+
}
144+
} else if left_list_ndims < right_list_ndims {
145+
if let Some(udf) =
146+
self.context_provider.get_function_meta("array_prepend")
147+
{
148+
return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
149+
udf,
150+
vec![left, right],
151+
)));
152+
} else {
153+
return internal_err!("array_append not found");
154+
}
155+
}
156+
}
157+
Ok(Expr::BinaryExpr(BinaryExpr::new(
158+
Box::new(left),
159+
op,
160+
Box::new(right),
161+
)))
162+
}
163+
106164
/// Generate a relational expression from a SQL expression
107165
pub fn sql_to_expr(
108166
&self,

0 commit comments

Comments
 (0)