@@ -2415,9 +2415,9 @@ private String getAggFieldAlias(UnresolvedExpression aggregateFunction) {
24152415 public RelNode visitChart (Chart node , CalcitePlanContext context ) {
24162416 visitChildren (node , context );
24172417 ArgumentMap argMap = ArgumentMap .of (node .getArguments ());
2418+ ChartConfig config = ChartConfig .fromArguments (argMap );
24182419 List <UnresolvedExpression > groupExprList =
24192420 Stream .of (node .getRowSplit (), node .getColumnSplit ()).filter (Objects ::nonNull ).toList ();
2420- ChartConfig config = ChartConfig .fromArguments (argMap );
24212421 Aggregation aggregation =
24222422 new Aggregation (
24232423 List .of (node .getAggregationFunction ()), List .of (), groupExprList , null , List .of ());
@@ -2441,15 +2441,6 @@ public RelNode visitChart(Chart node, CalcitePlanContext context) {
24412441 return relBuilder .peek ();
24422442 }
24432443
2444- String aggFunctionName = getAggFunctionName (node .getAggregationFunction ());
2445- BuiltinFunctionName aggFunction =
2446- BuiltinFunctionName .of (aggFunctionName )
2447- .orElseThrow (
2448- () ->
2449- new IllegalArgumentException (
2450- StringUtils .format (
2451- "Unrecognized aggregation function: %s" , aggFunctionName )));
2452-
24532444 // Convert the column split to string if necessary: column split was supposed to be pivoted to
24542445 // column names. This guarantees that its type compatibility with useother and usenull
24552446 RexNode colSplit = relBuilder .field (1 );
@@ -2463,34 +2454,8 @@ public RelNode visitChart(Chart node, CalcitePlanContext context) {
24632454 }
24642455 relBuilder .project (relBuilder .field (0 ), colSplit , relBuilder .field (2 ));
24652456 RelNode aggregated = relBuilder .peek ();
2466-
24672457 // 1: column-split, 2: agg
2468- relBuilder .project (relBuilder .field (1 ), relBuilder .field (2 ));
2469- // Make sure that rows who don't have a column split not interfere grand total calculation
2470- relBuilder .filter (relBuilder .isNotNull (relBuilder .field (0 )));
2471- final String GRAND_TOTAL_COL = "__grand_total__" ;
2472- relBuilder .aggregate (
2473- relBuilder .groupKey (relBuilder .field (0 )),
2474- // Top-K semantic: Retain categories whose summed values are among the greatest
2475- relBuilder .sum (relBuilder .field (1 )).as (GRAND_TOTAL_COL )); // results: group key, agg calls
2476- RexNode grandTotal = relBuilder .field (GRAND_TOTAL_COL );
2477- // Apply sorting: keep the max values if top is set
2478- if (config .top ) {
2479- grandTotal = relBuilder .desc (grandTotal );
2480- }
2481- // Always set it to null last so that nulls don't interfere with top / bottom calculation
2482- grandTotal = relBuilder .nullsLast (grandTotal );
2483- RexNode rowNum =
2484- PlanUtils .makeOver (
2485- context ,
2486- BuiltinFunctionName .ROW_NUMBER ,
2487- relBuilder .literal (1 ), // dummy expression for row number calculation
2488- List .of (),
2489- List .of (),
2490- List .of (grandTotal ),
2491- WindowFrame .toCurrentRow ());
2492- relBuilder .projectPlus (relBuilder .alias (rowNum , PlanUtils .ROW_NUMBER_COLUMN_FOR_CHART ));
2493- RelNode ranked = relBuilder .build ();
2458+ RelNode ranked = rankByColumnSplit (context , 1 , 2 , config .top );
24942459
24952460 relBuilder .push (aggregated );
24962461 relBuilder .push (ranked );
@@ -2534,6 +2499,14 @@ public RelNode visitChart(Chart node, CalcitePlanContext context) {
25342499 relBuilder .field (0 ),
25352500 relBuilder .alias (columnSplitExpr , columnSplitName ),
25362501 relBuilder .field (2 ));
2502+ String aggFunctionName = getAggFunctionName (node .getAggregationFunction ());
2503+ BuiltinFunctionName aggFunction =
2504+ BuiltinFunctionName .of (aggFunctionName )
2505+ .orElseThrow (
2506+ () ->
2507+ new IllegalArgumentException (
2508+ StringUtils .format (
2509+ "Unrecognized aggregation function: %s" , aggFunctionName )));
25372510 relBuilder .aggregate (
25382511 relBuilder .groupKey (relBuilder .field (0 ), relBuilder .field (1 )),
25392512 buildAggCall (context .relBuilder , aggFunction , relBuilder .field (2 )).as (aggFieldName ));
@@ -2542,6 +2515,42 @@ public RelNode visitChart(Chart node, CalcitePlanContext context) {
25422515 return relBuilder .peek ();
25432516 }
25442517
2518+ /**
2519+ * Aggregate by column split then rank by grand total (summed value of each category). The output
2520+ * is <code>[col-split, grand-total, row-number]</code>
2521+ */
2522+ private RelNode rankByColumnSplit (
2523+ CalcitePlanContext context , int columnSplitOrdinal , int aggOrdinal , boolean top ) {
2524+ RelBuilder relBuilder = context .relBuilder ;
2525+
2526+ relBuilder .project (relBuilder .field (columnSplitOrdinal ), relBuilder .field (aggOrdinal ));
2527+ // Make sure that rows who don't have a column split not interfere grand total calculation
2528+ relBuilder .filter (relBuilder .isNotNull (relBuilder .field (0 )));
2529+ final String GRAND_TOTAL_COL = "__grand_total__" ;
2530+ relBuilder .aggregate (
2531+ relBuilder .groupKey (relBuilder .field (0 )),
2532+ // Top-K semantic: Retain categories whose summed values are among the greatest
2533+ relBuilder .sum (relBuilder .field (1 )).as (GRAND_TOTAL_COL )); // results: group key, agg calls
2534+ RexNode grandTotal = relBuilder .field (GRAND_TOTAL_COL );
2535+ // Apply sorting: keep the max values if top is set
2536+ if (top ) {
2537+ grandTotal = relBuilder .desc (grandTotal );
2538+ }
2539+ // Always set it to null last so that nulls don't interfere with top / bottom calculation
2540+ grandTotal = relBuilder .nullsLast (grandTotal );
2541+ RexNode rowNum =
2542+ PlanUtils .makeOver (
2543+ context ,
2544+ BuiltinFunctionName .ROW_NUMBER ,
2545+ relBuilder .literal (1 ), // dummy expression for row number calculation
2546+ List .of (),
2547+ List .of (),
2548+ List .of (grandTotal ),
2549+ WindowFrame .toCurrentRow ());
2550+ relBuilder .projectPlus (relBuilder .alias (rowNum , PlanUtils .ROW_NUMBER_COLUMN_FOR_CHART ));
2551+ return relBuilder .build ();
2552+ }
2553+
25452554 @ AllArgsConstructor
25462555 private static class ChartConfig {
25472556 private final int limit ;
0 commit comments