|
| 1 | +/* |
| 2 | + * Copyright OpenSearch Contributors |
| 3 | + * SPDX-License-Identifier: Apache-2.0 |
| 4 | + */ |
| 5 | + |
| 6 | +package org.opensearch.sql.calcite.plan; |
| 7 | + |
| 8 | +import java.util.ArrayList; |
| 9 | +import java.util.Collection; |
| 10 | +import java.util.List; |
| 11 | +import java.util.Set; |
| 12 | +import java.util.stream.Collectors; |
| 13 | +import java.util.stream.IntStream; |
| 14 | +import org.apache.calcite.plan.RelOptRuleCall; |
| 15 | +import org.apache.calcite.plan.RelRule; |
| 16 | +import org.apache.calcite.rel.logical.LogicalAggregate; |
| 17 | +import org.apache.calcite.rel.logical.LogicalProject; |
| 18 | +import org.apache.calcite.rex.RexCall; |
| 19 | +import org.apache.calcite.rex.RexInputRef; |
| 20 | +import org.apache.calcite.rex.RexNode; |
| 21 | +import org.apache.calcite.rex.RexUtil; |
| 22 | +import org.apache.calcite.sql.SqlKind; |
| 23 | +import org.apache.calcite.tools.RelBuilder; |
| 24 | +import org.apache.calcite.util.ImmutableBitSet; |
| 25 | +import org.apache.calcite.util.mapping.Mapping; |
| 26 | +import org.apache.calcite.util.mapping.Mappings; |
| 27 | +import org.apache.commons.lang3.tuple.Pair; |
| 28 | +import org.immutables.value.Value; |
| 29 | +import org.opensearch.sql.calcite.utils.CalciteUtils; |
| 30 | + |
| 31 | +/** |
| 32 | + * Planner rule that merge multiple agg group fields into a single one, on which all other group |
| 33 | + * fields depend. e.g. |
| 34 | + * |
| 35 | + * <p>stats ... by a, f1(a), f2(a) -> stats ... by a | eval `f1(a)` = f1(a), `f2(a)` = f2(a) |
| 36 | + * |
| 37 | + * <p>TODO: this rule could be expanded further for more cases: 1. support multiple base group |
| 38 | + * fields, e.g. stats ... by a, f1(a), b, f2(b), f3(a, b) -> stats ... by a, b | eval `f1(a)` = |
| 39 | + * f1(a), `f2(b)` = f2(b), `f3(a, b)` = f3(a, b) 2. support no base fields, e.g. stats ... by f1(a), |
| 40 | + * f2(a) -> stats ... by a | eval `f1(a)` = f1(a), `f2(a)` = f2(a) | fields - a Note that one of |
| 41 | + * these UDFs' output must have equivalent cardinality as `a`. |
| 42 | + */ |
| 43 | +@Value.Enclosing |
| 44 | +public class PPLAggGroupMergeRule extends RelRule<PPLAggGroupMergeRule.Config> { |
| 45 | + |
| 46 | + /** Creates a OpenSearchAggregateConvertRule. */ |
| 47 | + protected PPLAggGroupMergeRule(Config config) { |
| 48 | + super(config); |
| 49 | + } |
| 50 | + |
| 51 | + @Override |
| 52 | + public void onMatch(RelOptRuleCall call) { |
| 53 | + if (call.rels.length == 2) { |
| 54 | + final LogicalAggregate aggregate = call.rel(0); |
| 55 | + final LogicalProject project = call.rel(1); |
| 56 | + apply(call, aggregate, project); |
| 57 | + } else { |
| 58 | + throw new AssertionError( |
| 59 | + String.format( |
| 60 | + "The length of rels should be %s but got %s", |
| 61 | + this.operands.size(), call.rels.length)); |
| 62 | + } |
| 63 | + } |
| 64 | + |
| 65 | + public void apply(RelOptRuleCall call, LogicalAggregate aggregate, LogicalProject project) { |
| 66 | + List<Integer> groupSet = aggregate.getGroupSet().asList(); |
| 67 | + List<RexNode> groupNodes = |
| 68 | + groupSet.stream().map(group -> project.getProjects().get(group)).toList(); |
| 69 | + Pair<List<Integer>, List<Integer>> baseFieldsAndOthers = |
| 70 | + CalciteUtils.partition( |
| 71 | + groupSet, i -> project.getProjects().get(i).getKind() == SqlKind.INPUT_REF); |
| 72 | + List<Integer> baseGroupList = baseFieldsAndOthers.getLeft(); |
| 73 | + // TODO: support more base fields in the future. |
| 74 | + if (baseGroupList.size() != 1) return; |
| 75 | + Integer baseGroupField = baseGroupList.get(0); |
| 76 | + RexInputRef baseGroupRef = (RexInputRef) project.getProjects().get(baseGroupField); |
| 77 | + List<Integer> otherGroupList = baseFieldsAndOthers.getRight(); |
| 78 | + boolean allDependOnBaseField = |
| 79 | + otherGroupList.stream() |
| 80 | + .map(i -> project.getProjects().get(i)) |
| 81 | + .allMatch(node -> isDependentField(node, List.of(baseGroupRef))); |
| 82 | + if (!allDependOnBaseField) return; |
| 83 | + |
| 84 | + final RelBuilder relBuilder = call.builder(); |
| 85 | + relBuilder.push(project); |
| 86 | + |
| 87 | + relBuilder.aggregate( |
| 88 | + relBuilder.groupKey(ImmutableBitSet.of(baseGroupField)), aggregate.getAggCallList()); |
| 89 | + |
| 90 | + /* Build the final project-aggregate-project */ |
| 91 | + final Mapping mapping = |
| 92 | + Mappings.target( |
| 93 | + List.of(baseGroupRef.getIndex()), |
| 94 | + baseGroupRef.getIndex() + 1); // set source count greater than the max ref index |
| 95 | + List<RexNode> parentProjections = new ArrayList<>(RexUtil.apply(mapping, groupNodes)); |
| 96 | + List<RexNode> aggCallRefs = |
| 97 | + relBuilder.fields( |
| 98 | + IntStream.range(baseGroupList.size(), relBuilder.peek().getRowType().getFieldCount()) |
| 99 | + .boxed() |
| 100 | + .toList()); |
| 101 | + parentProjections.addAll(aggCallRefs); |
| 102 | + relBuilder.project(parentProjections); |
| 103 | + call.transformTo(relBuilder.build()); |
| 104 | + } |
| 105 | + |
| 106 | + /** Rule configuration. */ |
| 107 | + @Value.Immutable |
| 108 | + public interface Config extends RelRule.Config { |
| 109 | + Config GROUP_MERGE = |
| 110 | + ImmutablePPLAggGroupMergeRule.Config.builder() |
| 111 | + .build() |
| 112 | + .withOperandSupplier( |
| 113 | + b0 -> |
| 114 | + b0.operand(LogicalAggregate.class) |
| 115 | + .predicate(Config::containsMultipleGroupSets) |
| 116 | + .oneInput( |
| 117 | + b1 -> |
| 118 | + b1.operand(LogicalProject.class) |
| 119 | + .predicate(Config::containsDependentFields) |
| 120 | + .anyInputs())); |
| 121 | + |
| 122 | + static boolean containsMultipleGroupSets(LogicalAggregate aggregate) { |
| 123 | + return aggregate.getGroupSet().cardinality() > 1; |
| 124 | + } |
| 125 | + |
| 126 | + // Only rough predication here since we don't know which fields are group fields currently. |
| 127 | + static boolean containsDependentFields(LogicalProject project) { |
| 128 | + Set<RexNode> baseFields = |
| 129 | + project.getProjects().stream() |
| 130 | + .filter(node -> node.getKind() == SqlKind.INPUT_REF) |
| 131 | + .collect(Collectors.toUnmodifiableSet()); |
| 132 | + return project.getProjects().stream() |
| 133 | + .anyMatch(node -> PPLAggGroupMergeRule.isDependentField(node, baseFields)); |
| 134 | + } |
| 135 | + |
| 136 | + @Override |
| 137 | + default PPLAggGroupMergeRule toRule() { |
| 138 | + return new PPLAggGroupMergeRule(this); |
| 139 | + } |
| 140 | + } |
| 141 | + |
| 142 | + public static boolean isDependentField(RexNode node, Collection<RexNode> baseFields) { |
| 143 | + // Always view literal field as dependent field here since we can always implement a function |
| 144 | + // to transform a field into such a literal |
| 145 | + if (node.getKind() == SqlKind.LITERAL) return true; |
| 146 | + if (node.getKind() == SqlKind.INPUT_REF && baseFields.contains(node)) return true; |
| 147 | + if (node instanceof RexCall && ((RexCall) node).getOperator().isDeterministic()) { |
| 148 | + return ((RexCall) node) |
| 149 | + .getOperands().stream().allMatch(op -> isDependentField(op, baseFields)); |
| 150 | + } |
| 151 | + return false; |
| 152 | + } |
| 153 | +} |
0 commit comments