|
24 | 24 | import io.substrait.relation.EmptyScan; |
25 | 25 | import io.substrait.relation.Fetch; |
26 | 26 | import io.substrait.relation.Filter; |
| 27 | +import io.substrait.relation.ImmutableAggregate; |
27 | 28 | import io.substrait.relation.ImmutableFetch; |
28 | 29 | import io.substrait.relation.ImmutableMeasure.Builder; |
29 | 30 | import io.substrait.relation.Join; |
|
34 | 35 | import io.substrait.relation.NamedWrite; |
35 | 36 | import io.substrait.relation.Project; |
36 | 37 | import io.substrait.relation.Rel; |
| 38 | +import io.substrait.relation.Rel.Remap; |
37 | 39 | import io.substrait.relation.Set; |
38 | 40 | import io.substrait.relation.Sort; |
39 | 41 | import io.substrait.relation.VirtualTableScan; |
|
45 | 47 | import java.util.Optional; |
46 | 48 | import java.util.OptionalLong; |
47 | 49 | import java.util.stream.Collectors; |
| 50 | +import java.util.stream.IntStream; |
48 | 51 | import java.util.stream.Stream; |
49 | 52 | import org.apache.calcite.rel.RelFieldCollation; |
50 | 53 | import org.apache.calcite.rel.RelFieldCollation.Direction; |
|
58 | 61 | import org.apache.calcite.rex.RexBuilder; |
59 | 62 | import org.apache.calcite.rex.RexFieldAccess; |
60 | 63 | import org.apache.calcite.rex.RexNode; |
| 64 | +import org.apache.calcite.sql.fun.SqlStdOperatorTable; |
61 | 65 | import org.apache.calcite.util.ImmutableBitSet; |
62 | 66 | import org.immutables.value.Value; |
63 | 67 |
|
@@ -262,16 +266,65 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) { |
262 | 266 | List<Grouping> groupings = |
263 | 267 | sets.filter(s -> s != null).map(s -> fromGroupSet(s, input)).collect(Collectors.toList()); |
264 | 268 |
|
265 | | - List<Measure> aggCalls = |
| 269 | + // get GROUP_ID() function calls |
| 270 | + List<AggregateCall> groupIdCalls = |
| 271 | + aggregate.getAggCallList().stream() |
| 272 | + .filter(c -> c.getAggregation().equals(SqlStdOperatorTable.GROUP_ID)) |
| 273 | + .collect(Collectors.toList()); |
| 274 | + |
| 275 | + List<AggregateCall> filteredAggCalls = |
266 | 276 | aggregate.getAggCallList().stream() |
| 277 | + // remove GROUP_ID() function calls |
| 278 | + .filter(c -> !groupIdCalls.contains(c)) |
| 279 | + .collect(Collectors.toList()); |
| 280 | + |
| 281 | + List<Measure> aggCalls = |
| 282 | + filteredAggCalls.stream() |
267 | 283 | .map(c -> fromAggCall(aggregate.getInput(), input.getRecordType(), c)) |
268 | 284 | .collect(Collectors.toList()); |
269 | 285 |
|
270 | | - return Aggregate.builder() |
271 | | - .input(input) |
272 | | - .addAllGroupings(groupings) |
273 | | - .addAllMeasures(aggCalls) |
274 | | - .build(); |
| 286 | + ImmutableAggregate.Builder builder = |
| 287 | + Aggregate.builder().input(input).addAllGroupings(groupings).addAllMeasures(aggCalls); |
| 288 | + |
| 289 | + if (groupings.size() > 1) { |
| 290 | + // remove the grouping set index if there was no explicit GROUP_ID() function call |
| 291 | + if (groupIdCalls.isEmpty()) { |
| 292 | + int groupingExprSize = |
| 293 | + Math.toIntExact( |
| 294 | + groupings.stream().flatMap(g -> g.getExpressions().stream()).distinct().count()); |
| 295 | + builder.remap(Remap.offset(0, groupingExprSize + aggCalls.size())); |
| 296 | + } else { |
| 297 | + // remap grouping set index at the field positions where the GROUP_ID() function calls were |
| 298 | + final int groupingFieldCount = |
| 299 | + Math.toIntExact(groupings.stream().flatMap(g -> g.getExpressions().stream()).count()); |
| 300 | + final int filterAggCallCount = aggCalls.size(); |
| 301 | + final Integer groupingSetIndex = groupingFieldCount + filterAggCallCount; |
| 302 | + |
| 303 | + final List<Integer> remap = |
| 304 | + IntStream.range(0, groupingFieldCount) |
| 305 | + .mapToObj(i -> i) |
| 306 | + .collect(Collectors.toCollection(ArrayList::new)); |
| 307 | + |
| 308 | + for (int i = 0; i < aggregate.getAggCallList().size(); i++) { |
| 309 | + AggregateCall aggCall = aggregate.getAggCallList().get(i); |
| 310 | + if (filteredAggCalls.contains(aggCall)) { |
| 311 | + remap.add( |
| 312 | + i + groupingFieldCount, filteredAggCalls.indexOf(aggCall) + groupingFieldCount); |
| 313 | + } else if (groupIdCalls.contains(aggCall)) { |
| 314 | + remap.add(i + groupingFieldCount, groupingSetIndex); |
| 315 | + } else { |
| 316 | + // this should never get triggered |
| 317 | + throw new IllegalStateException( |
| 318 | + "encountered AggregateCall that is neither in filteredAggCalls nor in groupIdCalls" |
| 319 | + + aggCall); |
| 320 | + } |
| 321 | + } |
| 322 | + |
| 323 | + builder.remap(Remap.of(remap)); |
| 324 | + } |
| 325 | + } |
| 326 | + |
| 327 | + return builder.build(); |
275 | 328 | } |
276 | 329 |
|
277 | 330 | Aggregate.Grouping fromGroupSet(ImmutableBitSet bitSet, Rel input) { |
|
0 commit comments