Skip to content

Commit a00a811

Browse files
feat(core,isthmus): support grouping set index in Aggregate (#565)
Signed-off-by: Niels Pardon <[email protected]> Co-authored-by: Mark S. Lewis <[email protected]>
1 parent 1b1dc73 commit a00a811

File tree

4 files changed

+114
-8
lines changed

4 files changed

+114
-8
lines changed

core/src/main/java/io/substrait/relation/Aggregate.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,13 @@ protected Type.Struct deriveRecordType() {
5656

5757
final Stream<Type> measureTypes = getMeasures().stream().map(t -> t.getFunction().getType());
5858

59-
return TypeCreator.REQUIRED.struct(Stream.concat(groupingTypes, measureTypes));
59+
// an aggregate relation with more than one grouping set receives an extra i32 column on the
60+
// right-hand side per spec:
61+
// https://substrait.io/relations/logical_relations/#aggregate-operation
62+
final Stream<Type> groupingSetIndex = Stream.of(TypeCreator.REQUIRED.I32);
63+
64+
return TypeCreator.REQUIRED.struct(
65+
Stream.concat(Stream.concat(groupingTypes, measureTypes), groupingSetIndex));
6066
}
6167

6268
@Override

isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,18 @@
3434
import io.substrait.relation.NamedWrite;
3535
import io.substrait.relation.Project;
3636
import io.substrait.relation.Rel;
37+
import io.substrait.relation.Rel.Remap;
3738
import io.substrait.relation.Set;
3839
import io.substrait.relation.Sort;
3940
import io.substrait.relation.VirtualTableScan;
4041
import io.substrait.type.NamedStruct;
42+
import io.substrait.type.TypeCreator;
4143
import io.substrait.util.VisitationContext;
4244
import java.util.ArrayList;
4345
import java.util.Collection;
4446
import java.util.Collections;
4547
import java.util.HashSet;
48+
import java.util.LinkedList;
4649
import java.util.List;
4750
import java.util.Optional;
4851
import java.util.OptionalLong;
@@ -74,6 +77,7 @@
7477
import org.apache.calcite.rex.RexSlot;
7578
import org.apache.calcite.sql.SqlAggFunction;
7679
import org.apache.calcite.sql.SqlOperator;
80+
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
7781
import org.apache.calcite.sql.parser.SqlParser;
7882
import org.apache.calcite.tools.Frameworks;
7983
import org.apache.calcite.tools.RelBuilder;
@@ -348,8 +352,43 @@ public RelNode visit(Aggregate aggregate, Context context) throws RuntimeExcepti
348352
aggregate.getMeasures().stream()
349353
.map(measure -> fromMeasure(measure, context))
350354
.collect(java.util.stream.Collectors.toList());
355+
356+
Optional<Remap> remap = aggregate.getRemap();
357+
final int lastFieldIndex = groupExprs.size() + aggregateCalls.size();
358+
359+
// map grouping set index if it is not removed via remap
360+
final boolean emitDirect = remap.isEmpty();
361+
final boolean groupingSetIndexGetsRemapped =
362+
remap.map(r -> r.indices().contains(lastFieldIndex)).orElse(false);
363+
if (aggregate.getGroupings().size() > 1 && (emitDirect || groupingSetIndexGetsRemapped)) {
364+
aggregateCalls.add(
365+
AggregateCall.create(
366+
SqlStdOperatorTable.GROUP_ID,
367+
false,
368+
false,
369+
false,
370+
Collections.emptyList(),
371+
Collections.emptyList(),
372+
-1,
373+
null,
374+
RelCollations.EMPTY,
375+
typeConverter.toCalcite(typeFactory, TypeCreator.REQUIRED.I64),
376+
null));
377+
final int groupingCallIndex = aggregateCalls.size() - 1;
378+
if (groupingSetIndexGetsRemapped) {
379+
List<Integer> remapList = new LinkedList<>(remap.get().indices());
380+
for (int i = 0; i < remapList.size(); i++) {
381+
if (remapList.get(i).equals(lastFieldIndex)) {
382+
// replace last field index with field index of the GROUP_ID() function call
383+
remapList.set(i, groupingCallIndex);
384+
}
385+
}
386+
remap = Optional.of(Remap.of(remapList));
387+
}
388+
}
389+
351390
RelNode node = relBuilder.push(child).aggregate(groupKey, aggregateCalls).build();
352-
return applyRemap(node, aggregate.getRemap());
391+
return applyRemap(node, remap);
353392
}
354393

355394
private AggregateCall fromMeasure(Aggregate.Measure measure, Context context) {

isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import io.substrait.relation.EmptyScan;
2525
import io.substrait.relation.Fetch;
2626
import io.substrait.relation.Filter;
27+
import io.substrait.relation.ImmutableAggregate;
2728
import io.substrait.relation.ImmutableFetch;
2829
import io.substrait.relation.ImmutableMeasure.Builder;
2930
import io.substrait.relation.Join;
@@ -34,6 +35,7 @@
3435
import io.substrait.relation.NamedWrite;
3536
import io.substrait.relation.Project;
3637
import io.substrait.relation.Rel;
38+
import io.substrait.relation.Rel.Remap;
3739
import io.substrait.relation.Set;
3840
import io.substrait.relation.Sort;
3941
import io.substrait.relation.VirtualTableScan;
@@ -45,6 +47,7 @@
4547
import java.util.Optional;
4648
import java.util.OptionalLong;
4749
import java.util.stream.Collectors;
50+
import java.util.stream.IntStream;
4851
import java.util.stream.Stream;
4952
import org.apache.calcite.rel.RelFieldCollation;
5053
import org.apache.calcite.rel.RelFieldCollation.Direction;
@@ -58,6 +61,7 @@
5861
import org.apache.calcite.rex.RexBuilder;
5962
import org.apache.calcite.rex.RexFieldAccess;
6063
import org.apache.calcite.rex.RexNode;
64+
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
6165
import org.apache.calcite.util.ImmutableBitSet;
6266
import org.immutables.value.Value;
6367

@@ -262,16 +266,65 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) {
262266
List<Grouping> groupings =
263267
sets.filter(s -> s != null).map(s -> fromGroupSet(s, input)).collect(Collectors.toList());
264268

265-
List<Measure> aggCalls =
269+
// get GROUP_ID() function calls
270+
List<AggregateCall> groupIdCalls =
271+
aggregate.getAggCallList().stream()
272+
.filter(c -> c.getAggregation().equals(SqlStdOperatorTable.GROUP_ID))
273+
.collect(Collectors.toList());
274+
275+
List<AggregateCall> filteredAggCalls =
266276
aggregate.getAggCallList().stream()
277+
// remove GROUP_ID() function calls
278+
.filter(c -> !groupIdCalls.contains(c))
279+
.collect(Collectors.toList());
280+
281+
List<Measure> aggCalls =
282+
filteredAggCalls.stream()
267283
.map(c -> fromAggCall(aggregate.getInput(), input.getRecordType(), c))
268284
.collect(Collectors.toList());
269285

270-
return Aggregate.builder()
271-
.input(input)
272-
.addAllGroupings(groupings)
273-
.addAllMeasures(aggCalls)
274-
.build();
286+
ImmutableAggregate.Builder builder =
287+
Aggregate.builder().input(input).addAllGroupings(groupings).addAllMeasures(aggCalls);
288+
289+
if (groupings.size() > 1) {
290+
// remove the grouping set index if there was no explicit GROUP_ID() function call
291+
if (groupIdCalls.isEmpty()) {
292+
int groupingExprSize =
293+
Math.toIntExact(
294+
groupings.stream().flatMap(g -> g.getExpressions().stream()).distinct().count());
295+
builder.remap(Remap.offset(0, groupingExprSize + aggCalls.size()));
296+
} else {
297+
// remap grouping set index at the field positions where the GROUP_ID() function calls were
298+
final int groupingFieldCount =
299+
Math.toIntExact(groupings.stream().flatMap(g -> g.getExpressions().stream()).count());
300+
final int filterAggCallCount = aggCalls.size();
301+
final Integer groupingSetIndex = groupingFieldCount + filterAggCallCount;
302+
303+
final List<Integer> remap =
304+
IntStream.range(0, groupingFieldCount)
305+
.mapToObj(i -> i)
306+
.collect(Collectors.toCollection(ArrayList::new));
307+
308+
for (int i = 0; i < aggregate.getAggCallList().size(); i++) {
309+
AggregateCall aggCall = aggregate.getAggCallList().get(i);
310+
if (filteredAggCalls.contains(aggCall)) {
311+
remap.add(
312+
i + groupingFieldCount, filteredAggCalls.indexOf(aggCall) + groupingFieldCount);
313+
} else if (groupIdCalls.contains(aggCall)) {
314+
remap.add(i + groupingFieldCount, groupingSetIndex);
315+
} else {
316+
// this should never get triggered
317+
throw new IllegalStateException(
318+
"encountered AggregateCall that is neither in filteredAggCalls nor in groupIdCalls"
319+
+ aggCall);
320+
}
321+
}
322+
323+
builder.remap(Remap.of(remap));
324+
}
325+
}
326+
327+
return builder.build();
275328
}
276329

277330
Aggregate.Grouping fromGroupSet(ImmutableBitSet bitSet, Rel input) {

isthmus/src/test/java/io/substrait/isthmus/Substrait2SqlTest.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,14 @@ public void simpleTestGroupingSets() throws Exception {
9292
"select sum(l_discount) from lineitem group by grouping sets ((l_orderkey, L_COMMITDATE), l_shipdate, ()), l_linestatus");
9393
assertSqlSubstraitRelRoundTrip(
9494
"select sum(l_discount) from lineitem group by grouping sets ((l_orderkey, L_COMMITDATE), (l_orderkey, L_COMMITDATE, l_linestatus), l_shipdate, ())");
95+
96+
// GROUP_ID()
97+
assertSqlSubstraitRelRoundTrip(
98+
"select sum(l_discount), group_id() from lineitem group by grouping sets ((l_orderkey, L_COMMITDATE), l_shipdate)");
99+
assertSqlSubstraitRelRoundTrip(
100+
"select group_id(), sum(l_discount) from lineitem group by grouping sets ((l_orderkey, L_COMMITDATE), l_shipdate)");
101+
assertSqlSubstraitRelRoundTrip(
102+
"select group_id(), sum(l_discount), group_id() from lineitem group by grouping sets ((l_orderkey, L_COMMITDATE), l_shipdate)");
95103
}
96104

97105
@Test

0 commit comments

Comments
 (0)