diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs b/ballista/rust/core/src/serde/logical_plan/mod.rs index 5307aff653e4..ab291210003f 100644 --- a/ballista/rust/core/src/serde/logical_plan/mod.rs +++ b/ballista/rust/core/src/serde/logical_plan/mod.rs @@ -968,6 +968,9 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlan::DropTable(_) => Err(proto_error( "Error converting DropTable. Not yet supported in Ballista", )), + LogicalPlan::TableUDFs(_) => Err(proto_error( + "Error converting TableUDFs. Not yet supported in Ballista", + )), } } } diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index d066d8d9dd2c..70ae29eb2ae1 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -102,6 +102,7 @@ impl ExpressionVisitor for ApplicabilityVisitor<'_> { Expr::ScalarUDF { fun, .. } => { self.visit_volatility(fun.signature.volatility) } + Expr::TableUDF { fun, .. } => self.visit_volatility(fun.signature.volatility), // TODO other expressions are not handled yet: // - AGGREGATE, WINDOW and SORT should not end up in filter conditions, except maybe in some edge cases diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 5378e38da6d0..9ea2ce963e85 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -84,6 +84,7 @@ use crate::physical_plan::file_format::{plan_to_csv, plan_to_json, plan_to_parqu use crate::physical_plan::planner::DefaultPhysicalPlanner; use crate::physical_plan::udaf::AggregateUDF; use crate::physical_plan::udf::ScalarUDF; +use crate::physical_plan::udtf::TableUDF; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::PhysicalPlanner; use crate::sql::{ @@ -452,6 +453,20 @@ impl SessionContext { .insert(f.name.clone(), Arc::new(f)); } + /// Registers a table UDF within this context. + /// + /// Note in SQL queries, function names are looked up using + /// lowercase unless the query uses quotes. For example, + /// + /// `SELECT MY_FUNC(x)...` will look for a function named `"my_func"` + /// `SELECT "my_FUNC"(x)` will look for a function named `"my_FUNC"` + pub fn register_udtf(&mut self, f: TableUDF) { + self.state + .write() + .table_functions + .insert(f.name.clone(), Arc::new(f)); + } + /// Registers an aggregate UDF within this context. /// /// Note in SQL queries, aggregate names are looked up using @@ -1142,6 +1157,8 @@ pub struct SessionState { pub catalog_list: Arc, /// Scalar functions that are registered with the context pub scalar_functions: HashMap>, + /// Table functions that are registered with the context + pub table_functions: HashMap>, /// Aggregate functions registered in the context pub aggregate_functions: HashMap>, /// Session configuration @@ -1225,6 +1242,7 @@ impl SessionState { query_planner: Arc::new(DefaultQueryPlanner {}), catalog_list, scalar_functions: HashMap::new(), + table_functions: HashMap::new(), aggregate_functions: HashMap::new(), config, execution_props: ExecutionProps::new(), @@ -1385,6 +1403,12 @@ impl ContextProvider for SessionState { self.aggregate_functions.get(name).cloned() } + fn get_table_function_meta(&self, name: &str) -> Option> { + self.table_functions + .get(&name.to_ascii_lowercase()) + .cloned() + } + fn get_variable_type(&self, variable_names: &[String]) -> Option { if variable_names.is_empty() { return None; @@ -1603,19 +1627,21 @@ mod tests { use super::*; use crate::execution::context::QueryPlanner; use crate::logical_plan::{binary_expr, lit, Operator}; - use crate::physical_plan::functions::make_scalar_function; + use crate::physical_plan::functions::{make_scalar_function, make_table_function}; use crate::test; use crate::variable::VarType; use crate::{ assert_batches_eq, assert_batches_sorted_eq, - logical_plan::{col, create_udf, sum, Expr}, + logical_plan::{col, create_udf, create_udtf, sum, Expr}, }; use crate::{logical_plan::create_udaf, physical_plan::expressions::AvgAccumulator}; - use arrow::array::ArrayRef; + use arrow::array::{ + ArrayBuilder, ArrayRef, Int64Array, Int64Builder, StringBuilder, StructBuilder, + }; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; use async_trait::async_trait; - use datafusion_expr::Volatility; + use datafusion_expr::{TableFunctionImplementation, Volatility}; use std::fs::File; use std::sync::Weak; use std::thread::{self, JoinHandle}; @@ -2115,6 +2141,195 @@ mod tests { Ok(()) } + #[tokio::test] + async fn user_defined_table_function() -> Result<()> { + let mut ctx = SessionContext::new(); + + let integer_series = integer_udtf(); + ctx.register_udtf(create_udtf( + "integer_series", + vec![DataType::Int64, DataType::Int64], + Arc::new(DataType::Int64), + Volatility::Immutable, + integer_series, + )); + + let struct_func = struct_udtf(); + ctx.register_udtf(create_udtf( + "struct_func", + vec![DataType::Int64], + Arc::new(DataType::Struct( + [ + Field::new("f1", DataType::Utf8, false), + Field::new("f2", DataType::Int64, false), + ] + .to_vec(), + )), + Volatility::Immutable, + struct_func, + )); + + let result = plan_and_collect(&ctx, "SELECT struct_func(5)").await?; + + let expected = vec![ + "+-------------------------+", + "| struct_func(Int64(5)) |", + "+-------------------------+", + "| {\"f1\": \"test\", \"f2\": 5} |", + "+-------------------------+", + ]; + + assert_batches_eq!(expected, &result); + + let result = plan_and_collect(&ctx, "SELECT integer_series(6,5)").await?; + + let expected = vec![ + "+-----------------------------------+", + "| integer_series(Int64(6),Int64(5)) |", + "+-----------------------------------+", + "+-----------------------------------+", + ]; + + assert_batches_eq!(expected, &result); + + let result = plan_and_collect(&ctx, "SELECT integer_series(1,5)").await?; + + let expected = vec![ + "+-----------------------------------+", + "| integer_series(Int64(1),Int64(5)) |", + "+-----------------------------------+", + "| 1 |", + "| 2 |", + "| 3 |", + "| 4 |", + "| 5 |", + "+-----------------------------------+", + ]; + + assert_batches_eq!(expected, &result); + + let result = plan_and_collect( + &ctx, + "SELECT asd, struct_func(qwe), integer_series(asd, qwe), integer_series(1, qwe) r FROM (select 1 asd, 3 qwe UNION ALL select 2 asd, 4 qwe) x", + ) + .await?; + + let expected = vec![ + "+-----+-------------------------+-----------------------------+---+", + "| asd | struct_func(x.qwe) | integer_series(x.asd,x.qwe) | r |", + "+-----+-------------------------+-----------------------------+---+", + "| 1 | {\"f1\": \"test\", \"f2\": 3} | 1 | 1 |", + "| 1 | | 2 | 2 |", + "| 1 | | 3 | 3 |", + "| 2 | {\"f1\": \"test\", \"f2\": 4} | 2 | 1 |", + "| 2 | | 3 | 2 |", + "| 2 | | 4 | 3 |", + "| 2 | | | 4 |", + "+-----+-------------------------+-----------------------------+---+", + ]; + + assert_batches_eq!(expected, &result); + + let result = + plan_and_collect(&ctx, "SELECT * from integer_series(1,5) pos(n)").await?; + + let expected = vec![ + "+---+", "| n |", "+---+", "| 1 |", "| 2 |", "| 3 |", "| 4 |", "| 5 |", + "+---+", + ]; + + assert_batches_eq!(expected, &result); + + let result = plan_and_collect(&ctx, "SELECT * from integer_series(1,5)").await?; + + let expected = vec![ + "+-----------------------------------+", + "| integer_series(Int64(1),Int64(5)) |", + "+-----------------------------------+", + "| 1 |", + "| 2 |", + "| 3 |", + "| 4 |", + "| 5 |", + "+-----------------------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + fn integer_udtf() -> TableFunctionImplementation { + make_table_function(move |args: &[ArrayRef]| { + assert!(args.len() == 2); + + let start_arr = &args[0] + .as_any() + .downcast_ref::() + .expect("cast failed"); + + let end_arr = &args[1] + .as_any() + .downcast_ref::() + .expect("cast failed"); + + let mut batch_sizes: Vec = Vec::new(); + let mut builder = Int64Builder::new(1); + + for (start, end) in start_arr.iter().zip(end_arr.iter()) { + let start_number = start.unwrap(); + let end_number = end.unwrap(); + let count: usize = if end_number < start_number { + 0 + } else { + (end_number - start_number + 1).try_into().unwrap() + }; + batch_sizes.push(count); + + for i in start_number..end_number + 1 { + builder.append_value(i).unwrap(); + } + } + + Ok((Arc::new(builder.finish()) as ArrayRef, batch_sizes)) + }) + } + + fn struct_udtf() -> TableFunctionImplementation { + make_table_function(move |args: &[ArrayRef]| { + let start_arr = &args[0] + .as_any() + .downcast_ref::() + .expect("cast failed"); + + let mut string_builder = StringBuilder::new(1); + let mut int_builder = Int64Builder::new(1); + + let mut batch_sizes: Vec = Vec::new(); + for start in start_arr.iter() { + let start_number = start.unwrap(); + batch_sizes.push(1); + + string_builder.append_value("test").unwrap(); + int_builder.append_value(start_number).unwrap(); + } + + let mut fields = Vec::new(); + let mut field_builders = Vec::new(); + fields.push(Field::new("f1", DataType::Utf8, false)); + field_builders.push(Box::new(string_builder) as Box); + fields.push(Field::new("f2", DataType::Int64, false)); + field_builders.push(Box::new(int_builder) as Box); + + let mut builder = StructBuilder::new(fields, field_builders); + for _start in start_arr.iter() { + builder.append(true).unwrap(); + } + + Ok((Arc::new(builder.finish()) as ArrayRef, batch_sizes)) + }) + } + struct MyPhysicalPlanner {} #[async_trait] diff --git a/datafusion/core/src/logical_plan/builder.rs b/datafusion/core/src/logical_plan/builder.rs index 0b0fdebb68f7..7480ff49b6ce 100644 --- a/datafusion/core/src/logical_plan/builder.rs +++ b/datafusion/core/src/logical_plan/builder.rs @@ -26,7 +26,7 @@ use crate::error::{DataFusionError, Result}; use crate::logical_expr::ExprSchemable; use crate::logical_plan::plan::{ Aggregate, Analyze, EmptyRelation, Explain, Filter, Join, Projection, Sort, - SubqueryAlias, TableScan, ToStringifiedPlan, Union, Window, + SubqueryAlias, TableScan, TableUDFs, ToStringifiedPlan, Union, Window, }; use crate::optimizer::utils; use crate::prelude::*; @@ -1199,6 +1199,29 @@ pub(crate) fn expand_qualified_wildcard( expand_wildcard(&qualifier_schema, plan) } +/// Build merged table udf schema from input and TableUDF experssions +pub fn build_table_udf_schema( + input: &LogicalPlan, + udtf_expr: &[Expr], +) -> Result { + let input_schema = input.schema(); + let mut schema = (**input_schema).clone(); + schema.merge(&DFSchema::new_with_metadata( + exprlist_to_fields(udtf_expr, input_schema)?, + HashMap::new(), + )?); + Ok(Arc::new(schema)) +} + +pub(crate) fn table_udfs(plan: LogicalPlan, udtf_expr: Vec) -> Result { + let schema = build_table_udf_schema(&plan, &udtf_expr)?; + Ok(LogicalPlan::TableUDFs(TableUDFs { + expr: udtf_expr, + input: Arc::new(plan), + schema, + })) +} + #[cfg(test)] mod tests { use arrow::datatypes::{DataType, Field}; diff --git a/datafusion/core/src/logical_plan/expr.rs b/datafusion/core/src/logical_plan/expr.rs index 673345c69b61..d5cef86443c5 100644 --- a/datafusion/core/src/logical_plan/expr.rs +++ b/datafusion/core/src/logical_plan/expr.rs @@ -25,11 +25,13 @@ use crate::logical_plan::{DFField, DFSchema}; use arrow::datatypes::DataType; pub use datafusion_common::{Column, ExprSchema}; pub use datafusion_expr::expr_fn::*; -use datafusion_expr::AccumulatorFunctionImplementation; use datafusion_expr::BuiltinScalarFunction; pub use datafusion_expr::Expr; use datafusion_expr::StateTypeFunction; pub use datafusion_expr::{lit, lit_timestamp_nano, Literal}; +use datafusion_expr::{ + AccumulatorFunctionImplementation, TableFunctionImplementation, TableUDF, +}; use datafusion_expr::{AggregateUDF, ScalarUDF}; use datafusion_expr::{ ReturnTypeFunction, ScalarFunctionImplementation, Signature, Volatility, @@ -113,6 +115,27 @@ pub fn create_udf( ) } +/// Creates a new UDTF with a specific signature and specific return type. +/// This is a helper function to create a new UDTF. +/// The function `create_udtf` returns a subset of all possible `TableFunction`: +/// * the UDTF has a fixed return type +/// * the UDTF has a fixed signature (e.g. [f64, f64]) +pub fn create_udtf( + name: &str, + input_types: Vec, + return_type: Arc, + volatility: Volatility, + fun: TableFunctionImplementation, +) -> TableUDF { + let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); + TableUDF::new( + name, + &Signature::exact(input_types, volatility), + &return_type, + &fun, + ) +} + /// Creates a new UDAF with a specific signature, state type and return type. /// The signature and state type must match the `Accumulator's implementation`. #[allow(clippy::rc_buffer)] diff --git a/datafusion/core/src/logical_plan/expr_rewriter.rs b/datafusion/core/src/logical_plan/expr_rewriter.rs index ef6d8dff9c8f..5df348c9838b 100644 --- a/datafusion/core/src/logical_plan/expr_rewriter.rs +++ b/datafusion/core/src/logical_plan/expr_rewriter.rs @@ -193,6 +193,10 @@ impl ExprRewritable for Expr { args: rewrite_vec(args, rewriter)?, fun, }, + Expr::TableUDF { args, fun } => Expr::TableUDF { + args: rewrite_vec(args, rewriter)?, + fun, + }, Expr::WindowFunction { args, fun, @@ -456,6 +460,34 @@ pub fn unnormalize_cols(exprs: impl IntoIterator) -> Vec { exprs.into_iter().map(unnormalize_col).collect() } +/// Rewrite all table udfs to columns +/// For now it used by Projection Node (to get access to columns returning by TableUDFs Node) +pub fn rewrite_udtfs_to_columns(exprs: Vec, schema: DFSchema) -> Vec { + struct ReplaceUdtfWithColumn<'a> { + schema: &'a DFSchema, + } + impl<'a> ExprRewriter for ReplaceUdtfWithColumn<'a> { + fn mutate(&mut self, expr: Expr) -> Result { + if let Expr::TableUDF { .. } = expr { + Ok(Expr::Column(Column { + relation: None, + name: expr.name(self.schema).unwrap(), + })) + } else { + Ok(expr) + } + } + } + + exprs + .into_iter() + .map(|expr| { + expr.rewrite(&mut ReplaceUdtfWithColumn { schema: &schema }) + .unwrap() + }) + .collect::>() +} + #[cfg(test)] mod test { use super::*; diff --git a/datafusion/core/src/logical_plan/expr_visitor.rs b/datafusion/core/src/logical_plan/expr_visitor.rs index 7c578da19b75..93238d3a4c53 100644 --- a/datafusion/core/src/logical_plan/expr_visitor.rs +++ b/datafusion/core/src/logical_plan/expr_visitor.rs @@ -147,6 +147,7 @@ impl ExprVisitable for Expr { } Expr::ScalarFunction { args, .. } | Expr::ScalarUDF { args, .. } + | Expr::TableUDF { args, .. } | Expr::AggregateFunction { args, .. } | Expr::AggregateUDF { args, .. } => args .iter() diff --git a/datafusion/core/src/logical_plan/mod.rs b/datafusion/core/src/logical_plan/mod.rs index bcf0a161447f..bfbd95dc11c2 100644 --- a/datafusion/core/src/logical_plan/mod.rs +++ b/datafusion/core/src/logical_plan/mod.rs @@ -30,7 +30,8 @@ pub mod plan; mod registry; pub mod window_frames; pub use builder::{ - build_join_schema, union_with_alias, LogicalPlanBuilder, UNNAMED_TABLE, + build_join_schema, build_table_udf_schema, union_with_alias, LogicalPlanBuilder, + UNNAMED_TABLE, }; pub use datafusion_common::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use datafusion_expr::{expr_fn::binary_expr, Operator}; @@ -40,9 +41,9 @@ pub use expr::{ abs, acos, and, approx_distinct, approx_percentile_cont, array, ascii, asin, atan, avg, bit_length, btrim, call_fn, case, ceil, character_length, chr, coalesce, col, columnize_expr, combine_filters, concat, concat_expr, concat_ws, concat_ws_expr, cos, - count, count_distinct, create_udaf, create_udf, date_part, date_trunc, digest, - exists, exp, exprlist_to_fields, floor, in_list, in_subquery, initcap, left, length, - lit, lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim, max, md5, min, + count, count_distinct, create_udaf, create_udf, create_udtf, date_part, date_trunc, + digest, exists, exp, exprlist_to_fields, floor, in_list, in_subquery, initcap, left, + length, lit, lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim, max, md5, min, not_exists, not_in_subquery, now, now_expr, nullif, octet_length, or, random, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, scalar_subquery, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, @@ -52,7 +53,8 @@ pub use expr::{ }; pub use expr_rewriter::{ normalize_col, normalize_cols, replace_col, rewrite_sort_cols_by_aggs, - unnormalize_col, unnormalize_cols, ExprRewritable, ExprRewriter, RewriteRecursion, + rewrite_udtfs_to_columns, unnormalize_col, unnormalize_cols, ExprRewritable, + ExprRewriter, RewriteRecursion, }; pub use expr_simplier::{ExprSimplifiable, SimplifyInfo}; pub use expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion}; diff --git a/datafusion/core/src/logical_plan/plan.rs b/datafusion/core/src/logical_plan/plan.rs index 08d1fa120e30..b170244f1481 100644 --- a/datafusion/core/src/logical_plan/plan.rs +++ b/datafusion/core/src/logical_plan/plan.rs @@ -28,8 +28,8 @@ pub use crate::logical_expr::{ CreateMemoryTable, CrossJoin, DropTable, EmptyRelation, Explain, Extension, FileType, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, PlanVisitor, Projection, Repartition, Sort, - StringifiedPlan, Subquery, SubqueryAlias, TableScan, ToStringifiedPlan, Union, - UserDefinedLogicalNode, Values, Window, + StringifiedPlan, Subquery, SubqueryAlias, TableScan, TableUDFs, + ToStringifiedPlan, Union, UserDefinedLogicalNode, Values, Window, }, TableProviderFilterPushDown, TableSource, }; diff --git a/datafusion/core/src/optimizer/common_subexpr_eliminate.rs b/datafusion/core/src/optimizer/common_subexpr_eliminate.rs index a9983cdf1e08..9efe8f0063cf 100644 --- a/datafusion/core/src/optimizer/common_subexpr_eliminate.rs +++ b/datafusion/core/src/optimizer/common_subexpr_eliminate.rs @@ -19,7 +19,7 @@ use crate::error::Result; use crate::execution::context::ExecutionProps; -use crate::logical_plan::plan::{Filter, Projection, Window}; +use crate::logical_plan::plan::{Filter, Projection, TableUDFs, Window}; use crate::logical_plan::{ col, plan::{Aggregate, Sort}, @@ -110,6 +110,28 @@ fn optimize(plan: &LogicalPlan, execution_props: &ExecutionProps) -> Result { + let arrays = to_arrays(expr, input, &mut expr_set)?; + + let (mut new_expr, new_input) = rewrite_expr( + &[expr], + &[&arrays], + input, + &mut expr_set, + schema, + execution_props, + )?; + + Ok(LogicalPlan::TableUDFs(TableUDFs { + expr: new_expr.pop().unwrap(), + input: Arc::new(new_input), + schema: schema.clone(), + })) + } LogicalPlan::Filter(Filter { predicate, input }) => { let schema = plan.schema().as_ref().clone(); let data_type = if let Ok(data_type) = predicate.get_type(&schema) { @@ -440,6 +462,10 @@ impl ExprIdentifierVisitor<'_> { desc.push_str("ScalarUDF-"); desc.push_str(&fun.name); } + Expr::TableUDF { fun, .. } => { + desc.push_str("TableUDF-"); + desc.push_str(&fun.name); + } Expr::WindowFunction { fun, window_frame, .. } => { diff --git a/datafusion/core/src/optimizer/projection_push_down.rs b/datafusion/core/src/optimizer/projection_push_down.rs index 5062082e8643..e553b961dc70 100644 --- a/datafusion/core/src/optimizer/projection_push_down.rs +++ b/datafusion/core/src/optimizer/projection_push_down.rs @@ -475,6 +475,7 @@ fn optimize_plan( | LogicalPlan::CreateCatalog(_) | LogicalPlan::DropTable(_) | LogicalPlan::CrossJoin(_) + | LogicalPlan::TableUDFs(_) | LogicalPlan::Extension { .. } => { let expr = plan.expressions(); // collect all required columns by this plan diff --git a/datafusion/core/src/optimizer/simplify_expressions.rs b/datafusion/core/src/optimizer/simplify_expressions.rs index 457224619007..9ff9e7aa17c6 100644 --- a/datafusion/core/src/optimizer/simplify_expressions.rs +++ b/datafusion/core/src/optimizer/simplify_expressions.rs @@ -384,6 +384,7 @@ impl<'a> ConstEvaluator<'a> { | Expr::QualifiedWildcard { .. } => false, Expr::ScalarFunction { fun, .. } => Self::volatility_ok(fun.volatility()), Expr::ScalarUDF { fun, .. } => Self::volatility_ok(fun.signature.volatility), + Expr::TableUDF { .. } => false, Expr::Literal(_) | Expr::BinaryExpr { .. } | Expr::Not(_) diff --git a/datafusion/core/src/optimizer/utils.rs b/datafusion/core/src/optimizer/utils.rs index df36761fec40..015143226d01 100644 --- a/datafusion/core/src/optimizer/utils.rs +++ b/datafusion/core/src/optimizer/utils.rs @@ -21,9 +21,10 @@ use super::optimizer::OptimizerRule; use crate::execution::context::ExecutionProps; use datafusion_expr::logical_plan::{ Aggregate, Analyze, Extension, Filter, Join, Projection, Sort, Subquery, - SubqueryAlias, Window, + SubqueryAlias, TableUDFs, Window, }; +use crate::logical_plan::builder::build_table_udf_schema; use crate::logical_plan::{ build_join_schema, Column, CreateMemoryTable, DFSchemaRef, Expr, ExprVisitable, Limit, LogicalPlan, LogicalPlanBuilder, Operator, Partitioning, Recursion, @@ -81,6 +82,7 @@ impl ExpressionVisitor for ColumnNameVisitor<'_> { | Expr::Sort { .. } | Expr::ScalarFunction { .. } | Expr::ScalarUDF { .. } + | Expr::TableUDF { .. } | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. } @@ -156,6 +158,13 @@ pub fn from_plan( alias: alias.clone(), })) } + LogicalPlan::TableUDFs(TableUDFs { .. }) => { + Ok(LogicalPlan::TableUDFs(TableUDFs { + expr: expr.to_vec(), + input: Arc::new(inputs[0].clone()), + schema: build_table_udf_schema(&inputs[0], expr)?, + })) + } LogicalPlan::Values(Values { schema, .. }) => Ok(LogicalPlan::Values(Values { schema: schema.clone(), values: expr @@ -321,6 +330,7 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result> { | Expr::GetIndexedField { expr, .. } => Ok(vec![expr.as_ref().to_owned()]), Expr::ScalarFunction { args, .. } | Expr::ScalarUDF { args, .. } + | Expr::TableUDF { args, .. } | Expr::AggregateFunction { args, .. } | Expr::AggregateUDF { args, .. } => Ok(args.clone()), Expr::WindowFunction { @@ -406,6 +416,10 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { fun: fun.clone(), args: expressions.to_vec(), }), + Expr::TableUDF { fun, .. } => Ok(Expr::TableUDF { + fun: fun.clone(), + args: expressions.to_vec(), + }), Expr::WindowFunction { fun, window_frame, .. } => { diff --git a/datafusion/core/src/physical_plan/functions.rs b/datafusion/core/src/physical_plan/functions.rs index 60cb33a80f86..22f6c27a6ffc 100644 --- a/datafusion/core/src/physical_plan/functions.rs +++ b/datafusion/core/src/physical_plan/functions.rs @@ -36,7 +36,10 @@ use crate::physical_plan::expressions::{ }; use crate::{ error::{DataFusionError, Result}, - logical_expr::{function, BuiltinScalarFunction, ScalarFunctionImplementation}, + logical_expr::{ + function, BuiltinScalarFunction, ScalarFunctionImplementation, + TableFunctionImplementation, + }, scalar::ScalarValue, }; use arrow::{ @@ -168,6 +171,7 @@ pub fn create_physical_expr( } pub use datafusion_physical_expr::ScalarFunctionExpr; +pub use datafusion_physical_expr::TableFunctionExpr; #[cfg(feature = "crypto_expressions")] macro_rules! invoke_if_crypto_expressions_feature_flag { @@ -267,6 +271,30 @@ where }) } +/// Create a table function. +pub fn make_table_function(inner: F) -> TableFunctionImplementation +where + F: Fn(&[ArrayRef]) -> Result<(ArrayRef, Vec)> + Sync + Send + 'static, +{ + Arc::new(move |args: &[ColumnarValue], num_rows| { + // to array + let args = args + .iter() + .map(|arg| arg.clone().into_array(num_rows)) + .collect::>(); + + let result = (inner)(&args); + + let to_return: (ArrayRef, Vec) = result + .iter() + .map(|(r, indexes)| (r.clone(), indexes.clone())) + .next() + .unwrap(); + + Ok(to_return) + }) +} + /// Create a physical scalar function. pub fn create_physical_fun( fun: &BuiltinScalarFunction, diff --git a/datafusion/core/src/physical_plan/mod.rs b/datafusion/core/src/physical_plan/mod.rs index b7b25a636efc..08a7547ec67b 100644 --- a/datafusion/core/src/physical_plan/mod.rs +++ b/datafusion/core/src/physical_plan/mod.rs @@ -568,9 +568,11 @@ pub mod repartition; pub mod sort_merge_join; pub mod sorts; pub mod stream; +pub mod table_fun; pub mod type_coercion; pub mod udaf; pub mod udf; +pub mod udtf; pub mod union; pub mod values; pub mod windows; diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 4cca768c7a34..b9d40f42c764 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -18,6 +18,7 @@ //! Physical query planner use super::analyze::AnalyzeExec; +use super::table_fun::TableFunExec; use super::{ aggregates, empty::EmptyExec, expressions::binary, functions, hash_join::PartitionMode, udaf, union::UnionExec, values::ValuesExec, windows, @@ -25,7 +26,7 @@ use super::{ use crate::execution::context::{ExecutionProps, SessionState}; use crate::logical_plan::plan::{ source_as_provider, Aggregate, EmptyRelation, Filter, Join, Projection, Sort, - SubqueryAlias, TableScan, Window, + SubqueryAlias, TableScan, TableUDFs, Window, }; use crate::logical_plan::{ unalias, unnormalize_cols, CrossJoin, DFSchema, Expr, LogicalPlan, Operator, @@ -48,6 +49,7 @@ use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::udf; +use crate::physical_plan::udtf; use crate::physical_plan::windows::WindowAggExec; use crate::physical_plan::{join_utils, Partitioning}; use crate::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, WindowExpr}; @@ -157,6 +159,9 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Expr::ScalarUDF { fun, args, .. } => { create_function_physical_name(&fun.name, false, args) } + Expr::TableUDF { fun, args, .. } => { + create_function_physical_name(&fun.name, false, args) + } Expr::WindowFunction { fun, args, .. } => { create_function_physical_name(&fun.to_string(), false, args) } @@ -584,6 +589,32 @@ impl DefaultPhysicalPlanner { physical_input_schema.clone(), )?) ) } + LogicalPlan::TableUDFs(TableUDFs { input, expr, .. }) => { + let input_exec = self.create_initial_plan(input, session_state).await?; + let input_schema = input.as_ref().schema(); + + let physical_exprs = expr + .iter() + .map(|e| { + let physical_name = physical_name(e); + + tuple_err(( + self.create_physical_expr( + e, + input_schema, + &input_exec.schema(), + session_state, + ), + physical_name, + )) + }) + .collect::>>()?; + + Ok(Arc::new(TableFunExec::try_new( + physical_exprs, + input_exec, + )?)) + } LogicalPlan::Projection(Projection { input, expr, .. }) => { let input_exec = self.create_initial_plan(input, session_state).await?; let input_schema = input.as_ref().schema(); @@ -634,7 +665,7 @@ impl DefaultPhysicalPlanner { Ok(Arc::new(ProjectionExec::try_new( physical_exprs, input_exec, - )?) ) + )?)) } LogicalPlan::Filter(Filter { input, predicate, .. @@ -1090,6 +1121,19 @@ pub fn create_physical_expr( udf::create_physical_expr(fun.clone().as_ref(), &physical_args, input_schema) } + Expr::TableUDF { fun, args } => { + let mut physical_args = vec![]; + for e in args { + physical_args.push(create_physical_expr( + e, + input_dfschema, + input_schema, + execution_props, + )?); + } + + udtf::create_physical_expr(fun.clone().as_ref(), &physical_args, input_schema) + } Expr::Between { expr, negated, diff --git a/datafusion/core/src/physical_plan/table_fun.rs b/datafusion/core/src/physical_plan/table_fun.rs new file mode 100644 index 000000000000..5f48c1b381eb --- /dev/null +++ b/datafusion/core/src/physical_plan/table_fun.rs @@ -0,0 +1,359 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! TableFun (UDTFs) Physical Node + +use std::any::Any; +use std::collections::{BTreeMap, HashMap}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::{ + ColumnStatistics, DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr, +}; +use arrow::array::{new_null_array, Array, ArrayRef, NullArray}; +use arrow::compute::kernels; +use arrow::datatypes::{Field, Schema, SchemaRef}; +use arrow::error::Result as ArrowResult; +use arrow::record_batch::RecordBatch; +use datafusion_physical_expr::TableFunctionExpr; + +use super::expressions::{Column, PhysicalSortExpr}; +use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; +use super::{RecordBatchStream, SendableRecordBatchStream, Statistics}; +use crate::execution::context::TaskContext; +use async_trait::async_trait; +use futures::stream::Stream; +use futures::stream::StreamExt; + +/// Execution plan for a table_fun +#[derive(Debug)] +pub struct TableFunExec { + // + // schema + // input + /// The expressions stored as tuples of (expression, output column name) + expr: Vec<(Arc, String)>, + /// The schema once the node has been applied to the input + schema: SchemaRef, + /// The input plan + input: Arc, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, +} + +impl TableFunExec { + /// Create a TableFun + pub fn try_new( + expr: Vec<(Arc, String)>, + input: Arc, + ) -> Result { + let input_schema = input.schema(); + + let fields: Result> = expr + .iter() + .map(|(e, name)| { + let mut field = Field::new( + name, + e.data_type(&input_schema)?, + e.nullable(&input_schema)?, + ); + field.set_metadata(get_field_metadata(e, &input_schema)); + + Ok(field) + }) + .collect(); + + let mut total_fields = input_schema.fields().clone(); + total_fields.append(&mut fields.unwrap()); + + let schema = Schema::new_with_metadata(total_fields, HashMap::new()); + + Ok(Self { + expr, + schema: Arc::new(schema), + input: input.clone(), + metrics: ExecutionPlanMetricsSet::new(), + }) + } + + /// The expressions stored as tuples of (expression, output column name) + pub fn expr(&self) -> &[(Arc, String)] { + &self.expr + } + + /// The input plan + pub fn input(&self) -> &Arc { + &self.input + } +} + +#[async_trait] +impl ExecutionPlan for TableFunExec { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + /// Get the schema for this execution plan + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn children(&self) -> Vec> { + vec![self.input.clone()] + } + + /// Get the output partitioning of this plan + fn output_partitioning(&self) -> Partitioning { + self.input.output_partitioning() + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + self.input.output_ordering() + } + + fn maintains_input_order(&self) -> bool { + // tell optimizer this operator doesn't reorder its input + true + } + + fn relies_on_input_order(&self) -> bool { + false + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + match children.len() { + 1 => Ok(Arc::new(TableFunExec::try_new( + self.expr.clone(), + children[0].clone(), + )?)), + _ => Err(DataFusionError::Internal( + "TableFunExec wrong number of children".to_string(), + )), + } + } + + async fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + Ok(Box::pin(TableFunStream { + schema: self.schema.clone(), + expr: self.expr.iter().map(|x| x.0.clone()).collect(), + input: self.input.execute(partition, context).await?, + baseline_metrics: BaselineMetrics::new(&self.metrics, partition), + })) + } + + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default => { + let expr: Vec = self + .expr + .iter() + .map(|(e, alias)| { + let e = e.to_string(); + if &e != alias { + format!("{} as {}", e, alias) + } else { + e + } + }) + .collect(); + + write!(f, "TableFunExec: expr=[{}]", expr.join(", ")) + } + } + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Statistics { + stats_table_fun( + self.input.statistics(), + self.expr.iter().map(|(e, _)| Arc::clone(e)), + ) + } +} + +/// If e is a direct column reference, returns the field level +/// metadata for that field, if any. Otherwise returns None +fn get_field_metadata( + e: &Arc, + input_schema: &Schema, +) -> Option> { + let name = if let Some(column) = e.as_any().downcast_ref::() { + column.name() + } else { + return None; + }; + + input_schema + .field_with_name(name) + .ok() + .and_then(|f| f.metadata().as_ref().cloned()) +} + +fn stats_table_fun( + stats: Statistics, + exprs: impl Iterator>, +) -> Statistics { + let column_statistics = stats.column_statistics.map(|input_col_stats| { + exprs + .map(|e| { + if let Some(col) = e.as_any().downcast_ref::() { + input_col_stats[col.index()].clone() + } else { + // TODO stats: estimate more statistics from expressions + // (expressions should compute their statistics themselves) + ColumnStatistics::default() + } + }) + .collect() + }); + + Statistics { + is_exact: stats.is_exact, + num_rows: stats.num_rows, + column_statistics, + // TODO stats: knowing the type of the new columns we can guess the output size + total_byte_size: None, + } +} + +impl TableFunStream { + fn batch(&self, batch: &RecordBatch) -> ArrowResult { + let arrays = (batch + .columns() + .iter() + // Filtering out placeholder column of EmptyRelation + .filter(|a| a.as_any().downcast_ref::().is_none()) + .map(|a| { + Ok(( + ( + a.clone(), + (0..a.len()).into_iter().map(|_| 1).collect::>(), + ), + false, + )) + })) + .chain(self.expr.iter().map(|expr| { + let t_expr = expr.as_any().downcast_ref::().unwrap(); + Ok((t_expr.evaluate_table(batch)?, true)) + })) + .collect::>>()?; + + // Detect count sizes of batch sections + let mut max_batch_sizes: Vec = vec![0; batch.num_rows()]; + for ((_, indexes), _) in arrays.iter() { + for (i, batch_size) in indexes.iter().enumerate() { + if max_batch_sizes[i] < *batch_size { + max_batch_sizes[i] = *batch_size; + } + } + } + + let mut columns: Vec = Vec::new(); + + for (col_i, ((col_arr, cur_batch_sizes), is_table_fun)) in + arrays.iter().enumerate() + { + let mut sections: Vec = Vec::new(); + + let mut start_i_of_batch = 0; + + for (i_batch, max_size) in max_batch_sizes.iter().enumerate() { + let current_batch_size = cur_batch_sizes[i_batch]; + sections.push(col_arr.slice(start_i_of_batch, current_batch_size)); + if *is_table_fun { + if max_size - current_batch_size > 0 { + sections.push(new_null_array( + self.schema.field(col_i).data_type(), + max_size - current_batch_size, + )); + } + } else if max_size - current_batch_size > 0 { + for _ in 0..(max_size - current_batch_size) { + sections + .push(col_arr.slice(start_i_of_batch, current_batch_size)); + } + } + + start_i_of_batch += cur_batch_sizes[i_batch]; + } + + columns.push(kernels::concat::concat( + sections + .iter() + .map(|a| a.as_ref()) + .collect::>() + .as_slice(), + )?); + } + + RecordBatch::try_new(self.schema.clone(), columns) + } +} + +/// TableFun iterator +struct TableFunStream { + schema: SchemaRef, + expr: Vec>, + input: SendableRecordBatchStream, + baseline_metrics: BaselineMetrics, +} + +impl Stream for TableFunStream { + type Item = ArrowResult; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let poll = self.input.poll_next_unpin(cx).map(|x| match x { + Some(Ok(batch)) => Some(self.batch(&batch)), + other => other, + }); + + self.baseline_metrics.record_poll(poll) + } + + fn size_hint(&self) -> (usize, Option) { + // same number of record batches + self.input.size_hint() + } +} + +impl RecordBatchStream for TableFunStream { + /// Get the schema + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} diff --git a/datafusion/core/src/physical_plan/udtf.rs b/datafusion/core/src/physical_plan/udtf.rs new file mode 100644 index 000000000000..b2b7b2305297 --- /dev/null +++ b/datafusion/core/src/physical_plan/udtf.rs @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! UDTF support + +use super::type_coercion::coerce; +use crate::error::Result; +use crate::physical_plan::functions::TableFunctionExpr; +use crate::physical_plan::PhysicalExpr; +use arrow::datatypes::Schema; + +pub use datafusion_expr::TableUDF; + +use std::sync::Arc; + +/// Create a physical expression of the UDTF. +/// This function errors when `args`' can't be coerced to a valid argument type of the UDTF. +pub fn create_physical_expr( + fun: &TableUDF, + input_phy_exprs: &[Arc], + input_schema: &Schema, +) -> Result> { + // coerce + let coerced_phy_exprs = coerce(input_phy_exprs, input_schema, &fun.signature)?; + + let coerced_exprs_types = coerced_phy_exprs + .iter() + .map(|e| e.data_type(input_schema)) + .collect::>>()?; + + Ok(Arc::new(TableFunctionExpr::new( + &fun.name, + fun.fun.clone(), + coerced_phy_exprs, + (fun.return_type)(&coerced_exprs_types)?.as_ref(), + ))) +} diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs index e693d7b209c8..3e5e855efcde 100644 --- a/datafusion/core/src/sql/planner.rs +++ b/datafusion/core/src/sql/planner.rs @@ -26,23 +26,25 @@ use std::{convert::TryInto, vec}; use crate::catalog::TableReference; use crate::datasource::TableProvider; use crate::logical_plan::window_frames::{WindowFrame, WindowFrameUnits}; +use crate::logical_plan::EmptyRelation; use crate::logical_plan::Expr::Alias; use crate::logical_plan::{ and, builder::expand_qualified_wildcard, builder::expand_wildcard, col, lit, - normalize_col, union_with_alias, Column, CreateCatalog, CreateCatalogSchema, - CreateExternalTable as PlanCreateExternalTable, CreateMemoryTable, DFSchema, - DFSchemaRef, DropTable, Expr, FileType, LogicalPlan, LogicalPlanBuilder, Operator, - PlanType, ToDFSchema, ToStringifiedPlan, + normalize_col, rewrite_udtfs_to_columns, union_with_alias, Column, CreateCatalog, + CreateCatalogSchema, CreateExternalTable as PlanCreateExternalTable, + CreateMemoryTable, DFSchema, DFSchemaRef, DropTable, Expr, FileType, LogicalPlan, + LogicalPlanBuilder, Operator, PlanType, ToDFSchema, ToStringifiedPlan, }; use crate::optimizer::utils::exprlist_to_columns; use crate::prelude::JoinType; use crate::scalar::ScalarValue; -use crate::sql::utils::{make_decimal_type, normalize_ident}; +use crate::sql::utils::{find_udtf_exprs, make_decimal_type, normalize_ident}; use crate::{ error::{DataFusionError, Result}, physical_plan::aggregates, physical_plan::udaf::AggregateUDF, physical_plan::udf::ScalarUDF, + physical_plan::udtf::TableUDF, sql::parser::{CreateExternalTable, Statement as DFStatement}, }; use arrow::datatypes::*; @@ -68,7 +70,7 @@ use super::{ resolve_aliases_to_exprs, resolve_positions_to_exprs, }, }; -use crate::logical_plan::builder::project_with_alias; +use crate::logical_plan::builder::{project_with_alias, table_udfs}; use crate::logical_plan::plan::{Analyze, Explain}; /// The ContextProvider trait allows the query planner to obtain meta-data about tables and @@ -78,6 +80,8 @@ pub trait ContextProvider { fn get_table_provider(&self, name: TableReference) -> Option>; /// Getter for a UDF description fn get_function_meta(&self, name: &str) -> Option>; + /// Getter for a UDTF description + fn get_table_function_meta(&self, name: &str) -> Option>; /// Getter for a UDAF description fn get_aggregate_meta(&self, name: &str) -> Option>; /// Getter for system/user-defined variable type @@ -666,7 +670,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { let (plan, alias) = match relation { TableFactor::Table { - ref name, alias, .. + ref name, + alias, + args, + .. } => { let table_name = name.to_string(); let cte = ctes.get(&table_name); @@ -685,10 +692,66 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; scan?.build() } - (None, None) => Err(DataFusionError::Plan(format!( - "Table or CTE with name '{}' not found", - name - ))), + (None, None) => { + let table_udf = + self.schema_provider.get_table_function_meta(&table_name); + if table_udf.is_some() { + let udtf = Expr::TableUDF { + fun: table_udf.unwrap(), + args: self + .function_args_to_expr(args, &DFSchema::empty()) + .unwrap(), + }; + + let udtf_plan = table_udfs( + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema: Arc::new(DFSchema::empty()), + }), + vec![udtf.clone()], + ) + .unwrap(); + + if alias.is_none() { + return Ok(udtf_plan); + } + + let mut select_exprs = rewrite_udtfs_to_columns( + vec![udtf], + udtf_plan.schema().clone().as_ref().to_owned(), + ); + + let alias = alias.unwrap(); + + if !alias.columns.is_empty() { + select_exprs = select_exprs + .iter() + .enumerate() + .map(|(i, e)| { + if alias.columns.len() > i { + Expr::Alias( + Box::new(e.clone()), + alias.columns[i].to_string(), + ) + } else { + e.clone() + } + }) + .collect(); + } + + return project_with_alias( + udtf_plan, + select_exprs, + Some(alias.name.value), + ); + } + + Err(DataFusionError::Plan(format!( + "Table or CTE with name '{}' not found", + name + ))) + } }?, alias, ) @@ -907,18 +970,30 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let empty_from = matches!(plans.first(), Some(LogicalPlan::EmptyRelation(_))); // process `where` clause - let plan = self.plan_selection(select.selection, plans, outer_query_schema)?; + let mut plan = + self.plan_selection(select.selection, plans, outer_query_schema)?; // process the SELECT expressions, with wildcards expanded. - let select_exprs = self.prepare_select_exprs( + let mut select_exprs = self.prepare_select_exprs( &plan, select.projection, empty_from, outer_query_schema, )?; + // create proxy node to handle udtfs and rewrite udtfs to columns (returning by TableUDFs Node) + let udtf_exprs = find_udtf_exprs(select_exprs.as_slice()); + if !udtf_exprs.is_empty() { + plan = table_udfs(plan, udtf_exprs).unwrap(); + select_exprs = rewrite_udtfs_to_columns( + select_exprs, + plan.schema().clone().as_ref().to_owned(), + ); + } + // having and group by clause may reference aliases defined in select projection let projected_plan = self.project(plan.clone(), select_exprs.clone())?; + let mut combined_schema = (**projected_plan.schema()).clone(); combined_schema.merge(plan.schema()); @@ -1870,10 +1945,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let args = self.function_args_to_expr(function.args, schema)?; Ok(Expr::AggregateUDF { fun: fm, args }) } - _ => Err(DataFusionError::Plan(format!( - "Invalid function '{}'", - name - ))), + None => match self.schema_provider.get_table_function_meta(&name) { + Some(fm) => { + let args = self.function_args_to_expr(function.args, schema)?; + Ok(Expr::TableUDF { fun: fm, args }) + } + _ => Err(DataFusionError::Plan(format!( + "Invalid function '{}'", + name + ))), + }, }, } } @@ -2371,7 +2452,9 @@ pub fn convert_data_type(sql_type: &SQLDataType) -> Result { SQLDataType::Float(_) => Ok(DataType::Float32), SQLDataType::Real => Ok(DataType::Float32), SQLDataType::Double => Ok(DataType::Float64), - SQLDataType::Char(_) | SQLDataType::Varchar(_) => Ok(DataType::Utf8), + SQLDataType::Char(_) | SQLDataType::Varchar(_) | SQLDataType::Text => { + Ok(DataType::Utf8) + } SQLDataType::Timestamp => Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)), SQLDataType::Date => Ok(DataType::Date32), SQLDataType::Decimal(precision, scale) => make_decimal_type(*precision, *scale), @@ -4186,6 +4269,10 @@ mod tests { } } + fn get_table_function_meta(&self, _name: &str) -> Option> { + unimplemented!() + } + fn get_aggregate_meta(&self, _name: &str) -> Option> { unimplemented!() } diff --git a/datafusion/core/src/sql/utils.rs b/datafusion/core/src/sql/utils.rs index 9ed7f0b66cc9..04521b0add86 100644 --- a/datafusion/core/src/sql/utils.rs +++ b/datafusion/core/src/sql/utils.rs @@ -63,6 +63,14 @@ pub(crate) fn find_column_exprs(exprs: &[Expr]) -> Vec { find_exprs_in_exprs(exprs, &|nested_expr| matches!(nested_expr, Expr::Column(_))) } +/// Collect all deeply nested `Expr::TableUDF`'s. They are returned in order of +/// appearance (depth first), with duplicates omitted. +pub(crate) fn find_udtf_exprs(exprs: &[Expr]) -> Vec { + find_exprs_in_exprs(exprs, &|nested_expr| { + matches!(nested_expr, Expr::TableUDF { .. }) + }) +} + /// Search the provided `Expr`'s, and all of their nested `Expr`, for any that /// pass the provided test. The returned `Expr`'s are deduplicated and returned /// in order of appearance (depth first). @@ -332,6 +340,13 @@ where .map(|arg| clone_with_replacement(arg, replacement_fn)) .collect::>>()?, }), + Expr::TableUDF { fun, args } => Ok(Expr::TableUDF { + fun: fun.clone(), + args: args + .iter() + .map(|arg| clone_with_replacement(arg, replacement_fn)) + .collect::>>()?, + }), Expr::Negative(nested_expr) => Ok(Expr::Negative(Box::new( clone_with_replacement(&**nested_expr, replacement_fn)?, ))), diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 4d88ed815b14..fc8c819e98e7 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -26,6 +26,7 @@ use crate::window_function; use crate::AggregateUDF; use crate::Operator; use crate::ScalarUDF; +use crate::TableUDF; use arrow::datatypes::DataType; use datafusion_common::Column; use datafusion_common::{DFSchema, Result}; @@ -189,6 +190,13 @@ pub enum Expr { /// List of expressions to feed to the functions as arguments args: Vec, }, + /// Represents the call of a user-defined table function with arguments. + TableUDF { + /// The function + fun: Arc, + /// List of expressions to feed to the functions as arguments + args: Vec, + }, /// Represents the call of an aggregate built-in function with arguments. AggregateFunction { /// Name of the function @@ -494,6 +502,9 @@ impl fmt::Debug for Expr { Expr::ScalarUDF { fun, ref args, .. } => { fmt_function(f, &fun.name, false, args, false) } + Expr::TableUDF { fun, ref args, .. } => { + fmt_function(f, &fun.name, false, args, false) + } Expr::WindowFunction { fun, args, @@ -673,6 +684,9 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { Expr::ScalarUDF { fun, args, .. } => { create_function_name(&fun.name, false, args, input_schema) } + Expr::TableUDF { fun, args, .. } => { + create_function_name(&fun.name, false, args, input_schema) + } Expr::WindowFunction { fun, args, diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index b932eefa0b96..925bec8557f1 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -70,6 +70,13 @@ impl ExprSchemable for Expr { .collect::>>()?; Ok((fun.return_type)(&data_types)?.as_ref().clone()) } + Expr::TableUDF { fun, args } => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + Ok((fun.return_type)(&data_types)?.as_ref().clone()) + } Expr::ScalarFunction { fun, args } => { let data_types = args .iter() @@ -174,6 +181,7 @@ impl ExprSchemable for Expr { | Expr::TryCast { .. } | Expr::ScalarFunction { .. } | Expr::ScalarUDF { .. } + | Expr::TableUDF { .. } | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. } => Ok(true), diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 93c5d0e12fce..c74ec8b3987b 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -24,6 +24,7 @@ use crate::{ array_expressions, conditional_expressions, Accumulator, BuiltinScalarFunction, Signature, TypeSignature, }; +use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field, TimeUnit}; use datafusion_common::{DataFusionError, Result}; use std::sync::Arc; @@ -39,6 +40,10 @@ use std::sync::Arc; pub type ScalarFunctionImplementation = Arc Result + Send + Sync>; +/// Table function. Second tuple +pub type TableFunctionImplementation = + Arc Result<(ArrayRef, Vec)> + Send + Sync>; + /// A function's return type pub type ReturnTypeFunction = Arc Result> + Send + Sync>; diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index b513bf52d4a1..7b95fbe2a173 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -36,6 +36,7 @@ mod table_source; pub mod type_coercion; mod udaf; mod udf; +mod udtf; pub mod window_frame; pub mod window_function; @@ -59,7 +60,7 @@ pub use expr_fn::{ pub use expr_schema::ExprSchemable; pub use function::{ AccumulatorFunctionImplementation, ReturnTypeFunction, ScalarFunctionImplementation, - StateTypeFunction, + StateTypeFunction, TableFunctionImplementation, }; pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral}; pub use logical_plan::{LogicalPlan, PlanVisitor}; @@ -69,5 +70,6 @@ pub use signature::{Signature, TypeSignature, Volatility}; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::AggregateUDF; pub use udf::ScalarUDF; +pub use udtf::TableUDF; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; pub use window_function::{BuiltInWindowFunction, WindowFunction}; diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index a37729f7dc5f..32ceba5b0aab 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -24,7 +24,7 @@ pub use plan::{ CreateMemoryTable, CrossJoin, DropTable, EmptyRelation, Explain, Extension, FileType, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, PlanVisitor, Projection, Repartition, Sort, StringifiedPlan, Subquery, SubqueryAlias, - TableScan, ToStringifiedPlan, Union, Values, Window, + TableScan, TableUDFs, ToStringifiedPlan, Union, Values, Window, }; pub use display::display_schema; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 579898dbe207..a0178a774690 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -93,6 +93,8 @@ pub enum LogicalPlan { /// Runs the actual plan, and then prints the physical plan with /// with execution metrics. Analyze(Analyze), + /// UDTF Node + TableUDFs(TableUDFs), /// Extension operator defined outside of DataFusion Extension(Extension), } @@ -132,6 +134,7 @@ impl LogicalPlan { } LogicalPlan::CreateCatalog(CreateCatalog { schema, .. }) => schema, LogicalPlan::DropTable(DropTable { schema, .. }) => schema, + LogicalPlan::TableUDFs(TableUDFs { schema, .. }) => schema, } } @@ -144,7 +147,8 @@ impl LogicalPlan { LogicalPlan::Values(Values { schema, .. }) => vec![schema], LogicalPlan::Window(Window { input, schema, .. }) | LogicalPlan::Projection(Projection { input, schema, .. }) - | LogicalPlan::Aggregate(Aggregate { input, schema, .. }) => { + | LogicalPlan::Aggregate(Aggregate { input, schema, .. }) + | LogicalPlan::TableUDFs(TableUDFs { input, schema, .. }) => { let mut schemas = input.all_schemas(); schemas.insert(0, schema); schemas @@ -227,6 +231,7 @@ impl LogicalPlan { .collect(), LogicalPlan::Sort(Sort { expr, .. }) => expr.clone(), LogicalPlan::Extension(extension) => extension.node.expressions(), + LogicalPlan::TableUDFs(TableUDFs { expr, .. }) => expr.clone(), // plans without expressions LogicalPlan::TableScan { .. } | LogicalPlan::EmptyRelation(_) @@ -269,6 +274,7 @@ impl LogicalPlan { LogicalPlan::CreateMemoryTable(CreateMemoryTable { input, .. }) => { vec![input] } + LogicalPlan::TableUDFs(TableUDFs { input, .. }) => vec![input], // plans without inputs LogicalPlan::TableScan { .. } | LogicalPlan::EmptyRelation { .. } @@ -418,6 +424,7 @@ impl LogicalPlan { } LogicalPlan::Explain(explain) => explain.plan.accept(visitor)?, LogicalPlan::Analyze(analyze) => analyze.input.accept(visitor)?, + LogicalPlan::TableUDFs(TableUDFs { input, .. }) => input.accept(visitor)?, // plans without inputs LogicalPlan::TableScan { .. } | LogicalPlan::EmptyRelation(_) @@ -702,6 +709,16 @@ impl LogicalPlan { } Ok(()) } + LogicalPlan::TableUDFs(TableUDFs { ref expr, .. }) => { + write!(f, "TableUDFs: ")?; + for (i, expr_item) in expr.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{:?}", expr_item)?; + } + Ok(()) + } LogicalPlan::Filter(Filter { predicate: ref expr, .. @@ -1154,6 +1171,17 @@ pub struct Join { pub null_equals_null: bool, } +/// Handle UDTFs +#[derive(Clone)] +pub struct TableUDFs { + /// The list of expressions + pub expr: Vec, + /// The incoming logical plan + pub input: Arc, + /// The schema description of the output + pub schema: DFSchemaRef, +} + /// Subquery #[derive(Clone)] pub struct Subquery { diff --git a/datafusion/expr/src/udtf.rs b/datafusion/expr/src/udtf.rs new file mode 100644 index 000000000000..dde5f6fb2a9f --- /dev/null +++ b/datafusion/expr/src/udtf.rs @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Udtf module contains foundational types that are used to represent UDTFs in DataFusion. + +use crate::{Expr, ReturnTypeFunction, Signature, TableFunctionImplementation}; +use std::fmt; +use std::fmt::Debug; +use std::fmt::Formatter; +use std::sync::Arc; + +/// Logical representation of a UDTF. +#[derive(Clone)] +pub struct TableUDF { + /// name + pub name: String, + /// signature + pub signature: Signature, + /// Return type + pub return_type: ReturnTypeFunction, + /// actual implementation + pub fun: TableFunctionImplementation, +} + +impl Debug for TableUDF { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("TableUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + +impl PartialEq for TableUDF { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.signature == other.signature + } +} + +impl std::hash::Hash for TableUDF { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.signature.hash(state); + } +} + +impl TableUDF { + /// Create a new TableUDF + pub fn new( + name: &str, + signature: &Signature, + return_type: &ReturnTypeFunction, + fun: &TableFunctionImplementation, + ) -> Self { + Self { + name: name.to_owned(), + signature: signature.clone(), + return_type: return_type.clone(), + fun: fun.clone(), + } + } + + /// creates a logical expression with a call of the UDTF + /// This utility allows using the UDTF without requiring access to the registry. + pub fn call(&self, args: Vec) -> Expr { + Expr::TableUDF { + fun: Arc::new(self.clone()), + args, + } + } +} diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 1350d49510d5..2204531c1139 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -30,13 +30,14 @@ //! to a function that supports f64, it is coerced to f64. use crate::PhysicalExpr; +use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::Result; -use datafusion_expr::BuiltinScalarFunction; +use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; pub use datafusion_expr::NullColumnarValue; use datafusion_expr::ScalarFunctionImplementation; +use datafusion_expr::{BuiltinScalarFunction, TableFunctionImplementation}; use std::any::Any; use std::fmt::Debug; use std::fmt::{self, Formatter}; @@ -146,3 +147,113 @@ impl PhysicalExpr for ScalarFunctionExpr { (fun)(&inputs) } } + +pub struct TableFunctionExpr { + fun: TableFunctionImplementation, + name: String, + args: Vec>, + return_type: DataType, +} + +impl Debug for TableFunctionExpr { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("TableFunctionExpr") + .field("fun", &"") + .field("name", &self.name) + .field("args", &self.args) + .field("return_type", &self.return_type) + .finish() + } +} + +impl TableFunctionExpr { + /// Create a new Table function + pub fn new( + name: &str, + fun: TableFunctionImplementation, + args: Vec>, + return_type: &DataType, + ) -> Self { + Self { + fun, + name: name.to_owned(), + args, + return_type: return_type.clone(), + } + } + + /// Get the table function implementation + pub fn fun(&self) -> &TableFunctionImplementation { + &self.fun + } + + /// The name for this expression + pub fn name(&self) -> &str { + &self.name + } + + /// Input arguments + pub fn args(&self) -> &[Arc] { + &self.args + } + + /// Data type produced by this expression + pub fn return_type(&self) -> &DataType { + &self.return_type + } + + pub fn evaluate_table(&self, batch: &RecordBatch) -> Result<(ArrayRef, Vec)> { + // evaluate the arguments, if there are no arguments we'll instead pass in a null array + // indicating the batch size (as a convention) + let inputs = match (self.args.len(), self.name.parse::()) { + (0, Ok(table_fun)) if table_fun.supports_zero_argument() => { + vec![NullColumnarValue::from(batch)] + } + _ => self + .args + .iter() + .map(|e| e.evaluate(batch)) + .collect::>>()?, + }; + + // evaluate the function + let fun = self.fun.as_ref(); + (fun)(&inputs, batch.num_rows()) + } +} + +impl fmt::Display for TableFunctionExpr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{}({})", + self.name, + self.args + .iter() + .map(|e| format!("{}", e)) + .collect::>() + .join(", ") + ) + } +} + +impl PhysicalExpr for TableFunctionExpr { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.return_type.clone()) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(true) + } + + fn evaluate(&self, _batch: &RecordBatch) -> Result { + Err(DataFusionError::NotImplemented( + "Use evaluate_table for table funs".to_string(), + )) + } +} diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 38b0da3b8d1e..01e13f90a61e 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -35,5 +35,6 @@ pub mod window; pub use aggregate::AggregateExpr; pub use functions::ScalarFunctionExpr; +pub use functions::TableFunctionExpr; pub use physical_expr::PhysicalExpr; pub use sort_expr::PhysicalSortExpr;