Skip to content

Commit df809bd

Browse files
committed
Remove unimplemented support for multiple aggregations in chart command
Signed-off-by: Yuanchun Shen <[email protected]>
1 parent 6b8934e commit df809bd

File tree

6 files changed

+46
-59
lines changed

6 files changed

+46
-59
lines changed

core/src/main/java/org/opensearch/sql/ast/tree/Chart.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ public class Chart extends UnresolvedPlan {
3434
private UnresolvedPlan child;
3535
private UnresolvedExpression rowSplit;
3636
private UnresolvedExpression columnSplit;
37-
private List<UnresolvedExpression> aggregationFunctions;
37+
private UnresolvedExpression aggregationFunction;
3838
private List<Argument> arguments;
3939

4040
@Override

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2018,30 +2018,31 @@ public RelNode visitChart(Chart node, CalcitePlanContext context) {
20182018
ArgumentMap argMap = ArgumentMap.of(node.getArguments());
20192019
List<UnresolvedExpression> groupExprList =
20202020
Stream.of(node.getRowSplit(), node.getColumnSplit()).filter(Objects::nonNull).toList();
2021-
Boolean useNull = (Boolean) argMap.getOrDefault("usenull", Chart.DEFAULT_USE_NULL).getValue();
2021+
ChartConfig config = ChartConfig.fromArguments(argMap);
20222022
Aggregation aggregation =
20232023
new Aggregation(
2024-
node.getAggregationFunctions(),
2024+
List.of(node.getAggregationFunction()),
20252025
List.of(),
20262026
groupExprList,
20272027
null,
2028-
List.of(new Argument(Argument.BUCKET_NULLABLE, AstDSL.booleanLiteral(useNull))));
2028+
List.of(new Argument(Argument.BUCKET_NULLABLE, AstDSL.booleanLiteral(config.useNull))));
20292029
RelNode aggregated = visitAggregation(aggregation, context);
20302030

20312031
// If row or column split does not present or limit equals 0, this is the same as `stats agg
2032-
// [group by col]`
2032+
// [group by col]` because all truncating is performed on the column split
20332033
Integer limit = (Integer) argMap.getOrDefault("limit", Chart.DEFAULT_LIMIT).getValue();
20342034
if (node.getRowSplit() == null || node.getColumnSplit() == null || Objects.equals(limit, 0)) {
20352035
return aggregated;
20362036
}
20372037

2038-
String aggFunctionName = getAggFunctionName(node.getAggregationFunctions().getFirst());
2039-
Optional<BuiltinFunctionName> aggFuncNameOptional = BuiltinFunctionName.of(aggFunctionName);
2040-
if (aggFuncNameOptional.isEmpty()) {
2041-
throw new IllegalArgumentException(
2042-
StringUtils.format("Unrecognized aggregation function: %s", aggFunctionName));
2043-
}
2044-
BuiltinFunctionName aggFunction = aggFuncNameOptional.get();
2038+
String aggFunctionName = getAggFunctionName(node.getAggregationFunction());
2039+
BuiltinFunctionName aggFunction =
2040+
BuiltinFunctionName.of(aggFunctionName)
2041+
.orElseThrow(
2042+
() ->
2043+
new IllegalArgumentException(
2044+
StringUtils.format(
2045+
"Unrecognized aggregation function: %s", aggFunctionName)));
20452046

20462047
// Convert the column split to string if necessary: column split was supposed to be pivoted to
20472048
// column names. This guarantees that its type compatibility with useother and usenull
@@ -2058,12 +2059,6 @@ public RelNode visitChart(Chart node, CalcitePlanContext context) {
20582059
relBuilder.project(relBuilder.field(0), relBuilder.field(1), colSplit);
20592060
aggregated = relBuilder.peek();
20602061

2061-
Boolean top = (Boolean) argMap.getOrDefault("top", Chart.DEFAULT_TOP).getValue();
2062-
Boolean useOther =
2063-
(Boolean) argMap.getOrDefault("useother", Chart.DEFAULT_USE_OTHER).getValue();
2064-
String otherStr = (String) argMap.getOrDefault("otherstr", Chart.DEFAULT_OTHER_STR).getValue();
2065-
String nullStr = (String) argMap.getOrDefault("nullstr", Chart.DEFAULT_NULL_STR).getValue();
2066-
20672062
// 0: agg; 2: column-split
20682063
relBuilder.project(relBuilder.field(0), relBuilder.field(2));
20692064
// 1: column split; 0: agg
@@ -2075,7 +2070,7 @@ public RelNode visitChart(Chart node, CalcitePlanContext context) {
20752070
// Apply sorting: for MIN/EARLIEST, reverse the top/bottom logic
20762071
boolean smallestFirst =
20772072
aggFunction == BuiltinFunctionName.MIN || aggFunction == BuiltinFunctionName.EARLIEST;
2078-
if (top != smallestFirst) {
2073+
if (config.top != smallestFirst) {
20792074
grandTotal = relBuilder.desc(grandTotal);
20802075
}
20812076

@@ -2108,26 +2103,26 @@ public RelNode visitChart(Chart node, CalcitePlanContext context) {
21082103
relBuilder.literal(limit));
21092104
RexNode nullCondition = relBuilder.isNull(colSplitPostJoin);
21102105
RexNode columnSplitExpr;
2111-
if (!useOther) {
2106+
if (!config.useOther) {
21122107
relBuilder.filter(lteCondition);
21132108
}
21142109

2115-
if (useNull) {
2110+
if (config.useNull) {
21162111
columnSplitExpr =
21172112
relBuilder.call(
21182113
SqlStdOperatorTable.CASE,
21192114
nullCondition,
2120-
relBuilder.literal(nullStr),
2115+
relBuilder.literal(config.nullStr),
21212116
lteCondition,
21222117
relBuilder.field(2),
2123-
relBuilder.literal(otherStr));
2118+
relBuilder.literal(config.otherStr));
21242119
} else {
21252120
columnSplitExpr =
21262121
relBuilder.call(
21272122
SqlStdOperatorTable.CASE,
21282123
lteCondition,
21292124
relBuilder.field(2),
2130-
relBuilder.literal(otherStr));
2125+
relBuilder.literal(config.otherStr));
21312126
}
21322127

21332128
String aggFieldName = relBuilder.peek().getRowType().getFieldNames().getFirst();
@@ -2141,6 +2136,21 @@ public RelNode visitChart(Chart node, CalcitePlanContext context) {
21412136
return relBuilder.peek();
21422137
}
21432138

2139+
private record ChartConfig(
2140+
int limit, boolean top, boolean useOther, boolean useNull, String otherStr, String nullStr) {
2141+
static ChartConfig fromArguments(ArgumentMap argMap) {
2142+
int limit = (Integer) argMap.getOrDefault("limit", Chart.DEFAULT_LIMIT).getValue();
2143+
boolean top = (Boolean) argMap.getOrDefault("top", Chart.DEFAULT_TOP).getValue();
2144+
boolean useOther =
2145+
(Boolean) argMap.getOrDefault("useother", Chart.DEFAULT_USE_OTHER).getValue();
2146+
boolean useNull = (Boolean) argMap.getOrDefault("usenull", Chart.DEFAULT_USE_NULL).getValue();
2147+
String otherStr =
2148+
(String) argMap.getOrDefault("otherstr", Chart.DEFAULT_OTHER_STR).getValue();
2149+
String nullStr = (String) argMap.getOrDefault("nullstr", Chart.DEFAULT_NULL_STR).getValue();
2150+
return new ChartConfig(limit, top, useOther, useNull, otherStr, nullStr);
2151+
}
2152+
}
2153+
21442154
/** Transforms timechart command into SQL-based operations. */
21452155
@Override
21462156
public RelNode visitTimechart(
@@ -2150,7 +2160,7 @@ public RelNode visitTimechart(
21502160
// Extract parameters
21512161
UnresolvedExpression spanExpr = node.getBinExpression();
21522162

2153-
List<UnresolvedExpression> groupExprList = Arrays.asList(spanExpr);
2163+
List<UnresolvedExpression> groupExprList;
21542164

21552165
// Handle no by field case
21562166
if (node.getByField() == null) {

ppl/src/main/antlr/OpenSearchPPLParser.g4

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,8 @@ reverseCommand
260260
;
261261

262262
chartCommand
263-
: CHART chartOptions* statsAggTerm (COMMA statsAggTerm)* (OVER rowSplit)? (BY columnSplit)?
264-
| CHART chartOptions* statsAggTerm (COMMA statsAggTerm)* BY rowSplit (COMMA)? columnSplit
263+
: CHART chartOptions* statsAggTerm (OVER rowSplit)? (BY columnSplit)?
264+
| CHART chartOptions* statsAggTerm BY rowSplit (COMMA)? columnSplit
265265
;
266266

267267
chartOptions

ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -605,15 +605,11 @@ public UnresolvedPlan visitChartCommand(OpenSearchPPLParser.ChartCommandContext
605605
UnresolvedExpression columnSplit =
606606
ctx.columnSplit() == null ? null : internalVisitExpression(ctx.columnSplit());
607607
List<Argument> arguments = ArgumentFactory.getArgumentList(ctx);
608-
List<UnresolvedExpression> aggList = parseAggTerms(ctx.statsAggTerm());
609-
if (aggList.size() > 1) {
610-
throw new IllegalArgumentException(
611-
"Chart command does not support multiple aggregation functions yet");
612-
}
608+
UnresolvedExpression aggFunction = parseAggTerms(List.of(ctx.statsAggTerm())).getFirst();
613609
return Chart.builder()
614610
.rowSplit(rowSplit)
615611
.columnSplit(columnSplit)
616-
.aggregationFunctions(aggList)
612+
.aggregationFunction(aggFunction)
617613
.arguments(arguments)
618614
.build();
619615
}

ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,7 +1277,7 @@ public void testChartCommandBasic() {
12771277
Chart.builder()
12781278
.child(relation("t"))
12791279
.columnSplit(alias("age", field("age")))
1280-
.aggregationFunctions(List.of(alias("count()", aggregate("count", AllFields.of()))))
1280+
.aggregationFunction(alias("count()", aggregate("count", AllFields.of())))
12811281
.arguments(emptyList())
12821282
.build());
12831283
}
@@ -1290,22 +1290,7 @@ public void testChartCommandWithRowSplit() {
12901290
.child(relation("t"))
12911291
.rowSplit(alias("status", field("status")))
12921292
.columnSplit(alias("age", field("age")))
1293-
.aggregationFunctions(List.of(alias("count()", aggregate("count", AllFields.of()))))
1294-
.arguments(emptyList())
1295-
.build());
1296-
}
1297-
1298-
@Test
1299-
public void testChartCommandWithMultipleAggregations() {
1300-
assertEqual(
1301-
"source=t | chart avg(salary), max(age) by department",
1302-
Chart.builder()
1303-
.child(relation("t"))
1304-
.columnSplit(alias("department", field("department")))
1305-
.aggregationFunctions(
1306-
List.of(
1307-
alias("avg(salary)", aggregate("avg", field("salary"))),
1308-
alias("max(age)", aggregate("max", field("age")))))
1293+
.aggregationFunction(alias("count()", aggregate("count", AllFields.of())))
13091294
.arguments(emptyList())
13101295
.build());
13111296
}
@@ -1317,7 +1302,7 @@ public void testChartCommandWithOptions() {
13171302
Chart.builder()
13181303
.child(relation("t"))
13191304
.columnSplit(alias("status", field("status")))
1320-
.aggregationFunctions(List.of(alias("count()", aggregate("count", AllFields.of()))))
1305+
.aggregationFunction(alias("count()", aggregate("count", AllFields.of())))
13211306
.arguments(
13221307
exprList(
13231308
argument("limit", intLiteral(10)),
@@ -1334,8 +1319,7 @@ public void testChartCommandWithAllOptions() {
13341319
Chart.builder()
13351320
.child(relation("t"))
13361321
.columnSplit(alias("gender", field("gender")))
1337-
.aggregationFunctions(
1338-
List.of(alias("avg(balance)", aggregate("avg", field("balance")))))
1322+
.aggregationFunction(alias("avg(balance)", aggregate("avg", field("balance"))))
13391323
.arguments(
13401324
exprList(
13411325
argument("limit", intLiteral(5)),
@@ -1354,7 +1338,7 @@ public void testChartCommandWithBottomLimit() {
13541338
Chart.builder()
13551339
.child(relation("t"))
13561340
.columnSplit(alias("category", field("category")))
1357-
.aggregationFunctions(List.of(alias("count()", aggregate("count", AllFields.of()))))
1341+
.aggregationFunction(alias("count()", aggregate("count", AllFields.of())))
13581342
.arguments(
13591343
exprList(argument("limit", intLiteral(3)), argument("top", booleanLiteral(false))))
13601344
.build());

ppl/src/test/java/org/opensearch/sql/ppl/utils/ArgumentFactoryTest.java

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import static org.opensearch.sql.ast.dsl.AstDSL.sort;
2121
import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral;
2222

23-
import com.google.common.collect.ImmutableList;
2423
import org.junit.Test;
2524
import org.opensearch.sql.ast.expression.AllFields;
2625
import org.opensearch.sql.ast.expression.Argument;
@@ -111,8 +110,7 @@ public void testChartCommandArguments() {
111110
Chart.builder()
112111
.child(relation("t"))
113112
.columnSplit(alias("age", field("age")))
114-
.aggregationFunctions(
115-
ImmutableList.of(alias("count()", aggregate("count", AllFields.of()))))
113+
.aggregationFunction(alias("count()", aggregate("count", AllFields.of())))
116114
.arguments(
117115
exprList(
118116
argument("limit", intLiteral(5)),
@@ -131,8 +129,7 @@ public void testChartCommandBottomArguments() {
131129
Chart.builder()
132130
.child(relation("t"))
133131
.columnSplit(alias("status", field("status")))
134-
.aggregationFunctions(
135-
ImmutableList.of(alias("count()", aggregate("count", AllFields.of()))))
132+
.aggregationFunction(alias("count()", aggregate("count", AllFields.of())))
136133
.arguments(
137134
exprList(argument("limit", intLiteral(3)), argument("top", booleanLiteral(false))))
138135
.build());

0 commit comments

Comments
 (0)