diff --git a/src/rewrite/exploitation.rs b/src/rewrite/exploitation.rs index 159d84b..7310d93 100644 --- a/src/rewrite/exploitation.rs +++ b/src/rewrite/exploitation.rs @@ -41,10 +41,55 @@ use crate::materialized::cast_to_materialized; use super::normal_form::SpjNormalForm; use super::QueryRewriteOptions; +/// Logical rewrite metadata propagated alongside equivalent candidate plans. +#[derive(Debug, Clone, Default, PartialEq, PartialOrd, Eq, Hash)] +pub struct RewriteContext { + root_table_refs: Vec, +} + +impl RewriteContext { + /// Create a new rewrite context from the root table refs visible during rewrite. + pub fn new(root_table_refs: Vec) -> Self { + Self { root_table_refs } + } + + /// Returns the root table refs that produced this rewrite opportunity. + pub fn root_table_refs(&self) -> &[String] { + &self.root_table_refs + } +} + +/// Inputs provided to a cost function when selecting the best candidate plan. +pub struct CostContext<'a> { + candidate_plans: Box + 'a>, + rewrite_context: &'a RewriteContext, +} + +impl<'a> CostContext<'a> { + /// Create a new cost context. + pub fn new( + candidate_plans: Box + 'a>, + rewrite_context: &'a RewriteContext, + ) -> Self { + Self { + candidate_plans, + rewrite_context, + } + } + + /// Consume the context and return the candidate plans iterator. + pub fn into_candidate_plans(self) -> Box + 'a> { + self.candidate_plans + } + + /// Returns rewrite metadata for the current candidate set. + pub fn rewrite_context(&self) -> &RewriteContext { + self.rewrite_context + } +} + /// A cost function. Used to evaluate the best physical plan among multiple equivalent choices. -pub type CostFn = Arc< - dyn for<'a> Fn(Box + 'a>) -> Vec + Send + Sync, ->; +pub type CostFn = Arc Fn(CostContext<'a>) -> Vec + Send + Sync>; /// A logical optimizer that generates candidate logical plans in the form of [`OneOf`] nodes. #[derive(Debug)] @@ -186,9 +231,10 @@ impl TreeNodeRewriter for ViewMatchingRewriter<'_> { } else { Ok(Transformed::new( LogicalPlan::Extension(Extension { - node: Arc::new(OneOf { - branches: Some(node).into_iter().chain(candidates).collect_vec(), - }), + node: Arc::new(OneOf::with_rewrite_context( + Some(node).into_iter().chain(candidates).collect_vec(), + RewriteContext::new(vec![table_reference.to_string()]), + )), }), true, TreeNodeRecursion::Jump, @@ -241,9 +287,9 @@ impl ExtensionPlanner for ViewExploitationPlanner { physical_inputs: &[Arc], _session_state: &SessionState, ) -> Result>> { - if node.as_any().downcast_ref::().is_none() { + let Some(one_of) = node.as_any().downcast_ref::() else { return Ok(None); - } + }; // Compare schemas ignoring nullability differences. // Different table types (FileScanTable, LiveTable, MV) may expose @@ -282,6 +328,7 @@ impl ExtensionPlanner for ViewExploitationPlanner { physical_inputs.to_vec(), None, Arc::clone(&self.cost), + one_of.rewrite_context().clone(), )?))) } } @@ -291,12 +338,29 @@ impl ExtensionPlanner for ViewExploitationPlanner { #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Hash)] pub struct OneOf { branches: Vec, + rewrite_context: RewriteContext, } impl OneOf { /// Create a new OneOf node with the given branches. pub fn new(branches: Vec) -> Self { - Self { branches } + Self::with_rewrite_context(branches, RewriteContext::default()) + } + + /// Create a new OneOf node with the given branches and rewrite context. + pub fn with_rewrite_context( + branches: Vec, + rewrite_context: RewriteContext, + ) -> Self { + Self { + branches, + rewrite_context, + } + } + + /// Returns logical rewrite metadata for this candidate set. + pub fn rewrite_context(&self) -> &RewriteContext { + &self.rewrite_context } } @@ -333,7 +397,10 @@ impl UserDefinedLogicalNodeCore for OneOf { _exprs: Vec, inputs: Vec, ) -> Result { - Ok(Self { branches: inputs }) + Ok(Self { + branches: inputs, + rewrite_context: self.rewrite_context.clone(), + }) } } @@ -349,6 +416,7 @@ pub struct OneOfExec { best: usize, // Cost function to use in optimization cost: CostFn, + rewrite_context: RewriteContext, } impl std::fmt::Debug for OneOfExec { @@ -357,6 +425,7 @@ impl std::fmt::Debug for OneOfExec { .field("candidates", &self.candidates) .field("required_input_ordering", &self.required_input_ordering) .field("best", &self.best) + .field("rewrite_context", &self.rewrite_context) .finish_non_exhaustive() } } @@ -367,6 +436,7 @@ impl OneOfExec { candidates: Vec>, required_input_ordering: Option, cost: CostFn, + rewrite_context: RewriteContext, ) -> Result { if candidates.is_empty() { return Err(DataFusionError::Plan( @@ -374,16 +444,20 @@ impl OneOfExec { )); } - let best = cost(Box::new(candidates.iter().map(|c| c.as_ref()))) - .iter() - .position_min_by_key(|&cost| OrderedFloat(*cost)) - .unwrap(); + let best = cost(CostContext::new( + Box::new(candidates.iter().map(|c| c.as_ref())), + &rewrite_context, + )) + .iter() + .position_min_by_key(|&cost| OrderedFloat(*cost)) + .unwrap(); Ok(Self { candidates, required_input_ordering, best, cost, + rewrite_context, }) } @@ -393,6 +467,11 @@ impl OneOfExec { Arc::clone(&self.candidates[self.best]) } + /// Returns rewrite metadata for this candidate set. + pub fn rewrite_context(&self) -> &RewriteContext { + &self.rewrite_context + } + /// Modify this plan's required input ordering. /// Used for sort pushdown pub fn with_required_input_ordering(self, requirement: Option) -> Self { @@ -444,6 +523,7 @@ impl ExecutionPlan for OneOfExec { children, self.required_input_ordering.clone(), Arc::clone(&self.cost), + self.rewrite_context.clone(), )?)) } @@ -473,7 +553,10 @@ impl ExecutionPlan for OneOfExec { impl DisplayAs for OneOfExec { fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { - let costs = (self.cost)(Box::new(self.children().iter().map(|arc| arc.as_ref()))); + let costs = (self.cost)(CostContext::new( + Box::new(self.children().iter().map(|arc| arc.as_ref())), + &self.rewrite_context, + )); match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { write!( @@ -595,3 +678,87 @@ mod tests_nullability { assert!(!schemas_equal_ignoring_nullability(&a, &b)); } } + +#[cfg(test)] +mod tests_rewrite_context { + use super::*; + use arrow_schema::Schema; + use datafusion::physical_plan::empty::EmptyExec; + use datafusion_expr::LogicalPlanBuilder; + use std::sync::Mutex; + + #[test] + fn one_of_preserves_rewrite_context_when_rebuilt() { + let plan = LogicalPlanBuilder::empty(false) + .build() + .expect("empty plan"); + let one_of = OneOf::with_rewrite_context( + vec![plan.clone()], + RewriteContext::new(vec!["catalog.schema.root_table".to_string()]), + ); + + let rebuilt = + UserDefinedLogicalNodeCore::with_exprs_and_inputs(&one_of, vec![], vec![plan]) + .expect("rebuild one_of"); + + assert_eq!( + rebuilt.rewrite_context().root_table_refs(), + ["catalog.schema.root_table".to_string()] + ); + } + + #[test] + fn one_of_exec_passes_rewrite_context_to_cost_function() { + let seen = Arc::new(Mutex::new(Vec::::new())); + let seen_clone = Arc::clone(&seen); + let cost: CostFn = Arc::new(move |ctx| { + *seen_clone.lock().expect("lock seen") = + ctx.rewrite_context().root_table_refs().to_vec(); + ctx.into_candidate_plans().map(|_| 1.0).collect() + }); + let context = RewriteContext::new(vec!["catalog.schema.root_table".to_string()]); + let schema = Arc::new(Schema::empty()); + let candidates = vec![ + Arc::new(EmptyExec::new(Arc::clone(&schema))) as Arc, + Arc::new(EmptyExec::new(schema)) as Arc, + ]; + + let exec = + OneOfExec::try_new(candidates, None, cost, context.clone()).expect("one_of exec"); + + assert_eq!(exec.rewrite_context(), &context); + assert_eq!(*seen.lock().expect("lock seen"), context.root_table_refs()); + } + + #[test] + fn one_of_exec_with_new_children_preserves_rewrite_context() { + let cost: CostFn = Arc::new(|ctx| ctx.into_candidate_plans().map(|_| 1.0).collect()); + let context = RewriteContext::new(vec!["catalog.schema.root_table".to_string()]); + let schema = Arc::new(Schema::empty()); + let exec = Arc::new( + OneOfExec::try_new( + vec![ + Arc::new(EmptyExec::new(Arc::clone(&schema))) as Arc, + Arc::new(EmptyExec::new(Arc::clone(&schema))) as Arc, + ], + None, + cost, + context.clone(), + ) + .expect("one_of exec"), + ); + + let rebuilt = exec + .with_new_children(vec![ + Arc::new(EmptyExec::new(Arc::clone(&schema))) as Arc, + Arc::new(EmptyExec::new(schema)) as Arc, + ]) + .expect("rebuild exec"); + let rebuilt = rebuilt + .as_any() + .downcast_ref::() + .expect("expected OneOfExec"); + + assert_eq!(rebuilt.rewrite_context(), &context); + } +}