Skip to content

Commit 933fec8

Browse files
authored
Consolidate example: simplify_udaf_expression.rs into advanced_udaf.rs (#13905)
1 parent 9665e09 commit 933fec8

File tree

2 files changed

+132
-229
lines changed

2 files changed

+132
-229
lines changed

datafusion-examples/examples/advanced_udaf.rs

Lines changed: 132 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ use datafusion::error::Result;
3131
use datafusion::prelude::*;
3232
use datafusion_common::{cast::as_float64_array, ScalarValue};
3333
use datafusion_expr::{
34-
function::{AccumulatorArgs, StateFieldsArgs},
34+
expr::AggregateFunction,
35+
function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs},
36+
simplify::SimplifyInfo,
3537
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
3638
};
3739

@@ -197,40 +199,6 @@ impl Accumulator for GeometricMean {
197199
}
198200
}
199201

200-
// create local session context with an in-memory table
201-
fn create_context() -> Result<SessionContext> {
202-
use datafusion::datasource::MemTable;
203-
// define a schema.
204-
let schema = Arc::new(Schema::new(vec![
205-
Field::new("a", DataType::Float32, false),
206-
Field::new("b", DataType::Float32, false),
207-
]));
208-
209-
// define data in two partitions
210-
let batch1 = RecordBatch::try_new(
211-
schema.clone(),
212-
vec![
213-
Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])),
214-
Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])),
215-
],
216-
)?;
217-
let batch2 = RecordBatch::try_new(
218-
schema.clone(),
219-
vec![
220-
Arc::new(Float32Array::from(vec![64.0])),
221-
Arc::new(Float32Array::from(vec![2.0])),
222-
],
223-
)?;
224-
225-
// declare a new context. In spark API, this corresponds to a new spark SQLsession
226-
let ctx = SessionContext::new();
227-
228-
// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
229-
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
230-
ctx.register_table("t", Arc::new(provider))?;
231-
Ok(ctx)
232-
}
233-
234202
// Define a `GroupsAccumulator` for GeometricMean
235203
/// which handles accumulator state for multiple groups at once.
236204
/// This API is significantly more complicated than `Accumulator`, which manages
@@ -399,35 +367,146 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
399367
}
400368
}
401369

370+
/// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user
371+
/// defined aggregate function with a different expression which is defined in the `simplify` method.
372+
#[derive(Debug, Clone)]
373+
struct SimplifiedGeoMeanUdaf {
374+
signature: Signature,
375+
}
376+
377+
impl SimplifiedGeoMeanUdaf {
378+
fn new() -> Self {
379+
Self {
380+
signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
381+
}
382+
}
383+
}
384+
385+
impl AggregateUDFImpl for SimplifiedGeoMeanUdaf {
386+
fn as_any(&self) -> &dyn Any {
387+
self
388+
}
389+
390+
fn name(&self) -> &str {
391+
"simplified_geo_mean"
392+
}
393+
394+
fn signature(&self) -> &Signature {
395+
&self.signature
396+
}
397+
398+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
399+
Ok(DataType::Float64)
400+
}
401+
402+
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
403+
unimplemented!("should not be invoked")
404+
}
405+
406+
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
407+
unimplemented!("should not be invoked")
408+
}
409+
410+
fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
411+
true
412+
}
413+
414+
fn create_groups_accumulator(
415+
&self,
416+
_args: AccumulatorArgs,
417+
) -> Result<Box<dyn GroupsAccumulator>> {
418+
unimplemented!("should not get here");
419+
}
420+
421+
/// Optionally replaces a UDAF with another expression during query optimization.
422+
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
423+
let simplify = |aggregate_function: AggregateFunction, _: &dyn SimplifyInfo| {
424+
// Replaces the UDAF with `GeoMeanUdaf` as a placeholder example to demonstrate the `simplify` method.
425+
// In real-world scenarios, you might create UDFs from built-in expressions.
426+
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
427+
Arc::new(AggregateUDF::from(GeoMeanUdaf::new())),
428+
aggregate_function.args,
429+
aggregate_function.distinct,
430+
aggregate_function.filter,
431+
aggregate_function.order_by,
432+
aggregate_function.null_treatment,
433+
)))
434+
};
435+
Some(Box::new(simplify))
436+
}
437+
}
438+
439+
// create local session context with an in-memory table
440+
fn create_context() -> Result<SessionContext> {
441+
use datafusion::datasource::MemTable;
442+
// define a schema.
443+
let schema = Arc::new(Schema::new(vec![
444+
Field::new("a", DataType::Float32, false),
445+
Field::new("b", DataType::Float32, false),
446+
]));
447+
448+
// define data in two partitions
449+
let batch1 = RecordBatch::try_new(
450+
schema.clone(),
451+
vec![
452+
Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])),
453+
Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])),
454+
],
455+
)?;
456+
let batch2 = RecordBatch::try_new(
457+
schema.clone(),
458+
vec![
459+
Arc::new(Float32Array::from(vec![64.0])),
460+
Arc::new(Float32Array::from(vec![2.0])),
461+
],
462+
)?;
463+
464+
// declare a new context. In spark API, this corresponds to a new spark SQLsession
465+
let ctx = SessionContext::new();
466+
467+
// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
468+
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
469+
ctx.register_table("t", Arc::new(provider))?;
470+
Ok(ctx)
471+
}
472+
402473
#[tokio::main]
403474
async fn main() -> Result<()> {
404475
let ctx = create_context()?;
405476

406-
// create the AggregateUDF
407-
let geometric_mean = AggregateUDF::from(GeoMeanUdaf::new());
408-
ctx.register_udaf(geometric_mean.clone());
477+
let geo_mean_udf = AggregateUDF::from(GeoMeanUdaf::new());
478+
let simplified_geo_mean_udf = AggregateUDF::from(SimplifiedGeoMeanUdaf::new());
479+
480+
for (udf, udf_name) in [
481+
(geo_mean_udf, "geo_mean"),
482+
(simplified_geo_mean_udf, "simplified_geo_mean"),
483+
] {
484+
ctx.register_udaf(udf.clone());
409485

410-
let sql_df = ctx.sql("SELECT geo_mean(a) FROM t group by b").await?;
411-
sql_df.show().await?;
486+
let sql_df = ctx
487+
.sql(&format!("SELECT {}(a) FROM t GROUP BY b", udf_name))
488+
.await?;
489+
sql_df.show().await?;
412490

413-
// get a DataFrame from the context
414-
// this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0.
415-
let df = ctx.table("t").await?;
491+
// get a DataFrame from the context
492+
// this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0.
493+
let df = ctx.table("t").await?;
416494

417-
// perform the aggregation
418-
let df = df.aggregate(vec![], vec![geometric_mean.call(vec![col("a")])])?;
495+
// perform the aggregation
496+
let df = df.aggregate(vec![], vec![udf.call(vec![col("a")])])?;
419497

420-
// note that "a" is f32, not f64. DataFusion coerces it to match the UDAF's signature.
498+
// note that "a" is f32, not f64. DataFusion coerces it to match the UDAF's signature.
421499

422-
// execute the query
423-
let results = df.collect().await?;
500+
// execute the query
501+
let results = df.collect().await?;
424502

425-
// downcast the array to the expected type
426-
let result = as_float64_array(results[0].column(0))?;
503+
// downcast the array to the expected type
504+
let result = as_float64_array(results[0].column(0))?;
427505

428-
// verify that the calculation is correct
429-
assert!((result.value(0) - 8.0).abs() < f64::EPSILON);
430-
println!("The geometric mean of [2,4,8,64] is {}", result.value(0));
506+
// verify that the calculation is correct
507+
assert!((result.value(0) - 8.0).abs() < f64::EPSILON);
508+
println!("The geometric mean of [2,4,8,64] is {}", result.value(0));
509+
}
431510

432511
Ok(())
433512
}

datafusion-examples/examples/simplify_udaf_expression.rs

Lines changed: 0 additions & 176 deletions
This file was deleted.

0 commit comments

Comments
 (0)