Skip to content

Commit 6b4bbd0

Browse files
authored
sum(distinct) support (#2405)
* sum(distinct) support * fix clippy * merge state() code logic * revise annotation * remove u64->i63 coercion
1 parent 21a2973 commit 6b4bbd0

File tree

6 files changed

+367
-5
lines changed

6 files changed

+367
-5
lines changed

datafusion/core/tests/sql/aggregates.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,6 +1235,63 @@ async fn simple_avg() -> Result<()> {
12351235
Ok(())
12361236
}
12371237

1238+
#[tokio::test]
1239+
async fn query_sum_distinct() -> Result<()> {
1240+
let schema = Arc::new(Schema::new(vec![
1241+
Field::new("c1", DataType::Int64, true),
1242+
Field::new("c2", DataType::Int64, true),
1243+
]));
1244+
1245+
let data = RecordBatch::try_new(
1246+
schema.clone(),
1247+
vec![
1248+
Arc::new(Int64Array::from(vec![
1249+
Some(0),
1250+
Some(1),
1251+
None,
1252+
Some(3),
1253+
Some(3),
1254+
])),
1255+
Arc::new(Int64Array::from(vec![
1256+
None,
1257+
Some(1),
1258+
Some(1),
1259+
Some(2),
1260+
Some(2),
1261+
])),
1262+
],
1263+
)?;
1264+
1265+
let table = MemTable::try_new(schema, vec![vec![data]])?;
1266+
let ctx = SessionContext::new();
1267+
ctx.register_table("test", Arc::new(table))?;
1268+
1269+
// 2 different aggregate functions: avg and sum(distinct)
1270+
let sql = "SELECT AVG(c1), SUM(DISTINCT c2) FROM test";
1271+
let actual = execute_to_batches(&ctx, sql).await;
1272+
let expected = vec![
1273+
"+--------------+-----------------------+",
1274+
"| AVG(test.c1) | SUM(DISTINCT test.c2) |",
1275+
"+--------------+-----------------------+",
1276+
"| 1.75 | 3 |",
1277+
"+--------------+-----------------------+",
1278+
];
1279+
assert_batches_eq!(expected, &actual);
1280+
1281+
// 2 sum(distinct) functions
1282+
let sql = "SELECT SUM(DISTINCT c1), SUM(DISTINCT c2) FROM test";
1283+
let actual = execute_to_batches(&ctx, sql).await;
1284+
let expected = vec![
1285+
"+-----------------------+-----------------------+",
1286+
"| SUM(DISTINCT test.c1) | SUM(DISTINCT test.c2) |",
1287+
"+-----------------------+-----------------------+",
1288+
"| 4 | 3 |",
1289+
"+-----------------------+-----------------------+",
1290+
];
1291+
assert_batches_eq!(expected, &actual);
1292+
Ok(())
1293+
}
1294+
12381295
#[tokio::test]
12391296
async fn query_count_distinct() -> Result<()> {
12401297
let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)]));

datafusion/physical-expr/src/aggregate/build_in.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ pub fn create_aggregate_expr(
8787
name,
8888
return_type,
8989
)),
90-
(AggregateFunction::Sum, true) => {
91-
return Err(DataFusionError::NotImplemented(
92-
"SUM(DISTINCT) aggregations are not available".to_string(),
93-
));
94-
}
90+
(AggregateFunction::Sum, true) => Arc::new(expressions::DistinctSum::new(
91+
vec![coerced_phy_exprs[0].clone()],
92+
name,
93+
return_type,
94+
)),
9595
(AggregateFunction::ApproxDistinct, _) => {
9696
Arc::new(expressions::ApproxDistinct::new(
9797
coerced_phy_exprs[0].clone(),

datafusion/physical-expr/src/aggregate/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ mod hyperloglog;
4242
pub(crate) mod stats;
4343
pub(crate) mod stddev;
4444
pub(crate) mod sum;
45+
pub(crate) mod sum_distinct;
4546
mod tdigest;
4647
pub(crate) mod variance;
4748

datafusion/physical-expr/src/aggregate/sum.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,15 @@ pub(crate) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
297297
(ScalarValue::Int64(lhs), ScalarValue::Int8(rhs)) => {
298298
typed_sum!(lhs, rhs, Int64, i64)
299299
}
300+
(ScalarValue::Int64(lhs), ScalarValue::UInt32(rhs)) => {
301+
typed_sum!(lhs, rhs, Int64, i64)
302+
}
303+
(ScalarValue::Int64(lhs), ScalarValue::UInt16(rhs)) => {
304+
typed_sum!(lhs, rhs, Int64, i64)
305+
}
306+
(ScalarValue::Int64(lhs), ScalarValue::UInt8(rhs)) => {
307+
typed_sum!(lhs, rhs, Int64, i64)
308+
}
300309
e => {
301310
return Err(DataFusionError::Internal(format!(
302311
"Sum is not expected to receive a scalar {:?}",
Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::aggregate::sum;
19+
use crate::expressions::format_state_name;
20+
use arrow::datatypes::{DataType, Field};
21+
use std::any::Any;
22+
use std::fmt::Debug;
23+
use std::sync::Arc;
24+
25+
use ahash::RandomState;
26+
use arrow::array::{Array, ArrayRef};
27+
use std::collections::HashSet;
28+
29+
use crate::{AggregateExpr, PhysicalExpr};
30+
use datafusion_common::ScalarValue;
31+
use datafusion_common::{DataFusionError, Result};
32+
use datafusion_expr::Accumulator;
33+
34+
/// Expression for a SUM(DISTINCT) aggregation.
35+
#[derive(Debug)]
36+
pub struct DistinctSum {
37+
/// Column name
38+
name: String,
39+
/// The DataType for the final sum
40+
data_type: DataType,
41+
/// The input arguments, only contains 1 item for sum
42+
exprs: Vec<Arc<dyn PhysicalExpr>>,
43+
}
44+
45+
impl DistinctSum {
46+
/// Create a SUM(DISTINCT) aggregate function.
47+
pub fn new(
48+
exprs: Vec<Arc<dyn PhysicalExpr>>,
49+
name: String,
50+
data_type: DataType,
51+
) -> Self {
52+
Self {
53+
name,
54+
data_type,
55+
exprs,
56+
}
57+
}
58+
}
59+
60+
impl AggregateExpr for DistinctSum {
61+
fn as_any(&self) -> &dyn Any {
62+
self
63+
}
64+
65+
fn field(&self) -> Result<Field> {
66+
Ok(Field::new(&self.name, self.data_type.clone(), true))
67+
}
68+
69+
fn state_fields(&self) -> Result<Vec<Field>> {
70+
// State field is a List which stores items to rebuild hash set.
71+
Ok(vec![Field::new(
72+
&format_state_name(&self.name, "sum distinct"),
73+
DataType::List(Box::new(Field::new("item", self.data_type.clone(), true))),
74+
false,
75+
)])
76+
}
77+
78+
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
79+
self.exprs.clone()
80+
}
81+
82+
fn name(&self) -> &str {
83+
&self.name
84+
}
85+
86+
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
87+
Ok(Box::new(DistinctSumAccumulator::try_new(&self.data_type)?))
88+
}
89+
}
90+
91+
#[derive(Debug)]
92+
struct DistinctSumAccumulator {
93+
hash_values: HashSet<ScalarValue, RandomState>,
94+
data_type: DataType,
95+
}
96+
impl DistinctSumAccumulator {
97+
pub fn try_new(data_type: &DataType) -> Result<Self> {
98+
Ok(Self {
99+
hash_values: HashSet::default(),
100+
data_type: data_type.clone(),
101+
})
102+
}
103+
104+
fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
105+
values.iter().for_each(|v| {
106+
// If the value is NULL, it is not included in the final sum.
107+
if !v.is_null() {
108+
self.hash_values.insert(v.clone());
109+
}
110+
});
111+
112+
Ok(())
113+
}
114+
115+
fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
116+
if states.is_empty() {
117+
return Ok(());
118+
}
119+
120+
states.iter().try_for_each(|state| match state {
121+
ScalarValue::List(Some(values), _) => self.update(values.as_ref()),
122+
_ => Err(DataFusionError::Internal(format!(
123+
"Unexpected accumulator state {:?}",
124+
state
125+
))),
126+
})
127+
}
128+
}
129+
130+
impl Accumulator for DistinctSumAccumulator {
131+
fn state(&self) -> Result<Vec<ScalarValue>> {
132+
// 1. Stores aggregate state in `ScalarValue::List`
133+
// 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set
134+
let state_out = {
135+
let mut distinct_values = Box::new(Vec::new());
136+
let data_type = Box::new(self.data_type.clone());
137+
self.hash_values
138+
.iter()
139+
.for_each(|distinct_value| distinct_values.push(distinct_value.clone()));
140+
vec![ScalarValue::List(Some(distinct_values), data_type)]
141+
};
142+
Ok(state_out)
143+
}
144+
145+
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
146+
if values.is_empty() {
147+
return Ok(());
148+
}
149+
150+
let scalar_values = (0..values[0].len())
151+
.map(|index| ScalarValue::try_from_array(&values[0], index))
152+
.collect::<Result<Vec<_>>>()?;
153+
self.update(&scalar_values)
154+
}
155+
156+
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
157+
if states.is_empty() {
158+
return Ok(());
159+
}
160+
161+
(0..states[0].len()).try_for_each(|index| {
162+
let v = states
163+
.iter()
164+
.map(|array| ScalarValue::try_from_array(array, index))
165+
.collect::<Result<Vec<_>>>()?;
166+
self.merge(&v)
167+
})
168+
}
169+
170+
fn evaluate(&self) -> Result<ScalarValue> {
171+
let mut sum_value = ScalarValue::try_from(&self.data_type)?;
172+
self.hash_values.iter().for_each(|distinct_value| {
173+
sum_value = sum::sum(&sum_value, distinct_value).unwrap()
174+
});
175+
Ok(sum_value)
176+
}
177+
}
178+
179+
#[cfg(test)]
180+
mod tests {
181+
use super::*;
182+
use crate::expressions::col;
183+
use crate::expressions::tests::aggregate;
184+
use arrow::record_batch::RecordBatch;
185+
use arrow::{array::*, datatypes::*};
186+
use datafusion_common::Result;
187+
188+
fn run_update_batch(
189+
return_type: DataType,
190+
arrays: &[ArrayRef],
191+
) -> Result<(Vec<ScalarValue>, ScalarValue)> {
192+
let agg = DistinctSum::new(vec![], String::from("__col_name__"), return_type);
193+
194+
let mut accum = agg.create_accumulator()?;
195+
accum.update_batch(arrays)?;
196+
197+
Ok((accum.state()?, accum.evaluate()?))
198+
}
199+
200+
macro_rules! generic_test_sum_distinct {
201+
($ARRAY:expr, $DATATYPE:expr, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{
202+
let schema = Schema::new(vec![Field::new("a", $DATATYPE, false)]);
203+
204+
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?;
205+
206+
let agg = Arc::new(DistinctSum::new(
207+
vec![col("a", &schema)?],
208+
"count_distinct_a".to_string(),
209+
$EXPECTED_DATATYPE,
210+
));
211+
let actual = aggregate(&batch, agg)?;
212+
let expected = ScalarValue::from($EXPECTED);
213+
214+
assert_eq!(expected, actual);
215+
216+
Ok(())
217+
}};
218+
}
219+
220+
#[test]
221+
fn sum_distinct_update_batch() -> Result<()> {
222+
let array_int64: ArrayRef = Arc::new(Int64Array::from(vec![1, 1, 3]));
223+
let arrays = vec![array_int64];
224+
let (states, result) = run_update_batch(DataType::Int64, &arrays)?;
225+
226+
assert_eq!(states.len(), 1);
227+
assert_eq!(result, ScalarValue::Int64(Some(4)));
228+
229+
Ok(())
230+
}
231+
232+
#[test]
233+
fn sum_distinct_i32_with_nulls() -> Result<()> {
234+
let array = Arc::new(Int32Array::from(vec![
235+
Some(1),
236+
Some(1),
237+
None,
238+
Some(2),
239+
Some(2),
240+
Some(3),
241+
]));
242+
generic_test_sum_distinct!(
243+
array,
244+
DataType::Int32,
245+
ScalarValue::from(6i64),
246+
DataType::Int64
247+
)
248+
}
249+
250+
#[test]
251+
fn sum_distinct_u32_with_nulls() -> Result<()> {
252+
let array: ArrayRef = Arc::new(UInt32Array::from(vec![
253+
Some(1_u32),
254+
Some(1_u32),
255+
Some(3_u32),
256+
Some(3_u32),
257+
None,
258+
]));
259+
generic_test_sum_distinct!(
260+
array,
261+
DataType::UInt32,
262+
ScalarValue::from(4i64),
263+
DataType::Int64
264+
)
265+
}
266+
267+
#[test]
268+
fn sum_distinct_f64() -> Result<()> {
269+
let array: ArrayRef =
270+
Arc::new(Float64Array::from(vec![1_f64, 1_f64, 3_f64, 3_f64, 3_f64]));
271+
generic_test_sum_distinct!(
272+
array,
273+
DataType::Float64,
274+
ScalarValue::from(4_f64),
275+
DataType::Float64
276+
)
277+
}
278+
279+
#[test]
280+
fn sum_distinct_decimal_with_nulls() -> Result<()> {
281+
let array: ArrayRef = Arc::new(
282+
(1..6)
283+
.map(|i| if i == 2 { None } else { Some(i % 2) })
284+
.collect::<DecimalArray>()
285+
.with_precision_and_scale(35, 0)?,
286+
);
287+
generic_test_sum_distinct!(
288+
array,
289+
DataType::Decimal(35, 0),
290+
ScalarValue::Decimal128(Some(1), 38, 0),
291+
DataType::Decimal(38, 0)
292+
)
293+
}
294+
}

0 commit comments

Comments
 (0)