diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index 780b22983393..509ef601c097 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -873,6 +873,12 @@ doc_comment::doctest!( user_guide_expressions ); +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/adding-udfs.md", + library_user_guide_adding_udfs +); + #[cfg(doctest)] doc_comment::doctest!( "../../../docs/source/library-user-guide/using-the-sql-api.md", diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index a9202976973b..96a782211185 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -55,46 +55,62 @@ of arguments. This a lower level API with more functionality but is more complex, also documented in [`advanced_udf.rs`]. ```rust +use std::sync::Arc; use std::any::Any; +use std::sync::LazyLock; use arrow::datatypes::DataType; +use datafusion_common::cast::as_int64_array; use datafusion_common::{DataFusionError, plan_err, Result}; -use datafusion_expr::{col, ColumnarValue, Signature, Volatility}; +use datafusion_expr::{col, ColumnarValue, Documentation, ScalarFunctionArgs, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; +use datafusion::arrow::array::{ArrayRef, Int64Array}; use datafusion_expr::{ScalarUDFImpl, ScalarUDF}; +/// This struct for a simple UDF that adds one to an int32 #[derive(Debug)] struct AddOne { - signature: Signature -}; + signature: Signature, +} impl AddOne { - fn new() -> Self { - Self { - signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable) - } - } + fn new() -> Self { + Self { + signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable), + } + } } +static DOCUMENTATION: LazyLock = LazyLock::new(|| { + Documentation::builder(DOC_SECTION_MATH, "Add one to an int32", "add_one(2)") + .with_argument("arg1", "The int32 number to add one to") + .build() +}); + /// Implement the ScalarUDFImpl trait for AddOne impl ScalarUDFImpl for AddOne { - fn as_any(&self) -> &dyn Any { self } - fn name(&self) -> &str { "add_one" } - fn signature(&self) -> &Signature { &self.signature } - fn return_type(&self, args: &[DataType]) -> Result { - if !matches!(args.get(0), Some(&DataType::Int32)) { - return plan_err!("add_one only accepts Int32 arguments"); - } - Ok(DataType::Int32) - } - // The actual implementation would add one to the argument - fn invoke_batch(&self, args: &[ColumnarValue], _number_rows: usize) -> Result { - let args = columnar_values_to_array(args)?; + fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "add_one" } + fn signature(&self) -> &Signature { &self.signature } + fn return_type(&self, args: &[DataType]) -> Result { + if !matches!(args.get(0), Some(&DataType::Int32)) { + return plan_err!("add_one only accepts Int32 arguments"); + } + Ok(DataType::Int32) + } + // The actual implementation would add one to the argument + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; let i64s = as_int64_array(&args[0])?; let new_array = i64s .iter() .map(|array_elem| array_elem.map(|value| value + 1)) .collect::(); - Ok(Arc::new(new_array)) + + Ok(ColumnarValue::from(Arc::new(new_array) as ArrayRef)) + } + fn documentation(&self) -> Option<&Documentation> { + Some(&*DOCUMENTATION) } } ``` @@ -102,15 +118,75 @@ impl ScalarUDFImpl for AddOne { We now need to register the function with DataFusion so that it can be used in the context of a query. ```rust +# use std::sync::Arc; +# use std::any::Any; +# use std::sync::LazyLock; +# use arrow::datatypes::DataType; +# use datafusion_common::cast::as_int64_array; +# use datafusion_common::{DataFusionError, plan_err, Result}; +# use datafusion_expr::{col, ColumnarValue, Documentation, ScalarFunctionArgs, Signature, Volatility}; +# use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; +# use datafusion::arrow::array::{ArrayRef, Int64Array}; +# use datafusion_expr::{ScalarUDFImpl, ScalarUDF}; +# +# /// This struct for a simple UDF that adds one to an int32 +# #[derive(Debug)] +# struct AddOne { +# signature: Signature, +# } +# +# impl AddOne { +# fn new() -> Self { +# Self { +# signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable), +# } +# } +# } +# +# static DOCUMENTATION: LazyLock = LazyLock::new(|| { +# Documentation::builder(DOC_SECTION_MATH, "Add one to an int32", "add_one(2)") +# .with_argument("arg1", "The int32 number to add one to") +# .build() +# }); +# +# /// Implement the ScalarUDFImpl trait for AddOne +# impl ScalarUDFImpl for AddOne { +# fn as_any(&self) -> &dyn Any { self } +# fn name(&self) -> &str { "add_one" } +# fn signature(&self) -> &Signature { &self.signature } +# fn return_type(&self, args: &[DataType]) -> Result { +# if !matches!(args.get(0), Some(&DataType::Int32)) { +# return plan_err!("add_one only accepts Int32 arguments"); +# } +# Ok(DataType::Int32) +# } +# // The actual implementation would add one to the argument +# fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { +# let args = ColumnarValue::values_to_arrays(&args.args)?; +# let i64s = as_int64_array(&args[0])?; +# +# let new_array = i64s +# .iter() +# .map(|array_elem| array_elem.map(|value| value + 1)) +# .collect::(); +# +# Ok(ColumnarValue::from(Arc::new(new_array) as ArrayRef)) +# } +# fn documentation(&self) -> Option<&Documentation> { +# Some(&*DOCUMENTATION) +# } +# } +use datafusion::execution::context::SessionContext; + // Create a new ScalarUDF from the implementation let add_one = ScalarUDF::from(AddOne::new()); +// Call the function `add_one(col)` +let expr = add_one.call(vec![col("a")]); + // register the UDF with the context so it can be invoked by name and from SQL let mut ctx = SessionContext::new(); ctx.register_udf(add_one.clone()); - -// Call the function `add_one(col)` -let expr = add_one.call(vec![col("a")]); ``` ### Adding a Scalar UDF by [`create_udf`] @@ -121,7 +197,6 @@ There is a an older, more concise, but also more limited API [`create_udf`] avai ```rust use std::sync::Arc; - use datafusion::arrow::array::{ArrayRef, Int64Array}; use datafusion::common::cast::as_int64_array; use datafusion::common::Result; @@ -145,6 +220,24 @@ This "works" in isolation, i.e. if you have a slice of `ArrayRef`s, you can call `ArrayRef` with 1 added to each value. ```rust +# use std::sync::Arc; +# use datafusion::arrow::array::{ArrayRef, Int64Array}; +# use datafusion::common::cast::as_int64_array; +# use datafusion::common::Result; +# use datafusion::logical_expr::ColumnarValue; +# +# pub fn add_one(args: &[ColumnarValue]) -> Result { +# // Error handling omitted for brevity +# let args = ColumnarValue::values_to_arrays(args)?; +# let i64s = as_int64_array(&args[0])?; +# +# let new_array = i64s +# .iter() +# .map(|array_elem| array_elem.map(|value| value + 1)) +# .collect::(); +# +# Ok(ColumnarValue::from(Arc::new(new_array) as ArrayRef)) +# } let input = vec![Some(1), None, Some(3)]; let input = ColumnarValue::from(Arc::new(Int64Array::from(input)) as ArrayRef); @@ -165,9 +258,26 @@ with the `SessionContext`. DataFusion provides the [`create_udf`] and helper functions to make this easier. ```rust +# use std::sync::Arc; +# use datafusion::arrow::array::{ArrayRef, Int64Array}; +# use datafusion::common::cast::as_int64_array; +# use datafusion::common::Result; +# use datafusion::logical_expr::ColumnarValue; +# +# pub fn add_one(args: &[ColumnarValue]) -> Result { +# // Error handling omitted for brevity +# let args = ColumnarValue::values_to_arrays(args)?; +# let i64s = as_int64_array(&args[0])?; +# +# let new_array = i64s +# .iter() +# .map(|array_elem| array_elem.map(|value| value + 1)) +# .collect::(); +# +# Ok(ColumnarValue::from(Arc::new(new_array) as ArrayRef)) +# } use datafusion::logical_expr::{Volatility, create_udf}; use datafusion::arrow::datatypes::DataType; -use std::sync::Arc; let udf = create_udf( "add_one", @@ -178,12 +288,7 @@ let udf = create_udf( ); ``` -[`scalarudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html -[`create_udf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udf.html -[`process_scalar_func_inputs`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.process_scalar_func_inputs.html -[`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs - -A few things to note: +A few things to note on `create_udf`: - The first argument is the name of the function. This is the name that will be used in SQL queries. - The second argument is a vector of `DataType`s. This is the list of argument types that the function accepts. I.e. in @@ -198,20 +303,51 @@ A few things to note: That gives us a `ScalarUDF` that we can register with the `SessionContext`: ```rust +# use std::sync::Arc; +# use datafusion::arrow::array::{ArrayRef, Int64Array}; +# use datafusion::common::cast::as_int64_array; +# use datafusion::common::Result; +# use datafusion::logical_expr::ColumnarValue; +# +# pub fn add_one(args: &[ColumnarValue]) -> Result { +# // Error handling omitted for brevity +# let args = ColumnarValue::values_to_arrays(args)?; +# let i64s = as_int64_array(&args[0])?; +# +# let new_array = i64s +# .iter() +# .map(|array_elem| array_elem.map(|value| value + 1)) +# .collect::(); +# +# Ok(ColumnarValue::from(Arc::new(new_array) as ArrayRef)) +# } +use datafusion::logical_expr::{Volatility, create_udf}; +use datafusion::arrow::datatypes::DataType; use datafusion::execution::context::SessionContext; -let mut ctx = SessionContext::new(); - -ctx.register_udf(udf); +#[tokio::main] +async fn main() { + let udf = create_udf( + "add_one", + vec![DataType::Int64], + DataType::Int64, + Volatility::Immutable, + Arc::new(add_one), + ); + + let mut ctx = SessionContext::new(); + ctx.register_udf(udf); + + // At this point, you can use the `add_one` function in your query: + let query = "SELECT add_one(1)"; + let df = ctx.sql(&query).await.unwrap(); +} ``` -At this point, you can use the `add_one` function in your query: - -```rust -let sql = "SELECT add_one(1)"; - -let df = ctx.sql( & sql).await.unwrap(); -``` +[`scalarudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html +[`create_udf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udf.html +[`process_scalar_func_inputs`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.process_scalar_func_inputs.html +[`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs ## Adding a Window UDF @@ -294,17 +430,61 @@ with the `SessionContext`. DataFusion provides the [`create_udwf`] helper functi There is a lower level API with more functionality but is more complex, that is documented in [`advanced_udwf.rs`]. ```rust +# use datafusion::arrow::{array::{ArrayRef, Float64Array, AsArray}, datatypes::Float64Type}; +# use datafusion::logical_expr::{PartitionEvaluator}; +# use datafusion::common::ScalarValue; +# use datafusion::error::Result; +# +# #[derive(Clone, Debug)] +# struct MyPartitionEvaluator {} +# +# impl MyPartitionEvaluator { +# fn new() -> Self { +# Self {} +# } +# } +# +# impl PartitionEvaluator for MyPartitionEvaluator { +# fn uses_window_frame(&self) -> bool { +# true +# } +# +# fn evaluate( +# &mut self, +# values: &[ArrayRef], +# range: &std::ops::Range, +# ) -> Result { +# // Again, the input argument is an array of floating +# // point numbers to calculate a moving average +# let arr: &Float64Array = values[0].as_ref().as_primitive::(); +# +# let range_len = range.end - range.start; +# +# // our smoothing function will average all the values in the +# let output = if range_len > 0 { +# let sum: f64 = arr.values().iter().skip(range.start).take(range_len).sum(); +# Some(sum / range_len as f64) +# } else { +# None +# }; +# +# Ok(ScalarValue::Float64(output)) +# } +# } +# fn make_partition_evaluator() -> Result> { +# Ok(Box::new(MyPartitionEvaluator::new())) +# } use datafusion::logical_expr::{Volatility, create_udwf}; use datafusion::arrow::datatypes::DataType; use std::sync::Arc; // here is where we define the UDWF. We also declare its signature: let smooth_it = create_udwf( -"smooth_it", -DataType::Float64, -Arc::new(DataType::Float64), -Volatility::Immutable, -Arc::new(make_partition_evaluator), + "smooth_it", + DataType::Float64, + Arc::new(DataType::Float64), + Volatility::Immutable, + Arc::new(make_partition_evaluator), ); ``` @@ -327,6 +507,62 @@ The `create_udwf` has five arguments to check: That gives us a `WindowUDF` that we can register with the `SessionContext`: ```rust +# use datafusion::arrow::{array::{ArrayRef, Float64Array, AsArray}, datatypes::Float64Type}; +# use datafusion::logical_expr::{PartitionEvaluator}; +# use datafusion::common::ScalarValue; +# use datafusion::error::Result; +# +# #[derive(Clone, Debug)] +# struct MyPartitionEvaluator {} +# +# impl MyPartitionEvaluator { +# fn new() -> Self { +# Self {} +# } +# } +# +# impl PartitionEvaluator for MyPartitionEvaluator { +# fn uses_window_frame(&self) -> bool { +# true +# } +# +# fn evaluate( +# &mut self, +# values: &[ArrayRef], +# range: &std::ops::Range, +# ) -> Result { +# // Again, the input argument is an array of floating +# // point numbers to calculate a moving average +# let arr: &Float64Array = values[0].as_ref().as_primitive::(); +# +# let range_len = range.end - range.start; +# +# // our smoothing function will average all the values in the +# let output = if range_len > 0 { +# let sum: f64 = arr.values().iter().skip(range.start).take(range_len).sum(); +# Some(sum / range_len as f64) +# } else { +# None +# }; +# +# Ok(ScalarValue::Float64(output)) +# } +# } +# fn make_partition_evaluator() -> Result> { +# Ok(Box::new(MyPartitionEvaluator::new())) +# } +# use datafusion::logical_expr::{Volatility, create_udwf}; +# use datafusion::arrow::datatypes::DataType; +# use std::sync::Arc; +# +# // here is where we define the UDWF. We also declare its signature: +# let smooth_it = create_udwf( +# "smooth_it", +# DataType::Float64, +# Arc::new(DataType::Float64), +# Volatility::Immutable, +# Arc::new(make_partition_evaluator), +# ); use datafusion::execution::context::SessionContext; let ctx = SessionContext::new(); @@ -336,10 +572,9 @@ ctx.register_udwf(smooth_it); At this point, you can use the `smooth_it` function in your query: -For example, if we have a [ -`cars.csv`](https://github.com/apache/datafusion/blob/main/datafusion/core/tests/data/cars.csv) whose contents like +For example, if we have a [`cars.csv`](https://github.com/apache/datafusion/blob/main/datafusion/core/tests/data/cars.csv) whose contents like -``` +```csv car,speed,time red,20.0,1996-04-12T12:05:03.000000000 red,20.3,1996-04-12T12:05:04.000000000 @@ -351,30 +586,97 @@ green,10.3,1996-04-12T12:05:04.000000000 Then, we can query like below: ```rust +# use datafusion::arrow::{array::{ArrayRef, Float64Array, AsArray}, datatypes::Float64Type}; +# use datafusion::logical_expr::{PartitionEvaluator}; +# use datafusion::common::ScalarValue; +# use datafusion::error::Result; +# +# #[derive(Clone, Debug)] +# struct MyPartitionEvaluator {} +# +# impl MyPartitionEvaluator { +# fn new() -> Self { +# Self {} +# } +# } +# +# impl PartitionEvaluator for MyPartitionEvaluator { +# fn uses_window_frame(&self) -> bool { +# true +# } +# +# fn evaluate( +# &mut self, +# values: &[ArrayRef], +# range: &std::ops::Range, +# ) -> Result { +# // Again, the input argument is an array of floating +# // point numbers to calculate a moving average +# let arr: &Float64Array = values[0].as_ref().as_primitive::(); +# +# let range_len = range.end - range.start; +# +# // our smoothing function will average all the values in the +# let output = if range_len > 0 { +# let sum: f64 = arr.values().iter().skip(range.start).take(range_len).sum(); +# Some(sum / range_len as f64) +# } else { +# None +# }; +# +# Ok(ScalarValue::Float64(output)) +# } +# } +# fn make_partition_evaluator() -> Result> { +# Ok(Box::new(MyPartitionEvaluator::new())) +# } +# use datafusion::logical_expr::{Volatility, create_udwf}; +# use datafusion::arrow::datatypes::DataType; +# use std::sync::Arc; +# use datafusion::execution::context::SessionContext; + use datafusion::datasource::file_format::options::CsvReadOptions; -// register csv table first -let csv_path = "cars.csv".to_string(); -ctx.register_csv("cars", & csv_path, CsvReadOptions::default ().has_header(true)).await?; -// do query with smooth_it -let df = ctx -.sql( -"SELECT \ - car, \ - speed, \ - smooth_it(speed) OVER (PARTITION BY car ORDER BY time) as smooth_speed,\ - time \ - from cars \ - ORDER BY \ - car", -) -.await?; -// print the results -df.show().await?; + +#[tokio::main] +async fn main() -> Result<()> { + + let ctx = SessionContext::new(); + + let smooth_it = create_udwf( + "smooth_it", + DataType::Float64, + Arc::new(DataType::Float64), + Volatility::Immutable, + Arc::new(make_partition_evaluator), + ); + ctx.register_udwf(smooth_it); + + // register csv table first + let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string(); + ctx.register_csv("cars", &csv_path, CsvReadOptions::default().has_header(true)).await?; + + // do query with smooth_it + let df = ctx + .sql(r#" + SELECT + car, + speed, + smooth_it(speed) OVER (PARTITION BY car ORDER BY time) as smooth_speed, + time + FROM cars + ORDER BY car + "#) + .await?; + + // print the results + df.show().await?; + Ok(()) +} ``` -the output will be like: +The output will be like: -``` +```text +-------+-------+--------------------+---------------------+ | car | speed | smooth_speed | time | +-------+-------+--------------------+---------------------+ @@ -403,6 +705,7 @@ Aggregate UDFs are functions that take a group of rows and return a single value For example, we will declare a single-type, single return type UDAF that computes the geometric mean. ```rust + use datafusion::arrow::array::ArrayRef; use datafusion::scalar::ScalarValue; use datafusion::{error::Result, physical_plan::Accumulator}; @@ -427,7 +730,7 @@ impl Accumulator for GeometricMean { // This function serializes our state to `ScalarValue`, which DataFusion uses // to pass this state between execution stages. // Note that this can be arbitrary data. - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![ ScalarValue::from(self.prod), ScalarValue::from(self.n), @@ -436,7 +739,7 @@ impl Accumulator for GeometricMean { // DataFusion expects this function to return the final value of this aggregator. // in this case, this is the formula of the geometric mean - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { let value = self.prod.powf(1.0 / self.n as f64); Ok(ScalarValue::from(value)) } @@ -491,37 +794,106 @@ impl Accumulator for GeometricMean { } ``` -### registering an Aggregate UDF +### Registering an Aggregate UDF To register a Aggregate UDF, you need to wrap the function implementation in a [`AggregateUDF`] struct and then register it with the `SessionContext`. DataFusion provides the [`create_udaf`] helper functions to make this easier. There is a lower level API with more functionality but is more complex, that is documented in [`advanced_udaf.rs`]. ```rust +# use datafusion::arrow::array::ArrayRef; +# use datafusion::scalar::ScalarValue; +# use datafusion::{error::Result, physical_plan::Accumulator}; +# +# #[derive(Debug)] +# struct GeometricMean { +# n: u32, +# prod: f64, +# } +# +# impl GeometricMean { +# pub fn new() -> Self { +# GeometricMean { n: 0, prod: 1.0 } +# } +# } +# +# impl Accumulator for GeometricMean { +# fn state(&mut self) -> Result> { +# Ok(vec![ +# ScalarValue::from(self.prod), +# ScalarValue::from(self.n), +# ]) +# } +# +# fn evaluate(&mut self) -> Result { +# let value = self.prod.powf(1.0 / self.n as f64); +# Ok(ScalarValue::from(value)) +# } +# +# fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { +# if values.is_empty() { +# return Ok(()); +# } +# let arr = &values[0]; +# (0..arr.len()).try_for_each(|index| { +# let v = ScalarValue::try_from_array(arr, index)?; +# +# if let ScalarValue::Float64(Some(value)) = v { +# self.prod *= value; +# self.n += 1; +# } else { +# unreachable!("") +# } +# Ok(()) +# }) +# } +# +# fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { +# if states.is_empty() { +# return Ok(()); +# } +# let arr = &states[0]; +# (0..arr.len()).try_for_each(|index| { +# let v = states +# .iter() +# .map(|array| ScalarValue::try_from_array(array, index)) +# .collect::>>()?; +# if let (ScalarValue::Float64(Some(prod)), ScalarValue::UInt32(Some(n))) = (&v[0], &v[1]) +# { +# self.prod *= prod; +# self.n += n; +# } else { +# unreachable!("") +# } +# Ok(()) +# }) +# } +# +# fn size(&self) -> usize { +# std::mem::size_of_val(self) +# } +# } + use datafusion::logical_expr::{Volatility, create_udaf}; use datafusion::arrow::datatypes::DataType; use std::sync::Arc; // here is where we define the UDAF. We also declare its signature: let geometric_mean = create_udaf( -// the name; used to represent it in plan descriptions and in the registry, to use in SQL. -"geo_mean", -// the input type; DataFusion guarantees that the first entry of `values` in `update` has this type. -vec![DataType::Float64], -// the return type; DataFusion expects this to match the type returned by `evaluate`. -Arc::new(DataType::Float64), -Volatility::Immutable, -// This is the accumulator factory; DataFusion uses it to create new accumulators. -Arc::new( | _ | Ok(Box::new(GeometricMean::new()))), -// This is the description of the state. `state()` must match the types here. -Arc::new(vec![DataType::Float64, DataType::UInt32]), + // the name; used to represent it in plan descriptions and in the registry, to use in SQL. + "geo_mean", + // the input type; DataFusion guarantees that the first entry of `values` in `update` has this type. + vec![DataType::Float64], + // the return type; DataFusion expects this to match the type returned by `evaluate`. + Arc::new(DataType::Float64), + Volatility::Immutable, + // This is the accumulator factory; DataFusion uses it to create new accumulators. + Arc::new( | _ | Ok(Box::new(GeometricMean::new()))), + // This is the description of the state. `state()` must match the types here. + Arc::new(vec![DataType::Float64, DataType::UInt32]), ); ``` -[`aggregateudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.AggregateUDF.html -[`create_udaf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udaf.html -[`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs - The `create_udaf` has six arguments to check: - The first argument is the name of the function. This is the name that will be used in SQL queries. @@ -535,22 +907,119 @@ The `create_udaf` has six arguments to check: - The fifth argument is the function implementation. This is the function that we defined above. - The sixth argument is the description of the state, which will by passed between execution stages. -That gives us a `AggregateUDF` that we can register with the `SessionContext`: - ```rust -use datafusion::execution::context::SessionContext; -let ctx = SessionContext::new(); +# use datafusion::arrow::array::ArrayRef; +# use datafusion::scalar::ScalarValue; +# use datafusion::{error::Result, physical_plan::Accumulator}; +# +# #[derive(Debug)] +# struct GeometricMean { +# n: u32, +# prod: f64, +# } +# +# impl GeometricMean { +# pub fn new() -> Self { +# GeometricMean { n: 0, prod: 1.0 } +# } +# } +# +# impl Accumulator for GeometricMean { +# fn state(&mut self) -> Result> { +# Ok(vec![ +# ScalarValue::from(self.prod), +# ScalarValue::from(self.n), +# ]) +# } +# +# fn evaluate(&mut self) -> Result { +# let value = self.prod.powf(1.0 / self.n as f64); +# Ok(ScalarValue::from(value)) +# } +# +# fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { +# if values.is_empty() { +# return Ok(()); +# } +# let arr = &values[0]; +# (0..arr.len()).try_for_each(|index| { +# let v = ScalarValue::try_from_array(arr, index)?; +# +# if let ScalarValue::Float64(Some(value)) = v { +# self.prod *= value; +# self.n += 1; +# } else { +# unreachable!("") +# } +# Ok(()) +# }) +# } +# +# fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { +# if states.is_empty() { +# return Ok(()); +# } +# let arr = &states[0]; +# (0..arr.len()).try_for_each(|index| { +# let v = states +# .iter() +# .map(|array| ScalarValue::try_from_array(array, index)) +# .collect::>>()?; +# if let (ScalarValue::Float64(Some(prod)), ScalarValue::UInt32(Some(n))) = (&v[0], &v[1]) +# { +# self.prod *= prod; +# self.n += n; +# } else { +# unreachable!("") +# } +# Ok(()) +# }) +# } +# +# fn size(&self) -> usize { +# std::mem::size_of_val(self) +# } +# } -ctx.register_udaf(geometric_mean); -``` +use datafusion::logical_expr::{Volatility, create_udaf}; +use datafusion::arrow::datatypes::DataType; +use std::sync::Arc; +use datafusion::execution::context::SessionContext; +use datafusion::datasource::file_format::options::CsvReadOptions; -Then, we can query like below: +#[tokio::main] +async fn main() -> Result<()> { + let geometric_mean = create_udaf( + "geo_mean", + vec![DataType::Float64], + Arc::new(DataType::Float64), + Volatility::Immutable, + Arc::new( | _ | Ok(Box::new(GeometricMean::new()))), + Arc::new(vec![DataType::Float64, DataType::UInt32]), + ); + + // That gives us a `AggregateUDF` that we can register with the `SessionContext`: + use datafusion::execution::context::SessionContext; + + let ctx = SessionContext::new(); + ctx.register_udaf(geometric_mean); + + // register csv table first + let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string(); + ctx.register_csv("cars", &csv_path, CsvReadOptions::default().has_header(true)).await?; + + // Then, we can query like below: + let df = ctx.sql("SELECT geo_mean(speed) FROM cars").await?; + Ok(()) +} -```rust -let df = ctx.sql("SELECT geo_mean(a) FROM t").await?; ``` +[`aggregateudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.AggregateUDF.html +[`create_udaf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udaf.html +[`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs + ## Adding a User-Defined Table Function A User-Defined Table Function (UDTF) is a function that takes parameters and returns a `TableProvider`. @@ -592,12 +1061,17 @@ In the `call` method, you parse the input `Expr`s and return a `TableProvider`. validation of the input `Expr`s, e.g. checking that the number of arguments is correct. ```rust -use datafusion::common::plan_err; -use datafusion::datasource::function::TableFunctionImpl; -// Other imports here +use std::sync::Arc; +use datafusion::common::{plan_err, ScalarValue, Result}; +use datafusion::catalog::{TableFunctionImpl, TableProvider}; +use datafusion::arrow::array::{ArrayRef, Int64Array}; +use datafusion::datasource::memory::MemTable; +use arrow::record_batch::RecordBatch; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion_expr::Expr; /// A table function that returns a table provider with the value as a single column -#[derive(Default)] +#[derive(Debug)] pub struct EchoFunction {} impl TableFunctionImpl for EchoFunction { @@ -628,22 +1102,57 @@ impl TableFunctionImpl for EchoFunction { With the UDTF implemented, you can register it with the `SessionContext`: ```rust +# use std::sync::Arc; +# use datafusion::common::{plan_err, ScalarValue, Result}; +# use datafusion::catalog::{TableFunctionImpl, TableProvider}; +# use datafusion::arrow::array::{ArrayRef, Int64Array}; +# use datafusion::datasource::memory::MemTable; +# use arrow::record_batch::RecordBatch; +# use arrow::datatypes::{DataType, Field, Schema}; +# use datafusion_expr::Expr; +# +# /// A table function that returns a table provider with the value as a single column +# #[derive(Debug, Default)] +# pub struct EchoFunction {} +# +# impl TableFunctionImpl for EchoFunction { +# fn call(&self, exprs: &[Expr]) -> Result> { +# let Some(Expr::Literal(ScalarValue::Int64(Some(value)))) = exprs.get(0) else { +# return plan_err!("First argument must be an integer"); +# }; +# +# // Create the schema for the table +# let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); +# +# // Create a single RecordBatch with the value as a single column +# let batch = RecordBatch::try_new( +# schema.clone(), +# vec![Arc::new(Int64Array::from(vec![*value]))], +# )?; +# +# // Create a MemTable plan that returns the RecordBatch +# let provider = MemTable::try_new(schema, vec![vec![batch]])?; +# +# Ok(Arc::new(provider)) +# } +# } + use datafusion::execution::context::SessionContext; +use datafusion::arrow::util::pretty; -let ctx = SessionContext::new(); +#[tokio::main] +async fn main() -> Result<()> { + let ctx = SessionContext::new(); -ctx.register_udtf("echo", Arc::new(EchoFunction::default ())); -``` + ctx.register_udtf("echo", Arc::new(EchoFunction::default())); -And if all goes well, you can use it in your query: + // And if all goes well, you can use it in your query: -```rust -use datafusion::arrow::util::pretty; - -let df = ctx.sql("SELECT * FROM echo(1)").await?; + let results = ctx.sql("SELECT * FROM echo(1)").await?.collect().await?; + pretty::print_batches(&results)?; + Ok(()) +} -let results = df.collect().await?; -pretty::print_batches( & results) ?; // +---+ // | a | // +---+