@@ -31,7 +31,9 @@ use datafusion::error::Result;
31
31
use datafusion:: prelude:: * ;
32
32
use datafusion_common:: { cast:: as_float64_array, ScalarValue } ;
33
33
use datafusion_expr:: {
34
- function:: { AccumulatorArgs , StateFieldsArgs } ,
34
+ expr:: AggregateFunction ,
35
+ function:: { AccumulatorArgs , AggregateFunctionSimplification , StateFieldsArgs } ,
36
+ simplify:: SimplifyInfo ,
35
37
Accumulator , AggregateUDF , AggregateUDFImpl , GroupsAccumulator , Signature ,
36
38
} ;
37
39
@@ -197,40 +199,6 @@ impl Accumulator for GeometricMean {
197
199
}
198
200
}
199
201
200
- // create local session context with an in-memory table
201
- fn create_context ( ) -> Result < SessionContext > {
202
- use datafusion:: datasource:: MemTable ;
203
- // define a schema.
204
- let schema = Arc :: new ( Schema :: new ( vec ! [
205
- Field :: new( "a" , DataType :: Float32 , false ) ,
206
- Field :: new( "b" , DataType :: Float32 , false ) ,
207
- ] ) ) ;
208
-
209
- // define data in two partitions
210
- let batch1 = RecordBatch :: try_new (
211
- schema. clone ( ) ,
212
- vec ! [
213
- Arc :: new( Float32Array :: from( vec![ 2.0 , 4.0 , 8.0 ] ) ) ,
214
- Arc :: new( Float32Array :: from( vec![ 2.0 , 2.0 , 2.0 ] ) ) ,
215
- ] ,
216
- ) ?;
217
- let batch2 = RecordBatch :: try_new (
218
- schema. clone ( ) ,
219
- vec ! [
220
- Arc :: new( Float32Array :: from( vec![ 64.0 ] ) ) ,
221
- Arc :: new( Float32Array :: from( vec![ 2.0 ] ) ) ,
222
- ] ,
223
- ) ?;
224
-
225
- // declare a new context. In spark API, this corresponds to a new spark SQLsession
226
- let ctx = SessionContext :: new ( ) ;
227
-
228
- // declare a table in memory. In spark API, this corresponds to createDataFrame(...).
229
- let provider = MemTable :: try_new ( schema, vec ! [ vec![ batch1] , vec![ batch2] ] ) ?;
230
- ctx. register_table ( "t" , Arc :: new ( provider) ) ?;
231
- Ok ( ctx)
232
- }
233
-
234
202
// Define a `GroupsAccumulator` for GeometricMean
235
203
/// which handles accumulator state for multiple groups at once.
236
204
/// This API is significantly more complicated than `Accumulator`, which manages
@@ -399,35 +367,146 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
399
367
}
400
368
}
401
369
370
+ /// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user
371
+ /// defined aggregate function with a different expression which is defined in the `simplify` method.
372
+ #[ derive( Debug , Clone ) ]
373
+ struct SimplifiedGeoMeanUdaf {
374
+ signature : Signature ,
375
+ }
376
+
377
+ impl SimplifiedGeoMeanUdaf {
378
+ fn new ( ) -> Self {
379
+ Self {
380
+ signature : Signature :: exact ( vec ! [ DataType :: Float64 ] , Volatility :: Immutable ) ,
381
+ }
382
+ }
383
+ }
384
+
385
+ impl AggregateUDFImpl for SimplifiedGeoMeanUdaf {
386
+ fn as_any ( & self ) -> & dyn Any {
387
+ self
388
+ }
389
+
390
+ fn name ( & self ) -> & str {
391
+ "simplified_geo_mean"
392
+ }
393
+
394
+ fn signature ( & self ) -> & Signature {
395
+ & self . signature
396
+ }
397
+
398
+ fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
399
+ Ok ( DataType :: Float64 )
400
+ }
401
+
402
+ fn accumulator ( & self , _acc_args : AccumulatorArgs ) -> Result < Box < dyn Accumulator > > {
403
+ unimplemented ! ( "should not be invoked" )
404
+ }
405
+
406
+ fn state_fields ( & self , _args : StateFieldsArgs ) -> Result < Vec < Field > > {
407
+ unimplemented ! ( "should not be invoked" )
408
+ }
409
+
410
+ fn groups_accumulator_supported ( & self , _args : AccumulatorArgs ) -> bool {
411
+ true
412
+ }
413
+
414
+ fn create_groups_accumulator (
415
+ & self ,
416
+ _args : AccumulatorArgs ,
417
+ ) -> Result < Box < dyn GroupsAccumulator > > {
418
+ unimplemented ! ( "should not get here" ) ;
419
+ }
420
+
421
+ /// Optionally replaces a UDAF with another expression during query optimization.
422
+ fn simplify ( & self ) -> Option < AggregateFunctionSimplification > {
423
+ let simplify = |aggregate_function : AggregateFunction , _: & dyn SimplifyInfo | {
424
+ // Replaces the UDAF with `GeoMeanUdaf` as a placeholder example to demonstrate the `simplify` method.
425
+ // In real-world scenarios, you might create UDFs from built-in expressions.
426
+ Ok ( Expr :: AggregateFunction ( AggregateFunction :: new_udf (
427
+ Arc :: new ( AggregateUDF :: from ( GeoMeanUdaf :: new ( ) ) ) ,
428
+ aggregate_function. args ,
429
+ aggregate_function. distinct ,
430
+ aggregate_function. filter ,
431
+ aggregate_function. order_by ,
432
+ aggregate_function. null_treatment ,
433
+ ) ) )
434
+ } ;
435
+ Some ( Box :: new ( simplify) )
436
+ }
437
+ }
438
+
439
+ // create local session context with an in-memory table
440
+ fn create_context ( ) -> Result < SessionContext > {
441
+ use datafusion:: datasource:: MemTable ;
442
+ // define a schema.
443
+ let schema = Arc :: new ( Schema :: new ( vec ! [
444
+ Field :: new( "a" , DataType :: Float32 , false ) ,
445
+ Field :: new( "b" , DataType :: Float32 , false ) ,
446
+ ] ) ) ;
447
+
448
+ // define data in two partitions
449
+ let batch1 = RecordBatch :: try_new (
450
+ schema. clone ( ) ,
451
+ vec ! [
452
+ Arc :: new( Float32Array :: from( vec![ 2.0 , 4.0 , 8.0 ] ) ) ,
453
+ Arc :: new( Float32Array :: from( vec![ 2.0 , 2.0 , 2.0 ] ) ) ,
454
+ ] ,
455
+ ) ?;
456
+ let batch2 = RecordBatch :: try_new (
457
+ schema. clone ( ) ,
458
+ vec ! [
459
+ Arc :: new( Float32Array :: from( vec![ 64.0 ] ) ) ,
460
+ Arc :: new( Float32Array :: from( vec![ 2.0 ] ) ) ,
461
+ ] ,
462
+ ) ?;
463
+
464
+ // declare a new context. In spark API, this corresponds to a new spark SQLsession
465
+ let ctx = SessionContext :: new ( ) ;
466
+
467
+ // declare a table in memory. In spark API, this corresponds to createDataFrame(...).
468
+ let provider = MemTable :: try_new ( schema, vec ! [ vec![ batch1] , vec![ batch2] ] ) ?;
469
+ ctx. register_table ( "t" , Arc :: new ( provider) ) ?;
470
+ Ok ( ctx)
471
+ }
472
+
402
473
#[ tokio:: main]
403
474
async fn main ( ) -> Result < ( ) > {
404
475
let ctx = create_context ( ) ?;
405
476
406
- // create the AggregateUDF
407
- let geometric_mean = AggregateUDF :: from ( GeoMeanUdaf :: new ( ) ) ;
408
- ctx. register_udaf ( geometric_mean. clone ( ) ) ;
477
+ let geo_mean_udf = AggregateUDF :: from ( GeoMeanUdaf :: new ( ) ) ;
478
+ let simplified_geo_mean_udf = AggregateUDF :: from ( SimplifiedGeoMeanUdaf :: new ( ) ) ;
479
+
480
+ for ( udf, udf_name) in [
481
+ ( geo_mean_udf, "geo_mean" ) ,
482
+ ( simplified_geo_mean_udf, "simplified_geo_mean" ) ,
483
+ ] {
484
+ ctx. register_udaf ( udf. clone ( ) ) ;
409
485
410
- let sql_df = ctx. sql ( "SELECT geo_mean(a) FROM t group by b" ) . await ?;
411
- sql_df. show ( ) . await ?;
486
+ let sql_df = ctx
487
+ . sql ( & format ! ( "SELECT {}(a) FROM t GROUP BY b" , udf_name) )
488
+ . await ?;
489
+ sql_df. show ( ) . await ?;
412
490
413
- // get a DataFrame from the context
414
- // this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0.
415
- let df = ctx. table ( "t" ) . await ?;
491
+ // get a DataFrame from the context
492
+ // this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0.
493
+ let df = ctx. table ( "t" ) . await ?;
416
494
417
- // perform the aggregation
418
- let df = df. aggregate ( vec ! [ ] , vec ! [ geometric_mean . call( vec![ col( "a" ) ] ) ] ) ?;
495
+ // perform the aggregation
496
+ let df = df. aggregate ( vec ! [ ] , vec ! [ udf . call( vec![ col( "a" ) ] ) ] ) ?;
419
497
420
- // note that "a" is f32, not f64. DataFusion coerces it to match the UDAF's signature.
498
+ // note that "a" is f32, not f64. DataFusion coerces it to match the UDAF's signature.
421
499
422
- // execute the query
423
- let results = df. collect ( ) . await ?;
500
+ // execute the query
501
+ let results = df. collect ( ) . await ?;
424
502
425
- // downcast the array to the expected type
426
- let result = as_float64_array ( results[ 0 ] . column ( 0 ) ) ?;
503
+ // downcast the array to the expected type
504
+ let result = as_float64_array ( results[ 0 ] . column ( 0 ) ) ?;
427
505
428
- // verify that the calculation is correct
429
- assert ! ( ( result. value( 0 ) - 8.0 ) . abs( ) < f64 :: EPSILON ) ;
430
- println ! ( "The geometric mean of [2,4,8,64] is {}" , result. value( 0 ) ) ;
506
+ // verify that the calculation is correct
507
+ assert ! ( ( result. value( 0 ) - 8.0 ) . abs( ) < f64 :: EPSILON ) ;
508
+ println ! ( "The geometric mean of [2,4,8,64] is {}" , result. value( 0 ) ) ;
509
+ }
431
510
432
511
Ok ( ( ) )
433
512
}
0 commit comments