diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index e674541253288..c65b990f42397 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -88,19 +88,18 @@ impl ScalarUDFImpl for ConcatFunc { &self.signature } + /// Match the return type to the input types to avoid unnecessary casts. On + /// mixed inputs, prefer Utf8View; prefer LargeUtf8 over Utf8 to avoid + /// potential overflow on LargeUtf8 input. fn return_type(&self, arg_types: &[DataType]) -> Result { use DataType::*; - let mut dt = &Utf8; - arg_types.iter().for_each(|data_type| { - if data_type == &Utf8View { - dt = data_type; - } - if data_type == &LargeUtf8 && dt != &Utf8View { - dt = data_type; - } - }); - - Ok(dt.to_owned()) + if arg_types.contains(&Utf8View) { + Ok(Utf8View) + } else if arg_types.contains(&LargeUtf8) { + Ok(LargeUtf8) + } else { + Ok(Utf8) + } } /// Concatenates the text representations of all the arguments. NULL arguments are ignored. @@ -108,17 +107,14 @@ impl ScalarUDFImpl for ConcatFunc { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let ScalarFunctionArgs { args, .. } = args; - let mut return_datatype = DataType::Utf8; - args.iter().for_each(|col| { - if col.data_type() == DataType::Utf8View { - return_datatype = col.data_type(); - } - if col.data_type() == DataType::LargeUtf8 - && return_datatype != DataType::Utf8View - { - return_datatype = col.data_type(); - } - }); + let return_datatype = if args.iter().any(|c| c.data_type() == DataType::Utf8View) + { + DataType::Utf8View + } else if args.iter().any(|c| c.data_type() == DataType::LargeUtf8) { + DataType::LargeUtf8 + } else { + DataType::Utf8 + }; let array_len = args.iter().find_map(|x| match x { ColumnarValue::Array(array) => Some(array.len()), @@ -247,7 +243,7 @@ impl ScalarUDFImpl for ConcatFunc { builder.append_offset(); } - let string_array = builder.finish(); + let string_array = builder.finish(None); Ok(ColumnarValue::Array(Arc::new(string_array))) } DataType::LargeUtf8 => { diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 9d3b32eedf8fd..80e11d286ed87 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -24,7 +24,9 @@ use arrow::datatypes::DataType; use crate::string::concat; use crate::string::concat::simplify_concat; use crate::string::concat_ws; -use crate::strings::{ColumnarValueRef, StringArrayBuilder}; +use crate::strings::{ + ColumnarValueRef, LargeStringArrayBuilder, StringArrayBuilder, StringViewArrayBuilder, +}; use datafusion_common::cast::{ as_large_string_array, as_string_array, as_string_view_array, }; @@ -97,12 +99,23 @@ impl ScalarUDFImpl for ConcatWsFunc { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { + /// Match the return type to the input types to avoid unnecessary casts. On + /// mixed inputs, prefer Utf8View; prefer LargeUtf8 over Utf8 to avoid + /// potential overflow on LargeUtf8 input. + fn return_type(&self, arg_types: &[DataType]) -> Result { use DataType::*; - Ok(Utf8) + if arg_types.contains(&Utf8View) { + Ok(Utf8View) + } else if arg_types.contains(&LargeUtf8) { + Ok(LargeUtf8) + } else { + Ok(Utf8) + } } - /// Concatenates all but the first argument, with separators. The first argument is used as the separator string, and should not be NULL. Other NULL arguments are ignored. + /// Concatenates all but the first argument, with separators. The first + /// argument is used as the separator string, and should not be NULL. Other + /// NULL arguments are ignored. /// concat_ws(',', 'abcde', 2, NULL, 22) = 'abcde,2,22' fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let ScalarFunctionArgs { args, .. } = args; @@ -114,6 +127,15 @@ impl ScalarUDFImpl for ConcatWsFunc { ); } + let return_datatype = if args.iter().any(|c| c.data_type() == DataType::Utf8View) + { + DataType::Utf8View + } else if args.iter().any(|c| c.data_type() == DataType::LargeUtf8) { + DataType::LargeUtf8 + } else { + DataType::Utf8 + }; + let array_len = args.iter().find_map(|x| match x { ColumnarValue::Array(array) => Some(array.len()), _ => None, @@ -128,7 +150,15 @@ impl ScalarUDFImpl for ConcatWsFunc { Some(Some(s)) => s, Some(None) => { // null literal string - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + return match return_datatype { + DataType::Utf8View => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(None))) + } + DataType::LargeUtf8 => { + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(None))) + } + _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + }; } None => return internal_err!("Expected string literal, got {scalar:?}"), }; @@ -149,7 +179,15 @@ impl ScalarUDFImpl for ConcatWsFunc { } let result = values.join(sep); - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))); + return match return_datatype { + DataType::Utf8View => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result)))) + } + DataType::LargeUtf8 => { + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result)))) + } + _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))), + }; } // Array @@ -164,7 +202,15 @@ impl ScalarUDFImpl for ConcatWsFunc { ColumnarValueRef::Scalar(s.as_bytes()) } Some(None) => { - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + return match return_datatype { + DataType::Utf8View => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(None))) + } + DataType::LargeUtf8 => { + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(None))) + } + _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + }; } None => { return internal_err!("Expected string separator, got {scalar:?}"); @@ -268,28 +314,71 @@ impl ScalarUDFImpl for ConcatWsFunc { } } - let mut builder = StringArrayBuilder::with_capacity(len, data_size); - for i in 0..len { - if !sep.is_valid(i) { - builder.append_offset(); - continue; + match return_datatype { + DataType::Utf8View => { + let mut builder = StringViewArrayBuilder::with_capacity(len, data_size); + for i in 0..len { + if !sep.is_valid(i) { + builder.append_offset(); + continue; + } + let mut first = true; + for column in &columns { + if column.is_valid(i) { + if !first { + builder.write::(&sep, i); + } + builder.write::(column, i); + first = false; + } + } + builder.append_offset(); + } + Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls())))) } - - let mut first = true; - for column in &columns { - if column.is_valid(i) { - if !first { - builder.write::(&sep, i); + DataType::LargeUtf8 => { + let mut builder = LargeStringArrayBuilder::with_capacity(len, data_size); + for i in 0..len { + if !sep.is_valid(i) { + builder.append_offset(); + continue; } - builder.write::(column, i); - first = false; + let mut first = true; + for column in &columns { + if column.is_valid(i) { + if !first { + builder.write::(&sep, i); + } + builder.write::(column, i); + first = false; + } + } + builder.append_offset(); } + Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls())))) + } + _ => { + let mut builder = StringArrayBuilder::with_capacity(len, data_size); + for i in 0..len { + if !sep.is_valid(i) { + builder.append_offset(); + continue; + } + let mut first = true; + for column in &columns { + if column.is_valid(i) { + if !first { + builder.write::(&sep, i); + } + builder.write::(column, i); + first = false; + } + } + builder.append_offset(); + } + Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls())))) } - - builder.append_offset(); } - - Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls())))) } /// Simply the `concat_ws` function by @@ -314,6 +403,21 @@ impl ScalarUDFImpl for ConcatWsFunc { } fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { + // Preserve the delimiter's string type for any new literals produced + // during simplification. + let delimiter_type = match delimiter { + Expr::Literal(v, _) => v.data_type(), + _ => DataType::Utf8, + }; + + let typed_lit = |s: String| -> Expr { + match delimiter_type { + DataType::LargeUtf8 => lit(ScalarValue::LargeUtf8(Some(s))), + DataType::Utf8View => lit(ScalarValue::Utf8View(Some(s))), + _ => lit(s), + } + }; + match delimiter { Expr::Literal( ScalarValue::Utf8(delimiter) @@ -322,8 +426,8 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { match delimiter { - // when the delimiter is an empty string, - // we can use `concat` to replace `concat_ws` + // When the delimiter is the empty string, replace `concat_ws` + // with `concat` Some(delimiter) if delimiter.is_empty() => { match simplify_concat(args.to_vec())? { ExprSimplifyResult::Original(_) => { @@ -339,7 +443,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { let mut new_args = Vec::with_capacity(args.len()); - new_args.push(lit(delimiter)); + new_args.push(typed_lit(delimiter.to_string())); let mut contiguous_scalar = None; for arg in args { match arg { @@ -373,7 +477,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { if let Some(val) = contiguous_scalar { - new_args.push(lit(val)); + new_args.push(typed_lit(val)); } new_args.push(arg.clone()); contiguous_scalar = None; @@ -381,7 +485,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result Result Ok(ExprSimplifyResult::Simplified(Expr::Literal( - ScalarValue::Utf8(None), - None, - ))), + // If the delimiter is null, then the value of the whole expression is null. + None => { + let null_scalar = match delimiter_type { + DataType::LargeUtf8 => ScalarValue::LargeUtf8(None), + DataType::Utf8View => ScalarValue::Utf8View(None), + _ => ScalarValue::Utf8(None), + }; + Ok(ExprSimplifyResult::Simplified(Expr::Literal( + null_scalar, + None, + ))) + } } } Expr::Literal(d, _) => internal_err!( @@ -425,8 +536,8 @@ mod tests { use std::sync::Arc; use crate::string::concat_ws::ConcatWsFunc; - use arrow::array::{Array, ArrayRef, StringArray}; - use arrow::datatypes::DataType::Utf8; + use arrow::array::{Array, ArrayRef, LargeStringArray, StringArray, StringViewArray}; + use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; use arrow::datatypes::Field; use datafusion_common::Result; use datafusion_common::ScalarValue; @@ -579,7 +690,7 @@ mod tests { ]))); let arg_fields = vec![ - Field::new("a", Utf8, true).into(), + Field::new("a", Utf8View, true).into(), Field::new("a", Utf8, true).into(), Field::new("a", Utf8, true).into(), ]; @@ -587,13 +698,13 @@ mod tests { args: vec![c0, c1, c2], arg_fields, number_rows: 3, - return_field: Field::new("f", Utf8, true).into(), + return_field: Field::new("f", Utf8View, true).into(), config_options: Arc::new(ConfigOptions::default()), }; let result = ConcatWsFunc::new().invoke_with_args(args)?; let expected = - Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef; + Arc::new(StringViewArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef; match &result { ColumnarValue::Array(array) => { assert_eq!(&expected, array); @@ -616,7 +727,7 @@ mod tests { ]))); let arg_fields = vec![ - Field::new("a", Utf8, true).into(), + Field::new("a", LargeUtf8, true).into(), Field::new("a", Utf8, true).into(), Field::new("a", Utf8, true).into(), ]; @@ -624,13 +735,13 @@ mod tests { args: vec![c0, c1, c2], arg_fields, number_rows: 3, - return_field: Field::new("f", Utf8, true).into(), + return_field: Field::new("f", LargeUtf8, true).into(), config_options: Arc::new(ConfigOptions::default()), }; let result = ConcatWsFunc::new().invoke_with_args(args)?; let expected = - Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef; + Arc::new(LargeStringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef; match &result { ColumnarValue::Array(array) => { assert_eq!(&expected, array); @@ -640,4 +751,191 @@ mod tests { Ok(()) } + + #[test] + fn concat_ws_utf8view_nullable_separator() -> Result<()> { + let c0 = ColumnarValue::Array(Arc::new(StringViewArray::from(vec![ + Some(","), + None, + Some("+"), + ]))); + let c1 = ColumnarValue::Array(Arc::new(StringViewArray::from(vec![ + "foo", "bar", "baz", + ]))); + let c2 = ColumnarValue::Array(Arc::new(StringViewArray::from(vec![ + Some("x"), + Some("y"), + Some("z"), + ]))); + + let arg_fields = vec![ + Field::new("a", Utf8View, true).into(), + Field::new("a", Utf8View, true).into(), + Field::new("a", Utf8View, true).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![c0, c1, c2], + arg_fields, + number_rows: 3, + return_field: Field::new("f", Utf8View, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatWsFunc::new().invoke_with_args(args)?; + let expected = Arc::new(StringViewArray::from(vec![ + Some("foo,x"), + None, + Some("baz+z"), + ])) as ArrayRef; + match &result { + ColumnarValue::Array(array) => { + assert_eq!(&expected, array); + } + _ => panic!("Expected array result"), + } + + Ok(()) + } + + #[test] + fn concat_ws_largeutf8_arrays() -> Result<()> { + let c0 = ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(",".to_string()))); + let c1 = ColumnarValue::Array(Arc::new(LargeStringArray::from(vec![ + "foo", "bar", "baz", + ]))); + let c2 = ColumnarValue::Array(Arc::new(LargeStringArray::from(vec![ + Some("x"), + None, + Some("z"), + ]))); + + let arg_fields = vec![ + Field::new("a", LargeUtf8, true).into(), + Field::new("a", LargeUtf8, true).into(), + Field::new("a", LargeUtf8, true).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![c0, c1, c2], + arg_fields, + number_rows: 3, + return_field: Field::new("f", LargeUtf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatWsFunc::new().invoke_with_args(args)?; + let expected = + Arc::new(LargeStringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef; + match &result { + ColumnarValue::Array(array) => { + assert_eq!(&expected, array); + } + _ => panic!("Expected array result"), + } + + Ok(()) + } + + #[test] + fn concat_ws_utf8view_null_separator() -> Result<()> { + // All-scalar path: null Utf8View separator should return Utf8View(None) + let c0 = ColumnarValue::Scalar(ScalarValue::Utf8View(None)); + let c1 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some("aa".to_string()))); + let c2 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some("bb".to_string()))); + + let arg_fields = vec![ + Field::new("a", Utf8View, true).into(), + Field::new("a", Utf8View, true).into(), + Field::new("a", Utf8View, true).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![c0, c1, c2], + arg_fields, + number_rows: 1, + return_field: Field::new("f", Utf8View, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatWsFunc::new().invoke_with_args(args)?; + match result { + ColumnarValue::Scalar(ScalarValue::Utf8View(None)) => {} + other => panic!("Expected Utf8View(None), got {other:?}"), + } + + // Array path: null Utf8View scalar separator with array args + let c0 = ColumnarValue::Scalar(ScalarValue::Utf8View(None)); + let c1 = + ColumnarValue::Array(Arc::new(StringViewArray::from(vec!["foo", "bar"]))); + + let arg_fields = vec![ + Field::new("a", Utf8View, true).into(), + Field::new("a", Utf8View, true).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![c0, c1], + arg_fields, + number_rows: 2, + return_field: Field::new("f", Utf8View, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatWsFunc::new().invoke_with_args(args)?; + match result { + ColumnarValue::Scalar(ScalarValue::Utf8View(None)) => {} + other => panic!("Expected Utf8View(None), got {other:?}"), + } + + Ok(()) + } + + #[test] + fn concat_ws_largeutf8_null_separator() -> Result<()> { + // All-scalar path: null LargeUtf8 separator should return LargeUtf8(None) + let c0 = ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)); + let c1 = ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("aa".to_string()))); + let c2 = ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("bb".to_string()))); + + let arg_fields = vec![ + Field::new("a", LargeUtf8, true).into(), + Field::new("a", LargeUtf8, true).into(), + Field::new("a", LargeUtf8, true).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![c0, c1, c2], + arg_fields, + number_rows: 1, + return_field: Field::new("f", LargeUtf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatWsFunc::new().invoke_with_args(args)?; + match result { + ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {} + other => panic!("Expected LargeUtf8(None), got {other:?}"), + } + + // Array path: null LargeUtf8 scalar separator with array args + let c0 = ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)); + let c1 = + ColumnarValue::Array(Arc::new(LargeStringArray::from(vec!["foo", "bar"]))); + + let arg_fields = vec![ + Field::new("a", LargeUtf8, true).into(), + Field::new("a", LargeUtf8, true).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![c0, c1], + arg_fields, + number_rows: 2, + return_field: Field::new("f", LargeUtf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatWsFunc::new().invoke_with_args(args)?; + match result { + ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {} + other => panic!("Expected LargeUtf8(None), got {other:?}"), + } + + Ok(()) + } } diff --git a/datafusion/functions/src/strings.rs b/datafusion/functions/src/strings.rs index cfddf57b094b5..8e25a45cf62dd 100644 --- a/datafusion/functions/src/strings.rs +++ b/datafusion/functions/src/strings.rs @@ -182,8 +182,16 @@ impl StringViewArrayBuilder { self.block.clear(); } - pub fn finish(mut self) -> StringViewArray { - self.builder.finish() + pub fn finish(mut self, null_buffer: Option) -> StringViewArray { + let array = self.builder.finish(); + match null_buffer { + Some(nulls) => { + let array_data = array.into_data().into_builder().nulls(Some(nulls)); + // SAFETY: the underlying data is valid; we are only adding a null buffer + StringViewArray::from(unsafe { array_data.build_unchecked() }) + } + None => array, + } } } diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt index 13b0aba653efb..dea4681a8401a 100644 --- a/datafusion/sqllogictest/test_files/string/string_view.slt +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -642,7 +642,17 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: concat_ws(Utf8(", "), test.column1_utf8view, test.column2_utf8view) AS c +01)Projection: concat_ws(Utf8View(", "), test.column1_utf8view, test.column2_utf8view) AS c +02)--TableScan: test projection=[column1_utf8view, column2_utf8view] + +## Ensure CONCAT_WS simplification preserves Utf8View for merged literals +query TT +EXPLAIN SELECT + concat_ws(', ', column1_utf8view, 'foo', 'bar', column2_utf8view) as c +FROM test; +---- +logical_plan +01)Projection: concat_ws(Utf8View(", "), test.column1_utf8view, Utf8View("foo, bar"), test.column2_utf8view) AS c 02)--TableScan: test projection=[column1_utf8view, column2_utf8view] ## Ensure no casts for CONTAINS