-
Notifications
You must be signed in to change notification settings - Fork 1.5k
User Defined Table Function (udtf) support #2177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d8d91a2
a16cdee
d84ef1c
64ee981
8d13284
814aa96
9107b6b
58eb2cc
d18807f
75c070b
0b7a4a6
86adb81
c72bef0
c548400
a9c1b9e
e979728
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -84,6 +84,7 @@ use crate::physical_plan::file_format::{plan_to_csv, plan_to_json, plan_to_parqu | |
use crate::physical_plan::planner::DefaultPhysicalPlanner; | ||
use crate::physical_plan::udaf::AggregateUDF; | ||
use crate::physical_plan::udf::ScalarUDF; | ||
use crate::physical_plan::udtf::TableUDF; | ||
use crate::physical_plan::ExecutionPlan; | ||
use crate::physical_plan::PhysicalPlanner; | ||
use crate::sql::{ | ||
|
@@ -452,6 +453,20 @@ impl SessionContext { | |
.insert(f.name.clone(), Arc::new(f)); | ||
} | ||
|
||
/// Registers a table UDF within this context. | ||
/// | ||
/// Note in SQL queries, function names are looked up using | ||
/// lowercase unless the query uses quotes. For example, | ||
/// | ||
/// `SELECT MY_FUNC(x)...` will look for a function named `"my_func"` | ||
/// `SELECT "my_FUNC"(x)` will look for a function named `"my_FUNC"` | ||
pub fn register_udtf(&mut self, f: TableUDF) { | ||
self.state | ||
.write() | ||
.table_functions | ||
.insert(f.name.clone(), Arc::new(f)); | ||
} | ||
|
||
/// Registers an aggregate UDF within this context. | ||
/// | ||
/// Note in SQL queries, aggregate names are looked up using | ||
|
@@ -1142,6 +1157,8 @@ pub struct SessionState { | |
pub catalog_list: Arc<dyn CatalogList>, | ||
/// Scalar functions that are registered with the context | ||
pub scalar_functions: HashMap<String, Arc<ScalarUDF>>, | ||
/// Table functions that are registered with the context | ||
pub table_functions: HashMap<String, Arc<TableUDF>>, | ||
/// Aggregate functions registered in the context | ||
pub aggregate_functions: HashMap<String, Arc<AggregateUDF>>, | ||
/// Session configuration | ||
|
@@ -1225,6 +1242,7 @@ impl SessionState { | |
query_planner: Arc::new(DefaultQueryPlanner {}), | ||
catalog_list, | ||
scalar_functions: HashMap::new(), | ||
table_functions: HashMap::new(), | ||
aggregate_functions: HashMap::new(), | ||
config, | ||
execution_props: ExecutionProps::new(), | ||
|
@@ -1385,6 +1403,12 @@ impl ContextProvider for SessionState { | |
self.aggregate_functions.get(name).cloned() | ||
} | ||
|
||
fn get_table_function_meta(&self, name: &str) -> Option<Arc<TableUDF>> { | ||
self.table_functions | ||
.get(&name.to_ascii_lowercase()) | ||
.cloned() | ||
} | ||
|
||
fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> { | ||
if variable_names.is_empty() { | ||
return None; | ||
|
@@ -1603,19 +1627,21 @@ mod tests { | |
use super::*; | ||
use crate::execution::context::QueryPlanner; | ||
use crate::logical_plan::{binary_expr, lit, Operator}; | ||
use crate::physical_plan::functions::make_scalar_function; | ||
use crate::physical_plan::functions::{make_scalar_function, make_table_function}; | ||
use crate::test; | ||
use crate::variable::VarType; | ||
use crate::{ | ||
assert_batches_eq, assert_batches_sorted_eq, | ||
logical_plan::{col, create_udf, sum, Expr}, | ||
logical_plan::{col, create_udf, create_udtf, sum, Expr}, | ||
}; | ||
use crate::{logical_plan::create_udaf, physical_plan::expressions::AvgAccumulator}; | ||
use arrow::array::ArrayRef; | ||
use arrow::array::{ | ||
ArrayBuilder, ArrayRef, Int64Array, Int64Builder, StringBuilder, StructBuilder, | ||
}; | ||
use arrow::datatypes::*; | ||
use arrow::record_batch::RecordBatch; | ||
use async_trait::async_trait; | ||
use datafusion_expr::Volatility; | ||
use datafusion_expr::{TableFunctionImplementation, Volatility}; | ||
use std::fs::File; | ||
use std::sync::Weak; | ||
use std::thread::{self, JoinHandle}; | ||
|
@@ -2115,6 +2141,195 @@ mod tests { | |
Ok(()) | ||
} | ||
|
||
#[tokio::test] | ||
async fn user_defined_table_function() -> Result<()> { | ||
let mut ctx = SessionContext::new(); | ||
|
||
let integer_series = integer_udtf(); | ||
ctx.register_udtf(create_udtf( | ||
"integer_series", | ||
vec![DataType::Int64, DataType::Int64], | ||
Arc::new(DataType::Int64), | ||
Volatility::Immutable, | ||
integer_series, | ||
)); | ||
|
||
let struct_func = struct_udtf(); | ||
ctx.register_udtf(create_udtf( | ||
"struct_func", | ||
vec![DataType::Int64], | ||
Arc::new(DataType::Struct( | ||
[ | ||
Field::new("f1", DataType::Utf8, false), | ||
Field::new("f2", DataType::Int64, false), | ||
] | ||
.to_vec(), | ||
)), | ||
Volatility::Immutable, | ||
struct_func, | ||
)); | ||
|
||
let result = plan_and_collect(&ctx, "SELECT struct_func(5)").await?; | ||
|
||
let expected = vec![ | ||
"+-------------------------+", | ||
"| struct_func(Int64(5)) |", | ||
"+-------------------------+", | ||
"| {\"f1\": \"test\", \"f2\": 5} |", | ||
"+-------------------------+", | ||
]; | ||
|
||
assert_batches_eq!(expected, &result); | ||
|
||
let result = plan_and_collect(&ctx, "SELECT integer_series(6,5)").await?; | ||
|
||
let expected = vec![ | ||
"+-----------------------------------+", | ||
"| integer_series(Int64(6),Int64(5)) |", | ||
"+-----------------------------------+", | ||
"+-----------------------------------+", | ||
]; | ||
|
||
assert_batches_eq!(expected, &result); | ||
|
||
let result = plan_and_collect(&ctx, "SELECT integer_series(1,5)").await?; | ||
|
||
let expected = vec![ | ||
"+-----------------------------------+", | ||
"| integer_series(Int64(1),Int64(5)) |", | ||
"+-----------------------------------+", | ||
"| 1 |", | ||
"| 2 |", | ||
"| 3 |", | ||
"| 4 |", | ||
"| 5 |", | ||
"+-----------------------------------+", | ||
]; | ||
Comment on lines
+2195
to
+2207
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a good example of a Would it be possible to write an example that also produces a different number of columns than went in? I think that is what @Ted-Jiang and I are pointing out in in our comments below There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't support it. You can use structures for that |
||
|
||
assert_batches_eq!(expected, &result); | ||
|
||
let result = plan_and_collect( | ||
&ctx, | ||
"SELECT asd, struct_func(qwe), integer_series(asd, qwe), integer_series(1, qwe) r FROM (select 1 asd, 3 qwe UNION ALL select 2 asd, 4 qwe) x", | ||
) | ||
.await?; | ||
|
||
let expected = vec![ | ||
"+-----+-------------------------+-----------------------------+---+", | ||
"| asd | struct_func(x.qwe) | integer_series(x.asd,x.qwe) | r |", | ||
"+-----+-------------------------+-----------------------------+---+", | ||
"| 1 | {\"f1\": \"test\", \"f2\": 3} | 1 | 1 |", | ||
"| 1 | | 2 | 2 |", | ||
"| 1 | | 3 | 3 |", | ||
"| 2 | {\"f1\": \"test\", \"f2\": 4} | 2 | 1 |", | ||
"| 2 | | 3 | 2 |", | ||
"| 2 | | 4 | 3 |", | ||
"| 2 | | | 4 |", | ||
"+-----+-------------------------+-----------------------------+---+", | ||
]; | ||
|
||
assert_batches_eq!(expected, &result); | ||
|
||
let result = | ||
plan_and_collect(&ctx, "SELECT * from integer_series(1,5) pos(n)").await?; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain what this test is supposed to be demonstrating? I am not quite sure what it shows There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have just explained it in the header of PR. Hope I did it clear enough:) |
||
|
||
let expected = vec![ | ||
"+---+", "| n |", "+---+", "| 1 |", "| 2 |", "| 3 |", "| 4 |", "| 5 |", | ||
"+---+", | ||
]; | ||
|
||
assert_batches_eq!(expected, &result); | ||
|
||
let result = plan_and_collect(&ctx, "SELECT * from integer_series(1,5)").await?; | ||
|
||
let expected = vec![ | ||
"+-----------------------------------+", | ||
"| integer_series(Int64(1),Int64(5)) |", | ||
"+-----------------------------------+", | ||
"| 1 |", | ||
"| 2 |", | ||
"| 3 |", | ||
"| 4 |", | ||
"| 5 |", | ||
"+-----------------------------------+", | ||
]; | ||
|
||
assert_batches_eq!(expected, &result); | ||
|
||
Ok(()) | ||
} | ||
|
||
fn integer_udtf() -> TableFunctionImplementation { | ||
make_table_function(move |args: &[ArrayRef]| { | ||
assert!(args.len() == 2); | ||
|
||
let start_arr = &args[0] | ||
.as_any() | ||
.downcast_ref::<Int64Array>() | ||
.expect("cast failed"); | ||
|
||
let end_arr = &args[1] | ||
.as_any() | ||
.downcast_ref::<Int64Array>() | ||
.expect("cast failed"); | ||
|
||
let mut batch_sizes: Vec<usize> = Vec::new(); | ||
let mut builder = Int64Builder::new(1); | ||
|
||
for (start, end) in start_arr.iter().zip(end_arr.iter()) { | ||
let start_number = start.unwrap(); | ||
let end_number = end.unwrap(); | ||
let count: usize = if end_number < start_number { | ||
0 | ||
} else { | ||
(end_number - start_number + 1).try_into().unwrap() | ||
}; | ||
batch_sizes.push(count); | ||
|
||
for i in start_number..end_number + 1 { | ||
builder.append_value(i).unwrap(); | ||
} | ||
} | ||
|
||
Ok((Arc::new(builder.finish()) as ArrayRef, batch_sizes)) | ||
}) | ||
} | ||
|
||
fn struct_udtf() -> TableFunctionImplementation { | ||
make_table_function(move |args: &[ArrayRef]| { | ||
let start_arr = &args[0] | ||
.as_any() | ||
.downcast_ref::<Int64Array>() | ||
.expect("cast failed"); | ||
|
||
let mut string_builder = StringBuilder::new(1); | ||
let mut int_builder = Int64Builder::new(1); | ||
|
||
let mut batch_sizes: Vec<usize> = Vec::new(); | ||
for start in start_arr.iter() { | ||
let start_number = start.unwrap(); | ||
batch_sizes.push(1); | ||
|
||
string_builder.append_value("test").unwrap(); | ||
int_builder.append_value(start_number).unwrap(); | ||
} | ||
|
||
let mut fields = Vec::new(); | ||
let mut field_builders = Vec::new(); | ||
fields.push(Field::new("f1", DataType::Utf8, false)); | ||
field_builders.push(Box::new(string_builder) as Box<dyn ArrayBuilder>); | ||
fields.push(Field::new("f2", DataType::Int64, false)); | ||
field_builders.push(Box::new(int_builder) as Box<dyn ArrayBuilder>); | ||
|
||
let mut builder = StructBuilder::new(fields, field_builders); | ||
for _start in start_arr.iter() { | ||
builder.append(true).unwrap(); | ||
} | ||
|
||
Ok((Arc::new(builder.finish()) as ArrayRef, batch_sizes)) | ||
}) | ||
} | ||
|
||
struct MyPhysicalPlanner {} | ||
|
||
#[async_trait] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recommend writing it like this:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, however, it doesn't work in this case (because argument fun has different types for TableUDF and ScalarUDF)