@@ -35,7 +35,9 @@ use crate::physical_plan::coercion_rule::aggregate_rule::{coerce_exprs, coerce_t
35
35
use crate :: physical_plan:: distinct_expressions;
36
36
use crate :: physical_plan:: expressions;
37
37
use arrow:: datatypes:: { DataType , Field , Schema , TimeUnit } ;
38
- use expressions:: { avg_return_type, sum_return_type} ;
38
+ use expressions:: {
39
+ avg_return_type, stddev_return_type, sum_return_type, variance_return_type,
40
+ } ;
39
41
use std:: { fmt, str:: FromStr , sync:: Arc } ;
40
42
41
43
/// the implementation of an aggregate function
@@ -64,6 +66,14 @@ pub enum AggregateFunction {
64
66
ApproxDistinct ,
65
67
/// array_agg
66
68
ArrayAgg ,
69
+ /// Variance (Sample)
70
+ Variance ,
71
+ /// Variance (Population)
72
+ VariancePop ,
73
+ /// Standard Deviation (Sample)
74
+ Stddev ,
75
+ /// Standard Deviation (Population)
76
+ StddevPop ,
67
77
}
68
78
69
79
impl fmt:: Display for AggregateFunction {
@@ -84,6 +94,12 @@ impl FromStr for AggregateFunction {
84
94
"sum" => AggregateFunction :: Sum ,
85
95
"approx_distinct" => AggregateFunction :: ApproxDistinct ,
86
96
"array_agg" => AggregateFunction :: ArrayAgg ,
97
+ "var" => AggregateFunction :: Variance ,
98
+ "var_samp" => AggregateFunction :: Variance ,
99
+ "var_pop" => AggregateFunction :: VariancePop ,
100
+ "stddev" => AggregateFunction :: Stddev ,
101
+ "stddev_samp" => AggregateFunction :: Stddev ,
102
+ "stddev_pop" => AggregateFunction :: StddevPop ,
87
103
_ => {
88
104
return Err ( DataFusionError :: Plan ( format ! (
89
105
"There is no built-in function named {}" ,
@@ -116,6 +132,10 @@ pub fn return_type(
116
132
Ok ( coerced_data_types[ 0 ] . clone ( ) )
117
133
}
118
134
AggregateFunction :: Sum => sum_return_type ( & coerced_data_types[ 0 ] ) ,
135
+ AggregateFunction :: Variance => variance_return_type ( & coerced_data_types[ 0 ] ) ,
136
+ AggregateFunction :: VariancePop => variance_return_type ( & coerced_data_types[ 0 ] ) ,
137
+ AggregateFunction :: Stddev => stddev_return_type ( & coerced_data_types[ 0 ] ) ,
138
+ AggregateFunction :: StddevPop => stddev_return_type ( & coerced_data_types[ 0 ] ) ,
119
139
AggregateFunction :: Avg => avg_return_type ( & coerced_data_types[ 0 ] ) ,
120
140
AggregateFunction :: ArrayAgg => Ok ( DataType :: List ( Box :: new ( Field :: new (
121
141
"item" ,
@@ -212,6 +232,48 @@ pub fn create_aggregate_expr(
212
232
"AVG(DISTINCT) aggregations are not available" . to_string ( ) ,
213
233
) ) ;
214
234
}
235
+ ( AggregateFunction :: Variance , false ) => Arc :: new ( expressions:: Variance :: new (
236
+ coerced_phy_exprs[ 0 ] . clone ( ) ,
237
+ name,
238
+ return_type,
239
+ ) ) ,
240
+ ( AggregateFunction :: Variance , true ) => {
241
+ return Err ( DataFusionError :: NotImplemented (
242
+ "VAR(DISTINCT) aggregations are not available" . to_string ( ) ,
243
+ ) ) ;
244
+ }
245
+ ( AggregateFunction :: VariancePop , false ) => {
246
+ Arc :: new ( expressions:: VariancePop :: new (
247
+ coerced_phy_exprs[ 0 ] . clone ( ) ,
248
+ name,
249
+ return_type,
250
+ ) )
251
+ }
252
+ ( AggregateFunction :: VariancePop , true ) => {
253
+ return Err ( DataFusionError :: NotImplemented (
254
+ "VAR_POP(DISTINCT) aggregations are not available" . to_string ( ) ,
255
+ ) ) ;
256
+ }
257
+ ( AggregateFunction :: Stddev , false ) => Arc :: new ( expressions:: Stddev :: new (
258
+ coerced_phy_exprs[ 0 ] . clone ( ) ,
259
+ name,
260
+ return_type,
261
+ ) ) ,
262
+ ( AggregateFunction :: Stddev , true ) => {
263
+ return Err ( DataFusionError :: NotImplemented (
264
+ "STDDEV(DISTINCT) aggregations are not available" . to_string ( ) ,
265
+ ) ) ;
266
+ }
267
+ ( AggregateFunction :: StddevPop , false ) => Arc :: new ( expressions:: StddevPop :: new (
268
+ coerced_phy_exprs[ 0 ] . clone ( ) ,
269
+ name,
270
+ return_type,
271
+ ) ) ,
272
+ ( AggregateFunction :: StddevPop , true ) => {
273
+ return Err ( DataFusionError :: NotImplemented (
274
+ "STDDEV_POP(DISTINCT) aggregations are not available" . to_string ( ) ,
275
+ ) ) ;
276
+ }
215
277
} )
216
278
}
217
279
@@ -256,7 +318,12 @@ pub fn signature(fun: &AggregateFunction) -> Signature {
256
318
. collect :: < Vec < _ > > ( ) ;
257
319
Signature :: uniform ( 1 , valid, Volatility :: Immutable )
258
320
}
259
- AggregateFunction :: Avg | AggregateFunction :: Sum => {
321
+ AggregateFunction :: Avg
322
+ | AggregateFunction :: Sum
323
+ | AggregateFunction :: Variance
324
+ | AggregateFunction :: VariancePop
325
+ | AggregateFunction :: Stddev
326
+ | AggregateFunction :: StddevPop => {
260
327
Signature :: uniform ( 1 , NUMERICS . to_vec ( ) , Volatility :: Immutable )
261
328
}
262
329
}
@@ -267,7 +334,7 @@ mod tests {
267
334
use super :: * ;
268
335
use crate :: error:: Result ;
269
336
use crate :: physical_plan:: expressions:: {
270
- ApproxDistinct , ArrayAgg , Avg , Count , Max , Min , Sum ,
337
+ ApproxDistinct , ArrayAgg , Avg , Count , Max , Min , Stddev , Sum , Variance ,
271
338
} ;
272
339
273
340
#[ test]
@@ -450,6 +517,158 @@ mod tests {
450
517
Ok ( ( ) )
451
518
}
452
519
520
+ #[ test]
521
+ fn test_variance_expr ( ) -> Result < ( ) > {
522
+ let funcs = vec ! [ AggregateFunction :: Variance ] ;
523
+ let data_types = vec ! [
524
+ DataType :: UInt32 ,
525
+ DataType :: UInt64 ,
526
+ DataType :: Int32 ,
527
+ DataType :: Int64 ,
528
+ DataType :: Float32 ,
529
+ DataType :: Float64 ,
530
+ ] ;
531
+ for fun in funcs {
532
+ for data_type in & data_types {
533
+ let input_schema =
534
+ Schema :: new ( vec ! [ Field :: new( "c1" , data_type. clone( ) , true ) ] ) ;
535
+ let input_phy_exprs: Vec < Arc < dyn PhysicalExpr > > = vec ! [ Arc :: new(
536
+ expressions:: Column :: new_with_schema( "c1" , & input_schema) . unwrap( ) ,
537
+ ) ] ;
538
+ let result_agg_phy_exprs = create_aggregate_expr (
539
+ & fun,
540
+ false ,
541
+ & input_phy_exprs[ 0 ..1 ] ,
542
+ & input_schema,
543
+ "c1" ,
544
+ ) ?;
545
+ if fun == AggregateFunction :: Variance {
546
+ assert ! ( result_agg_phy_exprs. as_any( ) . is:: <Variance >( ) ) ;
547
+ assert_eq ! ( "c1" , result_agg_phy_exprs. name( ) ) ;
548
+ assert_eq ! (
549
+ Field :: new( "c1" , DataType :: Float64 , true ) ,
550
+ result_agg_phy_exprs. field( ) . unwrap( )
551
+ )
552
+ }
553
+ }
554
+ }
555
+ Ok ( ( ) )
556
+ }
557
+
558
+ #[ test]
559
+ fn test_var_pop_expr ( ) -> Result < ( ) > {
560
+ let funcs = vec ! [ AggregateFunction :: VariancePop ] ;
561
+ let data_types = vec ! [
562
+ DataType :: UInt32 ,
563
+ DataType :: UInt64 ,
564
+ DataType :: Int32 ,
565
+ DataType :: Int64 ,
566
+ DataType :: Float32 ,
567
+ DataType :: Float64 ,
568
+ ] ;
569
+ for fun in funcs {
570
+ for data_type in & data_types {
571
+ let input_schema =
572
+ Schema :: new ( vec ! [ Field :: new( "c1" , data_type. clone( ) , true ) ] ) ;
573
+ let input_phy_exprs: Vec < Arc < dyn PhysicalExpr > > = vec ! [ Arc :: new(
574
+ expressions:: Column :: new_with_schema( "c1" , & input_schema) . unwrap( ) ,
575
+ ) ] ;
576
+ let result_agg_phy_exprs = create_aggregate_expr (
577
+ & fun,
578
+ false ,
579
+ & input_phy_exprs[ 0 ..1 ] ,
580
+ & input_schema,
581
+ "c1" ,
582
+ ) ?;
583
+ if fun == AggregateFunction :: Variance {
584
+ assert ! ( result_agg_phy_exprs. as_any( ) . is:: <Variance >( ) ) ;
585
+ assert_eq ! ( "c1" , result_agg_phy_exprs. name( ) ) ;
586
+ assert_eq ! (
587
+ Field :: new( "c1" , DataType :: Float64 , true ) ,
588
+ result_agg_phy_exprs. field( ) . unwrap( )
589
+ )
590
+ }
591
+ }
592
+ }
593
+ Ok ( ( ) )
594
+ }
595
+
596
+ #[ test]
597
+ fn test_stddev_expr ( ) -> Result < ( ) > {
598
+ let funcs = vec ! [ AggregateFunction :: Stddev ] ;
599
+ let data_types = vec ! [
600
+ DataType :: UInt32 ,
601
+ DataType :: UInt64 ,
602
+ DataType :: Int32 ,
603
+ DataType :: Int64 ,
604
+ DataType :: Float32 ,
605
+ DataType :: Float64 ,
606
+ ] ;
607
+ for fun in funcs {
608
+ for data_type in & data_types {
609
+ let input_schema =
610
+ Schema :: new ( vec ! [ Field :: new( "c1" , data_type. clone( ) , true ) ] ) ;
611
+ let input_phy_exprs: Vec < Arc < dyn PhysicalExpr > > = vec ! [ Arc :: new(
612
+ expressions:: Column :: new_with_schema( "c1" , & input_schema) . unwrap( ) ,
613
+ ) ] ;
614
+ let result_agg_phy_exprs = create_aggregate_expr (
615
+ & fun,
616
+ false ,
617
+ & input_phy_exprs[ 0 ..1 ] ,
618
+ & input_schema,
619
+ "c1" ,
620
+ ) ?;
621
+ if fun == AggregateFunction :: Variance {
622
+ assert ! ( result_agg_phy_exprs. as_any( ) . is:: <Stddev >( ) ) ;
623
+ assert_eq ! ( "c1" , result_agg_phy_exprs. name( ) ) ;
624
+ assert_eq ! (
625
+ Field :: new( "c1" , DataType :: Float64 , true ) ,
626
+ result_agg_phy_exprs. field( ) . unwrap( )
627
+ )
628
+ }
629
+ }
630
+ }
631
+ Ok ( ( ) )
632
+ }
633
+
634
+ #[ test]
635
+ fn test_stddev_pop_expr ( ) -> Result < ( ) > {
636
+ let funcs = vec ! [ AggregateFunction :: StddevPop ] ;
637
+ let data_types = vec ! [
638
+ DataType :: UInt32 ,
639
+ DataType :: UInt64 ,
640
+ DataType :: Int32 ,
641
+ DataType :: Int64 ,
642
+ DataType :: Float32 ,
643
+ DataType :: Float64 ,
644
+ ] ;
645
+ for fun in funcs {
646
+ for data_type in & data_types {
647
+ let input_schema =
648
+ Schema :: new ( vec ! [ Field :: new( "c1" , data_type. clone( ) , true ) ] ) ;
649
+ let input_phy_exprs: Vec < Arc < dyn PhysicalExpr > > = vec ! [ Arc :: new(
650
+ expressions:: Column :: new_with_schema( "c1" , & input_schema) . unwrap( ) ,
651
+ ) ] ;
652
+ let result_agg_phy_exprs = create_aggregate_expr (
653
+ & fun,
654
+ false ,
655
+ & input_phy_exprs[ 0 ..1 ] ,
656
+ & input_schema,
657
+ "c1" ,
658
+ ) ?;
659
+ if fun == AggregateFunction :: Variance {
660
+ assert ! ( result_agg_phy_exprs. as_any( ) . is:: <Stddev >( ) ) ;
661
+ assert_eq ! ( "c1" , result_agg_phy_exprs. name( ) ) ;
662
+ assert_eq ! (
663
+ Field :: new( "c1" , DataType :: Float64 , true ) ,
664
+ result_agg_phy_exprs. field( ) . unwrap( )
665
+ )
666
+ }
667
+ }
668
+ }
669
+ Ok ( ( ) )
670
+ }
671
+
453
672
#[ test]
454
673
fn test_min_max ( ) -> Result < ( ) > {
455
674
let observed = return_type ( & AggregateFunction :: Min , & [ DataType :: Utf8 ] ) ?;
@@ -544,4 +763,56 @@ mod tests {
544
763
let observed = return_type ( & AggregateFunction :: Avg , & [ DataType :: Utf8 ] ) ;
545
764
assert ! ( observed. is_err( ) ) ;
546
765
}
766
+
767
+ #[ test]
768
+ fn test_variance_return_type ( ) -> Result < ( ) > {
769
+ let observed = return_type ( & AggregateFunction :: Variance , & [ DataType :: Float32 ] ) ?;
770
+ assert_eq ! ( DataType :: Float64 , observed) ;
771
+
772
+ let observed = return_type ( & AggregateFunction :: Variance , & [ DataType :: Float64 ] ) ?;
773
+ assert_eq ! ( DataType :: Float64 , observed) ;
774
+
775
+ let observed = return_type ( & AggregateFunction :: Variance , & [ DataType :: Int32 ] ) ?;
776
+ assert_eq ! ( DataType :: Float64 , observed) ;
777
+
778
+ let observed = return_type ( & AggregateFunction :: Variance , & [ DataType :: UInt32 ] ) ?;
779
+ assert_eq ! ( DataType :: Float64 , observed) ;
780
+
781
+ let observed = return_type ( & AggregateFunction :: Variance , & [ DataType :: Int64 ] ) ?;
782
+ assert_eq ! ( DataType :: Float64 , observed) ;
783
+
784
+ Ok ( ( ) )
785
+ }
786
+
787
+ #[ test]
788
+ fn test_variance_no_utf8 ( ) {
789
+ let observed = return_type ( & AggregateFunction :: Variance , & [ DataType :: Utf8 ] ) ;
790
+ assert ! ( observed. is_err( ) ) ;
791
+ }
792
+
793
+ #[ test]
794
+ fn test_stddev_return_type ( ) -> Result < ( ) > {
795
+ let observed = return_type ( & AggregateFunction :: Stddev , & [ DataType :: Float32 ] ) ?;
796
+ assert_eq ! ( DataType :: Float64 , observed) ;
797
+
798
+ let observed = return_type ( & AggregateFunction :: Stddev , & [ DataType :: Float64 ] ) ?;
799
+ assert_eq ! ( DataType :: Float64 , observed) ;
800
+
801
+ let observed = return_type ( & AggregateFunction :: Stddev , & [ DataType :: Int32 ] ) ?;
802
+ assert_eq ! ( DataType :: Float64 , observed) ;
803
+
804
+ let observed = return_type ( & AggregateFunction :: Stddev , & [ DataType :: UInt32 ] ) ?;
805
+ assert_eq ! ( DataType :: Float64 , observed) ;
806
+
807
+ let observed = return_type ( & AggregateFunction :: Stddev , & [ DataType :: Int64 ] ) ?;
808
+ assert_eq ! ( DataType :: Float64 , observed) ;
809
+
810
+ Ok ( ( ) )
811
+ }
812
+
813
+ #[ test]
814
+ fn test_stddev_no_utf8 ( ) {
815
+ let observed = return_type ( & AggregateFunction :: Stddev , & [ DataType :: Utf8 ] ) ;
816
+ assert ! ( observed. is_err( ) ) ;
817
+ }
547
818
}
0 commit comments