Skip to content

Commit 74179dc

Browse files
committed
Add PushPartialAggregationIntoTableScan optimization
Allows for predicate pushdown of aggregation through unions.
1 parent bddec49 commit 74179dc

File tree

7 files changed

+70185
-0
lines changed

7 files changed

+70185
-0
lines changed

core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java

+2
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@
138138
import io.trino.sql.planner.iterative.rule.PruneWindowColumns;
139139
import io.trino.sql.planner.iterative.rule.PushAggregationIntoTableScan;
140140
import io.trino.sql.planner.iterative.rule.PushAggregationThroughOuterJoin;
141+
import io.trino.sql.planner.iterative.rule.PushPartialAggregationIntoTableScan;
141142
import io.trino.sql.planner.iterative.rule.PushCastIntoRow;
142143
import io.trino.sql.planner.iterative.rule.PushDistinctLimitIntoTableScan;
143144
import io.trino.sql.planner.iterative.rule.PushDownDereferenceThroughFilter;
@@ -1003,6 +1004,7 @@ public PlanOptimizers(
10031004
ImmutableSet.<Rule<?>>builder()
10041005
.addAll(new PushPartialAggregationThroughJoin().rules())
10051006
.add(new PushPartialAggregationThroughExchange(plannerContext),
1007+
new PushPartialAggregationIntoTableScan(plannerContext),
10061008
new PruneJoinColumns(),
10071009
new PruneJoinChildrenColumns(),
10081010
new RemoveRedundantIdentityProjections())
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.sql.planner.iterative.rule;
15+
16+
import com.google.common.collect.ImmutableBiMap;
17+
import com.google.common.collect.ImmutableList;
18+
import com.google.common.collect.ImmutableMap;
19+
import io.trino.Session;
20+
import io.trino.matching.Capture;
21+
import io.trino.matching.Captures;
22+
import io.trino.matching.Pattern;
23+
import io.trino.metadata.TableHandle;
24+
import io.trino.spi.connector.AggregateFunction;
25+
import io.trino.spi.connector.AggregationApplicationResult;
26+
import io.trino.spi.connector.Assignment;
27+
import io.trino.spi.connector.ColumnHandle;
28+
import io.trino.spi.connector.SortItem;
29+
import io.trino.spi.expression.ConnectorExpression;
30+
import io.trino.spi.expression.Variable;
31+
import io.trino.spi.function.BoundSignature;
32+
import io.trino.spi.predicate.TupleDomain;
33+
import io.trino.sql.PlannerContext;
34+
import io.trino.sql.ir.Expression;
35+
import io.trino.sql.ir.Reference;
36+
import io.trino.sql.planner.ConnectorExpressionTranslator;
37+
import io.trino.sql.planner.OrderingScheme;
38+
import io.trino.sql.planner.Symbol;
39+
import io.trino.sql.planner.iterative.Rule;
40+
import io.trino.sql.planner.plan.AggregationNode;
41+
import io.trino.sql.planner.plan.Assignments;
42+
import io.trino.sql.planner.plan.PlanNode;
43+
import io.trino.sql.planner.plan.ProjectNode;
44+
import io.trino.sql.planner.plan.TableScanNode;
45+
46+
import java.util.HashMap;
47+
import java.util.List;
48+
import java.util.Map;
49+
import java.util.Map.Entry;
50+
import java.util.Optional;
51+
import java.util.stream.IntStream;
52+
53+
import static com.google.common.base.Verify.verify;
54+
import static com.google.common.collect.ImmutableList.toImmutableList;
55+
import static com.google.common.collect.ImmutableMap.toImmutableMap;
56+
import static io.trino.SystemSessionProperties.isAllowPushdownIntoConnectors;
57+
import static io.trino.matching.Capture.newCapture;
58+
import static io.trino.sql.ir.optimizer.IrExpressionOptimizer.newOptimizer;
59+
import static io.trino.sql.planner.iterative.rule.Rules.deriveTableStatisticsForPushdown;
60+
import static io.trino.sql.planner.plan.Patterns.Aggregation.step;
61+
import static io.trino.sql.planner.plan.Patterns.aggregation;
62+
import static io.trino.sql.planner.plan.Patterns.source;
63+
import static io.trino.sql.planner.plan.Patterns.tableScan;
64+
import static java.util.Objects.requireNonNull;
65+
66+
public class PushPartialAggregationIntoTableScan
67+
implements Rule<AggregationNode>
68+
{
69+
private static final Capture<TableScanNode> TABLE_SCAN = newCapture();
70+
71+
private static final Pattern<AggregationNode> PATTERN =
72+
aggregation()
73+
.with(step().equalTo(AggregationNode.Step.PARTIAL))
74+
// skip arguments that are, for instance, lambda expressions
75+
.matching(PushPartialAggregationIntoTableScan::allArgumentsAreSimpleReferences)
76+
.matching(node -> node.getGroupingSets().getGroupingSetCount() <= 1)
77+
.matching(PushPartialAggregationIntoTableScan::hasNoMasks)
78+
.with(source().matching(tableScan().capturedAs(TABLE_SCAN)));
79+
80+
private final PlannerContext plannerContext;
81+
82+
public PushPartialAggregationIntoTableScan(PlannerContext plannerContext)
83+
{
84+
this.plannerContext = requireNonNull(plannerContext, "plannerContext is null");
85+
}
86+
87+
@Override
88+
public Pattern<AggregationNode> getPattern()
89+
{
90+
return PATTERN;
91+
}
92+
93+
@Override
94+
public boolean isEnabled(Session session)
95+
{
96+
return isAllowPushdownIntoConnectors(session);
97+
}
98+
99+
private static boolean allArgumentsAreSimpleReferences(AggregationNode node)
100+
{
101+
return node.getAggregations()
102+
.values().stream()
103+
.flatMap(aggregation -> aggregation.getArguments().stream())
104+
.allMatch(Reference.class::isInstance);
105+
}
106+
107+
private static boolean hasNoMasks(AggregationNode node)
108+
{
109+
return node.getAggregations()
110+
.values().stream()
111+
.allMatch(aggregation -> aggregation.getMask().isEmpty());
112+
}
113+
114+
@Override
115+
public Result apply(AggregationNode node, Captures captures, Context context)
116+
{
117+
return PushPartialAggregationIntoTableScan(plannerContext, context, node, captures.get(TABLE_SCAN), node.getAggregations(), node.getGroupingSets().getGroupingKeys())
118+
.map(Rule.Result::ofPlanNode)
119+
.orElseGet(Rule.Result::empty);
120+
}
121+
122+
public static Optional<PlanNode> PushPartialAggregationIntoTableScan(
123+
PlannerContext plannerContext,
124+
Context context,
125+
PlanNode aggregationNode,
126+
TableScanNode tableScan,
127+
Map<Symbol, AggregationNode.Aggregation> aggregations,
128+
List<Symbol> groupingKeys)
129+
{
130+
Session session = context.getSession();
131+
132+
if (groupingKeys.isEmpty() && aggregations.isEmpty()) {
133+
// Global aggregation with no aggregate functions. No point to push this down into connector.
134+
return Optional.empty();
135+
}
136+
137+
Map<String, ColumnHandle> assignments = tableScan.getAssignments()
138+
.entrySet().stream()
139+
.collect(toImmutableMap(entry -> entry.getKey().name(), Entry::getValue));
140+
141+
List<Entry<Symbol, AggregationNode.Aggregation>> aggregationsList = ImmutableList.copyOf(aggregations.entrySet());
142+
143+
List<AggregateFunction> aggregateFunctions = aggregationsList.stream()
144+
.map(Entry::getValue)
145+
.map(PushPartialAggregationIntoTableScan::toAggregateFunction)
146+
.collect(toImmutableList());
147+
148+
List<Symbol> aggregationOutputSymbols = aggregationsList.stream()
149+
.map(Entry::getKey)
150+
.collect(toImmutableList());
151+
152+
List<ColumnHandle> groupByColumns = groupingKeys.stream()
153+
.map(groupByColumn -> assignments.get(groupByColumn.name()))
154+
.collect(toImmutableList());
155+
156+
Optional<AggregationApplicationResult<TableHandle>> aggregationPushdownResult = plannerContext.getMetadata().applyAggregation(
157+
session,
158+
tableScan.getTable(),
159+
aggregateFunctions,
160+
assignments,
161+
ImmutableList.of(groupByColumns));
162+
163+
if (aggregationPushdownResult.isEmpty()) {
164+
return Optional.empty();
165+
}
166+
167+
AggregationApplicationResult<TableHandle> result = aggregationPushdownResult.get();
168+
169+
// The new scan outputs should be the symbols associated with grouping columns plus the symbols associated with aggregations.
170+
ImmutableList.Builder<Symbol> newScanOutputs = ImmutableList.builder();
171+
newScanOutputs.addAll(tableScan.getOutputSymbols());
172+
173+
ImmutableBiMap.Builder<Symbol, ColumnHandle> newScanAssignments = ImmutableBiMap.builder();
174+
newScanAssignments.putAll(tableScan.getAssignments());
175+
176+
Map<String, Symbol> variableMappings = new HashMap<>();
177+
178+
for (Assignment assignment : result.getAssignments()) {
179+
Symbol symbol = context.getSymbolAllocator().newSymbol(assignment.getVariable(), assignment.getType());
180+
181+
newScanOutputs.add(symbol);
182+
newScanAssignments.put(symbol, assignment.getColumn());
183+
variableMappings.put(assignment.getVariable(), symbol);
184+
}
185+
186+
List<Expression> newProjections = result.getProjections().stream()
187+
.map(expression -> {
188+
Expression translated = ConnectorExpressionTranslator.translate(session, expression, plannerContext, variableMappings);
189+
// ConnectorExpressionTranslator may or may not preserve optimized form of expressions during round-trip. Avoid potential optimizer loop
190+
// by ensuring expression is optimized.
191+
return newOptimizer(plannerContext).process(translated, session, ImmutableMap.of()).orElse(translated);
192+
})
193+
.collect(toImmutableList());
194+
195+
verify(aggregationOutputSymbols.size() == newProjections.size());
196+
197+
Assignments.Builder assignmentBuilder = Assignments.builder();
198+
IntStream.range(0, aggregationOutputSymbols.size())
199+
.forEach(index -> assignmentBuilder.put(aggregationOutputSymbols.get(index), newProjections.get(index)));
200+
201+
ImmutableBiMap<Symbol, ColumnHandle> scanAssignments = newScanAssignments.build();
202+
ImmutableBiMap<ColumnHandle, Symbol> columnHandleToSymbol = scanAssignments.inverse();
203+
// projections assignmentBuilder should have both agg and group by so we add all the group bys as symbol references
204+
groupingKeys
205+
.forEach(groupBySymbol -> {
206+
// if the connector returned a new mapping from oldColumnHandle to newColumnHandle, groupBy needs to point to
207+
// new columnHandle's symbol reference, otherwise it will continue pointing at oldColumnHandle.
208+
ColumnHandle originalColumnHandle = assignments.get(groupBySymbol.name());
209+
ColumnHandle groupByColumnHandle = result.getGroupingColumnMapping().getOrDefault(originalColumnHandle, originalColumnHandle);
210+
assignmentBuilder.put(groupBySymbol, columnHandleToSymbol.get(groupByColumnHandle).toSymbolReference());
211+
});
212+
213+
return Optional.of(
214+
new ProjectNode(
215+
context.getIdAllocator().getNextId(),
216+
new TableScanNode(
217+
context.getIdAllocator().getNextId(),
218+
result.getHandle(),
219+
newScanOutputs.build(),
220+
scanAssignments,
221+
TupleDomain.all(),
222+
deriveTableStatisticsForPushdown(context.getStatsProvider(), session, result.isPrecalculateStatistics(), aggregationNode),
223+
tableScan.isUpdateTarget(),
224+
tableScan.getUseConnectorNodePartitioning()),
225+
assignmentBuilder.build()));
226+
}
227+
228+
private static AggregateFunction toAggregateFunction(AggregationNode.Aggregation aggregation)
229+
{
230+
BoundSignature signature = aggregation.getResolvedFunction().signature();
231+
232+
ImmutableList.Builder<ConnectorExpression> arguments = ImmutableList.builder();
233+
for (int i = 0; i < aggregation.getArguments().size(); i++) {
234+
Reference argument = (Reference) aggregation.getArguments().get(i);
235+
arguments.add(new Variable(argument.name(), signature.getArgumentTypes().get(i)));
236+
}
237+
238+
Optional<OrderingScheme> orderingScheme = aggregation.getOrderingScheme();
239+
Optional<List<SortItem>> sortBy = orderingScheme.map(OrderingScheme::toSortItems);
240+
241+
Optional<ConnectorExpression> filter = aggregation.getFilter()
242+
.map(symbol -> new Variable(symbol.name(), symbol.type()));
243+
244+
return new AggregateFunction(
245+
signature.getName().getFunctionName(),
246+
signature.getReturnType(),
247+
arguments.build(),
248+
sortBy.orElse(ImmutableList.of()),
249+
aggregation.isDistinct(),
250+
filter);
251+
}
252+
}

0 commit comments

Comments
 (0)