diff --git a/datafusion/expr/src/logical_plan/extension.rs b/datafusion/expr/src/logical_plan/extension.rs index bb2c932ce391..b55256ca17de 100644 --- a/datafusion/expr/src/logical_plan/extension.rs +++ b/datafusion/expr/src/logical_plan/extension.rs @@ -98,6 +98,24 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync { inputs: &[LogicalPlan], ) -> Arc; + /// Returns the necessary input columns for this node required to compute + /// the columns in the output schema + /// + /// This is used for projection push-down when DataFusion has determined that + /// only a subset of the output columns of this node are needed by its parents. + /// This API is used to tell DataFusion which, if any, of the input columns are no longer + /// needed. + /// + /// Return `None`, the default, if this information can not be determined. + /// Returns `Some(_)` with the column indices for each child of this node that are + /// needed to compute `output_columns` + fn necessary_children_exprs( + &self, + _output_columns: &[usize], + ) -> Option>> { + None + } + /// Update the hash `state` with this node requirements from /// [`Hash`]. /// @@ -243,6 +261,24 @@ pub trait UserDefinedLogicalNodeCore: // but the doc comments have not been updated. #[allow(clippy::wrong_self_convention)] fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self; + + /// Returns the necessary input columns for this node required to compute + /// the columns in the output schema + /// + /// This is used for projection push-down when DataFusion has determined that + /// only a subset of the output columns of this node are needed by its parents. + /// This API is used to tell DataFusion which, if any, of the input columns are no longer + /// needed. + /// + /// Return `None`, the default, if this information can not be determined. + /// Returns `Some(_)` with the column indices for each child of this node that are + /// needed to compute `output_columns` + fn necessary_children_exprs( + &self, + _output_columns: &[usize], + ) -> Option>> { + None + } } /// Automatically derive UserDefinedLogicalNode to `UserDefinedLogicalNode` @@ -284,6 +320,13 @@ impl UserDefinedLogicalNode for T { Arc::new(self.from_template(exprs, inputs)) } + fn necessary_children_exprs( + &self, + output_columns: &[usize], + ) -> Option>> { + self.necessary_children_exprs(output_columns) + } + fn dyn_hash(&self, state: &mut dyn Hasher) { let mut s = state; self.hash(&mut s); diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index 08ee38f64abd..b942f187c331 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -31,7 +31,8 @@ use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::SchemaRef; use datafusion_common::{ - get_required_group_by_exprs_indices, Column, DFSchema, DFSchemaRef, JoinType, Result, + get_required_group_by_exprs_indices, internal_err, Column, DFSchema, DFSchemaRef, + JoinType, Result, }; use datafusion_expr::expr::{Alias, ScalarFunction}; use datafusion_expr::{ @@ -162,14 +163,40 @@ fn optimize_projections( .map(|input| ((0..input.schema().fields().len()).collect_vec(), false)) .collect::>() } + LogicalPlan::Extension(extension) => { + let necessary_children_indices = if let Some(necessary_children_indices) = + extension.node.necessary_children_exprs(indices) + { + necessary_children_indices + } else { + // Requirements from parent cannot be routed down to user defined logical plan safely + return Ok(None); + }; + let children = extension.node.inputs(); + if children.len() != necessary_children_indices.len() { + return internal_err!("Inconsistent length between children and necessary children indices. \ + Make sure `.necessary_children_exprs` implementation of the `UserDefinedLogicalNode` is \ + consistent with actual children length for the node."); + } + // Expressions used by node. + let exprs = plan.expressions(); + children + .into_iter() + .zip(necessary_children_indices) + .map(|(child, necessary_indices)| { + let child_schema = child.schema(); + let child_req_indices = + indices_referred_by_exprs(child_schema, exprs.iter())?; + Ok((merge_slices(&necessary_indices, &child_req_indices), false)) + }) + .collect::>>()? + } LogicalPlan::EmptyRelation(_) | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Statement(_) | LogicalPlan::Values(_) - | LogicalPlan::Extension(_) | LogicalPlan::DescribeTable(_) => { // These operators have no inputs, so stop the optimization process. - // TODO: Add support for `LogicalPlan::Extension`. return Ok(None); } LogicalPlan::Projection(proj) => { @@ -899,21 +926,161 @@ fn is_projection_unnecessary(input: &LogicalPlan, proj_exprs: &[Expr]) -> Result #[cfg(test)] mod tests { + use std::fmt::Formatter; use std::sync::Arc; use crate::optimize_projections::OptimizeProjections; - use crate::test::{assert_optimized_plan_eq, test_table_scan}; + use crate::test::{ + assert_optimized_plan_eq, test_table_scan, test_table_scan_with_name, + }; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{Result, TableReference}; + use datafusion_common::{Column, DFSchemaRef, JoinType, Result, TableReference}; use datafusion_expr::{ - binary_expr, col, count, lit, logical_plan::builder::LogicalPlanBuilder, not, - table_scan, try_cast, when, Expr, Like, LogicalPlan, Operator, + binary_expr, build_join_schema, col, count, lit, + logical_plan::builder::LogicalPlanBuilder, not, table_scan, try_cast, when, + BinaryExpr, Expr, Extension, Like, LogicalPlan, Operator, + UserDefinedLogicalNodeCore, }; fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) } + #[derive(Debug, Hash, PartialEq, Eq)] + struct NoOpUserDefined { + exprs: Vec, + schema: DFSchemaRef, + input: Arc, + } + + impl NoOpUserDefined { + fn new(schema: DFSchemaRef, input: Arc) -> Self { + Self { + exprs: vec![], + schema, + input, + } + } + + fn with_exprs(mut self, exprs: Vec) -> Self { + self.exprs = exprs; + self + } + } + + impl UserDefinedLogicalNodeCore for NoOpUserDefined { + fn name(&self) -> &str { + "NoOpUserDefined" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + self.exprs.clone() + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "NoOpUserDefined") + } + + fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { + Self { + exprs: exprs.to_vec(), + input: Arc::new(inputs[0].clone()), + schema: self.schema.clone(), + } + } + + fn necessary_children_exprs( + &self, + output_columns: &[usize], + ) -> Option>> { + // Since schema is same. Output columns requires their corresponding version in the input columns. + Some(vec![output_columns.to_vec()]) + } + } + + #[derive(Debug, Hash, PartialEq, Eq)] + struct UserDefinedCrossJoin { + exprs: Vec, + schema: DFSchemaRef, + left_child: Arc, + right_child: Arc, + } + + impl UserDefinedCrossJoin { + fn new(left_child: Arc, right_child: Arc) -> Self { + let left_schema = left_child.schema(); + let right_schema = right_child.schema(); + let schema = Arc::new( + build_join_schema(left_schema, right_schema, &JoinType::Inner).unwrap(), + ); + Self { + exprs: vec![], + schema, + left_child, + right_child, + } + } + } + + impl UserDefinedLogicalNodeCore for UserDefinedCrossJoin { + fn name(&self) -> &str { + "UserDefinedCrossJoin" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.left_child, &self.right_child] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + self.exprs.clone() + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "UserDefinedCrossJoin") + } + + fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { + assert_eq!(inputs.len(), 2); + Self { + exprs: exprs.to_vec(), + left_child: Arc::new(inputs[0].clone()), + right_child: Arc::new(inputs[1].clone()), + schema: self.schema.clone(), + } + } + + fn necessary_children_exprs( + &self, + output_columns: &[usize], + ) -> Option>> { + let left_child_len = self.left_child.schema().fields().len(); + let mut left_reqs = vec![]; + let mut right_reqs = vec![]; + for &out_idx in output_columns { + if out_idx < left_child_len { + left_reqs.push(out_idx); + } else { + // Output indices further than the left_child_len + // comes from right children + right_reqs.push(out_idx - left_child_len) + } + } + Some(vec![left_reqs, right_reqs]) + } + } + #[test] fn merge_two_projection() -> Result<()> { let table_scan = test_table_scan()?; @@ -1192,4 +1359,112 @@ mod tests { \n TableScan: test projection=[a]"; assert_optimized_plan_equal(&plan, expected) } + + // Since only column `a` is referred at the output. Scan should only contain projection=[a]. + // User defined node should be able to propagate necessary expressions by its parent to its child. + #[test] + fn test_user_defined_logical_plan_node() -> Result<()> { + let table_scan = test_table_scan()?; + let custom_plan = LogicalPlan::Extension(Extension { + node: Arc::new(NoOpUserDefined::new( + table_scan.schema().clone(), + Arc::new(table_scan.clone()), + )), + }); + let plan = LogicalPlanBuilder::from(custom_plan) + .project(vec![col("a"), lit(0).alias("d")])? + .build()?; + + let expected = "Projection: test.a, Int32(0) AS d\ + \n NoOpUserDefined\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + // Only column `a` is referred at the output. However, User defined node itself uses column `b` + // during its operation. Hence, scan should contain projection=[a, b]. + // User defined node should be able to propagate necessary expressions by its parent, as well as its own + // required expressions. + #[test] + fn test_user_defined_logical_plan_node2() -> Result<()> { + let table_scan = test_table_scan()?; + let exprs = vec![Expr::Column(Column::from_qualified_name("b"))]; + let custom_plan = LogicalPlan::Extension(Extension { + node: Arc::new( + NoOpUserDefined::new( + table_scan.schema().clone(), + Arc::new(table_scan.clone()), + ) + .with_exprs(exprs), + ), + }); + let plan = LogicalPlanBuilder::from(custom_plan) + .project(vec![col("a"), lit(0).alias("d")])? + .build()?; + + let expected = "Projection: test.a, Int32(0) AS d\ + \n NoOpUserDefined\ + \n TableScan: test projection=[a, b]"; + assert_optimized_plan_equal(&plan, expected) + } + + // Only column `a` is referred at the output. However, User defined node itself uses expression `b+c` + // during its operation. Hence, scan should contain projection=[a, b, c]. + // User defined node should be able to propagate necessary expressions by its parent, as well as its own + // required expressions. Expressions doesn't have to be just column. Requirements from complex expressions + // should be propagated also. + #[test] + fn test_user_defined_logical_plan_node3() -> Result<()> { + let table_scan = test_table_scan()?; + let left_expr = Expr::Column(Column::from_qualified_name("b")); + let right_expr = Expr::Column(Column::from_qualified_name("c")); + let binary_expr = Expr::BinaryExpr(BinaryExpr::new( + Box::new(left_expr), + Operator::Plus, + Box::new(right_expr), + )); + let exprs = vec![binary_expr]; + let custom_plan = LogicalPlan::Extension(Extension { + node: Arc::new( + NoOpUserDefined::new( + table_scan.schema().clone(), + Arc::new(table_scan.clone()), + ) + .with_exprs(exprs), + ), + }); + let plan = LogicalPlanBuilder::from(custom_plan) + .project(vec![col("a"), lit(0).alias("d")])? + .build()?; + + let expected = "Projection: test.a, Int32(0) AS d\ + \n NoOpUserDefined\ + \n TableScan: test projection=[a, b, c]"; + assert_optimized_plan_equal(&plan, expected) + } + + // Columns `l.a`, `l.c`, `r.a` is referred at the output. + // User defined node should be able to propagate necessary expressions by its parent, to its children. + // Even if it has multiple children. + // left child should have `projection=[a, c]`, and right side should have `projection=[a]`. + #[test] + fn test_user_defined_logical_plan_node4() -> Result<()> { + let left_table = test_table_scan_with_name("l")?; + let right_table = test_table_scan_with_name("r")?; + let custom_plan = LogicalPlan::Extension(Extension { + node: Arc::new(UserDefinedCrossJoin::new( + Arc::new(left_table.clone()), + Arc::new(right_table.clone()), + )), + }); + let plan = LogicalPlanBuilder::from(custom_plan) + .project(vec![col("l.a"), col("l.c"), col("r.a"), lit(0).alias("d")])? + .build()?; + + let expected = "Projection: l.a, l.c, r.a, Int32(0) AS d\ + \n UserDefinedCrossJoin\ + \n TableScan: l projection=[a, c]\ + \n TableScan: r projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } }