Skip to content

Commit 6f87009

Browse files
committed
Support User Defined Window Functions
1 parent 11a5a63 commit 6f87009

File tree

20 files changed

+1093
-19
lines changed

20 files changed

+1093
-19
lines changed

datafusion-examples/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ prost = { version = "0.11", default-features = false }
5656
prost-derive = { version = "0.11", default-features = false }
5757
serde = { version = "1.0.136", features = ["derive"] }
5858
serde_json = "1.0.82"
59+
tempfile = "3"
5960
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] }
6061
tonic = "0.9"
6162
url = "2.2"

datafusion-examples/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ cargo run --example csv_sql
5757
- [`rewrite_expr.rs`](examples/rewrite_expr.rs): Define and invoke a custom Query Optimizer pass
5858
- [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF)
5959
- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined (scalar) Function (UDF)
60+
- [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF)
6061

6162
## Distributed
6263

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::sync::Arc;
19+
20+
use arrow::{
21+
array::{ArrayRef, AsArray, Float64Array},
22+
datatypes::Float64Type,
23+
};
24+
use arrow_schema::DataType;
25+
use datafusion::datasource::file_format::options::CsvReadOptions;
26+
27+
use datafusion::error::Result;
28+
use datafusion::prelude::*;
29+
use datafusion_common::{DataFusionError, ScalarValue};
30+
use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDF};
31+
32+
// create local execution context with `cars.csv` registered as a table named `cars`
33+
async fn create_context() -> Result<SessionContext> {
34+
// declare a new context. In spark API, this corresponds to a new spark SQLsession
35+
let ctx = SessionContext::new();
36+
37+
// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
38+
println!("pwd: {}", std::env::current_dir().unwrap().display());
39+
let csv_path = format!("datafusion/core/tests/data/cars.csv");
40+
let read_options = CsvReadOptions::default().has_header(true);
41+
42+
ctx.register_csv("cars", &csv_path, read_options).await?;
43+
Ok(ctx)
44+
}
45+
46+
/// In this example we will declare a user defined window function that computes a moving average and then run it using SQL
47+
#[tokio::main]
48+
async fn main() -> Result<()> {
49+
let ctx = create_context().await?;
50+
51+
// register the window function with DataFusion so wecan call it
52+
ctx.register_udwf(smooth_it());
53+
54+
// Use SQL to run the new window function
55+
let df = ctx.sql("SELECT * from cars").await?;
56+
// print the results
57+
df.show().await?;
58+
59+
// Use SQL to run the new window function:
60+
//
61+
// `PARTITION BY car`:each distinct value of car (red, and green)
62+
// should be treated as a seprate partition (and will result in
63+
// creating a new `PartitionEvaluator`)
64+
//
65+
// `ORDER BY time`: within each partition ('green' or 'red') the
66+
// rows will be be orderd by the value in the `time` column
67+
//
68+
// `evaluate_inside_range` is invoked with a window defined by the
69+
// SQL. In this case:
70+
//
71+
// The first invocation will be passed row 0, the first row in the
72+
// partition.
73+
//
74+
// The second invocation will be passed rows 0 and 1, the first
75+
// two rows in the partition.
76+
//
77+
// etc.
78+
let df = ctx
79+
.sql(
80+
"SELECT \
81+
car, \
82+
speed, \
83+
smooth_it(speed) OVER (PARTITION BY car ORDER BY time),\
84+
time \
85+
from cars \
86+
ORDER BY \
87+
car",
88+
)
89+
.await?;
90+
// print the results
91+
df.show().await?;
92+
93+
// this time, call the new widow function with an explicit
94+
// window. This *requires* that `evaluate_all` or
95+
//
96+
// `ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING`: each invocation
97+
// sees at most 3 rows: the row before, the current row, and the 1
98+
// row afterward.
99+
let df = ctx.sql(
100+
"SELECT \
101+
car, \
102+
speed, \
103+
smooth_it(speed) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
104+
time \
105+
from cars \
106+
ORDER BY \
107+
car",
108+
).await?;
109+
// print the results
110+
df.show().await?;
111+
112+
// todo show how to run dataframe API as well
113+
114+
Ok(())
115+
}
116+
fn smooth_it() -> WindowUDF {
117+
WindowUDF {
118+
name: String::from("smooth_it"),
119+
// it will take 1 arguments -- the column to smooth
120+
signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable),
121+
return_type: Arc::new(return_type),
122+
partition_evaluator_factory: Arc::new(make_partition_evaluator),
123+
}
124+
}
125+
126+
/// Compute the return type of the smooth_it window function given
127+
/// arguments of `arg_types`.
128+
fn return_type(arg_types: &[DataType]) -> Result<Arc<DataType>> {
129+
if arg_types.len() != 1 {
130+
return Err(DataFusionError::Plan(format!(
131+
"my_udwf expects 1 argument, got {}: {:?}",
132+
arg_types.len(),
133+
arg_types
134+
)));
135+
}
136+
Ok(Arc::new(arg_types[0].clone()))
137+
}
138+
139+
/// Create a `PartitionEvalutor` to evaluate this function on a new
140+
/// partition.
141+
fn make_partition_evaluator() -> Result<Box<dyn PartitionEvaluator>> {
142+
Ok(Box::new(MyPartitionEvaluator::new()))
143+
}
144+
145+
/// This implements the lowest level evaluation for a window function
146+
///
147+
/// It handles calculating the value of the window function for each
148+
/// distinct values of `PARTITION BY` (each car type in our example)
149+
#[derive(Clone, Debug)]
150+
struct MyPartitionEvaluator {}
151+
152+
impl MyPartitionEvaluator {
153+
fn new() -> Self {
154+
Self {}
155+
}
156+
}
157+
158+
/// These different evaluation methods are called depending on the various settings of WindowUDF
159+
impl PartitionEvaluator for MyPartitionEvaluator {
160+
/// Tell DataFusion the window function varies based on the value
161+
/// of the window frame.
162+
fn uses_window_frame(&self) -> bool {
163+
true
164+
}
165+
166+
/// This function is called once per input row.
167+
///
168+
/// `range`specifies which indexes of `values` should be
169+
/// considered for the calculation.
170+
///
171+
/// Note this is the SLOWEST, but simplest, way to evaluate a
172+
/// window function. It is much faster to implement
173+
/// evaluate_all or evaluate_all_with_rank, if possible
174+
fn evaluate(
175+
&mut self,
176+
values: &[ArrayRef],
177+
range: &std::ops::Range<usize>,
178+
) -> Result<ScalarValue> {
179+
println!("evaluate_inside_range(). range: {range:#?}, values: {values:#?}");
180+
181+
// Again, the input argument is an array of floating
182+
// point numbers to calculate a moving average
183+
let arr: &Float64Array = values[0].as_ref().as_primitive::<Float64Type>();
184+
185+
let range_len = range.end - range.start;
186+
187+
// our smoothing function will average all the values in the
188+
let output = if range_len > 0 {
189+
let sum: f64 = arr.values().iter().skip(range.start).take(range_len).sum();
190+
Some(sum / range_len as f64)
191+
} else {
192+
None
193+
};
194+
195+
Ok(ScalarValue::Float64(output))
196+
}
197+
}

datafusion/core/src/execution/context.rs

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ use datafusion_common::alias::AliasGenerator;
3232
use datafusion_execution::registry::SerializerRegistry;
3333
use datafusion_expr::{
3434
logical_plan::{DdlStatement, Statement},
35-
DescribeTable, StringifiedPlan, UserDefinedLogicalNode,
35+
DescribeTable, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
3636
};
3737
pub use datafusion_physical_expr::execution_props::ExecutionProps;
3838
use datafusion_physical_expr::var_provider::is_system_variables;
@@ -786,6 +786,20 @@ impl SessionContext {
786786
.insert(f.name.clone(), Arc::new(f));
787787
}
788788

789+
/// Registers an window UDF within this context.
790+
///
791+
/// Note in SQL queries, window function names are looked up using
792+
/// lowercase unless the query uses quotes. For example,
793+
///
794+
/// - `SELECT MY_UDAF(x)...` will look for an aggregate named `"my_udaf"`
795+
/// - `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"`
796+
pub fn register_udwf(&self, f: WindowUDF) {
797+
self.state
798+
.write()
799+
.window_functions
800+
.insert(f.name.clone(), Arc::new(f));
801+
}
802+
789803
/// Creates a [`DataFrame`] for reading a data source.
790804
///
791805
/// For more control such as reading multiple files, you can use
@@ -1279,6 +1293,10 @@ impl FunctionRegistry for SessionContext {
12791293
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
12801294
self.state.read().udaf(name)
12811295
}
1296+
1297+
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
1298+
self.state.read().udwf(name)
1299+
}
12821300
}
12831301

12841302
/// A planner used to add extensions to DataFusion logical and physical plans.
@@ -1329,6 +1347,8 @@ pub struct SessionState {
13291347
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
13301348
/// Aggregate functions registered in the context
13311349
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
1350+
/// Window functions registered in the context
1351+
window_functions: HashMap<String, Arc<WindowUDF>>,
13321352
/// Deserializer registry for extensions.
13331353
serializer_registry: Arc<dyn SerializerRegistry>,
13341354
/// Session configuration
@@ -1423,6 +1443,7 @@ impl SessionState {
14231443
catalog_list,
14241444
scalar_functions: HashMap::new(),
14251445
aggregate_functions: HashMap::new(),
1446+
window_functions: HashMap::new(),
14261447
serializer_registry: Arc::new(EmptySerializerRegistry),
14271448
config,
14281449
execution_props: ExecutionProps::new(),
@@ -1899,6 +1920,11 @@ impl SessionState {
18991920
&self.aggregate_functions
19001921
}
19011922

1923+
/// Return reference to window functions
1924+
pub fn window_functions(&self) -> &HashMap<String, Arc<WindowUDF>> {
1925+
&self.window_functions
1926+
}
1927+
19021928
/// Return [SerializerRegistry] for extensions
19031929
pub fn serializer_registry(&self) -> Arc<dyn SerializerRegistry> {
19041930
self.serializer_registry.clone()
@@ -1932,6 +1958,10 @@ impl<'a> ContextProvider for SessionContextProvider<'a> {
19321958
self.state.aggregate_functions().get(name).cloned()
19331959
}
19341960

1961+
fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
1962+
self.state.window_functions().get(name).cloned()
1963+
}
1964+
19351965
fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
19361966
if variable_names.is_empty() {
19371967
return None;
@@ -1979,6 +2009,16 @@ impl FunctionRegistry for SessionState {
19792009
))
19802010
})
19812011
}
2012+
2013+
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
2014+
let result = self.window_functions.get(name);
2015+
2016+
result.cloned().ok_or_else(|| {
2017+
DataFusionError::Plan(format!(
2018+
"There is no UDWF named \"{name}\" in the registry"
2019+
))
2020+
})
2021+
}
19822022
}
19832023

19842024
impl OptimizerConfig for SessionState {
@@ -2012,6 +2052,7 @@ impl From<&SessionState> for TaskContext {
20122052
state.config.clone(),
20132053
state.scalar_functions.clone(),
20142054
state.aggregate_functions.clone(),
2055+
state.window_functions.clone(),
20152056
state.runtime_env.clone(),
20162057
)
20172058
}

0 commit comments

Comments
 (0)