Skip to content

Commit

Permalink
Add PushPartialAggregationIntoTableScan optimization
Browse files Browse the repository at this point in the history
Allows for predicate pushdown of aggregation through unions.
  • Loading branch information
tontinton committed Dec 25, 2024
1 parent bddec49 commit 74179dc
Show file tree
Hide file tree
Showing 7 changed files with 70,185 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@
import io.trino.sql.planner.iterative.rule.PruneWindowColumns;
import io.trino.sql.planner.iterative.rule.PushAggregationIntoTableScan;
import io.trino.sql.planner.iterative.rule.PushAggregationThroughOuterJoin;
import io.trino.sql.planner.iterative.rule.PushPartialAggregationIntoTableScan;
import io.trino.sql.planner.iterative.rule.PushCastIntoRow;
import io.trino.sql.planner.iterative.rule.PushDistinctLimitIntoTableScan;
import io.trino.sql.planner.iterative.rule.PushDownDereferenceThroughFilter;
Expand Down Expand Up @@ -1003,6 +1004,7 @@ public PlanOptimizers(
ImmutableSet.<Rule<?>>builder()
.addAll(new PushPartialAggregationThroughJoin().rules())
.add(new PushPartialAggregationThroughExchange(plannerContext),
new PushPartialAggregationIntoTableScan(plannerContext),
new PruneJoinColumns(),
new PruneJoinChildrenColumns(),
new RemoveRedundantIdentityProjections())
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.TableHandle;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.connector.AggregationApplicationResult;
import io.trino.spi.connector.Assignment;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.SortItem;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Variable;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.predicate.TupleDomain;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.ConnectorExpressionTranslator;
import io.trino.sql.planner.OrderingScheme;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.TableScanNode;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.stream.IntStream;

import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.trino.SystemSessionProperties.isAllowPushdownIntoConnectors;
import static io.trino.matching.Capture.newCapture;
import static io.trino.sql.ir.optimizer.IrExpressionOptimizer.newOptimizer;
import static io.trino.sql.planner.iterative.rule.Rules.deriveTableStatisticsForPushdown;
import static io.trino.sql.planner.plan.Patterns.Aggregation.step;
import static io.trino.sql.planner.plan.Patterns.aggregation;
import static io.trino.sql.planner.plan.Patterns.source;
import static io.trino.sql.planner.plan.Patterns.tableScan;
import static java.util.Objects.requireNonNull;

public class PushPartialAggregationIntoTableScan
implements Rule<AggregationNode>
{
private static final Capture<TableScanNode> TABLE_SCAN = newCapture();

private static final Pattern<AggregationNode> PATTERN =
aggregation()
.with(step().equalTo(AggregationNode.Step.PARTIAL))
// skip arguments that are, for instance, lambda expressions
.matching(PushPartialAggregationIntoTableScan::allArgumentsAreSimpleReferences)
.matching(node -> node.getGroupingSets().getGroupingSetCount() <= 1)
.matching(PushPartialAggregationIntoTableScan::hasNoMasks)
.with(source().matching(tableScan().capturedAs(TABLE_SCAN)));

private final PlannerContext plannerContext;

public PushPartialAggregationIntoTableScan(PlannerContext plannerContext)
{
this.plannerContext = requireNonNull(plannerContext, "plannerContext is null");
}

@Override
public Pattern<AggregationNode> getPattern()
{
return PATTERN;
}

@Override
public boolean isEnabled(Session session)
{
return isAllowPushdownIntoConnectors(session);
}

private static boolean allArgumentsAreSimpleReferences(AggregationNode node)
{
return node.getAggregations()
.values().stream()
.flatMap(aggregation -> aggregation.getArguments().stream())
.allMatch(Reference.class::isInstance);
}

private static boolean hasNoMasks(AggregationNode node)
{
return node.getAggregations()
.values().stream()
.allMatch(aggregation -> aggregation.getMask().isEmpty());
}

@Override
public Result apply(AggregationNode node, Captures captures, Context context)
{
return PushPartialAggregationIntoTableScan(plannerContext, context, node, captures.get(TABLE_SCAN), node.getAggregations(), node.getGroupingSets().getGroupingKeys())
.map(Rule.Result::ofPlanNode)
.orElseGet(Rule.Result::empty);
}

public static Optional<PlanNode> PushPartialAggregationIntoTableScan(
PlannerContext plannerContext,
Context context,
PlanNode aggregationNode,
TableScanNode tableScan,
Map<Symbol, AggregationNode.Aggregation> aggregations,
List<Symbol> groupingKeys)
{
Session session = context.getSession();

if (groupingKeys.isEmpty() && aggregations.isEmpty()) {
// Global aggregation with no aggregate functions. No point to push this down into connector.
return Optional.empty();
}

Map<String, ColumnHandle> assignments = tableScan.getAssignments()
.entrySet().stream()
.collect(toImmutableMap(entry -> entry.getKey().name(), Entry::getValue));

List<Entry<Symbol, AggregationNode.Aggregation>> aggregationsList = ImmutableList.copyOf(aggregations.entrySet());

List<AggregateFunction> aggregateFunctions = aggregationsList.stream()
.map(Entry::getValue)
.map(PushPartialAggregationIntoTableScan::toAggregateFunction)
.collect(toImmutableList());

List<Symbol> aggregationOutputSymbols = aggregationsList.stream()
.map(Entry::getKey)
.collect(toImmutableList());

List<ColumnHandle> groupByColumns = groupingKeys.stream()
.map(groupByColumn -> assignments.get(groupByColumn.name()))
.collect(toImmutableList());

Optional<AggregationApplicationResult<TableHandle>> aggregationPushdownResult = plannerContext.getMetadata().applyAggregation(
session,
tableScan.getTable(),
aggregateFunctions,
assignments,
ImmutableList.of(groupByColumns));

if (aggregationPushdownResult.isEmpty()) {
return Optional.empty();
}

AggregationApplicationResult<TableHandle> result = aggregationPushdownResult.get();

// The new scan outputs should be the symbols associated with grouping columns plus the symbols associated with aggregations.
ImmutableList.Builder<Symbol> newScanOutputs = ImmutableList.builder();
newScanOutputs.addAll(tableScan.getOutputSymbols());

ImmutableBiMap.Builder<Symbol, ColumnHandle> newScanAssignments = ImmutableBiMap.builder();
newScanAssignments.putAll(tableScan.getAssignments());

Map<String, Symbol> variableMappings = new HashMap<>();

for (Assignment assignment : result.getAssignments()) {
Symbol symbol = context.getSymbolAllocator().newSymbol(assignment.getVariable(), assignment.getType());

newScanOutputs.add(symbol);
newScanAssignments.put(symbol, assignment.getColumn());
variableMappings.put(assignment.getVariable(), symbol);
}

List<Expression> newProjections = result.getProjections().stream()
.map(expression -> {
Expression translated = ConnectorExpressionTranslator.translate(session, expression, plannerContext, variableMappings);
// ConnectorExpressionTranslator may or may not preserve optimized form of expressions during round-trip. Avoid potential optimizer loop
// by ensuring expression is optimized.
return newOptimizer(plannerContext).process(translated, session, ImmutableMap.of()).orElse(translated);
})
.collect(toImmutableList());

verify(aggregationOutputSymbols.size() == newProjections.size());

Assignments.Builder assignmentBuilder = Assignments.builder();
IntStream.range(0, aggregationOutputSymbols.size())
.forEach(index -> assignmentBuilder.put(aggregationOutputSymbols.get(index), newProjections.get(index)));

ImmutableBiMap<Symbol, ColumnHandle> scanAssignments = newScanAssignments.build();
ImmutableBiMap<ColumnHandle, Symbol> columnHandleToSymbol = scanAssignments.inverse();
// projections assignmentBuilder should have both agg and group by so we add all the group bys as symbol references
groupingKeys
.forEach(groupBySymbol -> {
// if the connector returned a new mapping from oldColumnHandle to newColumnHandle, groupBy needs to point to
// new columnHandle's symbol reference, otherwise it will continue pointing at oldColumnHandle.
ColumnHandle originalColumnHandle = assignments.get(groupBySymbol.name());
ColumnHandle groupByColumnHandle = result.getGroupingColumnMapping().getOrDefault(originalColumnHandle, originalColumnHandle);
assignmentBuilder.put(groupBySymbol, columnHandleToSymbol.get(groupByColumnHandle).toSymbolReference());
});

return Optional.of(
new ProjectNode(
context.getIdAllocator().getNextId(),
new TableScanNode(
context.getIdAllocator().getNextId(),
result.getHandle(),
newScanOutputs.build(),
scanAssignments,
TupleDomain.all(),
deriveTableStatisticsForPushdown(context.getStatsProvider(), session, result.isPrecalculateStatistics(), aggregationNode),
tableScan.isUpdateTarget(),
tableScan.getUseConnectorNodePartitioning()),
assignmentBuilder.build()));
}

private static AggregateFunction toAggregateFunction(AggregationNode.Aggregation aggregation)
{
BoundSignature signature = aggregation.getResolvedFunction().signature();

ImmutableList.Builder<ConnectorExpression> arguments = ImmutableList.builder();
for (int i = 0; i < aggregation.getArguments().size(); i++) {
Reference argument = (Reference) aggregation.getArguments().get(i);
arguments.add(new Variable(argument.name(), signature.getArgumentTypes().get(i)));
}

Optional<OrderingScheme> orderingScheme = aggregation.getOrderingScheme();
Optional<List<SortItem>> sortBy = orderingScheme.map(OrderingScheme::toSortItems);

Optional<ConnectorExpression> filter = aggregation.getFilter()
.map(symbol -> new Variable(symbol.name(), symbol.type()));

return new AggregateFunction(
signature.getName().getFunctionName(),
signature.getReturnType(),
arguments.build(),
sortBy.orElse(ImmutableList.of()),
aggregation.isDistinct(),
filter);
}
}
Loading

0 comments on commit 74179dc

Please sign in to comment.