Skip to content

Commit 13cd26f

Browse files
committed
Move ranking by column split to a helper function
Signed-off-by: Yuanchun Shen <[email protected]>
1 parent ed159ba commit 13cd26f

File tree

1 file changed

+46
-37
lines changed

1 file changed

+46
-37
lines changed

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2415,9 +2415,9 @@ private String getAggFieldAlias(UnresolvedExpression aggregateFunction) {
24152415
public RelNode visitChart(Chart node, CalcitePlanContext context) {
24162416
visitChildren(node, context);
24172417
ArgumentMap argMap = ArgumentMap.of(node.getArguments());
2418+
ChartConfig config = ChartConfig.fromArguments(argMap);
24182419
List<UnresolvedExpression> groupExprList =
24192420
Stream.of(node.getRowSplit(), node.getColumnSplit()).filter(Objects::nonNull).toList();
2420-
ChartConfig config = ChartConfig.fromArguments(argMap);
24212421
Aggregation aggregation =
24222422
new Aggregation(
24232423
List.of(node.getAggregationFunction()), List.of(), groupExprList, null, List.of());
@@ -2441,15 +2441,6 @@ public RelNode visitChart(Chart node, CalcitePlanContext context) {
24412441
return relBuilder.peek();
24422442
}
24432443

2444-
String aggFunctionName = getAggFunctionName(node.getAggregationFunction());
2445-
BuiltinFunctionName aggFunction =
2446-
BuiltinFunctionName.of(aggFunctionName)
2447-
.orElseThrow(
2448-
() ->
2449-
new IllegalArgumentException(
2450-
StringUtils.format(
2451-
"Unrecognized aggregation function: %s", aggFunctionName)));
2452-
24532444
// Convert the column split to string if necessary: column split was supposed to be pivoted to
24542445
// column names. This guarantees that its type compatibility with useother and usenull
24552446
RexNode colSplit = relBuilder.field(1);
@@ -2463,34 +2454,8 @@ public RelNode visitChart(Chart node, CalcitePlanContext context) {
24632454
}
24642455
relBuilder.project(relBuilder.field(0), colSplit, relBuilder.field(2));
24652456
RelNode aggregated = relBuilder.peek();
2466-
24672457
// 1: column-split, 2: agg
2468-
relBuilder.project(relBuilder.field(1), relBuilder.field(2));
2469-
// Make sure that rows who don't have a column split not interfere grand total calculation
2470-
relBuilder.filter(relBuilder.isNotNull(relBuilder.field(0)));
2471-
final String GRAND_TOTAL_COL = "__grand_total__";
2472-
relBuilder.aggregate(
2473-
relBuilder.groupKey(relBuilder.field(0)),
2474-
// Top-K semantic: Retain categories whose summed values are among the greatest
2475-
relBuilder.sum(relBuilder.field(1)).as(GRAND_TOTAL_COL)); // results: group key, agg calls
2476-
RexNode grandTotal = relBuilder.field(GRAND_TOTAL_COL);
2477-
// Apply sorting: keep the max values if top is set
2478-
if (config.top) {
2479-
grandTotal = relBuilder.desc(grandTotal);
2480-
}
2481-
// Always set it to null last so that nulls don't interfere with top / bottom calculation
2482-
grandTotal = relBuilder.nullsLast(grandTotal);
2483-
RexNode rowNum =
2484-
PlanUtils.makeOver(
2485-
context,
2486-
BuiltinFunctionName.ROW_NUMBER,
2487-
relBuilder.literal(1), // dummy expression for row number calculation
2488-
List.of(),
2489-
List.of(),
2490-
List.of(grandTotal),
2491-
WindowFrame.toCurrentRow());
2492-
relBuilder.projectPlus(relBuilder.alias(rowNum, PlanUtils.ROW_NUMBER_COLUMN_FOR_CHART));
2493-
RelNode ranked = relBuilder.build();
2458+
RelNode ranked = rankByColumnSplit(context, 1, 2, config.top);
24942459

24952460
relBuilder.push(aggregated);
24962461
relBuilder.push(ranked);
@@ -2534,6 +2499,14 @@ public RelNode visitChart(Chart node, CalcitePlanContext context) {
25342499
relBuilder.field(0),
25352500
relBuilder.alias(columnSplitExpr, columnSplitName),
25362501
relBuilder.field(2));
2502+
String aggFunctionName = getAggFunctionName(node.getAggregationFunction());
2503+
BuiltinFunctionName aggFunction =
2504+
BuiltinFunctionName.of(aggFunctionName)
2505+
.orElseThrow(
2506+
() ->
2507+
new IllegalArgumentException(
2508+
StringUtils.format(
2509+
"Unrecognized aggregation function: %s", aggFunctionName)));
25372510
relBuilder.aggregate(
25382511
relBuilder.groupKey(relBuilder.field(0), relBuilder.field(1)),
25392512
buildAggCall(context.relBuilder, aggFunction, relBuilder.field(2)).as(aggFieldName));
@@ -2542,6 +2515,42 @@ public RelNode visitChart(Chart node, CalcitePlanContext context) {
25422515
return relBuilder.peek();
25432516
}
25442517

2518+
/**
2519+
* Aggregate by column split then rank by grand total (summed value of each category). The output
2520+
* is <code>[col-split, grand-total, row-number]</code>
2521+
*/
2522+
private RelNode rankByColumnSplit(
2523+
CalcitePlanContext context, int columnSplitOrdinal, int aggOrdinal, boolean top) {
2524+
RelBuilder relBuilder = context.relBuilder;
2525+
2526+
relBuilder.project(relBuilder.field(columnSplitOrdinal), relBuilder.field(aggOrdinal));
2527+
// Make sure that rows who don't have a column split not interfere grand total calculation
2528+
relBuilder.filter(relBuilder.isNotNull(relBuilder.field(0)));
2529+
final String GRAND_TOTAL_COL = "__grand_total__";
2530+
relBuilder.aggregate(
2531+
relBuilder.groupKey(relBuilder.field(0)),
2532+
// Top-K semantic: Retain categories whose summed values are among the greatest
2533+
relBuilder.sum(relBuilder.field(1)).as(GRAND_TOTAL_COL)); // results: group key, agg calls
2534+
RexNode grandTotal = relBuilder.field(GRAND_TOTAL_COL);
2535+
// Apply sorting: keep the max values if top is set
2536+
if (top) {
2537+
grandTotal = relBuilder.desc(grandTotal);
2538+
}
2539+
// Always set it to null last so that nulls don't interfere with top / bottom calculation
2540+
grandTotal = relBuilder.nullsLast(grandTotal);
2541+
RexNode rowNum =
2542+
PlanUtils.makeOver(
2543+
context,
2544+
BuiltinFunctionName.ROW_NUMBER,
2545+
relBuilder.literal(1), // dummy expression for row number calculation
2546+
List.of(),
2547+
List.of(),
2548+
List.of(grandTotal),
2549+
WindowFrame.toCurrentRow());
2550+
relBuilder.projectPlus(relBuilder.alias(rowNum, PlanUtils.ROW_NUMBER_COLUMN_FOR_CHART));
2551+
return relBuilder.build();
2552+
}
2553+
25452554
@AllArgsConstructor
25462555
private static class ChartConfig {
25472556
private final int limit;

0 commit comments

Comments
 (0)