Skip to content

Commit 4e4059a

Browse files
jayzhan211alamb
andauthored
Convert Binary Operator StringConcat to Function for array_concat, array_append and array_prepend (#8636)
* reuse function for string concat Signed-off-by: jayzhan211 <[email protected]> * remove casting in string concat Signed-off-by: jayzhan211 <[email protected]> * add test Signed-off-by: jayzhan211 <[email protected]> * operator to function rewrite Signed-off-by: jayzhan211 <[email protected]> * fix explain Signed-off-by: jayzhan211 <[email protected]> * add more test Signed-off-by: jayzhan211 <[email protected]> * add column cases Signed-off-by: jayzhan211 <[email protected]> * cleanup Signed-off-by: jayzhan211 <[email protected]> * presever name Signed-off-by: jayzhan211 <[email protected]> * Update datafusion/optimizer/src/analyzer/rewrite_expr.rs Co-authored-by: Andrew Lamb <[email protected]> * rename Signed-off-by: jayzhan211 <[email protected]> --------- Signed-off-by: jayzhan211 <[email protected]> Co-authored-by: Andrew Lamb <[email protected]>
1 parent 821db54 commit 4e4059a

File tree

7 files changed

+371
-11
lines changed

7 files changed

+371
-11
lines changed

datafusion/expr/src/type_coercion/binary.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -667,8 +667,6 @@ fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<Da
667667
(LargeUtf8, from_type) | (from_type, LargeUtf8) => {
668668
string_concat_internal_coercion(from_type, &LargeUtf8)
669669
}
670-
// TODO: cast between array elements (#6558)
671-
(List(_), from_type) | (from_type, List(_)) => Some(from_type.to_owned()),
672670
_ => None,
673671
})
674672
}

datafusion/optimizer/src/analyzer/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
pub mod count_wildcard_rule;
1919
pub mod inline_table_scan;
20+
pub mod rewrite_expr;
2021
pub mod subquery;
2122
pub mod type_coercion;
2223

@@ -37,6 +38,8 @@ use log::debug;
3738
use std::sync::Arc;
3839
use std::time::Instant;
3940

41+
use self::rewrite_expr::OperatorToFunction;
42+
4043
/// [`AnalyzerRule`]s transform [`LogicalPlan`]s in some way to make
4144
/// the plan valid prior to the rest of the DataFusion optimization process.
4245
///
@@ -72,6 +75,9 @@ impl Analyzer {
7275
pub fn new() -> Self {
7376
let rules: Vec<Arc<dyn AnalyzerRule + Send + Sync>> = vec![
7477
Arc::new(InlineTableScan::new()),
78+
// OperatorToFunction should be run before TypeCoercion, since it rewrite based on the argument types (List or Scalar),
79+
// and TypeCoercion may cast the argument types from Scalar to List.
80+
Arc::new(OperatorToFunction::new()),
7581
Arc::new(TypeCoercion::new()),
7682
Arc::new(CountWildcardRule::new()),
7783
];
Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Analyzer rule for to replace operators with function calls (e.g `||` to array_concat`)
19+
20+
use std::sync::Arc;
21+
22+
use datafusion_common::config::ConfigOptions;
23+
use datafusion_common::tree_node::TreeNodeRewriter;
24+
use datafusion_common::utils::list_ndims;
25+
use datafusion_common::DFSchema;
26+
use datafusion_common::DFSchemaRef;
27+
use datafusion_common::Result;
28+
use datafusion_expr::expr::ScalarFunction;
29+
use datafusion_expr::expr_rewriter::rewrite_preserving_name;
30+
use datafusion_expr::utils::merge_schema;
31+
use datafusion_expr::BuiltinScalarFunction;
32+
use datafusion_expr::Operator;
33+
use datafusion_expr::ScalarFunctionDefinition;
34+
use datafusion_expr::{BinaryExpr, Expr, LogicalPlan};
35+
36+
use super::AnalyzerRule;
37+
38+
#[derive(Default)]
39+
pub struct OperatorToFunction {}
40+
41+
impl OperatorToFunction {
42+
pub fn new() -> Self {
43+
Self {}
44+
}
45+
}
46+
47+
impl AnalyzerRule for OperatorToFunction {
48+
fn name(&self) -> &str {
49+
"operator_to_function"
50+
}
51+
52+
fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
53+
analyze_internal(&plan)
54+
}
55+
}
56+
57+
fn analyze_internal(plan: &LogicalPlan) -> Result<LogicalPlan> {
58+
// optimize child plans first
59+
let new_inputs = plan
60+
.inputs()
61+
.iter()
62+
.map(|p| analyze_internal(p))
63+
.collect::<Result<Vec<_>>>()?;
64+
65+
// get schema representing all available input fields. This is used for data type
66+
// resolution only, so order does not matter here
67+
let mut schema = merge_schema(new_inputs.iter().collect());
68+
69+
if let LogicalPlan::TableScan(ts) = plan {
70+
let source_schema =
71+
DFSchema::try_from_qualified_schema(&ts.table_name, &ts.source.schema())?;
72+
schema.merge(&source_schema);
73+
}
74+
75+
let mut expr_rewrite = OperatorToFunctionRewriter {
76+
schema: Arc::new(schema),
77+
};
78+
79+
let new_expr = plan
80+
.expressions()
81+
.into_iter()
82+
.map(|expr| {
83+
// ensure names don't change:
84+
// https://github.com/apache/arrow-datafusion/issues/3555
85+
rewrite_preserving_name(expr, &mut expr_rewrite)
86+
})
87+
.collect::<Result<Vec<_>>>()?;
88+
89+
plan.with_new_exprs(new_expr, &new_inputs)
90+
}
91+
92+
pub(crate) struct OperatorToFunctionRewriter {
93+
pub(crate) schema: DFSchemaRef,
94+
}
95+
96+
impl TreeNodeRewriter for OperatorToFunctionRewriter {
97+
type N = Expr;
98+
99+
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
100+
match expr {
101+
Expr::BinaryExpr(BinaryExpr {
102+
ref left,
103+
op,
104+
ref right,
105+
}) => {
106+
if let Some(fun) = rewrite_array_concat_operator_to_func_for_column(
107+
left.as_ref(),
108+
op,
109+
right.as_ref(),
110+
self.schema.as_ref(),
111+
)?
112+
.or_else(|| {
113+
rewrite_array_concat_operator_to_func(
114+
left.as_ref(),
115+
op,
116+
right.as_ref(),
117+
)
118+
}) {
119+
// Convert &Box<Expr> -> Expr
120+
let left = (**left).clone();
121+
let right = (**right).clone();
122+
return Ok(Expr::ScalarFunction(ScalarFunction {
123+
func_def: ScalarFunctionDefinition::BuiltIn(fun),
124+
args: vec![left, right],
125+
}));
126+
}
127+
128+
Ok(expr)
129+
}
130+
_ => Ok(expr),
131+
}
132+
}
133+
}
134+
135+
/// Summary of the logic below:
136+
///
137+
/// 1) array || array -> array concat
138+
///
139+
/// 2) array || scalar -> array append
140+
///
141+
/// 3) scalar || array -> array prepend
142+
///
143+
/// 4) (arry concat, array append, array prepend) || array -> array concat
144+
///
145+
/// 5) (arry concat, array append, array prepend) || scalar -> array append
146+
fn rewrite_array_concat_operator_to_func(
147+
left: &Expr,
148+
op: Operator,
149+
right: &Expr,
150+
) -> Option<BuiltinScalarFunction> {
151+
// Convert `Array StringConcat Array` to ScalarFunction::ArrayConcat
152+
153+
if op != Operator::StringConcat {
154+
return None;
155+
}
156+
157+
match (left, right) {
158+
// Chain concat operator (a || b) || array,
159+
// (arry concat, array append, array prepend) || array -> array concat
160+
(
161+
Expr::ScalarFunction(ScalarFunction {
162+
func_def:
163+
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat),
164+
args: _left_args,
165+
}),
166+
Expr::ScalarFunction(ScalarFunction {
167+
func_def:
168+
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
169+
args: _right_args,
170+
}),
171+
)
172+
| (
173+
Expr::ScalarFunction(ScalarFunction {
174+
func_def:
175+
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend),
176+
args: _left_args,
177+
}),
178+
Expr::ScalarFunction(ScalarFunction {
179+
func_def:
180+
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
181+
args: _right_args,
182+
}),
183+
)
184+
| (
185+
Expr::ScalarFunction(ScalarFunction {
186+
func_def:
187+
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend),
188+
args: _left_args,
189+
}),
190+
Expr::ScalarFunction(ScalarFunction {
191+
func_def:
192+
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
193+
args: _right_args,
194+
}),
195+
) => Some(BuiltinScalarFunction::ArrayConcat),
196+
// Chain concat operator (a || b) || scalar,
197+
// (arry concat, array append, array prepend) || scalar -> array append
198+
(
199+
Expr::ScalarFunction(ScalarFunction {
200+
func_def:
201+
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat),
202+
args: _left_args,
203+
}),
204+
_scalar,
205+
)
206+
| (
207+
Expr::ScalarFunction(ScalarFunction {
208+
func_def:
209+
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend),
210+
args: _left_args,
211+
}),
212+
_scalar,
213+
)
214+
| (
215+
Expr::ScalarFunction(ScalarFunction {
216+
func_def:
217+
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend),
218+
args: _left_args,
219+
}),
220+
_scalar,
221+
) => Some(BuiltinScalarFunction::ArrayAppend),
222+
// array || array -> array concat
223+
(
224+
Expr::ScalarFunction(ScalarFunction {
225+
func_def:
226+
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
227+
args: _left_args,
228+
}),
229+
Expr::ScalarFunction(ScalarFunction {
230+
func_def:
231+
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
232+
args: _right_args,
233+
}),
234+
) => Some(BuiltinScalarFunction::ArrayConcat),
235+
// array || scalar -> array append
236+
(
237+
Expr::ScalarFunction(ScalarFunction {
238+
func_def:
239+
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
240+
args: _left_args,
241+
}),
242+
_right_scalar,
243+
) => Some(BuiltinScalarFunction::ArrayAppend),
244+
// scalar || array -> array prepend
245+
(
246+
_left_scalar,
247+
Expr::ScalarFunction(ScalarFunction {
248+
func_def:
249+
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
250+
args: _right_args,
251+
}),
252+
) => Some(BuiltinScalarFunction::ArrayPrepend),
253+
254+
_ => None,
255+
}
256+
}
257+
258+
/// Summary of the logic below:
259+
///
260+
/// 1) (arry concat, array append, array prepend) || column -> (array append, array concat)
261+
///
262+
/// 2) column1 || column2 -> (array prepend, array append, array concat)
263+
fn rewrite_array_concat_operator_to_func_for_column(
264+
left: &Expr,
265+
op: Operator,
266+
right: &Expr,
267+
schema: &DFSchema,
268+
) -> Result<Option<BuiltinScalarFunction>> {
269+
if op != Operator::StringConcat {
270+
return Ok(None);
271+
}
272+
273+
match (left, right) {
274+
// Column cases:
275+
// 1) array_prepend/append/concat || column
276+
(
277+
Expr::ScalarFunction(ScalarFunction {
278+
func_def:
279+
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend),
280+
args: _left_args,
281+
}),
282+
Expr::Column(c),
283+
)
284+
| (
285+
Expr::ScalarFunction(ScalarFunction {
286+
func_def:
287+
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend),
288+
args: _left_args,
289+
}),
290+
Expr::Column(c),
291+
)
292+
| (
293+
Expr::ScalarFunction(ScalarFunction {
294+
func_def:
295+
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat),
296+
args: _left_args,
297+
}),
298+
Expr::Column(c),
299+
) => {
300+
let d = schema.field_from_column(c)?.data_type();
301+
let ndim = list_ndims(d);
302+
match ndim {
303+
0 => Ok(Some(BuiltinScalarFunction::ArrayAppend)),
304+
_ => Ok(Some(BuiltinScalarFunction::ArrayConcat)),
305+
}
306+
}
307+
// 2) select column1 || column2
308+
(Expr::Column(c1), Expr::Column(c2)) => {
309+
let d1 = schema.field_from_column(c1)?.data_type();
310+
let d2 = schema.field_from_column(c2)?.data_type();
311+
let ndim1 = list_ndims(d1);
312+
let ndim2 = list_ndims(d2);
313+
match (ndim1, ndim2) {
314+
(0, _) => Ok(Some(BuiltinScalarFunction::ArrayPrepend)),
315+
(_, 0) => Ok(Some(BuiltinScalarFunction::ArrayAppend)),
316+
_ => Ok(Some(BuiltinScalarFunction::ArrayConcat)),
317+
}
318+
}
319+
_ => Ok(None),
320+
}
321+
}

datafusion/physical-expr/src/expressions/binary.rs

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@ mod kernels;
2020
use std::hash::{Hash, Hasher};
2121
use std::{any::Any, sync::Arc};
2222

23-
use crate::array_expressions::{
24-
array_append, array_concat, array_has_all, array_prepend,
25-
};
23+
use crate::array_expressions::array_has_all;
2624
use crate::expressions::datum::{apply, apply_cmp};
2725
use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison};
2826
use crate::physical_expr::down_cast_any_ref;
@@ -598,12 +596,7 @@ impl BinaryExpr {
598596
BitwiseXor => bitwise_xor_dyn(left, right),
599597
BitwiseShiftRight => bitwise_shift_right_dyn(left, right),
600598
BitwiseShiftLeft => bitwise_shift_left_dyn(left, right),
601-
StringConcat => match (left_data_type, right_data_type) {
602-
(DataType::List(_), DataType::List(_)) => array_concat(&[left, right]),
603-
(DataType::List(_), _) => array_append(&[left, right]),
604-
(_, DataType::List(_)) => array_prepend(&[left, right]),
605-
_ => binary_string_array_op!(left, right, concat_elements),
606-
},
599+
StringConcat => binary_string_array_op!(left, right, concat_elements),
607600
AtArrow => array_has_all(&[left, right]),
608601
ArrowAt => array_has_all(&[right, left]),
609602
}

datafusion/sql/src/expr/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
9898
StackEntry::Operator(op) => {
9999
let right = eval_stack.pop().unwrap();
100100
let left = eval_stack.pop().unwrap();
101+
101102
let expr = Expr::BinaryExpr(BinaryExpr::new(
102103
Box::new(left),
103104
op,
104105
Box::new(right),
105106
));
107+
106108
eval_stack.push(expr);
107109
}
108110
}

0 commit comments

Comments
 (0)