Skip to content

Commit ddf29f1

Browse files
authored
implement 'StringConcat' operator to support sql like "select 'aa' || 'b' " (#2142)
* implement stringconcat operator * snake case fix * support non-string concat & handle NULL * value -> array * string concat internal coercion * get NULL in right index of vec Co-authored-by: duripeng <[email protected]>
1 parent fa9e016 commit ddf29f1

File tree

5 files changed

+119
-0
lines changed

5 files changed

+119
-0
lines changed

datafusion/core/src/sql/planner.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1328,6 +1328,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
13281328
BinaryOperator::PGRegexNotIMatch => Ok(Operator::RegexNotIMatch),
13291329
BinaryOperator::BitwiseAnd => Ok(Operator::BitwiseAnd),
13301330
BinaryOperator::BitwiseOr => Ok(Operator::BitwiseOr),
1331+
BinaryOperator::StringConcat => Ok(Operator::StringConcat),
13311332
_ => Err(DataFusionError::NotImplemented(format!(
13321333
"Unsupported SQL binary operator {:?}",
13331334
op

datafusion/core/tests/sql/expr.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,59 @@ async fn query_scalar_minus_array() -> Result<()> {
280280
Ok(())
281281
}
282282

283+
#[tokio::test]
284+
async fn test_string_concat_operator() -> Result<()> {
285+
let ctx = SessionContext::new();
286+
// concat 2 strings
287+
let sql = "SELECT 'aa' || 'b'";
288+
let actual = execute_to_batches(&ctx, sql).await;
289+
let expected = vec![
290+
"+-------------------------+",
291+
"| Utf8(\"aa\") || Utf8(\"b\") |",
292+
"+-------------------------+",
293+
"| aab |",
294+
"+-------------------------+",
295+
];
296+
assert_batches_eq!(expected, &actual);
297+
298+
// concat 4 strings as a string concat pipe.
299+
let sql = "SELECT 'aa' || 'b' || 'cc' || 'd'";
300+
let actual = execute_to_batches(&ctx, sql).await;
301+
let expected = vec![
302+
"+----------------------------------------------------+",
303+
"| Utf8(\"aa\") || Utf8(\"b\") || Utf8(\"cc\") || Utf8(\"d\") |",
304+
"+----------------------------------------------------+",
305+
"| aabccd |",
306+
"+----------------------------------------------------+",
307+
];
308+
assert_batches_eq!(expected, &actual);
309+
310+
// concat 2 strings and NULL, output should be NULL
311+
let sql = "SELECT 'aa' || NULL || 'd'";
312+
let actual = execute_to_batches(&ctx, sql).await;
313+
let expected = vec![
314+
"+---------------------------------------+",
315+
"| Utf8(\"aa\") || Utf8(NULL) || Utf8(\"d\") |",
316+
"+---------------------------------------+",
317+
"| |",
318+
"+---------------------------------------+",
319+
];
320+
assert_batches_eq!(expected, &actual);
321+
322+
// concat 1 strings and 2 numeric
323+
let sql = "SELECT 'a' || 42 || 23.3";
324+
let actual = execute_to_batches(&ctx, sql).await;
325+
let expected = vec![
326+
"+-----------------------------------------+",
327+
"| Utf8(\"a\") || Int64(42) || Float64(23.3) |",
328+
"+-----------------------------------------+",
329+
"| a4223.3 |",
330+
"+-----------------------------------------+",
331+
];
332+
assert_batches_eq!(expected, &actual);
333+
Ok(())
334+
}
335+
283336
#[tokio::test]
284337
async fn test_boolean_expressions() -> Result<()> {
285338
test_expression!("true", "true");

datafusion/expr/src/operator.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ pub enum Operator {
7171
BitwiseAnd,
7272
/// Bitwise or, like `|`
7373
BitwiseOr,
74+
/// String concat
75+
StringConcat,
7476
}
7577

7678
impl fmt::Display for Operator {
@@ -99,6 +101,7 @@ impl fmt::Display for Operator {
99101
Operator::IsNotDistinctFrom => "IS NOT DISTINCT FROM",
100102
Operator::BitwiseAnd => "&",
101103
Operator::BitwiseOr => "|",
104+
Operator::StringConcat => "||",
102105
};
103106
write!(f, "{}", display)
104107
}

datafusion/physical-expr/src/coercion_rule/binary_rule.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
//! Coercion rules for matching argument types for binary operators
1919
20+
use arrow::compute::can_cast_types;
2021
use arrow::datatypes::{DataType, DECIMAL_MAX_PRECISION, DECIMAL_MAX_SCALE};
2122
use datafusion_common::DataFusionError;
2223
use datafusion_common::Result;
@@ -58,6 +59,8 @@ pub(crate) fn coerce_types(
5859
| Operator::RegexIMatch
5960
| Operator::RegexNotMatch
6061
| Operator::RegexNotIMatch => string_coercion(lhs_type, rhs_type),
62+
// "||" operator has its own rules, and always return a string type
63+
Operator::StringConcat => string_concat_coercion(lhs_type, rhs_type),
6164
Operator::IsDistinctFrom | Operator::IsNotDistinctFrom => {
6265
eq_coercion(lhs_type, rhs_type)
6366
}
@@ -369,6 +372,34 @@ fn dictionary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataT
369372
}
370373
}
371374

375+
/// Coercion rules for string concat.
376+
/// This is a union of string coercion rules and specified rules:
377+
/// 1. At lease one side of lhs and rhs should be string type (Utf8 / LargeUtf8)
378+
/// 2. Data type of the other side should be able to cast to string type
379+
fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
380+
use arrow::datatypes::DataType::*;
381+
string_coercion(lhs_type, rhs_type).or(match (lhs_type, rhs_type) {
382+
(Utf8, from_type) | (from_type, Utf8) => {
383+
string_concat_internal_coercion(from_type, &Utf8)
384+
}
385+
(LargeUtf8, from_type) | (from_type, LargeUtf8) => {
386+
string_concat_internal_coercion(from_type, &LargeUtf8)
387+
}
388+
_ => None,
389+
})
390+
}
391+
392+
fn string_concat_internal_coercion(
393+
from_type: &DataType,
394+
to_type: &DataType,
395+
) -> Option<DataType> {
396+
if can_cast_types(from_type, to_type) {
397+
Some(to_type.to_owned())
398+
} else {
399+
None
400+
}
401+
}
402+
372403
/// Coercion rules for Strings: the type that both lhs and rhs can be
373404
/// casted to for the purpose of a string computation
374405
fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {

datafusion/physical-expr/src/expressions/binary.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ use arrow::record_batch::RecordBatch;
5656

5757
use crate::coercion_rule::binary_rule::coerce_types;
5858
use crate::expressions::try_cast;
59+
use crate::string_expressions;
5960
use crate::PhysicalExpr;
6061
use datafusion_common::ScalarValue;
6162
use datafusion_common::{DataFusionError, Result};
@@ -416,6 +417,33 @@ fn bitwise_or(left: ArrayRef, right: ArrayRef) -> Result<ArrayRef> {
416417
}
417418
}
418419

420+
/// Use datafusion build-in expression `concat` to evaluate `StringConcat` operator.
421+
/// Besides, any `NULL` exists on lhs or rhs will come out result `NULL`
422+
/// 1. 'a' || 'b' || 32 = 'ab32'
423+
/// 2. 'a' || NULL = NULL
424+
fn string_concat(left: ArrayRef, right: ArrayRef) -> Result<ArrayRef> {
425+
let ignore_null = match string_expressions::concat(&[
426+
ColumnarValue::Array(left.clone()),
427+
ColumnarValue::Array(right.clone()),
428+
])? {
429+
ColumnarValue::Array(array_ref) => array_ref,
430+
scalar_value => scalar_value.into_array(left.clone().len()),
431+
};
432+
let ignore_null_array = ignore_null.as_any().downcast_ref::<StringArray>().unwrap();
433+
let result = (0..ignore_null_array.len())
434+
.into_iter()
435+
.map(|index| {
436+
if left.is_null(index) || right.is_null(index) {
437+
None
438+
} else {
439+
Some(ignore_null_array.value(index))
440+
}
441+
})
442+
.collect::<StringArray>();
443+
444+
Ok(Arc::new(result) as ArrayRef)
445+
}
446+
419447
fn bitwise_and_scalar(
420448
array: &dyn Array,
421449
scalar: ScalarValue,
@@ -1005,6 +1033,8 @@ pub fn binary_operator_data_type(
10051033
| Operator::Divide
10061034
| Operator::Multiply
10071035
| Operator::Modulo => Ok(result_type),
1036+
// string operations return the same values as the common coerced type
1037+
Operator::StringConcat => Ok(result_type),
10081038
}
10091039
}
10101040

@@ -1266,6 +1296,7 @@ impl BinaryExpr {
12661296
}
12671297
Operator::BitwiseAnd => bitwise_and(left, right),
12681298
Operator::BitwiseOr => bitwise_or(left, right),
1299+
Operator::StringConcat => string_concat(left, right),
12691300
}
12701301
}
12711302
}

0 commit comments

Comments
 (0)