Skip to content

Commit 07fe23f

Browse files
jayzhan211skyzh
andauthored
support simple/cross lateral joins (#16015)
* support simple lateral joins Signed-off-by: Alex Chi Z <[email protected]> * fix explain test Signed-off-by: Alex Chi Z <[email protected]> * plan scalar agg correctly Signed-off-by: Alex Chi Z <[email protected]> * add uncorrelated query tests Signed-off-by: Alex Chi Z <[email protected]> * fix clippy + fmt Signed-off-by: Alex Chi Z <[email protected]> * make rule matching faster Signed-off-by: Alex Chi Z <[email protected]> * revert build_join visibility Signed-off-by: Alex Chi Z <[email protected]> * revert find plan outer column changes Signed-off-by: Alex Chi Z <[email protected]> * remove clone * address comment --------- Signed-off-by: Alex Chi Z <[email protected]> Co-authored-by: Alex Chi Z <[email protected]>
1 parent c74faee commit 07fe23f

File tree

6 files changed

+256
-0
lines changed

6 files changed

+256
-0
lines changed

datafusion/optimizer/src/decorrelate.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ pub struct PullUpCorrelatedExpr {
7171
pub collected_count_expr_map: HashMap<LogicalPlan, ExprResultMap>,
7272
/// pull up having expr, which must be evaluated after the Join
7373
pub pull_up_having_expr: Option<Expr>,
74+
/// whether we have converted a scalar aggregation into a group aggregation. When unnesting
75+
/// lateral joins, we need to produce a left outer join in such cases.
76+
pub pulled_up_scalar_agg: bool,
7477
}
7578

7679
impl Default for PullUpCorrelatedExpr {
@@ -91,6 +94,7 @@ impl PullUpCorrelatedExpr {
9194
need_handle_count_bug: false,
9295
collected_count_expr_map: HashMap::new(),
9396
pull_up_having_expr: None,
97+
pulled_up_scalar_agg: false,
9498
}
9599
}
96100

@@ -313,6 +317,11 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr {
313317
missing_exprs.push(un_matched_row);
314318
}
315319
}
320+
if aggregate.group_expr.is_empty() {
321+
// TODO: how do we handle the case where we have pulled multiple aggregations? For example,
322+
// a group agg with a scalar agg as child.
323+
self.pulled_up_scalar_agg = true;
324+
}
316325
let new_plan = LogicalPlanBuilder::from((*aggregate.input).clone())
317326
.aggregate(missing_exprs, aggregate.aggr_expr.to_vec())?
318327
.build()?;
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! [`DecorrelateLateralJoin`] decorrelates logical plans produced by lateral joins.
19+
20+
use std::collections::BTreeSet;
21+
22+
use crate::decorrelate::PullUpCorrelatedExpr;
23+
use crate::optimizer::ApplyOrder;
24+
use crate::{OptimizerConfig, OptimizerRule};
25+
use datafusion_expr::{lit, Join};
26+
27+
use datafusion_common::tree_node::{
28+
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
29+
};
30+
use datafusion_common::Result;
31+
use datafusion_expr::logical_plan::JoinType;
32+
use datafusion_expr::utils::conjunction;
33+
use datafusion_expr::{LogicalPlan, LogicalPlanBuilder};
34+
35+
/// Optimizer rule for rewriting lateral joins to joins
36+
#[derive(Default, Debug)]
37+
pub struct DecorrelateLateralJoin {}
38+
39+
impl DecorrelateLateralJoin {
40+
#[allow(missing_docs)]
41+
pub fn new() -> Self {
42+
Self::default()
43+
}
44+
}
45+
46+
impl OptimizerRule for DecorrelateLateralJoin {
47+
fn supports_rewrite(&self) -> bool {
48+
true
49+
}
50+
51+
fn rewrite(
52+
&self,
53+
plan: LogicalPlan,
54+
_config: &dyn OptimizerConfig,
55+
) -> Result<Transformed<LogicalPlan>> {
56+
// Find cross joins with outer column references on the right side (i.e., the apply operator).
57+
let LogicalPlan::Join(join) = plan else {
58+
return Ok(Transformed::no(plan));
59+
};
60+
61+
rewrite_internal(join)
62+
}
63+
64+
fn name(&self) -> &str {
65+
"decorrelate_lateral_join"
66+
}
67+
68+
fn apply_order(&self) -> Option<ApplyOrder> {
69+
Some(ApplyOrder::TopDown)
70+
}
71+
}
72+
73+
// Build the decorrelated join based on the original lateral join query. For now, we only support cross/inner
74+
// lateral joins.
75+
fn rewrite_internal(join: Join) -> Result<Transformed<LogicalPlan>> {
76+
if join.join_type != JoinType::Inner {
77+
return Ok(Transformed::no(LogicalPlan::Join(join)));
78+
}
79+
80+
match join.right.apply_with_subqueries(|p| {
81+
// TODO: support outer joins
82+
if p.contains_outer_reference() {
83+
Ok(TreeNodeRecursion::Stop)
84+
} else {
85+
Ok(TreeNodeRecursion::Continue)
86+
}
87+
})? {
88+
TreeNodeRecursion::Stop => {}
89+
TreeNodeRecursion::Continue => {
90+
// The left side contains outer references, we need to decorrelate it.
91+
return Ok(Transformed::new(
92+
LogicalPlan::Join(join),
93+
false,
94+
TreeNodeRecursion::Jump,
95+
));
96+
}
97+
TreeNodeRecursion::Jump => {
98+
unreachable!("")
99+
}
100+
}
101+
102+
let LogicalPlan::Subquery(subquery) = join.right.as_ref() else {
103+
return Ok(Transformed::no(LogicalPlan::Join(join)));
104+
};
105+
106+
if join.join_type != JoinType::Inner {
107+
return Ok(Transformed::no(LogicalPlan::Join(join)));
108+
}
109+
let subquery_plan = subquery.subquery.as_ref();
110+
let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true);
111+
let rewritten_subquery = subquery_plan.clone().rewrite(&mut pull_up).data()?;
112+
if !pull_up.can_pull_up {
113+
return Ok(Transformed::no(LogicalPlan::Join(join)));
114+
}
115+
116+
let mut all_correlated_cols = BTreeSet::new();
117+
pull_up
118+
.correlated_subquery_cols_map
119+
.values()
120+
.for_each(|cols| all_correlated_cols.extend(cols.clone()));
121+
let join_filter_opt = conjunction(pull_up.join_filters);
122+
let join_filter = match join_filter_opt {
123+
Some(join_filter) => join_filter,
124+
None => lit(true),
125+
};
126+
// -- inner join but the right side always has one row, we need to rewrite it to a left join
127+
// SELECT * FROM t0, LATERAL (SELECT sum(v1) FROM t1 WHERE t0.v0 = t1.v0);
128+
// -- inner join but the right side number of rows is related to the filter (join) condition, so keep inner join.
129+
// SELECT * FROM t0, LATERAL (SELECT * FROM t1 WHERE t0.v0 = t1.v0);
130+
let new_plan = LogicalPlanBuilder::from(join.left)
131+
.join_on(
132+
rewritten_subquery,
133+
if pull_up.pulled_up_scalar_agg {
134+
JoinType::Left
135+
} else {
136+
JoinType::Inner
137+
},
138+
Some(join_filter),
139+
)?
140+
.build()?;
141+
// TODO: handle count(*) bug
142+
Ok(Transformed::new(new_plan, true, TreeNodeRecursion::Jump))
143+
}

datafusion/optimizer/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
pub mod analyzer;
4141
pub mod common_subexpr_eliminate;
4242
pub mod decorrelate;
43+
pub mod decorrelate_lateral_join;
4344
pub mod decorrelate_predicate_subquery;
4445
pub mod eliminate_cross_join;
4546
pub mod eliminate_duplicated_expr;

datafusion/optimizer/src/optimizer.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use datafusion_common::{internal_err, DFSchema, DataFusionError, HashSet, Result
3333
use datafusion_expr::logical_plan::LogicalPlan;
3434

3535
use crate::common_subexpr_eliminate::CommonSubexprEliminate;
36+
use crate::decorrelate_lateral_join::DecorrelateLateralJoin;
3637
use crate::decorrelate_predicate_subquery::DecorrelatePredicateSubquery;
3738
use crate::eliminate_cross_join::EliminateCrossJoin;
3839
use crate::eliminate_duplicated_expr::EliminateDuplicatedExpr;
@@ -226,6 +227,7 @@ impl Optimizer {
226227
Arc::new(EliminateJoin::new()),
227228
Arc::new(DecorrelatePredicateSubquery::new()),
228229
Arc::new(ScalarSubqueryToJoin::new()),
230+
Arc::new(DecorrelateLateralJoin::new()),
229231
Arc::new(ExtractEquijoinPredicate::new()),
230232
Arc::new(EliminateDuplicatedExpr::new()),
231233
Arc::new(EliminateFilter::new()),

datafusion/sqllogictest/test_files/explain.slt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE
183183
logical_plan after eliminate_join SAME TEXT AS ABOVE
184184
logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE
185185
logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE
186+
logical_plan after decorrelate_lateral_join SAME TEXT AS ABOVE
186187
logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE
187188
logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE
188189
logical_plan after eliminate_filter SAME TEXT AS ABOVE
@@ -204,6 +205,7 @@ logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE
204205
logical_plan after eliminate_join SAME TEXT AS ABOVE
205206
logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE
206207
logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE
208+
logical_plan after decorrelate_lateral_join SAME TEXT AS ABOVE
207209
logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE
208210
logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE
209211
logical_plan after eliminate_filter SAME TEXT AS ABOVE

datafusion/sqllogictest/test_files/join.slt.part

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,3 +1404,102 @@ set datafusion.execution.target_partitions = 4;
14041404

14051405
statement ok
14061406
set datafusion.optimizer.repartition_joins = false;
1407+
1408+
statement ok
1409+
CREATE TABLE t1(v0 BIGINT, v1 BIGINT);
1410+
1411+
statement ok
1412+
CREATE TABLE t0(v0 BIGINT, v1 BIGINT);
1413+
1414+
statement ok
1415+
INSERT INTO t0(v0, v1) VALUES (1, 1), (1, 2), (3, 3), (4, 4);
1416+
1417+
statement ok
1418+
INSERT INTO t1(v0, v1) VALUES (1, 1), (3, 2), (3, 5);
1419+
1420+
query TT
1421+
explain SELECT *
1422+
FROM t0,
1423+
LATERAL (SELECT sum(v1) FROM t1 WHERE t0.v0 = t1.v0);
1424+
----
1425+
logical_plan
1426+
01)Projection: t0.v0, t0.v1, sum(t1.v1)
1427+
02)--Left Join: t0.v0 = t1.v0
1428+
03)----TableScan: t0 projection=[v0, v1]
1429+
04)----Projection: sum(t1.v1), t1.v0
1430+
05)------Aggregate: groupBy=[[t1.v0]], aggr=[[sum(t1.v1)]]
1431+
06)--------TableScan: t1 projection=[v0, v1]
1432+
physical_plan
1433+
01)ProjectionExec: expr=[v0@1 as v0, v1@2 as v1, sum(t1.v1)@0 as sum(t1.v1)]
1434+
02)--CoalesceBatchesExec: target_batch_size=8192
1435+
03)----HashJoinExec: mode=CollectLeft, join_type=Right, on=[(v0@1, v0@0)], projection=[sum(t1.v1)@0, v0@2, v1@3]
1436+
04)------CoalescePartitionsExec
1437+
05)--------ProjectionExec: expr=[sum(t1.v1)@1 as sum(t1.v1), v0@0 as v0]
1438+
06)----------AggregateExec: mode=FinalPartitioned, gby=[v0@0 as v0], aggr=[sum(t1.v1)]
1439+
07)------------CoalesceBatchesExec: target_batch_size=8192
1440+
08)--------------RepartitionExec: partitioning=Hash([v0@0], 4), input_partitions=4
1441+
09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
1442+
10)------------------AggregateExec: mode=Partial, gby=[v0@0 as v0], aggr=[sum(t1.v1)]
1443+
11)--------------------DataSourceExec: partitions=1, partition_sizes=[1]
1444+
12)------DataSourceExec: partitions=1, partition_sizes=[1]
1445+
1446+
query III
1447+
SELECT *
1448+
FROM t0,
1449+
LATERAL (SELECT sum(v1) FROM t1 WHERE t0.v0 = t1.v0);
1450+
----
1451+
1 1 1
1452+
1 2 1
1453+
3 3 7
1454+
4 4 NULL
1455+
1456+
query TT
1457+
explain SELECT * FROM t0, LATERAL (SELECT * FROM t1 WHERE t0.v0 = t1.v0);
1458+
----
1459+
logical_plan
1460+
01)Inner Join: t0.v0 = t1.v0
1461+
02)--TableScan: t0 projection=[v0, v1]
1462+
03)--TableScan: t1 projection=[v0, v1]
1463+
physical_plan
1464+
01)CoalesceBatchesExec: target_batch_size=8192
1465+
02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(v0@0, v0@0)]
1466+
03)----DataSourceExec: partitions=1, partition_sizes=[1]
1467+
04)----DataSourceExec: partitions=1, partition_sizes=[1]
1468+
1469+
query IIII
1470+
SELECT * FROM t0, LATERAL (SELECT * FROM t1 WHERE t0.v0 = t1.v0);
1471+
----
1472+
1 1 1 1
1473+
1 2 1 1
1474+
3 3 3 2
1475+
3 3 3 5
1476+
1477+
query III
1478+
SELECT * FROM t0, LATERAL (SELECT 1);
1479+
----
1480+
1 1 1
1481+
1 2 1
1482+
3 3 1
1483+
4 4 1
1484+
1485+
query IIII
1486+
SELECT * FROM t0, LATERAL (SELECT * FROM t1 WHERE t1.v0 = 1);
1487+
----
1488+
1 1 1 1
1489+
1 2 1 1
1490+
3 3 1 1
1491+
4 4 1 1
1492+
1493+
query IIII
1494+
SELECT * FROM t0 JOIN LATERAL (SELECT * FROM t1 WHERE t1.v0 = 1) on true;
1495+
----
1496+
1 1 1 1
1497+
1 2 1 1
1498+
3 3 1 1
1499+
4 4 1 1
1500+
1501+
statement ok
1502+
drop table t1;
1503+
1504+
statement ok
1505+
drop table t0;

0 commit comments

Comments
 (0)