diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 64314ef6df68..a7594b9ccb01 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -17,15 +17,17 @@ //! [`StringAgg`] accumulator for the `string_agg` function +use crate::array_agg::ArrayAgg; use arrow::array::ArrayRef; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use datafusion_common::cast::as_generic_string_array; use datafusion_common::Result; -use datafusion_common::{not_impl_err, ScalarValue}; +use datafusion_common::{internal_err, not_impl_err, ScalarValue}; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility, }; +use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs; use datafusion_macros::user_doc; use datafusion_physical_expr::expressions::Literal; use std::any::Any; @@ -41,15 +43,31 @@ make_udaf_expr_and_func!( #[user_doc( doc_section(label = "General Functions"), - description = "Concatenates the values of string expressions and places separator values between them.", - syntax_example = "string_agg(expression, delimiter)", + description = "Concatenates the values of string expressions and places separator values between them. \ +If ordering is required, strings are concatenated in the specified order. \ +This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the first argument expression.", + syntax_example = "string_agg([DISTINCT] expression, delimiter [ORDER BY expression])", sql_example = r#"```sql > SELECT string_agg(name, ', ') AS names_list FROM employee; +--------------------------+ | names_list | +--------------------------+ -| Alice, Bob, Charlie | +| Alice, Bob, Bob, Charlie | ++--------------------------+ +> SELECT string_agg(name, ', ' ORDER BY name DESC) AS names_list + FROM employee; ++--------------------------+ +| names_list | ++--------------------------+ +| Charlie, Bob, Bob, Alice | ++--------------------------+ +> SELECT string_agg(DISTINCT name, ', ' ORDER BY name DESC) AS names_list + FROM employee; ++--------------------------+ +| names_list | ++--------------------------+ +| Charlie, Bob, Alice | +--------------------------+ ```"#, argument( @@ -65,6 +83,7 @@ make_udaf_expr_and_func!( #[derive(Debug)] pub struct StringAgg { signature: Signature, + array_agg: ArrayAgg, } impl StringAgg { @@ -76,9 +95,13 @@ impl StringAgg { TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]), TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]), + TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]), + TypeSignature::Exact(vec![DataType::Utf8, DataType::Null]), ], Volatility::Immutable, ), + array_agg: Default::default(), } } } @@ -106,20 +129,40 @@ impl AggregateUDFImpl for StringAgg { Ok(DataType::LargeUtf8) } + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + self.array_agg.state_fields(args) + } + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - if let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::() { - return match lit.value().try_as_str() { - Some(Some(delimiter)) => { - Ok(Box::new(StringAggAccumulator::new(delimiter))) - } - Some(None) => Ok(Box::new(StringAggAccumulator::new(""))), - None => { - not_impl_err!("StringAgg not supported for delimiter {}", lit.value()) - } - }; - } + let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::() else { + return not_impl_err!( + "The second argument of the string_agg function must be a string literal" + ); + }; + + let delimiter = if lit.value().is_null() { + // If the second argument (the delimiter that joins strings) is NULL, join + // on an empty string. (e.g. [a, b, c] => "abc"). + "" + } else if let Some(lit_string) = lit.value().try_as_str() { + lit_string.unwrap_or("") + } else { + return not_impl_err!( + "StringAgg not supported for delimiter \"{}\"", + lit.value() + ); + }; + + let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs { + return_type: &DataType::new_list(acc_args.return_type.clone(), true), + exprs: &filter_index(acc_args.exprs, 1), + ..acc_args + })?; - not_impl_err!("expect literal") + Ok(Box::new(StringAggAccumulator::new( + array_agg_acc, + delimiter, + ))) } fn documentation(&self) -> Option<&Documentation> { @@ -129,14 +172,14 @@ impl AggregateUDFImpl for StringAgg { #[derive(Debug)] pub(crate) struct StringAggAccumulator { - values: Option, + array_agg_acc: Box, delimiter: String, } impl StringAggAccumulator { - pub fn new(delimiter: &str) -> Self { + pub fn new(array_agg_acc: Box, delimiter: &str) -> Self { Self { - values: None, + array_agg_acc, delimiter: delimiter.to_string(), } } @@ -144,37 +187,311 @@ impl StringAggAccumulator { impl Accumulator for StringAggAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let string_array: Vec<_> = as_generic_string_array::(&values[0])? - .iter() - .filter_map(|v| v.as_ref().map(ToString::to_string)) - .collect(); - if !string_array.is_empty() { - let s = string_array.join(self.delimiter.as_str()); - let v = self.values.get_or_insert("".to_string()); - if !v.is_empty() { - v.push_str(self.delimiter.as_str()); + self.array_agg_acc.update_batch(&filter_index(values, 1)) + } + + fn evaluate(&mut self) -> Result { + let scalar = self.array_agg_acc.evaluate()?; + + let ScalarValue::List(list) = scalar else { + return internal_err!("Expected a DataType::List while evaluating underlying ArrayAggAccumulator, but got {}", scalar.data_type()); + }; + + let string_arr: Vec<_> = match list.value_type() { + DataType::LargeUtf8 => as_generic_string_array::(list.values())? + .iter() + .flatten() + .collect(), + DataType::Utf8 => as_generic_string_array::(list.values())? + .iter() + .flatten() + .collect(), + _ => { + return internal_err!( + "Expected elements to of type Utf8 or LargeUtf8, but got {}", + list.value_type() + ) } - v.push_str(s.as_str()); + }; + + if string_arr.is_empty() { + return Ok(ScalarValue::LargeUtf8(None)); } - Ok(()) + + Ok(ScalarValue::LargeUtf8(Some( + string_arr.join(&self.delimiter), + ))) + } + + fn size(&self) -> usize { + size_of_val(self) - size_of_val(&self.array_agg_acc) + + self.array_agg_acc.size() + + self.delimiter.capacity() + } + + fn state(&mut self) -> Result> { + self.array_agg_acc.state() } fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - self.update_batch(values)?; + self.array_agg_acc.merge_batch(values) + } +} + +fn filter_index(values: &[T], index: usize) -> Vec { + values + .iter() + .enumerate() + .filter(|(i, _)| *i != index) + .map(|(_, v)| v) + .cloned() + .collect::>() +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::LargeStringArray; + use arrow::compute::SortOptions; + use arrow::datatypes::{Fields, Schema}; + use datafusion_common::internal_err; + use datafusion_physical_expr::expressions::Column; + use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; + use std::sync::Arc; + + #[test] + fn no_duplicates_no_distinct() -> Result<()> { + let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?; + + acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?; + acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?; + acc1 = merge(acc1, acc2)?; + + let result = some_str(acc1.evaluate()?); + + assert_eq!(result, "a,b,c,d,e,f"); + Ok(()) } - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) + #[test] + fn no_duplicates_distinct() -> Result<()> { + let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",") + .distinct() + .build_two()?; + + acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?; + acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?; + acc1 = merge(acc1, acc2)?; + + let result = some_str_sorted(acc1.evaluate()?, ","); + + assert_eq!(result, "a,b,c,d,e,f"); + + Ok(()) } - fn evaluate(&mut self) -> Result { - Ok(ScalarValue::LargeUtf8(self.values.clone())) + #[test] + fn duplicates_no_distinct() -> Result<()> { + let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?; + + acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?; + acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?; + acc1 = merge(acc1, acc2)?; + + let result = some_str(acc1.evaluate()?); + + assert_eq!(result, "a,b,c,a,b,c"); + + Ok(()) } - fn size(&self) -> usize { - size_of_val(self) - + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0) - + self.delimiter.capacity() + #[test] + fn duplicates_distinct() -> Result<()> { + let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",") + .distinct() + .build_two()?; + + acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?; + acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?; + acc1 = merge(acc1, acc2)?; + + let result = some_str_sorted(acc1.evaluate()?, ","); + + assert_eq!(result, "a,b,c"); + + Ok(()) + } + + #[test] + fn no_duplicates_distinct_sort_asc() -> Result<()> { + let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",") + .distinct() + .order_by_col("col", SortOptions::new(false, false)) + .build_two()?; + + acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?; + acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?; + acc1 = merge(acc1, acc2)?; + + let result = some_str(acc1.evaluate()?); + + assert_eq!(result, "a,b,c,d,e,f"); + + Ok(()) + } + + #[test] + fn no_duplicates_distinct_sort_desc() -> Result<()> { + let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",") + .distinct() + .order_by_col("col", SortOptions::new(true, false)) + .build_two()?; + + acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?; + acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?; + acc1 = merge(acc1, acc2)?; + + let result = some_str(acc1.evaluate()?); + + assert_eq!(result, "f,e,d,c,b,a"); + + Ok(()) + } + + #[test] + fn duplicates_distinct_sort_asc() -> Result<()> { + let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",") + .distinct() + .order_by_col("col", SortOptions::new(false, false)) + .build_two()?; + + acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?; + acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?; + acc1 = merge(acc1, acc2)?; + + let result = some_str(acc1.evaluate()?); + + assert_eq!(result, "a,b,c"); + + Ok(()) + } + + #[test] + fn duplicates_distinct_sort_desc() -> Result<()> { + let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",") + .distinct() + .order_by_col("col", SortOptions::new(true, false)) + .build_two()?; + + acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?; + acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?; + acc1 = merge(acc1, acc2)?; + + let result = some_str(acc1.evaluate()?); + + assert_eq!(result, "c,b,a"); + + Ok(()) + } + + struct StringAggAccumulatorBuilder { + sep: String, + distinct: bool, + ordering: LexOrdering, + schema: Schema, + } + + impl StringAggAccumulatorBuilder { + fn new(sep: &str) -> Self { + Self { + sep: sep.to_string(), + distinct: Default::default(), + ordering: Default::default(), + schema: Schema { + fields: Fields::from(vec![Field::new( + "col", + DataType::LargeUtf8, + true, + )]), + metadata: Default::default(), + }, + } + } + fn distinct(mut self) -> Self { + self.distinct = true; + self + } + + fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self { + self.ordering.extend([PhysicalSortExpr::new( + Arc::new( + Column::new_with_schema(col, &self.schema) + .expect("column not available in schema"), + ), + sort_options, + )]); + self + } + + fn build(&self) -> Result> { + StringAgg::new().accumulator(AccumulatorArgs { + return_type: &DataType::LargeUtf8, + schema: &self.schema, + ignore_nulls: false, + ordering_req: &self.ordering, + is_reversed: false, + name: "", + is_distinct: self.distinct, + exprs: &[ + Arc::new(Column::new("col", 0)), + Arc::new(Literal::new(ScalarValue::Utf8(Some(self.sep.to_string())))), + ], + }) + } + + fn build_two(&self) -> Result<(Box, Box)> { + Ok((self.build()?, self.build()?)) + } + } + + fn some_str(value: ScalarValue) -> String { + str(value) + .expect("ScalarValue was not a String") + .expect("ScalarValue was None") + } + + fn some_str_sorted(value: ScalarValue, sep: &str) -> String { + let value = some_str(value); + let mut parts: Vec<&str> = value.split(sep).collect(); + parts.sort(); + parts.join(sep) + } + + fn str(value: ScalarValue) -> Result> { + match value { + ScalarValue::LargeUtf8(v) => Ok(v), + _ => internal_err!( + "Expected ScalarValue::LargeUtf8, got {}", + value.data_type() + ), + } + } + + fn data(list: [&str; N]) -> ArrayRef { + Arc::new(LargeStringArray::from(list.to_vec())) + } + + fn merge( + mut acc1: Box, + mut acc2: Box, + ) -> Result> { + let intermediate_state = acc2.state().and_then(|e| { + e.iter() + .map(|v| v.to_array()) + .collect::>>() + })?; + acc1.merge_batch(&intermediate_state)?; + Ok(acc1) } } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 004846bc369e..a55ac079aa74 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -5620,6 +5620,11 @@ SELECT STRING_AGG(column1, '|') FROM (values (''), (null), ('')); ---- | +query T +SELECT STRING_AGG(DISTINCT column1, '|') FROM (values (''), (null), ('')); +---- +(empty) + statement ok CREATE TABLE strings(g INTEGER, x VARCHAR, y VARCHAR) @@ -5641,6 +5646,22 @@ SELECT STRING_AGG(x,',') FROM strings WHERE g > 100 ---- NULL +query T +SELECT STRING_AGG(DISTINCT x,',') FROM strings WHERE g > 100 +---- +NULL + +query T +SELECT STRING_AGG(DISTINCT x,'|' ORDER BY x) FROM strings +---- +a|b|i|j|p|x|y|z + +query error This feature is not implemented: The second argument of the string_agg function must be a string literal +SELECT STRING_AGG(DISTINCT x,y) FROM strings + +query error Execution error: In an aggregate with DISTINCT, ORDER BY expressions must appear in argument list +SELECT STRING_AGG(DISTINCT x,'|' ORDER BY y) FROM strings + statement ok drop table strings @@ -5655,6 +5676,17 @@ FROM my_data ---- text1, text1, text1 +query T +WITH my_data as ( +SELECT 'text1'::varchar(1000) as my_column union all +SELECT 'text1'::varchar(1000) as my_column union all +SELECT 'text1'::varchar(1000) as my_column +) +SELECT string_agg(DISTINCT my_column,', ') as my_string_agg +FROM my_data +---- +text1 + query T WITH my_data as ( SELECT 1 as dummy, 'text1'::varchar(1000) as my_column union all @@ -5667,6 +5699,18 @@ GROUP BY dummy ---- text1, text1, text1 +query T +WITH my_data as ( +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column union all +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column union all +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column +) +SELECT string_agg(DISTINCT my_column,', ') as my_string_agg +FROM my_data +GROUP BY dummy +---- +text1 + # Tests for aggregating with NaN values statement ok CREATE TABLE float_table ( diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 496f24abf6ed..32401529da54 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -649,7 +649,7 @@ datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestam datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestamp(Second, None) SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestamp(Second, Some("+TZ")) SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) datafusion public rank datafusion public rank FUNCTION true NULL WINDOW Returns the rank of the current row within its partition, allowing gaps between ranks. This function provides a ranking similar to `row_number`, but skips ranks for identical values. rank() -datafusion public string_agg datafusion public string_agg FUNCTION true LargeUtf8 AGGREGATE Concatenates the values of string expressions and places separator values between them. string_agg(expression, delimiter) +datafusion public string_agg datafusion public string_agg FUNCTION true LargeUtf8 AGGREGATE Concatenates the values of string expressions and places separator values between them. If ordering is required, strings are concatenated in the specified order. This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the first argument expression. string_agg([DISTINCT] expression, delimiter [ORDER BY expression]) query B select is_deterministic from information_schema.routines where routine_name = 'now'; @@ -717,6 +717,15 @@ datafusion public string_agg 1 OUT NULL LargeUtf8 NULL false 1 datafusion public string_agg 1 IN expression LargeUtf8 NULL false 2 datafusion public string_agg 2 IN delimiter Null NULL false 2 datafusion public string_agg 1 OUT NULL LargeUtf8 NULL false 2 +datafusion public string_agg 1 IN expression Utf8 NULL false 3 +datafusion public string_agg 2 IN delimiter Utf8 NULL false 3 +datafusion public string_agg 1 OUT NULL LargeUtf8 NULL false 3 +datafusion public string_agg 1 IN expression Utf8 NULL false 4 +datafusion public string_agg 2 IN delimiter LargeUtf8 NULL false 4 +datafusion public string_agg 1 OUT NULL LargeUtf8 NULL false 4 +datafusion public string_agg 1 IN expression Utf8 NULL false 5 +datafusion public string_agg 2 IN delimiter Null NULL false 5 +datafusion public string_agg 1 OUT NULL LargeUtf8 NULL false 5 # test variable length arguments query TTTBI rowsort diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index c7f5c5f67442..684db52e6323 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -371,10 +371,10 @@ min(expression) ### `string_agg` -Concatenates the values of string expressions and places separator values between them. +Concatenates the values of string expressions and places separator values between them. If ordering is required, strings are concatenated in the specified order. This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the first argument expression. ```sql -string_agg(expression, delimiter) +string_agg([DISTINCT] expression, delimiter [ORDER BY expression]) ``` #### Arguments @@ -390,7 +390,21 @@ string_agg(expression, delimiter) +--------------------------+ | names_list | +--------------------------+ -| Alice, Bob, Charlie | +| Alice, Bob, Bob, Charlie | ++--------------------------+ +> SELECT string_agg(name, ', ' ORDER BY name DESC) AS names_list + FROM employee; ++--------------------------+ +| names_list | ++--------------------------+ +| Charlie, Bob, Bob, Alice | ++--------------------------+ +> SELECT string_agg(DISTINCT name, ', ' ORDER BY name DESC) AS names_list + FROM employee; ++--------------------------+ +| names_list | ++--------------------------+ +| Charlie, Bob, Alice | +--------------------------+ ```