Skip to content

Commit 5517c1e

Browse files
authored
Merge group fields for aggregate if having dependent group fields (opensearch-project#4703)
* [Enhancement]Merge group fields for aggregate if having dependent group fields Signed-off-by: Heng Qian <[email protected]> * fix CI Signed-off-by: Heng Qian <[email protected]> * Fix windows UT Signed-off-by: Heng Qian <[email protected]> --------- Signed-off-by: Heng Qian <[email protected]>
1 parent 373b394 commit 5517c1e

File tree

13 files changed

+454
-72
lines changed

13 files changed

+454
-72
lines changed

core/src/main/java/org/opensearch/sql/calcite/plan/OpenSearchRules.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
import org.apache.calcite.plan.RelOptRule;
1111

1212
public class OpenSearchRules {
13-
private static final PPLAggregateConvertRule AGGREGATE_CONVERT_RULE =
13+
public static final PPLAggregateConvertRule AGGREGATE_CONVERT_RULE =
1414
PPLAggregateConvertRule.Config.SUM_CONVERTER.toRule();
15+
public static final PPLAggGroupMergeRule AGG_GROUP_MERGE_RULE =
16+
PPLAggGroupMergeRule.Config.GROUP_MERGE.toRule();
1517

1618
public static final List<RelOptRule> OPEN_SEARCH_OPT_RULES =
17-
ImmutableList.of(AGGREGATE_CONVERT_RULE);
19+
ImmutableList.of(AGGREGATE_CONVERT_RULE, AGG_GROUP_MERGE_RULE);
1820

1921
// prevent instantiation
2022
private OpenSearchRules() {}
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.calcite.plan;
7+
8+
import java.util.ArrayList;
9+
import java.util.Collection;
10+
import java.util.List;
11+
import java.util.Set;
12+
import java.util.stream.Collectors;
13+
import java.util.stream.IntStream;
14+
import org.apache.calcite.plan.RelOptRuleCall;
15+
import org.apache.calcite.plan.RelRule;
16+
import org.apache.calcite.rel.logical.LogicalAggregate;
17+
import org.apache.calcite.rel.logical.LogicalProject;
18+
import org.apache.calcite.rex.RexCall;
19+
import org.apache.calcite.rex.RexInputRef;
20+
import org.apache.calcite.rex.RexNode;
21+
import org.apache.calcite.rex.RexUtil;
22+
import org.apache.calcite.sql.SqlKind;
23+
import org.apache.calcite.tools.RelBuilder;
24+
import org.apache.calcite.util.ImmutableBitSet;
25+
import org.apache.calcite.util.mapping.Mapping;
26+
import org.apache.calcite.util.mapping.Mappings;
27+
import org.apache.commons.lang3.tuple.Pair;
28+
import org.immutables.value.Value;
29+
import org.opensearch.sql.calcite.utils.CalciteUtils;
30+
31+
/**
32+
* Planner rule that merge multiple agg group fields into a single one, on which all other group
33+
* fields depend. e.g.
34+
*
35+
* <p>stats ... by a, f1(a), f2(a) -> stats ... by a | eval `f1(a)` = f1(a), `f2(a)` = f2(a)
36+
*
37+
* <p>TODO: this rule could be expanded further for more cases: 1. support multiple base group
38+
* fields, e.g. stats ... by a, f1(a), b, f2(b), f3(a, b) -> stats ... by a, b | eval `f1(a)` =
39+
* f1(a), `f2(b)` = f2(b), `f3(a, b)` = f3(a, b) 2. support no base fields, e.g. stats ... by f1(a),
40+
* f2(a) -> stats ... by a | eval `f1(a)` = f1(a), `f2(a)` = f2(a) | fields - a Note that one of
41+
* these UDFs' output must have equivalent cardinality as `a`.
42+
*/
43+
@Value.Enclosing
44+
public class PPLAggGroupMergeRule extends RelRule<PPLAggGroupMergeRule.Config> {
45+
46+
/** Creates a OpenSearchAggregateConvertRule. */
47+
protected PPLAggGroupMergeRule(Config config) {
48+
super(config);
49+
}
50+
51+
@Override
52+
public void onMatch(RelOptRuleCall call) {
53+
if (call.rels.length == 2) {
54+
final LogicalAggregate aggregate = call.rel(0);
55+
final LogicalProject project = call.rel(1);
56+
apply(call, aggregate, project);
57+
} else {
58+
throw new AssertionError(
59+
String.format(
60+
"The length of rels should be %s but got %s",
61+
this.operands.size(), call.rels.length));
62+
}
63+
}
64+
65+
public void apply(RelOptRuleCall call, LogicalAggregate aggregate, LogicalProject project) {
66+
List<Integer> groupSet = aggregate.getGroupSet().asList();
67+
List<RexNode> groupNodes =
68+
groupSet.stream().map(group -> project.getProjects().get(group)).toList();
69+
Pair<List<Integer>, List<Integer>> baseFieldsAndOthers =
70+
CalciteUtils.partition(
71+
groupSet, i -> project.getProjects().get(i).getKind() == SqlKind.INPUT_REF);
72+
List<Integer> baseGroupList = baseFieldsAndOthers.getLeft();
73+
// TODO: support more base fields in the future.
74+
if (baseGroupList.size() != 1) return;
75+
Integer baseGroupField = baseGroupList.get(0);
76+
RexInputRef baseGroupRef = (RexInputRef) project.getProjects().get(baseGroupField);
77+
List<Integer> otherGroupList = baseFieldsAndOthers.getRight();
78+
boolean allDependOnBaseField =
79+
otherGroupList.stream()
80+
.map(i -> project.getProjects().get(i))
81+
.allMatch(node -> isDependentField(node, List.of(baseGroupRef)));
82+
if (!allDependOnBaseField) return;
83+
84+
final RelBuilder relBuilder = call.builder();
85+
relBuilder.push(project);
86+
87+
relBuilder.aggregate(
88+
relBuilder.groupKey(ImmutableBitSet.of(baseGroupField)), aggregate.getAggCallList());
89+
90+
/* Build the final project-aggregate-project */
91+
final Mapping mapping =
92+
Mappings.target(
93+
List.of(baseGroupRef.getIndex()),
94+
baseGroupRef.getIndex() + 1); // set source count greater than the max ref index
95+
List<RexNode> parentProjections = new ArrayList<>(RexUtil.apply(mapping, groupNodes));
96+
List<RexNode> aggCallRefs =
97+
relBuilder.fields(
98+
IntStream.range(baseGroupList.size(), relBuilder.peek().getRowType().getFieldCount())
99+
.boxed()
100+
.toList());
101+
parentProjections.addAll(aggCallRefs);
102+
relBuilder.project(parentProjections);
103+
call.transformTo(relBuilder.build());
104+
}
105+
106+
/** Rule configuration. */
107+
@Value.Immutable
108+
public interface Config extends RelRule.Config {
109+
Config GROUP_MERGE =
110+
ImmutablePPLAggGroupMergeRule.Config.builder()
111+
.build()
112+
.withOperandSupplier(
113+
b0 ->
114+
b0.operand(LogicalAggregate.class)
115+
.predicate(Config::containsMultipleGroupSets)
116+
.oneInput(
117+
b1 ->
118+
b1.operand(LogicalProject.class)
119+
.predicate(Config::containsDependentFields)
120+
.anyInputs()));
121+
122+
static boolean containsMultipleGroupSets(LogicalAggregate aggregate) {
123+
return aggregate.getGroupSet().cardinality() > 1;
124+
}
125+
126+
// Only rough predication here since we don't know which fields are group fields currently.
127+
static boolean containsDependentFields(LogicalProject project) {
128+
Set<RexNode> baseFields =
129+
project.getProjects().stream()
130+
.filter(node -> node.getKind() == SqlKind.INPUT_REF)
131+
.collect(Collectors.toUnmodifiableSet());
132+
return project.getProjects().stream()
133+
.anyMatch(node -> PPLAggGroupMergeRule.isDependentField(node, baseFields));
134+
}
135+
136+
@Override
137+
default PPLAggGroupMergeRule toRule() {
138+
return new PPLAggGroupMergeRule(this);
139+
}
140+
}
141+
142+
public static boolean isDependentField(RexNode node, Collection<RexNode> baseFields) {
143+
// Always view literal field as dependent field here since we can always implement a function
144+
// to transform a field into such a literal
145+
if (node.getKind() == SqlKind.LITERAL) return true;
146+
if (node.getKind() == SqlKind.INPUT_REF && baseFields.contains(node)) return true;
147+
if (node instanceof RexCall && ((RexCall) node).getOperator().isDeterministic()) {
148+
return ((RexCall) node)
149+
.getOperands().stream().allMatch(op -> isDependentField(op, baseFields));
150+
}
151+
return false;
152+
}
153+
}

core/src/main/java/org/opensearch/sql/calcite/plan/PPLAggregateConvertRule.java

Lines changed: 2 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,10 @@
77

88
import com.google.common.collect.ImmutableList;
99
import java.util.ArrayList;
10-
import java.util.HashMap;
1110
import java.util.List;
12-
import java.util.Map;
13-
import java.util.Set;
1411
import java.util.function.Function;
1512
import java.util.stream.IntStream;
1613
import org.apache.calcite.plan.RelOptRuleCall;
17-
import org.apache.calcite.plan.RelOptUtil;
1814
import org.apache.calcite.plan.RelRule;
1915
import org.apache.calcite.rel.RelNode;
2016
import org.apache.calcite.rel.core.AggregateCall;
@@ -30,8 +26,6 @@
3026
import org.apache.calcite.sql.SqlKind;
3127
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
3228
import org.apache.calcite.tools.RelBuilder;
33-
import org.apache.calcite.util.ImmutableBitSet;
34-
import org.apache.calcite.util.mapping.Mappings;
3529
import org.apache.commons.lang3.tuple.Pair;
3630
import org.immutables.value.Value;
3731

@@ -188,36 +182,9 @@ public void apply(RelOptRuleCall call, LogicalAggregate aggregate, LogicalProjec
188182
}
189183
}
190184

191-
/* Eliminate unused fields in the child project */
192-
ImmutableBitSet newGroupSet = aggregate.getGroupSet();
193-
;
194-
ImmutableList<ImmutableBitSet> newGroupSets = aggregate.getGroupSets();
195-
;
196-
final Set<Integer> fieldsUsed =
197-
RelOptUtil.getAllFields2(aggregate.getGroupSet(), distinctAggregateCalls);
198-
if (fieldsUsed.size() < newChildProjects.size()) {
199-
// Some fields are computed but not used. Prune them.
200-
final Map<Integer, Integer> sourceFieldToTargetFieldMap = new HashMap<>();
201-
for (int source : fieldsUsed) {
202-
sourceFieldToTargetFieldMap.put(source, sourceFieldToTargetFieldMap.size());
203-
}
204-
newGroupSet = aggregate.getGroupSet().permute(sourceFieldToTargetFieldMap);
205-
newGroupSets =
206-
ImmutableBitSet.ORDERING.immutableSortedCopy(
207-
ImmutableBitSet.permute(aggregate.getGroupSets(), sourceFieldToTargetFieldMap));
208-
final Mappings.TargetMapping targetMapping =
209-
Mappings.target(sourceFieldToTargetFieldMap, newChildProjects.size(), fieldsUsed.size());
210-
final List<AggregateCall> oldAggregateCalls = new ArrayList<>(distinctAggregateCalls);
211-
distinctAggregateCalls.clear();
212-
for (AggregateCall aggregateCall : oldAggregateCalls) {
213-
distinctAggregateCalls.add(aggregateCall.transform(targetMapping));
214-
}
215-
// Project the used fields
216-
relBuilder.project(relBuilder.fields(fieldsUsed.stream().toList()));
217-
}
185+
relBuilder.aggregate(relBuilder.groupKey(aggregate.getGroupSet()), distinctAggregateCalls);
218186

219-
/* Build the final project-aggregate-project after eliminating unused fields */
220-
relBuilder.aggregate(relBuilder.groupKey(newGroupSet, newGroupSets), distinctAggregateCalls);
187+
/* Build the final project-aggregate-project */
221188
List<RexNode> parentProjects =
222189
new ArrayList<>(relBuilder.fields(IntStream.range(0, groupSetOffset).boxed().toList()));
223190
parentProjects.addAll(

0 commit comments

Comments
 (0)