Skip to content

Commit ef4ca7e

Browse files
authored
Improve documentation for ExprVisitor, port simple uses to new walking function (#4916)
1 parent ba9fc12 commit ef4ca7e

File tree

4 files changed

+93
-126
lines changed

4 files changed

+93
-126
lines changed

datafusion/expr/src/expr_visitor.rs

Lines changed: 47 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,46 @@ pub enum Recursion<V: ExpressionVisitor> {
3333
Stop(V),
3434
}
3535

36-
/// Encode the traversal of an expression tree. When passed to
37-
/// `Expr::accept`, `ExpressionVisitor::visit` is invoked
38-
/// recursively on all nodes of an expression tree. See the comments
39-
/// on `Expr::accept` for details on its use
36+
/// Implements the [visitor
37+
/// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively walking [`Expr`]s.
38+
///
39+
/// [`ExpressionVisitor`] allows keeping the algorithms
40+
/// separate from the code to traverse the structure of the `Expr`
41+
/// tree and makes it easier to add new types of expressions and
42+
/// algorithms by.
43+
///
44+
/// When passed to[`Expr::accept`], [`ExpressionVisitor::pre_visit`]
45+
/// and [`ExpressionVisitor::post_visit`] are invoked recursively
46+
/// on all nodes of an expression tree.
47+
///
48+
///
49+
/// For an expression tree such as
50+
/// ```text
51+
/// BinaryExpr (GT)
52+
/// left: Column("foo")
53+
/// right: Column("bar")
54+
/// ```
55+
///
56+
/// The nodes are visited using the following order
57+
/// ```text
58+
/// pre_visit(BinaryExpr(GT))
59+
/// pre_visit(Column("foo"))
60+
/// post_visit(Column("foo"))
61+
/// pre_visit(Column("bar"))
62+
/// post_visit(Column("bar"))
63+
/// post_visit(BinaryExpr(GT))
64+
/// ```
65+
///
66+
/// If an [`Err`] result is returned, recursion is stopped
67+
/// immediately.
68+
///
69+
/// If [`Recursion::Stop`] is returned on a call to pre_visit, no
70+
/// children of that expression are visited, nor is post_visit
71+
/// called on that expression
72+
///
73+
/// # See Also:
74+
/// * [`Expr::accept`] to drive a visitor through an [`Expr`]
75+
/// * [inspect_expr_pre]: For visiting [`Expr`]s using functions
4076
pub trait ExpressionVisitor<E: ExprVisitable = Expr>: Sized {
4177
/// Invoked before any children of `expr` are visited.
4278
fn pre_visit(self, expr: &E) -> Result<Recursion<Self>>
@@ -58,37 +94,7 @@ pub trait ExprVisitable: Sized {
5894

5995
impl ExprVisitable for Expr {
6096
/// Performs a depth first walk of an expression and
61-
/// its children, calling [`ExpressionVisitor::pre_visit`] and
62-
/// `visitor.post_visit`.
63-
///
64-
/// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to
65-
/// separate expression algorithms from the structure of the
66-
/// `Expr` tree and make it easier to add new types of expressions
67-
/// and algorithms that walk the tree.
68-
///
69-
/// For an expression tree such as
70-
/// ```text
71-
/// BinaryExpr (GT)
72-
/// left: Column("foo")
73-
/// right: Column("bar")
74-
/// ```
75-
///
76-
/// The nodes are visited using the following order
77-
/// ```text
78-
/// pre_visit(BinaryExpr(GT))
79-
/// pre_visit(Column("foo"))
80-
/// post_visit(Column("foo"))
81-
/// pre_visit(Column("bar"))
82-
/// post_visit(Column("bar"))
83-
/// post_visit(BinaryExpr(GT))
84-
/// ```
85-
///
86-
/// If an Err result is returned, recursion is stopped immediately
87-
///
88-
/// If `Recursion::Stop` is returned on a call to pre_visit, no
89-
/// children of that expression are visited, nor is post_visit
90-
/// called on that expression
91-
///
97+
/// its children, see [`ExpressionVisitor`] for more details
9298
fn accept<V: ExpressionVisitor>(&self, visitor: V) -> Result<V> {
9399
let visitor = match visitor.pre_visit(self)? {
94100
Recursion::Continue(visitor) => visitor,
@@ -223,6 +229,7 @@ impl ExprVisitable for Expr {
223229

224230
struct VisitorAdapter<F, E> {
225231
f: F,
232+
// Store returned error as it my not be a DataFusionError
226233
err: std::result::Result<(), E>,
227234
}
228235

@@ -242,10 +249,12 @@ where
242249
}
243250
}
244251

245-
/// Conveniece function for using a mutable function as an expression visiitor
252+
/// Recursively inspect an [`Expr`] and all its childen.
246253
///
247-
/// TODO make this match names in physical plan
248-
pub fn walk_expr_down<F, E>(expr: &Expr, f: F) -> std::result::Result<(), E>
254+
/// Performs a pre-visit traversal by recursively calling `f(expr)` on
255+
/// `expr`, and then on all its children. See [`ExpressionVisitor`]
256+
/// for more details and more options to control the walk.
257+
pub fn inspect_expr_pre<F, E>(expr: &Expr, f: F) -> std::result::Result<(), E>
249258
where
250259
F: FnMut(&Expr) -> std::result::Result<(), E>,
251260
{

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
// under the License.
1717

1818
use crate::expr_rewriter::{ExprRewritable, ExprRewriter};
19-
use crate::expr_visitor::walk_expr_down;
19+
use crate::expr_visitor::inspect_expr_pre;
2020
///! Logical plan types
2121
use crate::logical_plan::builder::validate_unique_names;
2222
use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor};
@@ -580,7 +580,7 @@ impl LogicalPlan {
580580
{
581581
self.inspect_expressions(|expr| {
582582
// recursively look for subqueries
583-
walk_expr_down(expr, |expr| {
583+
inspect_expr_pre(expr, |expr| {
584584
match expr {
585585
Expr::Exists { subquery, .. }
586586
| Expr::InSubquery { subquery, .. }
@@ -1219,7 +1219,8 @@ pub struct DropView {
12191219
pub schema: DFSchemaRef,
12201220
}
12211221

1222-
/// Set a Variable's value -- value in [`ConfigOptions`]
1222+
/// Set a Variable's value -- value in
1223+
/// [`ConfigOptions`](datafusion_common::config::ConfigOptions)
12231224
#[derive(Clone)]
12241225
pub struct SetVariable {
12251226
/// The variable name

datafusion/expr/src/utils.rs

Lines changed: 29 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
2020
use crate::expr::{Sort, WindowFunction};
2121
use crate::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
22-
use crate::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion};
22+
use crate::expr_visitor::{
23+
inspect_expr_pre, ExprVisitable, ExpressionVisitor, Recursion,
24+
};
2325
use crate::logical_plan::builder::build_join_schema;
2426
use crate::logical_plan::{
2527
Aggregate, Analyze, CreateMemoryTable, CreateView, Distinct, Extension, Filter, Join,
@@ -83,20 +85,16 @@ pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<Expr>> {
8385
}
8486
}
8587

86-
/// Recursively walk an expression tree, collecting the unique set of column names
88+
/// Recursively walk an expression tree, collecting the unique set of columns
8789
/// referenced in the expression
88-
struct ColumnNameVisitor<'a> {
89-
accum: &'a mut HashSet<Column>,
90-
}
91-
92-
impl ExpressionVisitor for ColumnNameVisitor<'_> {
93-
fn pre_visit(self, expr: &Expr) -> Result<Recursion<Self>> {
90+
pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
91+
inspect_expr_pre(expr, |expr| {
9492
match expr {
9593
Expr::Column(qc) => {
96-
self.accum.insert(qc.clone());
94+
accum.insert(qc.clone());
9795
}
9896
Expr::ScalarVariable(_, var_names) => {
99-
self.accum.insert(Column::from_name(var_names.join(".")));
97+
accum.insert(Column::from_name(var_names.join(".")));
10098
}
10199
Expr::Alias(_, _)
102100
| Expr::Literal(_)
@@ -134,15 +132,8 @@ impl ExpressionVisitor for ColumnNameVisitor<'_> {
134132
| Expr::GetIndexedField { .. }
135133
| Expr::Placeholder { .. } => {}
136134
}
137-
Ok(Recursion::Continue(self))
138-
}
139-
}
140-
141-
/// Recursively walk an expression tree, collecting the unique set of columns
142-
/// referenced in the expression
143-
pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
144-
expr.accept(ColumnNameVisitor { accum })?;
145-
Ok(())
135+
Ok(())
136+
})
146137
}
147138

148139
/// Resolves an `Expr::Wildcard` to a collection of `Expr::Column`'s.
@@ -861,27 +852,17 @@ pub fn find_column_exprs(exprs: &[Expr]) -> Vec<Expr> {
861852
.collect()
862853
}
863854

864-
/// Recursively find all columns referenced by an expression
865-
#[derive(Debug, Default)]
866-
struct ColumnCollector {
867-
exprs: Vec<Column>,
868-
}
869-
870-
impl ExpressionVisitor for ColumnCollector {
871-
fn pre_visit(mut self, expr: &Expr) -> Result<Recursion<Self>> {
855+
pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
856+
let mut exprs = vec![];
857+
inspect_expr_pre(e, |expr| {
872858
if let Expr::Column(c) = expr {
873-
self.exprs.push(c.clone())
859+
exprs.push(c.clone())
874860
}
875-
Ok(Recursion::Continue(self))
876-
}
877-
}
878-
879-
pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
861+
Ok(()) as Result<()>
862+
})
880863
// As the `ExpressionVisitor` impl above always returns Ok, this
881864
// "can't" error
882-
let ColumnCollector { exprs } = e
883-
.accept(ColumnCollector::default())
884-
.expect("Unexpected error");
865+
.expect("Unexpected error");
885866
exprs
886867
}
887868

@@ -898,43 +879,26 @@ pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
898879

899880
/// Recursively walk an expression tree, collecting the column indexes
900881
/// referenced in the expression
901-
struct ColumnIndexesCollector<'a> {
902-
schema: &'a DFSchemaRef,
903-
indexes: Vec<usize>,
904-
}
905-
906-
impl ExpressionVisitor for ColumnIndexesCollector<'_> {
907-
fn pre_visit(mut self, expr: &Expr) -> Result<Recursion<Self>>
908-
where
909-
Self: ExpressionVisitor,
910-
{
882+
pub(crate) fn find_column_indexes_referenced_by_expr(
883+
e: &Expr,
884+
schema: &DFSchemaRef,
885+
) -> Vec<usize> {
886+
let mut indexes = vec![];
887+
inspect_expr_pre(e, |expr| {
911888
match expr {
912889
Expr::Column(qc) => {
913-
if let Ok(idx) = self.schema.index_of_column(qc) {
914-
self.indexes.push(idx);
890+
if let Ok(idx) = schema.index_of_column(qc) {
891+
indexes.push(idx);
915892
}
916893
}
917894
Expr::Literal(_) => {
918-
self.indexes.push(std::usize::MAX);
895+
indexes.push(std::usize::MAX);
919896
}
920897
_ => {}
921898
}
922-
Ok(Recursion::Continue(self))
923-
}
924-
}
925-
926-
pub(crate) fn find_column_indexes_referenced_by_expr(
927-
e: &Expr,
928-
schema: &DFSchemaRef,
929-
) -> Vec<usize> {
930-
// As the `ExpressionVisitor` impl above always returns Ok, this
931-
// "can't" error
932-
let ColumnIndexesCollector { indexes, .. } = e
933-
.accept(ColumnIndexesCollector {
934-
schema,
935-
indexes: vec![],
936-
})
937-
.expect("Unexpected error");
899+
Ok(()) as Result<()>
900+
})
901+
.unwrap();
938902
indexes
939903
}
940904

datafusion/optimizer/src/utils.rs

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use datafusion_common::Result;
2222
use datafusion_common::{plan_err, Column, DFSchemaRef};
2323
use datafusion_expr::expr::{BinaryExpr, Sort};
2424
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter};
25-
use datafusion_expr::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion};
25+
use datafusion_expr::expr_visitor::inspect_expr_pre;
2626
use datafusion_expr::{
2727
and, col,
2828
logical_plan::{Filter, LogicalPlan},
@@ -232,28 +232,21 @@ pub fn unalias(expr: Expr) -> Expr {
232232
///
233233
/// A PlanError if a disjunction is found
234234
pub fn verify_not_disjunction(predicates: &[&Expr]) -> Result<()> {
235-
struct DisjunctionVisitor {}
236-
237-
impl ExpressionVisitor for DisjunctionVisitor {
238-
fn pre_visit(self, expr: &Expr) -> Result<Recursion<Self>> {
239-
match expr {
240-
Expr::BinaryExpr(BinaryExpr {
241-
left: _,
242-
op: Operator::Or,
243-
right: _,
244-
}) => {
245-
plan_err!("Optimizing disjunctions not supported!")
246-
}
247-
_ => Ok(Recursion::Continue(self)),
235+
// recursively check for unallowed predicates in expr
236+
fn check(expr: &&Expr) -> Result<()> {
237+
inspect_expr_pre(expr, |expr| match expr {
238+
Expr::BinaryExpr(BinaryExpr {
239+
left: _,
240+
op: Operator::Or,
241+
right: _,
242+
}) => {
243+
plan_err!("Optimizing disjunctions not supported!")
248244
}
249-
}
250-
}
251-
252-
for predicate in predicates.iter() {
253-
predicate.accept(DisjunctionVisitor {})?;
245+
_ => Ok(()),
246+
})
254247
}
255248

256-
Ok(())
249+
predicates.iter().try_for_each(check)
257250
}
258251

259252
/// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with

0 commit comments

Comments
 (0)