Skip to content

User Defined Table Function (udtf) support #2177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ballista/rust/core/src/serde/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,9 @@ impl AsLogicalPlan for LogicalPlanNode {
LogicalPlan::DropTable(_) => Err(proto_error(
"Error converting DropTable. Not yet supported in Ballista",
)),
LogicalPlan::TableUDFs(_) => Err(proto_error(
"Error converting TableUDFs. Not yet supported in Ballista",
)),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions datafusion/core/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ impl ExpressionVisitor for ApplicabilityVisitor<'_> {
Expr::ScalarUDF { fun, .. } => {
self.visit_volatility(fun.signature.volatility)
}
Expr::TableUDF { fun, .. } => self.visit_volatility(fun.signature.volatility),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend writing it like this:

            Expr::ScalarUDF { fun, .. } | Expr::TableUDF { fun, .. } => {
                self.visit_volatility(fun.signature.volatility)
            }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, however, it doesn't work in this case (because argument fun has different types for TableUDF and ScalarUDF)


// TODO other expressions are not handled yet:
// - AGGREGATE, WINDOW and SORT should not end up in filter conditions, except maybe in some edge cases
Expand Down
223 changes: 219 additions & 4 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ use crate::physical_plan::file_format::{plan_to_csv, plan_to_json, plan_to_parqu
use crate::physical_plan::planner::DefaultPhysicalPlanner;
use crate::physical_plan::udaf::AggregateUDF;
use crate::physical_plan::udf::ScalarUDF;
use crate::physical_plan::udtf::TableUDF;
use crate::physical_plan::ExecutionPlan;
use crate::physical_plan::PhysicalPlanner;
use crate::sql::{
Expand Down Expand Up @@ -452,6 +453,20 @@ impl SessionContext {
.insert(f.name.clone(), Arc::new(f));
}

/// Registers a table UDF within this context.
///
/// Note in SQL queries, function names are looked up using
/// lowercase unless the query uses quotes. For example,
///
/// `SELECT MY_FUNC(x)...` will look for a function named `"my_func"`
/// `SELECT "my_FUNC"(x)` will look for a function named `"my_FUNC"`
pub fn register_udtf(&mut self, f: TableUDF) {
self.state
.write()
.table_functions
.insert(f.name.clone(), Arc::new(f));
}

/// Registers an aggregate UDF within this context.
///
/// Note in SQL queries, aggregate names are looked up using
Expand Down Expand Up @@ -1142,6 +1157,8 @@ pub struct SessionState {
pub catalog_list: Arc<dyn CatalogList>,
/// Scalar functions that are registered with the context
pub scalar_functions: HashMap<String, Arc<ScalarUDF>>,
/// Table functions that are registered with the context
pub table_functions: HashMap<String, Arc<TableUDF>>,
/// Aggregate functions registered in the context
pub aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
/// Session configuration
Expand Down Expand Up @@ -1225,6 +1242,7 @@ impl SessionState {
query_planner: Arc::new(DefaultQueryPlanner {}),
catalog_list,
scalar_functions: HashMap::new(),
table_functions: HashMap::new(),
aggregate_functions: HashMap::new(),
config,
execution_props: ExecutionProps::new(),
Expand Down Expand Up @@ -1385,6 +1403,12 @@ impl ContextProvider for SessionState {
self.aggregate_functions.get(name).cloned()
}

fn get_table_function_meta(&self, name: &str) -> Option<Arc<TableUDF>> {
self.table_functions
.get(&name.to_ascii_lowercase())
.cloned()
}

fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
if variable_names.is_empty() {
return None;
Expand Down Expand Up @@ -1603,19 +1627,21 @@ mod tests {
use super::*;
use crate::execution::context::QueryPlanner;
use crate::logical_plan::{binary_expr, lit, Operator};
use crate::physical_plan::functions::make_scalar_function;
use crate::physical_plan::functions::{make_scalar_function, make_table_function};
use crate::test;
use crate::variable::VarType;
use crate::{
assert_batches_eq, assert_batches_sorted_eq,
logical_plan::{col, create_udf, sum, Expr},
logical_plan::{col, create_udf, create_udtf, sum, Expr},
};
use crate::{logical_plan::create_udaf, physical_plan::expressions::AvgAccumulator};
use arrow::array::ArrayRef;
use arrow::array::{
ArrayBuilder, ArrayRef, Int64Array, Int64Builder, StringBuilder, StructBuilder,
};
use arrow::datatypes::*;
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion_expr::Volatility;
use datafusion_expr::{TableFunctionImplementation, Volatility};
use std::fs::File;
use std::sync::Weak;
use std::thread::{self, JoinHandle};
Expand Down Expand Up @@ -2115,6 +2141,195 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn user_defined_table_function() -> Result<()> {
let mut ctx = SessionContext::new();

let integer_series = integer_udtf();
ctx.register_udtf(create_udtf(
"integer_series",
vec![DataType::Int64, DataType::Int64],
Arc::new(DataType::Int64),
Volatility::Immutable,
integer_series,
));

let struct_func = struct_udtf();
ctx.register_udtf(create_udtf(
"struct_func",
vec![DataType::Int64],
Arc::new(DataType::Struct(
[
Field::new("f1", DataType::Utf8, false),
Field::new("f2", DataType::Int64, false),
]
.to_vec(),
)),
Volatility::Immutable,
struct_func,
));

let result = plan_and_collect(&ctx, "SELECT struct_func(5)").await?;

let expected = vec![
"+-------------------------+",
"| struct_func(Int64(5)) |",
"+-------------------------+",
"| {\"f1\": \"test\", \"f2\": 5} |",
"+-------------------------+",
];

assert_batches_eq!(expected, &result);

let result = plan_and_collect(&ctx, "SELECT integer_series(6,5)").await?;

let expected = vec![
"+-----------------------------------+",
"| integer_series(Int64(6),Int64(5)) |",
"+-----------------------------------+",
"+-----------------------------------+",
];

assert_batches_eq!(expected, &result);

let result = plan_and_collect(&ctx, "SELECT integer_series(1,5)").await?;

let expected = vec![
"+-----------------------------------+",
"| integer_series(Int64(1),Int64(5)) |",
"+-----------------------------------+",
"| 1 |",
"| 2 |",
"| 3 |",
"| 4 |",
"| 5 |",
"+-----------------------------------+",
];
Comment on lines +2195 to +2207
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good example of a UDT producing more row than went in 👍

Would it be possible to write an example that also produces a different number of columns than went in? I think that is what @Ted-Jiang and I are pointing out in in our comments below

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't support it. You can use structures for that


assert_batches_eq!(expected, &result);

let result = plan_and_collect(
&ctx,
"SELECT asd, struct_func(qwe), integer_series(asd, qwe), integer_series(1, qwe) r FROM (select 1 asd, 3 qwe UNION ALL select 2 asd, 4 qwe) x",
)
.await?;

let expected = vec![
"+-----+-------------------------+-----------------------------+---+",
"| asd | struct_func(x.qwe) | integer_series(x.asd,x.qwe) | r |",
"+-----+-------------------------+-----------------------------+---+",
"| 1 | {\"f1\": \"test\", \"f2\": 3} | 1 | 1 |",
"| 1 | | 2 | 2 |",
"| 1 | | 3 | 3 |",
"| 2 | {\"f1\": \"test\", \"f2\": 4} | 2 | 1 |",
"| 2 | | 3 | 2 |",
"| 2 | | 4 | 3 |",
"| 2 | | | 4 |",
"+-----+-------------------------+-----------------------------+---+",
];

assert_batches_eq!(expected, &result);

let result =
plan_and_collect(&ctx, "SELECT * from integer_series(1,5) pos(n)").await?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what this test is supposed to be demonstrating? I am not quite sure what it shows

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have just explained it in the header of PR. Hope I did it clear enough:)


let expected = vec![
"+---+", "| n |", "+---+", "| 1 |", "| 2 |", "| 3 |", "| 4 |", "| 5 |",
"+---+",
];

assert_batches_eq!(expected, &result);

let result = plan_and_collect(&ctx, "SELECT * from integer_series(1,5)").await?;

let expected = vec![
"+-----------------------------------+",
"| integer_series(Int64(1),Int64(5)) |",
"+-----------------------------------+",
"| 1 |",
"| 2 |",
"| 3 |",
"| 4 |",
"| 5 |",
"+-----------------------------------+",
];

assert_batches_eq!(expected, &result);

Ok(())
}

fn integer_udtf() -> TableFunctionImplementation {
make_table_function(move |args: &[ArrayRef]| {
assert!(args.len() == 2);

let start_arr = &args[0]
.as_any()
.downcast_ref::<Int64Array>()
.expect("cast failed");

let end_arr = &args[1]
.as_any()
.downcast_ref::<Int64Array>()
.expect("cast failed");

let mut batch_sizes: Vec<usize> = Vec::new();
let mut builder = Int64Builder::new(1);

for (start, end) in start_arr.iter().zip(end_arr.iter()) {
let start_number = start.unwrap();
let end_number = end.unwrap();
let count: usize = if end_number < start_number {
0
} else {
(end_number - start_number + 1).try_into().unwrap()
};
batch_sizes.push(count);

for i in start_number..end_number + 1 {
builder.append_value(i).unwrap();
}
}

Ok((Arc::new(builder.finish()) as ArrayRef, batch_sizes))
})
}

fn struct_udtf() -> TableFunctionImplementation {
make_table_function(move |args: &[ArrayRef]| {
let start_arr = &args[0]
.as_any()
.downcast_ref::<Int64Array>()
.expect("cast failed");

let mut string_builder = StringBuilder::new(1);
let mut int_builder = Int64Builder::new(1);

let mut batch_sizes: Vec<usize> = Vec::new();
for start in start_arr.iter() {
let start_number = start.unwrap();
batch_sizes.push(1);

string_builder.append_value("test").unwrap();
int_builder.append_value(start_number).unwrap();
}

let mut fields = Vec::new();
let mut field_builders = Vec::new();
fields.push(Field::new("f1", DataType::Utf8, false));
field_builders.push(Box::new(string_builder) as Box<dyn ArrayBuilder>);
fields.push(Field::new("f2", DataType::Int64, false));
field_builders.push(Box::new(int_builder) as Box<dyn ArrayBuilder>);

let mut builder = StructBuilder::new(fields, field_builders);
for _start in start_arr.iter() {
builder.append(true).unwrap();
}

Ok((Arc::new(builder.finish()) as ArrayRef, batch_sizes))
})
}

struct MyPhysicalPlanner {}

#[async_trait]
Expand Down
25 changes: 24 additions & 1 deletion datafusion/core/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::error::{DataFusionError, Result};
use crate::logical_expr::ExprSchemable;
use crate::logical_plan::plan::{
Aggregate, Analyze, EmptyRelation, Explain, Filter, Join, Projection, Sort,
SubqueryAlias, TableScan, ToStringifiedPlan, Union, Window,
SubqueryAlias, TableScan, TableUDFs, ToStringifiedPlan, Union, Window,
};
use crate::optimizer::utils;
use crate::prelude::*;
Expand Down Expand Up @@ -1199,6 +1199,29 @@ pub(crate) fn expand_qualified_wildcard(
expand_wildcard(&qualifier_schema, plan)
}

/// Build merged table udf schema from input and TableUDF experssions
pub fn build_table_udf_schema(
input: &LogicalPlan,
udtf_expr: &[Expr],
) -> Result<DFSchemaRef> {
let input_schema = input.schema();
let mut schema = (**input_schema).clone();
schema.merge(&DFSchema::new_with_metadata(
exprlist_to_fields(udtf_expr, input_schema)?,
HashMap::new(),
)?);
Ok(Arc::new(schema))
}

pub(crate) fn table_udfs(plan: LogicalPlan, udtf_expr: Vec<Expr>) -> Result<LogicalPlan> {
let schema = build_table_udf_schema(&plan, &udtf_expr)?;
Ok(LogicalPlan::TableUDFs(TableUDFs {
expr: udtf_expr,
input: Arc::new(plan),
schema,
}))
}

#[cfg(test)]
mod tests {
use arrow::datatypes::{DataType, Field};
Expand Down
25 changes: 24 additions & 1 deletion datafusion/core/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@ use crate::logical_plan::{DFField, DFSchema};
use arrow::datatypes::DataType;
pub use datafusion_common::{Column, ExprSchema};
pub use datafusion_expr::expr_fn::*;
use datafusion_expr::AccumulatorFunctionImplementation;
use datafusion_expr::BuiltinScalarFunction;
pub use datafusion_expr::Expr;
use datafusion_expr::StateTypeFunction;
pub use datafusion_expr::{lit, lit_timestamp_nano, Literal};
use datafusion_expr::{
AccumulatorFunctionImplementation, TableFunctionImplementation, TableUDF,
};
use datafusion_expr::{AggregateUDF, ScalarUDF};
use datafusion_expr::{
ReturnTypeFunction, ScalarFunctionImplementation, Signature, Volatility,
Expand Down Expand Up @@ -113,6 +115,27 @@ pub fn create_udf(
)
}

/// Creates a new UDTF with a specific signature and specific return type.
/// This is a helper function to create a new UDTF.
/// The function `create_udtf` returns a subset of all possible `TableFunction`:
/// * the UDTF has a fixed return type
/// * the UDTF has a fixed signature (e.g. [f64, f64])
pub fn create_udtf(
name: &str,
input_types: Vec<DataType>,
return_type: Arc<DataType>,
volatility: Volatility,
fun: TableFunctionImplementation,
) -> TableUDF {
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone()));
TableUDF::new(
name,
&Signature::exact(input_types, volatility),
&return_type,
&fun,
)
}

/// Creates a new UDAF with a specific signature, state type and return type.
/// The signature and state type must match the `Accumulator's implementation`.
#[allow(clippy::rc_buffer)]
Expand Down
Loading