Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 182 additions & 15 deletions src/rewrite/exploitation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,55 @@ use crate::materialized::cast_to_materialized;
use super::normal_form::SpjNormalForm;
use super::QueryRewriteOptions;

/// Logical rewrite metadata propagated alongside equivalent candidate plans.
#[derive(Debug, Clone, Default, PartialEq, PartialOrd, Eq, Hash)]
pub struct RewriteContext {
root_table_refs: Vec<String>,
}

impl RewriteContext {
/// Create a new rewrite context from the root table refs visible during rewrite.
pub fn new(root_table_refs: Vec<String>) -> Self {
Self { root_table_refs }
}

/// Returns the root table refs that produced this rewrite opportunity.
pub fn root_table_refs(&self) -> &[String] {
&self.root_table_refs
}
}

/// Inputs provided to a cost function when selecting the best candidate plan.
pub struct CostContext<'a> {
candidate_plans: Box<dyn Iterator<Item = &'a dyn ExecutionPlan> + 'a>,
rewrite_context: &'a RewriteContext,
}

impl<'a> CostContext<'a> {
/// Create a new cost context.
pub fn new(
candidate_plans: Box<dyn Iterator<Item = &'a dyn ExecutionPlan> + 'a>,
rewrite_context: &'a RewriteContext,
) -> Self {
Self {
candidate_plans,
rewrite_context,
}
}

/// Consume the context and return the candidate plans iterator.
pub fn into_candidate_plans(self) -> Box<dyn Iterator<Item = &'a dyn ExecutionPlan> + 'a> {
self.candidate_plans
}

/// Returns rewrite metadata for the current candidate set.
pub fn rewrite_context(&self) -> &RewriteContext {
self.rewrite_context
}
}

/// A cost function. Used to evaluate the best physical plan among multiple equivalent choices.
pub type CostFn = Arc<
dyn for<'a> Fn(Box<dyn Iterator<Item = &'a dyn ExecutionPlan> + 'a>) -> Vec<f64> + Send + Sync,
>;
pub type CostFn = Arc<dyn for<'a> Fn(CostContext<'a>) -> Vec<f64> + Send + Sync>;

/// A logical optimizer that generates candidate logical plans in the form of [`OneOf`] nodes.
#[derive(Debug)]
Expand Down Expand Up @@ -186,9 +231,10 @@ impl TreeNodeRewriter for ViewMatchingRewriter<'_> {
} else {
Ok(Transformed::new(
LogicalPlan::Extension(Extension {
node: Arc::new(OneOf {
branches: Some(node).into_iter().chain(candidates).collect_vec(),
}),
node: Arc::new(OneOf::with_rewrite_context(
Some(node).into_iter().chain(candidates).collect_vec(),
RewriteContext::new(vec![table_reference.to_string()]),
)),
}),
true,
TreeNodeRecursion::Jump,
Expand Down Expand Up @@ -241,9 +287,9 @@ impl ExtensionPlanner for ViewExploitationPlanner {
physical_inputs: &[Arc<dyn ExecutionPlan>],
_session_state: &SessionState,
) -> Result<Option<Arc<dyn ExecutionPlan>>> {
if node.as_any().downcast_ref::<OneOf>().is_none() {
let Some(one_of) = node.as_any().downcast_ref::<OneOf>() else {
return Ok(None);
}
};

// Compare schemas ignoring nullability differences.
// Different table types (FileScanTable, LiveTable, MV) may expose
Expand Down Expand Up @@ -282,6 +328,7 @@ impl ExtensionPlanner for ViewExploitationPlanner {
physical_inputs.to_vec(),
None,
Arc::clone(&self.cost),
one_of.rewrite_context().clone(),
)?)))
}
}
Expand All @@ -291,12 +338,29 @@ impl ExtensionPlanner for ViewExploitationPlanner {
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Hash)]
pub struct OneOf {
branches: Vec<LogicalPlan>,
rewrite_context: RewriteContext,
}

impl OneOf {
/// Create a new OneOf node with the given branches.
pub fn new(branches: Vec<LogicalPlan>) -> Self {
Self { branches }
Self::with_rewrite_context(branches, RewriteContext::default())
}

/// Create a new OneOf node with the given branches and rewrite context.
pub fn with_rewrite_context(
branches: Vec<LogicalPlan>,
rewrite_context: RewriteContext,
) -> Self {
Self {
branches,
rewrite_context,
}
}

/// Returns logical rewrite metadata for this candidate set.
pub fn rewrite_context(&self) -> &RewriteContext {
&self.rewrite_context
}
}

Expand Down Expand Up @@ -333,7 +397,10 @@ impl UserDefinedLogicalNodeCore for OneOf {
_exprs: Vec<datafusion::prelude::Expr>,
inputs: Vec<LogicalPlan>,
) -> Result<Self> {
Ok(Self { branches: inputs })
Ok(Self {
branches: inputs,
rewrite_context: self.rewrite_context.clone(),
})
}
}

Expand All @@ -349,6 +416,7 @@ pub struct OneOfExec {
best: usize,
// Cost function to use in optimization
cost: CostFn,
rewrite_context: RewriteContext,
}

impl std::fmt::Debug for OneOfExec {
Expand All @@ -357,6 +425,7 @@ impl std::fmt::Debug for OneOfExec {
.field("candidates", &self.candidates)
.field("required_input_ordering", &self.required_input_ordering)
.field("best", &self.best)
.field("rewrite_context", &self.rewrite_context)
.finish_non_exhaustive()
}
}
Expand All @@ -367,23 +436,28 @@ impl OneOfExec {
candidates: Vec<Arc<dyn ExecutionPlan>>,
required_input_ordering: Option<OrderingRequirements>,
cost: CostFn,
rewrite_context: RewriteContext,
) -> Result<Self> {
if candidates.is_empty() {
return Err(DataFusionError::Plan(
"can't create OneOfExec with empty children".to_string(),
));
}

let best = cost(Box::new(candidates.iter().map(|c| c.as_ref())))
.iter()
.position_min_by_key(|&cost| OrderedFloat(*cost))
.unwrap();
let best = cost(CostContext::new(
Box::new(candidates.iter().map(|c| c.as_ref())),
&rewrite_context,
))
.iter()
.position_min_by_key(|&cost| OrderedFloat(*cost))
.unwrap();

Ok(Self {
candidates,
required_input_ordering,
best,
cost,
rewrite_context,
})
}

Expand All @@ -393,6 +467,11 @@ impl OneOfExec {
Arc::clone(&self.candidates[self.best])
}

/// Returns rewrite metadata for this candidate set.
pub fn rewrite_context(&self) -> &RewriteContext {
&self.rewrite_context
}

/// Modify this plan's required input ordering.
/// Used for sort pushdown
pub fn with_required_input_ordering(self, requirement: Option<OrderingRequirements>) -> Self {
Expand Down Expand Up @@ -444,6 +523,7 @@ impl ExecutionPlan for OneOfExec {
children,
self.required_input_ordering.clone(),
Arc::clone(&self.cost),
self.rewrite_context.clone(),
)?))
}

Expand Down Expand Up @@ -473,7 +553,10 @@ impl ExecutionPlan for OneOfExec {

impl DisplayAs for OneOfExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let costs = (self.cost)(Box::new(self.children().iter().map(|arc| arc.as_ref())));
let costs = (self.cost)(CostContext::new(
Box::new(self.children().iter().map(|arc| arc.as_ref())),
&self.rewrite_context,
));
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(
Expand Down Expand Up @@ -595,3 +678,87 @@ mod tests_nullability {
assert!(!schemas_equal_ignoring_nullability(&a, &b));
}
}

#[cfg(test)]
mod tests_rewrite_context {
use super::*;
use arrow_schema::Schema;
use datafusion::physical_plan::empty::EmptyExec;
use datafusion_expr::LogicalPlanBuilder;
use std::sync::Mutex;

#[test]
fn one_of_preserves_rewrite_context_when_rebuilt() {
let plan = LogicalPlanBuilder::empty(false)
.build()
.expect("empty plan");
let one_of = OneOf::with_rewrite_context(
vec![plan.clone()],
RewriteContext::new(vec!["catalog.schema.root_table".to_string()]),
);

let rebuilt =
UserDefinedLogicalNodeCore::with_exprs_and_inputs(&one_of, vec![], vec![plan])
.expect("rebuild one_of");

assert_eq!(
rebuilt.rewrite_context().root_table_refs(),
["catalog.schema.root_table".to_string()]
);
}

#[test]
fn one_of_exec_passes_rewrite_context_to_cost_function() {
let seen = Arc::new(Mutex::new(Vec::<String>::new()));
let seen_clone = Arc::clone(&seen);
let cost: CostFn = Arc::new(move |ctx| {
*seen_clone.lock().expect("lock seen") =
ctx.rewrite_context().root_table_refs().to_vec();
ctx.into_candidate_plans().map(|_| 1.0).collect()
});
let context = RewriteContext::new(vec!["catalog.schema.root_table".to_string()]);
let schema = Arc::new(Schema::empty());
let candidates = vec![
Arc::new(EmptyExec::new(Arc::clone(&schema))) as Arc<dyn ExecutionPlan>,
Arc::new(EmptyExec::new(schema)) as Arc<dyn ExecutionPlan>,
];

let exec =
OneOfExec::try_new(candidates, None, cost, context.clone()).expect("one_of exec");

assert_eq!(exec.rewrite_context(), &context);
assert_eq!(*seen.lock().expect("lock seen"), context.root_table_refs());
}

#[test]
fn one_of_exec_with_new_children_preserves_rewrite_context() {
let cost: CostFn = Arc::new(|ctx| ctx.into_candidate_plans().map(|_| 1.0).collect());
let context = RewriteContext::new(vec!["catalog.schema.root_table".to_string()]);
let schema = Arc::new(Schema::empty());
let exec = Arc::new(
OneOfExec::try_new(
vec![
Arc::new(EmptyExec::new(Arc::clone(&schema))) as Arc<dyn ExecutionPlan>,
Arc::new(EmptyExec::new(Arc::clone(&schema))) as Arc<dyn ExecutionPlan>,
],
None,
cost,
context.clone(),
)
.expect("one_of exec"),
);

let rebuilt = exec
.with_new_children(vec![
Arc::new(EmptyExec::new(Arc::clone(&schema))) as Arc<dyn ExecutionPlan>,
Arc::new(EmptyExec::new(schema)) as Arc<dyn ExecutionPlan>,
])
.expect("rebuild exec");
let rebuilt = rebuilt
.as_any()
.downcast_ref::<OneOfExec>()
.expect("expected OneOfExec");

assert_eq!(rebuilt.rewrite_context(), &context);
}
}
Loading