Skip to content

Commit bfffdba

Browse files
authored
Implement serialization for UDWF and UDAF in plan protobuf (#6769)
* update * update udwf test * update
1 parent 2f78536 commit bfffdba

File tree

6 files changed

+179
-15
lines changed

6 files changed

+179
-15
lines changed

datafusion/proto/proto/datafusion.proto

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,8 @@ message WindowExprNode {
640640
oneof window_function {
641641
AggregateFunction aggr_function = 1;
642642
BuiltInWindowFunction built_in_function = 2;
643-
// udaf = 3
643+
string udaf = 3;
644+
string udwf = 9;
644645
}
645646
LogicalExprNode expr = 4;
646647
repeated LogicalExprNode partition_by = 5;

datafusion/proto/src/generated/pbjson.rs

Lines changed: 24 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/proto/src/generated/prost.rs

Lines changed: 5 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/proto/src/logical_plan/from_proto.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -998,6 +998,36 @@ pub fn parse_expr(
998998
window_frame,
999999
)))
10001000
}
1001+
window_expr_node::WindowFunction::Udaf(udaf_name) => {
1002+
let udaf_function = registry.udaf(udaf_name)?;
1003+
let args = parse_optional_expr(expr.expr.as_deref(), registry)?
1004+
.map(|e| vec![e])
1005+
.unwrap_or_else(Vec::new);
1006+
Ok(Expr::WindowFunction(WindowFunction::new(
1007+
datafusion_expr::window_function::WindowFunction::AggregateUDF(
1008+
udaf_function,
1009+
),
1010+
args,
1011+
partition_by,
1012+
order_by,
1013+
window_frame,
1014+
)))
1015+
}
1016+
window_expr_node::WindowFunction::Udwf(udwf_name) => {
1017+
let udwf_function = registry.udwf(udwf_name)?;
1018+
let args = parse_optional_expr(expr.expr.as_deref(), registry)?
1019+
.map(|e| vec![e])
1020+
.unwrap_or_else(Vec::new);
1021+
Ok(Expr::WindowFunction(WindowFunction::new(
1022+
datafusion_expr::window_function::WindowFunction::WindowUDF(
1023+
udwf_function,
1024+
),
1025+
args,
1026+
partition_by,
1027+
order_by,
1028+
window_frame,
1029+
)))
1030+
}
10011031
}
10021032
}
10031033
ExprType::AggregateExpr(expr) => {

datafusion/proto/src/logical_plan/mod.rs

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,7 +1460,8 @@ mod roundtrip_tests {
14601460
Expr, LogicalPlan, Operator, TryCast, Volatility,
14611461
};
14621462
use datafusion_expr::{
1463-
create_udaf, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction,
1463+
create_udaf, PartitionEvaluator, Signature, WindowFrame, WindowFrameBound,
1464+
WindowFrameUnits, WindowFunction, WindowUDF,
14641465
};
14651466
use prost::Message;
14661467
use std::collections::HashMap;
@@ -2786,12 +2787,119 @@ mod roundtrip_tests {
27862787
vec![col("col1")],
27872788
vec![col("col1")],
27882789
vec![col("col2")],
2790+
row_number_frame.clone(),
2791+
));
2792+
2793+
// 5. test with AggregateUDF
2794+
#[derive(Debug)]
2795+
struct DummyAggr {}
2796+
2797+
impl Accumulator for DummyAggr {
2798+
fn state(&self) -> datafusion::error::Result<Vec<ScalarValue>> {
2799+
Ok(vec![])
2800+
}
2801+
2802+
fn update_batch(
2803+
&mut self,
2804+
_values: &[ArrayRef],
2805+
) -> datafusion::error::Result<()> {
2806+
Ok(())
2807+
}
2808+
2809+
fn merge_batch(
2810+
&mut self,
2811+
_states: &[ArrayRef],
2812+
) -> datafusion::error::Result<()> {
2813+
Ok(())
2814+
}
2815+
2816+
fn evaluate(&self) -> datafusion::error::Result<ScalarValue> {
2817+
Ok(ScalarValue::Float64(None))
2818+
}
2819+
2820+
fn size(&self) -> usize {
2821+
std::mem::size_of_val(self)
2822+
}
2823+
}
2824+
2825+
let dummy_agg = create_udaf(
2826+
// the name; used to represent it in plan descriptions and in the registry, to use in SQL.
2827+
"dummy_agg",
2828+
// the input type; DataFusion guarantees that the first entry of `values` in `update` has this type.
2829+
DataType::Float64,
2830+
// the return type; DataFusion expects this to match the type returned by `evaluate`.
2831+
Arc::new(DataType::Float64),
2832+
Volatility::Immutable,
2833+
// This is the accumulator factory; DataFusion uses it to create new accumulators.
2834+
Arc::new(|_| Ok(Box::new(DummyAggr {}))),
2835+
// This is the description of the state. `state()` must match the types here.
2836+
Arc::new(vec![DataType::Float64, DataType::UInt32]),
2837+
);
2838+
2839+
let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new(
2840+
WindowFunction::AggregateUDF(Arc::new(dummy_agg.clone())),
2841+
vec![col("col1")],
2842+
vec![col("col1")],
2843+
vec![col("col2")],
2844+
row_number_frame.clone(),
2845+
));
2846+
ctx.register_udaf(dummy_agg);
2847+
2848+
// 6. test with WindowUDF
2849+
#[derive(Clone, Debug)]
2850+
struct DummyWindow {}
2851+
2852+
impl PartitionEvaluator for DummyWindow {
2853+
fn uses_window_frame(&self) -> bool {
2854+
true
2855+
}
2856+
2857+
fn evaluate(
2858+
&mut self,
2859+
_values: &[ArrayRef],
2860+
_range: &std::ops::Range<usize>,
2861+
) -> Result<ScalarValue> {
2862+
Ok(ScalarValue::Float64(None))
2863+
}
2864+
}
2865+
2866+
fn return_type(arg_types: &[DataType]) -> Result<Arc<DataType>> {
2867+
if arg_types.len() != 1 {
2868+
return Err(DataFusionError::Plan(format!(
2869+
"dummy_udwf expects 1 argument, got {}: {:?}",
2870+
arg_types.len(),
2871+
arg_types
2872+
)));
2873+
}
2874+
Ok(Arc::new(arg_types[0].clone()))
2875+
}
2876+
2877+
fn make_partition_evaluator() -> Result<Box<dyn PartitionEvaluator>> {
2878+
Ok(Box::new(DummyWindow {}))
2879+
}
2880+
2881+
let dummy_window_udf = WindowUDF {
2882+
name: String::from("dummy_udwf"),
2883+
signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
2884+
return_type: Arc::new(return_type),
2885+
partition_evaluator_factory: Arc::new(make_partition_evaluator),
2886+
};
2887+
2888+
let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new(
2889+
WindowFunction::WindowUDF(Arc::new(dummy_window_udf.clone())),
2890+
vec![col("col1")],
2891+
vec![col("col1")],
2892+
vec![col("col2")],
27892893
row_number_frame,
27902894
));
27912895

2896+
ctx.register_udwf(dummy_window_udf);
2897+
27922898
roundtrip_expr_test(test_expr1, ctx.clone());
27932899
roundtrip_expr_test(test_expr2, ctx.clone());
27942900
roundtrip_expr_test(test_expr3, ctx.clone());
2795-
roundtrip_expr_test(test_expr4, ctx);
2901+
roundtrip_expr_test(test_expr4, ctx.clone());
2902+
roundtrip_expr_test(test_expr5, ctx.clone());
2903+
roundtrip_expr_test(test_expr6, ctx);
27962904
}
27972905
}

datafusion/proto/src/logical_plan/to_proto.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -584,17 +584,15 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
584584
protobuf::BuiltInWindowFunction::from(fun).into(),
585585
)
586586
}
587-
// TODO: Tracked in https://github.com/apache/arrow-datafusion/issues/4584
588-
WindowFunction::AggregateUDF(_) => {
589-
return Err(Error::NotImplemented(
590-
"UDAF as window function in proto".to_string(),
591-
))
587+
WindowFunction::AggregateUDF(aggr_udf) => {
588+
protobuf::window_expr_node::WindowFunction::Udaf(
589+
aggr_udf.name.clone(),
590+
)
592591
}
593-
// TODO: Tracked in https://github.com/apache/arrow-datafusion/issues/6733
594-
WindowFunction::WindowUDF(_) => {
595-
return Err(Error::NotImplemented(
596-
"UDWF as window function in proto".to_string(),
597-
))
592+
WindowFunction::WindowUDF(window_udf) => {
593+
protobuf::window_expr_node::WindowFunction::Udwf(
594+
window_udf.name.clone(),
595+
)
598596
}
599597
};
600598
let arg_expr: Option<Box<Self>> = if !args.is_empty() {

0 commit comments

Comments
 (0)