Skip to content

Commit 90de12a

Browse files
authored
Add stddev operator (#1525)
* Initial implementation of variance * get simple f64 type tests working * add math functions to ScalarValue, some tests * add to expressions and tests * add more tests * add test for ScalarValue add * add tests for scalar arithmetic * add test, finish variance * fix warnings * add more sql tests * add stddev and tests * add the hooks and expression * add more tests * fix lint and clipy * address comments and fix test errors * address comments * add population and sample for variance and stddev * address more comments * fmt * add test for less than 2 values * fix inconsistency in the merge logic * fix lint and clipy
1 parent d6d90e9 commit 90de12a

File tree

12 files changed

+1987
-5
lines changed

12 files changed

+1987
-5
lines changed

ballista/rust/core/proto/ballista.proto

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,10 @@ enum AggregateFunction {
169169
COUNT = 4;
170170
APPROX_DISTINCT = 5;
171171
ARRAY_AGG = 6;
172+
VARIANCE=7;
173+
VARIANCE_POP=8;
174+
STDDEV=9;
175+
STDDEV_POP=10;
172176
}
173177

174178
message AggregateExprNode {

ballista/rust/core/src/serde/logical_plan/to_proto.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,14 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
10261026
AggregateFunction::Sum => protobuf::AggregateFunction::Sum,
10271027
AggregateFunction::Avg => protobuf::AggregateFunction::Avg,
10281028
AggregateFunction::Count => protobuf::AggregateFunction::Count,
1029+
AggregateFunction::Variance => protobuf::AggregateFunction::Variance,
1030+
AggregateFunction::VariancePop => {
1031+
protobuf::AggregateFunction::VariancePop
1032+
}
1033+
AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev,
1034+
AggregateFunction::StddevPop => {
1035+
protobuf::AggregateFunction::StddevPop
1036+
}
10291037
};
10301038

10311039
let arg = &args[0];
@@ -1256,6 +1264,10 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction {
12561264
AggregateFunction::Count => Self::Count,
12571265
AggregateFunction::ApproxDistinct => Self::ApproxDistinct,
12581266
AggregateFunction::ArrayAgg => Self::ArrayAgg,
1267+
AggregateFunction::Variance => Self::Variance,
1268+
AggregateFunction::VariancePop => Self::VariancePop,
1269+
AggregateFunction::Stddev => Self::Stddev,
1270+
AggregateFunction::StddevPop => Self::StddevPop,
12591271
}
12601272
}
12611273
}

ballista/rust/core/src/serde/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@ impl From<protobuf::AggregateFunction> for AggregateFunction {
119119
AggregateFunction::ApproxDistinct
120120
}
121121
protobuf::AggregateFunction::ArrayAgg => AggregateFunction::ArrayAgg,
122+
protobuf::AggregateFunction::Variance => AggregateFunction::Variance,
123+
protobuf::AggregateFunction::VariancePop => AggregateFunction::VariancePop,
124+
protobuf::AggregateFunction::Stddev => AggregateFunction::Stddev,
125+
protobuf::AggregateFunction::StddevPop => AggregateFunction::StddevPop,
122126
}
123127
}
124128
}

datafusion/src/optimizer/simplify_expressions.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ impl ConstEvaluator {
359359
}
360360

361361
/// Internal helper to evaluates an Expr
362-
fn evaluate_to_scalar(&self, expr: Expr) -> Result<ScalarValue> {
362+
pub(crate) fn evaluate_to_scalar(&self, expr: Expr) -> Result<ScalarValue> {
363363
if let Expr::Literal(s) = expr {
364364
return Ok(s);
365365
}

datafusion/src/physical_plan/aggregates.rs

Lines changed: 274 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ use crate::physical_plan::coercion_rule::aggregate_rule::{coerce_exprs, coerce_t
3535
use crate::physical_plan::distinct_expressions;
3636
use crate::physical_plan::expressions;
3737
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
38-
use expressions::{avg_return_type, sum_return_type};
38+
use expressions::{
39+
avg_return_type, stddev_return_type, sum_return_type, variance_return_type,
40+
};
3941
use std::{fmt, str::FromStr, sync::Arc};
4042

4143
/// the implementation of an aggregate function
@@ -64,6 +66,14 @@ pub enum AggregateFunction {
6466
ApproxDistinct,
6567
/// array_agg
6668
ArrayAgg,
69+
/// Variance (Sample)
70+
Variance,
71+
/// Variance (Population)
72+
VariancePop,
73+
/// Standard Deviation (Sample)
74+
Stddev,
75+
/// Standard Deviation (Population)
76+
StddevPop,
6777
}
6878

6979
impl fmt::Display for AggregateFunction {
@@ -84,6 +94,12 @@ impl FromStr for AggregateFunction {
8494
"sum" => AggregateFunction::Sum,
8595
"approx_distinct" => AggregateFunction::ApproxDistinct,
8696
"array_agg" => AggregateFunction::ArrayAgg,
97+
"var" => AggregateFunction::Variance,
98+
"var_samp" => AggregateFunction::Variance,
99+
"var_pop" => AggregateFunction::VariancePop,
100+
"stddev" => AggregateFunction::Stddev,
101+
"stddev_samp" => AggregateFunction::Stddev,
102+
"stddev_pop" => AggregateFunction::StddevPop,
87103
_ => {
88104
return Err(DataFusionError::Plan(format!(
89105
"There is no built-in function named {}",
@@ -116,6 +132,10 @@ pub fn return_type(
116132
Ok(coerced_data_types[0].clone())
117133
}
118134
AggregateFunction::Sum => sum_return_type(&coerced_data_types[0]),
135+
AggregateFunction::Variance => variance_return_type(&coerced_data_types[0]),
136+
AggregateFunction::VariancePop => variance_return_type(&coerced_data_types[0]),
137+
AggregateFunction::Stddev => stddev_return_type(&coerced_data_types[0]),
138+
AggregateFunction::StddevPop => stddev_return_type(&coerced_data_types[0]),
119139
AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]),
120140
AggregateFunction::ArrayAgg => Ok(DataType::List(Box::new(Field::new(
121141
"item",
@@ -212,6 +232,48 @@ pub fn create_aggregate_expr(
212232
"AVG(DISTINCT) aggregations are not available".to_string(),
213233
));
214234
}
235+
(AggregateFunction::Variance, false) => Arc::new(expressions::Variance::new(
236+
coerced_phy_exprs[0].clone(),
237+
name,
238+
return_type,
239+
)),
240+
(AggregateFunction::Variance, true) => {
241+
return Err(DataFusionError::NotImplemented(
242+
"VAR(DISTINCT) aggregations are not available".to_string(),
243+
));
244+
}
245+
(AggregateFunction::VariancePop, false) => {
246+
Arc::new(expressions::VariancePop::new(
247+
coerced_phy_exprs[0].clone(),
248+
name,
249+
return_type,
250+
))
251+
}
252+
(AggregateFunction::VariancePop, true) => {
253+
return Err(DataFusionError::NotImplemented(
254+
"VAR_POP(DISTINCT) aggregations are not available".to_string(),
255+
));
256+
}
257+
(AggregateFunction::Stddev, false) => Arc::new(expressions::Stddev::new(
258+
coerced_phy_exprs[0].clone(),
259+
name,
260+
return_type,
261+
)),
262+
(AggregateFunction::Stddev, true) => {
263+
return Err(DataFusionError::NotImplemented(
264+
"STDDEV(DISTINCT) aggregations are not available".to_string(),
265+
));
266+
}
267+
(AggregateFunction::StddevPop, false) => Arc::new(expressions::StddevPop::new(
268+
coerced_phy_exprs[0].clone(),
269+
name,
270+
return_type,
271+
)),
272+
(AggregateFunction::StddevPop, true) => {
273+
return Err(DataFusionError::NotImplemented(
274+
"STDDEV_POP(DISTINCT) aggregations are not available".to_string(),
275+
));
276+
}
215277
})
216278
}
217279

@@ -256,7 +318,12 @@ pub fn signature(fun: &AggregateFunction) -> Signature {
256318
.collect::<Vec<_>>();
257319
Signature::uniform(1, valid, Volatility::Immutable)
258320
}
259-
AggregateFunction::Avg | AggregateFunction::Sum => {
321+
AggregateFunction::Avg
322+
| AggregateFunction::Sum
323+
| AggregateFunction::Variance
324+
| AggregateFunction::VariancePop
325+
| AggregateFunction::Stddev
326+
| AggregateFunction::StddevPop => {
260327
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
261328
}
262329
}
@@ -267,7 +334,7 @@ mod tests {
267334
use super::*;
268335
use crate::error::Result;
269336
use crate::physical_plan::expressions::{
270-
ApproxDistinct, ArrayAgg, Avg, Count, Max, Min, Sum,
337+
ApproxDistinct, ArrayAgg, Avg, Count, Max, Min, Stddev, Sum, Variance,
271338
};
272339

273340
#[test]
@@ -450,6 +517,158 @@ mod tests {
450517
Ok(())
451518
}
452519

520+
#[test]
521+
fn test_variance_expr() -> Result<()> {
522+
let funcs = vec![AggregateFunction::Variance];
523+
let data_types = vec![
524+
DataType::UInt32,
525+
DataType::UInt64,
526+
DataType::Int32,
527+
DataType::Int64,
528+
DataType::Float32,
529+
DataType::Float64,
530+
];
531+
for fun in funcs {
532+
for data_type in &data_types {
533+
let input_schema =
534+
Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
535+
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new(
536+
expressions::Column::new_with_schema("c1", &input_schema).unwrap(),
537+
)];
538+
let result_agg_phy_exprs = create_aggregate_expr(
539+
&fun,
540+
false,
541+
&input_phy_exprs[0..1],
542+
&input_schema,
543+
"c1",
544+
)?;
545+
if fun == AggregateFunction::Variance {
546+
assert!(result_agg_phy_exprs.as_any().is::<Variance>());
547+
assert_eq!("c1", result_agg_phy_exprs.name());
548+
assert_eq!(
549+
Field::new("c1", DataType::Float64, true),
550+
result_agg_phy_exprs.field().unwrap()
551+
)
552+
}
553+
}
554+
}
555+
Ok(())
556+
}
557+
558+
#[test]
559+
fn test_var_pop_expr() -> Result<()> {
560+
let funcs = vec![AggregateFunction::VariancePop];
561+
let data_types = vec![
562+
DataType::UInt32,
563+
DataType::UInt64,
564+
DataType::Int32,
565+
DataType::Int64,
566+
DataType::Float32,
567+
DataType::Float64,
568+
];
569+
for fun in funcs {
570+
for data_type in &data_types {
571+
let input_schema =
572+
Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
573+
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new(
574+
expressions::Column::new_with_schema("c1", &input_schema).unwrap(),
575+
)];
576+
let result_agg_phy_exprs = create_aggregate_expr(
577+
&fun,
578+
false,
579+
&input_phy_exprs[0..1],
580+
&input_schema,
581+
"c1",
582+
)?;
583+
if fun == AggregateFunction::Variance {
584+
assert!(result_agg_phy_exprs.as_any().is::<Variance>());
585+
assert_eq!("c1", result_agg_phy_exprs.name());
586+
assert_eq!(
587+
Field::new("c1", DataType::Float64, true),
588+
result_agg_phy_exprs.field().unwrap()
589+
)
590+
}
591+
}
592+
}
593+
Ok(())
594+
}
595+
596+
#[test]
597+
fn test_stddev_expr() -> Result<()> {
598+
let funcs = vec![AggregateFunction::Stddev];
599+
let data_types = vec![
600+
DataType::UInt32,
601+
DataType::UInt64,
602+
DataType::Int32,
603+
DataType::Int64,
604+
DataType::Float32,
605+
DataType::Float64,
606+
];
607+
for fun in funcs {
608+
for data_type in &data_types {
609+
let input_schema =
610+
Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
611+
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new(
612+
expressions::Column::new_with_schema("c1", &input_schema).unwrap(),
613+
)];
614+
let result_agg_phy_exprs = create_aggregate_expr(
615+
&fun,
616+
false,
617+
&input_phy_exprs[0..1],
618+
&input_schema,
619+
"c1",
620+
)?;
621+
if fun == AggregateFunction::Variance {
622+
assert!(result_agg_phy_exprs.as_any().is::<Stddev>());
623+
assert_eq!("c1", result_agg_phy_exprs.name());
624+
assert_eq!(
625+
Field::new("c1", DataType::Float64, true),
626+
result_agg_phy_exprs.field().unwrap()
627+
)
628+
}
629+
}
630+
}
631+
Ok(())
632+
}
633+
634+
#[test]
635+
fn test_stddev_pop_expr() -> Result<()> {
636+
let funcs = vec![AggregateFunction::StddevPop];
637+
let data_types = vec![
638+
DataType::UInt32,
639+
DataType::UInt64,
640+
DataType::Int32,
641+
DataType::Int64,
642+
DataType::Float32,
643+
DataType::Float64,
644+
];
645+
for fun in funcs {
646+
for data_type in &data_types {
647+
let input_schema =
648+
Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
649+
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new(
650+
expressions::Column::new_with_schema("c1", &input_schema).unwrap(),
651+
)];
652+
let result_agg_phy_exprs = create_aggregate_expr(
653+
&fun,
654+
false,
655+
&input_phy_exprs[0..1],
656+
&input_schema,
657+
"c1",
658+
)?;
659+
if fun == AggregateFunction::Variance {
660+
assert!(result_agg_phy_exprs.as_any().is::<Stddev>());
661+
assert_eq!("c1", result_agg_phy_exprs.name());
662+
assert_eq!(
663+
Field::new("c1", DataType::Float64, true),
664+
result_agg_phy_exprs.field().unwrap()
665+
)
666+
}
667+
}
668+
}
669+
Ok(())
670+
}
671+
453672
#[test]
454673
fn test_min_max() -> Result<()> {
455674
let observed = return_type(&AggregateFunction::Min, &[DataType::Utf8])?;
@@ -544,4 +763,56 @@ mod tests {
544763
let observed = return_type(&AggregateFunction::Avg, &[DataType::Utf8]);
545764
assert!(observed.is_err());
546765
}
766+
767+
#[test]
768+
fn test_variance_return_type() -> Result<()> {
769+
let observed = return_type(&AggregateFunction::Variance, &[DataType::Float32])?;
770+
assert_eq!(DataType::Float64, observed);
771+
772+
let observed = return_type(&AggregateFunction::Variance, &[DataType::Float64])?;
773+
assert_eq!(DataType::Float64, observed);
774+
775+
let observed = return_type(&AggregateFunction::Variance, &[DataType::Int32])?;
776+
assert_eq!(DataType::Float64, observed);
777+
778+
let observed = return_type(&AggregateFunction::Variance, &[DataType::UInt32])?;
779+
assert_eq!(DataType::Float64, observed);
780+
781+
let observed = return_type(&AggregateFunction::Variance, &[DataType::Int64])?;
782+
assert_eq!(DataType::Float64, observed);
783+
784+
Ok(())
785+
}
786+
787+
#[test]
788+
fn test_variance_no_utf8() {
789+
let observed = return_type(&AggregateFunction::Variance, &[DataType::Utf8]);
790+
assert!(observed.is_err());
791+
}
792+
793+
#[test]
794+
fn test_stddev_return_type() -> Result<()> {
795+
let observed = return_type(&AggregateFunction::Stddev, &[DataType::Float32])?;
796+
assert_eq!(DataType::Float64, observed);
797+
798+
let observed = return_type(&AggregateFunction::Stddev, &[DataType::Float64])?;
799+
assert_eq!(DataType::Float64, observed);
800+
801+
let observed = return_type(&AggregateFunction::Stddev, &[DataType::Int32])?;
802+
assert_eq!(DataType::Float64, observed);
803+
804+
let observed = return_type(&AggregateFunction::Stddev, &[DataType::UInt32])?;
805+
assert_eq!(DataType::Float64, observed);
806+
807+
let observed = return_type(&AggregateFunction::Stddev, &[DataType::Int64])?;
808+
assert_eq!(DataType::Float64, observed);
809+
810+
Ok(())
811+
}
812+
813+
#[test]
814+
fn test_stddev_no_utf8() {
815+
let observed = return_type(&AggregateFunction::Stddev, &[DataType::Utf8]);
816+
assert!(observed.is_err());
817+
}
547818
}

0 commit comments

Comments
 (0)