diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 7fddd31a6ba5..7cf3a517b31d 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -639,7 +639,8 @@ message WindowExprNode { oneof window_function { AggregateFunction aggr_function = 1; BuiltInWindowFunction built_in_function = 2; - // udaf = 3 + string udaf = 3; + string udwf = 9; } LogicalExprNode expr = 4; repeated LogicalExprNode partition_by = 5; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 4c1bab5e397c..fa113a57a373 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22190,6 +22190,12 @@ impl serde::Serialize for WindowExprNode { .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; struct_ser.serialize_field("builtInFunction", &v)?; } + window_expr_node::WindowFunction::Udaf(v) => { + struct_ser.serialize_field("udaf", v)?; + } + window_expr_node::WindowFunction::Udwf(v) => { + struct_ser.serialize_field("udwf", v)?; + } } } struct_ser.end() @@ -22213,6 +22219,8 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { "aggrFunction", "built_in_function", "builtInFunction", + "udaf", + "udwf", ]; #[allow(clippy::enum_variant_names)] @@ -22223,6 +22231,8 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { WindowFrame, AggrFunction, BuiltInFunction, + Udaf, + Udwf, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -22250,6 +22260,8 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { "windowFrame" | "window_frame" => Ok(GeneratedField::WindowFrame), "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), "builtInFunction" | "built_in_function" => Ok(GeneratedField::BuiltInFunction), + "udaf" => Ok(GeneratedField::Udaf), + "udwf" => Ok(GeneratedField::Udwf), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -22312,6 +22324,18 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { } window_function__ = map.next_value::<::std::option::Option>()?.map(|x| window_expr_node::WindowFunction::BuiltInFunction(x as i32)); } + GeneratedField::Udaf => { + if window_function__.is_some() { + return Err(serde::de::Error::duplicate_field("udaf")); + } + window_function__ = map.next_value::<::std::option::Option<_>>()?.map(window_expr_node::WindowFunction::Udaf); + } + GeneratedField::Udwf => { + if window_function__.is_some() { + return Err(serde::de::Error::duplicate_field("udwf")); + } + window_function__ = map.next_value::<::std::option::Option<_>>()?.map(window_expr_node::WindowFunction::Udwf); + } } } Ok(WindowExprNode { diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 8dfc209477ef..312906cc89f6 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -747,7 +747,7 @@ pub struct WindowExprNode { /// repeated LogicalExprNode filter = 7; #[prost(message, optional, tag = "8")] pub window_frame: ::core::option::Option, - #[prost(oneof = "window_expr_node::WindowFunction", tags = "1, 2")] + #[prost(oneof = "window_expr_node::WindowFunction", tags = "1, 2, 3, 9")] pub window_function: ::core::option::Option, } /// Nested message and enum types in `WindowExprNode`. @@ -757,9 +757,12 @@ pub mod window_expr_node { pub enum WindowFunction { #[prost(enumeration = "super::AggregateFunction", tag = "1")] AggrFunction(i32), - /// udaf = 3 #[prost(enumeration = "super::BuiltInWindowFunction", tag = "2")] BuiltInFunction(i32), + #[prost(string, tag = "3")] + Udaf(::prost::alloc::string::String), + #[prost(string, tag = "9")] + Udwf(::prost::alloc::string::String), } } #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index ab2985f448a8..5cd7d78f835c 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -997,6 +997,36 @@ pub fn parse_expr( window_frame, ))) } + window_expr_node::WindowFunction::Udaf(udaf_name) => { + let udaf_function = registry.udaf(udaf_name)?; + let args = parse_optional_expr(expr.expr.as_deref(), registry)? + .map(|e| vec![e]) + .unwrap_or_else(Vec::new); + Ok(Expr::WindowFunction(WindowFunction::new( + datafusion_expr::window_function::WindowFunction::AggregateUDF( + udaf_function, + ), + args, + partition_by, + order_by, + window_frame, + ))) + } + window_expr_node::WindowFunction::Udwf(udwf_name) => { + let udwf_function = registry.udwf(udwf_name)?; + let args = parse_optional_expr(expr.expr.as_deref(), registry)? + .map(|e| vec![e]) + .unwrap_or_else(Vec::new); + Ok(Expr::WindowFunction(WindowFunction::new( + datafusion_expr::window_function::WindowFunction::WindowUDF( + udwf_function, + ), + args, + partition_by, + order_by, + window_frame, + ))) + } } } ExprType::AggregateExpr(expr) => { diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 51f60dcd5fd2..7d0ddac48493 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -1460,7 +1460,8 @@ mod roundtrip_tests { Expr, LogicalPlan, Operator, TryCast, Volatility, }; use datafusion_expr::{ - create_udaf, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, + create_udaf, PartitionEvaluator, Signature, WindowFrame, WindowFrameBound, + WindowFrameUnits, WindowFunction, WindowUDF, }; use prost::Message; use std::collections::HashMap; @@ -2786,12 +2787,119 @@ mod roundtrip_tests { vec![col("col1")], vec![col("col1")], vec![col("col2")], + row_number_frame.clone(), + )); + + // 5. test with AggregateUDF + #[derive(Debug)] + struct DummyAggr {} + + impl Accumulator for DummyAggr { + fn state(&self) -> datafusion::error::Result> { + Ok(vec![]) + } + + fn update_batch( + &mut self, + _values: &[ArrayRef], + ) -> datafusion::error::Result<()> { + Ok(()) + } + + fn merge_batch( + &mut self, + _states: &[ArrayRef], + ) -> datafusion::error::Result<()> { + Ok(()) + } + + fn evaluate(&self) -> datafusion::error::Result { + Ok(ScalarValue::Float64(None)) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + } + + let dummy_agg = create_udaf( + // the name; used to represent it in plan descriptions and in the registry, to use in SQL. + "dummy_agg", + // the input type; DataFusion guarantees that the first entry of `values` in `update` has this type. + DataType::Float64, + // the return type; DataFusion expects this to match the type returned by `evaluate`. + Arc::new(DataType::Float64), + Volatility::Immutable, + // This is the accumulator factory; DataFusion uses it to create new accumulators. + Arc::new(|_| Ok(Box::new(DummyAggr {}))), + // This is the description of the state. `state()` must match the types here. + Arc::new(vec![DataType::Float64, DataType::UInt32]), + ); + + let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new( + WindowFunction::AggregateUDF(Arc::new(dummy_agg.clone())), + vec![col("col1")], + vec![col("col1")], + vec![col("col2")], + row_number_frame.clone(), + )); + ctx.register_udaf(dummy_agg); + + // 6. test with WindowUDF + #[derive(Clone, Debug)] + struct DummyWindow {} + + impl PartitionEvaluator for DummyWindow { + fn uses_window_frame(&self) -> bool { + true + } + + fn evaluate( + &mut self, + _values: &[ArrayRef], + _range: &std::ops::Range, + ) -> Result { + Ok(ScalarValue::Float64(None)) + } + } + + fn return_type(arg_types: &[DataType]) -> Result> { + if arg_types.len() != 1 { + return Err(DataFusionError::Plan(format!( + "dummy_udwf expects 1 argument, got {}: {:?}", + arg_types.len(), + arg_types + ))); + } + Ok(Arc::new(arg_types[0].clone())) + } + + fn make_partition_evaluator() -> Result> { + Ok(Box::new(DummyWindow {})) + } + + let dummy_window_udf = WindowUDF { + name: String::from("dummy_udwf"), + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), + return_type: Arc::new(return_type), + partition_evaluator_factory: Arc::new(make_partition_evaluator), + }; + + let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( + WindowFunction::WindowUDF(Arc::new(dummy_window_udf.clone())), + vec![col("col1")], + vec![col("col1")], + vec![col("col2")], row_number_frame, )); + ctx.register_udwf(dummy_window_udf); + roundtrip_expr_test(test_expr1, ctx.clone()); roundtrip_expr_test(test_expr2, ctx.clone()); roundtrip_expr_test(test_expr3, ctx.clone()); - roundtrip_expr_test(test_expr4, ctx); + roundtrip_expr_test(test_expr4, ctx.clone()); + roundtrip_expr_test(test_expr5, ctx.clone()); + roundtrip_expr_test(test_expr6, ctx); } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index b3e5bd0fa6c2..1bddeaeb6724 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -584,17 +584,15 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { protobuf::BuiltInWindowFunction::from(fun).into(), ) } - // TODO: Tracked in https://github.com/apache/arrow-datafusion/issues/4584 - WindowFunction::AggregateUDF(_) => { - return Err(Error::NotImplemented( - "UDAF as window function in proto".to_string(), - )) + WindowFunction::AggregateUDF(aggr_udf) => { + protobuf::window_expr_node::WindowFunction::Udaf( + aggr_udf.name.clone(), + ) } - // TODO: Tracked in https://github.com/apache/arrow-datafusion/issues/6733 - WindowFunction::WindowUDF(_) => { - return Err(Error::NotImplemented( - "UDWF as window function in proto".to_string(), - )) + WindowFunction::WindowUDF(window_udf) => { + protobuf::window_expr_node::WindowFunction::Udwf( + window_udf.name.clone(), + ) } }; let arg_expr: Option> = if !args.is_empty() {