Skip to content

Commit dabb710

Browse files
committed
Add unit tests for chart command
Signed-off-by: Yuanchun Shen <[email protected]>
1 parent df809bd commit dabb710

File tree

1 file changed

+381
-0
lines changed

1 file changed

+381
-0
lines changed
Lines changed: 381 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,381 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.ppl.calcite;
7+
8+
import com.google.common.collect.ImmutableList;
9+
import java.util.List;
10+
import lombok.RequiredArgsConstructor;
11+
import org.apache.calcite.DataContext;
12+
import org.apache.calcite.config.CalciteConnectionConfig;
13+
import org.apache.calcite.linq4j.Enumerable;
14+
import org.apache.calcite.linq4j.Linq4j;
15+
import org.apache.calcite.plan.RelTraitDef;
16+
import org.apache.calcite.rel.RelCollations;
17+
import org.apache.calcite.rel.RelNode;
18+
import org.apache.calcite.rel.type.RelDataType;
19+
import org.apache.calcite.rel.type.RelDataTypeFactory;
20+
import org.apache.calcite.rel.type.RelProtoDataType;
21+
import org.apache.calcite.schema.ScannableTable;
22+
import org.apache.calcite.schema.Schema;
23+
import org.apache.calcite.schema.SchemaPlus;
24+
import org.apache.calcite.schema.Statistic;
25+
import org.apache.calcite.schema.Statistics;
26+
import org.apache.calcite.sql.SqlCall;
27+
import org.apache.calcite.sql.SqlNode;
28+
import org.apache.calcite.sql.parser.SqlParser;
29+
import org.apache.calcite.sql.type.SqlTypeName;
30+
import org.apache.calcite.test.CalciteAssert;
31+
import org.apache.calcite.tools.Frameworks;
32+
import org.apache.calcite.tools.Programs;
33+
import org.checkerframework.checker.nullness.qual.Nullable;
34+
import org.junit.Test;
35+
import org.opensearch.sql.ast.tree.UnresolvedPlan;
36+
import org.opensearch.sql.ppl.antlr.PPLSyntaxParser;
37+
import org.opensearch.sql.ppl.parser.AstBuilder;
38+
39+
public class CalcitePPLChartTest extends CalcitePPLAbstractTest {
40+
41+
public CalcitePPLChartTest() {
42+
super(CalciteAssert.SchemaSpec.SCOTT_WITH_TEMPORAL);
43+
}
44+
45+
@Override
46+
protected Frameworks.ConfigBuilder config(CalciteAssert.SchemaSpec... schemaSpecs) {
47+
final SchemaPlus rootSchema = Frameworks.createRootSchema(true);
48+
final SchemaPlus schema = CalciteAssert.addSchema(rootSchema, schemaSpecs);
49+
// Add events table for chart tests - similar to bank data used in integration tests
50+
ImmutableList<Object[]> rows =
51+
ImmutableList.of(
52+
new Object[] {32838, "F", 28, "VA", java.sql.Timestamp.valueOf("2024-07-01 00:00:00")},
53+
new Object[] {40540, "F", 39, "PA", java.sql.Timestamp.valueOf("2024-07-01 00:01:00")},
54+
new Object[] {39225, "M", 32, "IL", java.sql.Timestamp.valueOf("2024-07-01 00:02:00")},
55+
new Object[] {4180, "M", 33, "MD", java.sql.Timestamp.valueOf("2024-07-01 00:03:00")},
56+
new Object[] {11052, "M", 36, "WA", java.sql.Timestamp.valueOf("2024-07-01 00:04:00")},
57+
new Object[] {48086, "F", 34, "IN", java.sql.Timestamp.valueOf("2024-07-01 00:05:00")});
58+
schema.add("bank", new BankTable(rows));
59+
60+
// Add time_data table for span tests
61+
ImmutableList<Object[]> timeRows =
62+
ImmutableList.of(
63+
new Object[] {java.sql.Timestamp.valueOf("2025-07-28 00:00:00"), "A", 9367},
64+
new Object[] {java.sql.Timestamp.valueOf("2025-07-29 00:00:00"), "B", 9521},
65+
new Object[] {java.sql.Timestamp.valueOf("2025-07-30 00:00:00"), "C", 9187},
66+
new Object[] {java.sql.Timestamp.valueOf("2025-07-31 00:00:00"), "D", 8736},
67+
new Object[] {java.sql.Timestamp.valueOf("2025-08-01 00:00:00"), "A", 9015});
68+
schema.add("time_data", new TimeDataTable(timeRows));
69+
70+
return Frameworks.newConfigBuilder()
71+
.parserConfig(SqlParser.Config.DEFAULT)
72+
.defaultSchema(schema)
73+
.traitDefs((List<RelTraitDef>) null)
74+
.programs(Programs.heuristicJoinOrder(Programs.RULE_SET, true, 2));
75+
}
76+
77+
@Test
78+
public void testChartWithSingleGroupKey() {
79+
String ppl = "source=bank | chart avg(balance) by gender";
80+
81+
RelNode root = getRelNode(ppl);
82+
String expectedSparkSql =
83+
"SELECT AVG(`balance`) `avg(balance)`, `gender`\n"
84+
+ "FROM `scott`.`bank`\n"
85+
+ "GROUP BY `gender`";
86+
verifyPPLToSparkSQL(root, expectedSparkSql);
87+
}
88+
89+
@Test
90+
public void testChartWithOverSyntax() {
91+
String ppl = "source=bank | chart avg(balance) over gender";
92+
93+
RelNode root = getRelNode(ppl);
94+
String expectedSparkSql =
95+
"SELECT AVG(`balance`) `avg(balance)`, `gender`\n"
96+
+ "FROM `scott`.`bank`\n"
97+
+ "GROUP BY `gender`";
98+
verifyPPLToSparkSQL(root, expectedSparkSql);
99+
}
100+
101+
@Test
102+
public void testChartWithMultipleGroupKeys() {
103+
String ppl = "source=bank | chart avg(balance) over gender by age";
104+
105+
RelNode root = getRelNode(ppl);
106+
String expectedSparkSql =
107+
"SELECT `t1`.`gender`, CASE WHEN `t1`.`age` IS NULL THEN 'NULL' WHEN `t6`.`__row_number__`"
108+
+ " <= 10 THEN `t1`.`age` ELSE 'OTHER' END `age`, AVG(`t1`.`avg(balance)`)"
109+
+ " `avg(balance)`\n"
110+
+ "FROM (SELECT AVG(`balance`) `avg(balance)`, `gender`, SAFE_CAST(`age` AS STRING)"
111+
+ " `age`\n"
112+
+ "FROM `scott`.`bank`\n"
113+
+ "GROUP BY `gender`, `age`) `t1`\n"
114+
+ "LEFT JOIN (SELECT `age`, AVG(`avg(balance)`) `__grand_total__`, ROW_NUMBER() OVER"
115+
+ " (ORDER BY AVG(`avg(balance)`) DESC) `__row_number__`\n"
116+
+ "FROM (SELECT AVG(`balance`) `avg(balance)`, SAFE_CAST(`age` AS STRING) `age`\n"
117+
+ "FROM `scott`.`bank`\n"
118+
+ "GROUP BY `gender`, `age`) `t4`\n"
119+
+ "GROUP BY `age`) `t6` ON `t1`.`age` = `t6`.`age`\n"
120+
+ "GROUP BY `t1`.`gender`, CASE WHEN `t1`.`age` IS NULL THEN 'NULL' WHEN"
121+
+ " `t6`.`__row_number__` <= 10 THEN `t1`.`age` ELSE 'OTHER' END";
122+
verifyPPLToSparkSQL(root, expectedSparkSql);
123+
}
124+
125+
@Test
126+
public void testChartWithMultipleGroupKeysAlternativeSyntax() {
127+
String ppl = "source=bank | chart avg(balance) by gender, age";
128+
129+
RelNode root = getRelNode(ppl);
130+
String expectedSparkSql =
131+
"SELECT `t1`.`gender`, CASE WHEN `t1`.`age` IS NULL THEN 'NULL' WHEN `t6`.`__row_number__`"
132+
+ " <= 10 THEN `t1`.`age` ELSE 'OTHER' END `age`, AVG(`t1`.`avg(balance)`)"
133+
+ " `avg(balance)`\n"
134+
+ "FROM (SELECT AVG(`balance`) `avg(balance)`, `gender`, SAFE_CAST(`age` AS STRING)"
135+
+ " `age`\n"
136+
+ "FROM `scott`.`bank`\n"
137+
+ "GROUP BY `gender`, `age`) `t1`\n"
138+
+ "LEFT JOIN (SELECT `age`, AVG(`avg(balance)`) `__grand_total__`, ROW_NUMBER() OVER"
139+
+ " (ORDER BY AVG(`avg(balance)`) DESC) `__row_number__`\n"
140+
+ "FROM (SELECT AVG(`balance`) `avg(balance)`, SAFE_CAST(`age` AS STRING) `age`\n"
141+
+ "FROM `scott`.`bank`\n"
142+
+ "GROUP BY `gender`, `age`) `t4`\n"
143+
+ "GROUP BY `age`) `t6` ON `t1`.`age` = `t6`.`age`\n"
144+
+ "GROUP BY `t1`.`gender`, CASE WHEN `t1`.`age` IS NULL THEN 'NULL' WHEN"
145+
+ " `t6`.`__row_number__` <= 10 THEN `t1`.`age` ELSE 'OTHER' END";
146+
verifyPPLToSparkSQL(root, expectedSparkSql);
147+
}
148+
149+
@Test
150+
public void testChartWithLimit() {
151+
String ppl = "source=bank | chart limit=2 avg(balance) by gender";
152+
153+
RelNode root = getRelNode(ppl);
154+
String expectedSparkSql =
155+
"SELECT AVG(`balance`) `avg(balance)`, `gender`\n"
156+
+ "FROM `scott`.`bank`\n"
157+
+ "GROUP BY `gender`";
158+
verifyPPLToSparkSQL(root, expectedSparkSql);
159+
}
160+
161+
@Test
162+
public void testChartWithLimitZero() {
163+
String ppl = "source=bank | chart limit=0 avg(balance) over state by gender";
164+
165+
RelNode root = getRelNode(ppl);
166+
String expectedSparkSql =
167+
"SELECT AVG(`balance`) `avg(balance)`, `state`, `gender`\n"
168+
+ "FROM `scott`.`bank`\n"
169+
+ "GROUP BY `state`, `gender`";
170+
verifyPPLToSparkSQL(root, expectedSparkSql);
171+
}
172+
173+
@Test
174+
public void testChartWithSpan() {
175+
String ppl = "source=bank | chart max(balance) by age span=10";
176+
177+
RelNode root = getRelNode(ppl);
178+
String expectedSparkSql =
179+
"SELECT MAX(`balance`) `max(balance)`, `SPAN`(`age`, 10, NULL) `age`\n"
180+
+ "FROM `scott`.`bank`\n"
181+
+ "GROUP BY `SPAN`(`age`, 10, NULL)";
182+
verifyPPLToSparkSQL(root, expectedSparkSql);
183+
}
184+
185+
@Test
186+
public void testChartWithTimeSpan() {
187+
String ppl = "source=time_data | chart max(value) over timestamp span=1week by category";
188+
189+
RelNode root = getRelNode(ppl);
190+
String expectedSparkSql =
191+
"SELECT `t1`.`timestamp`, CASE WHEN `t1`.`category` IS NULL THEN 'NULL' WHEN"
192+
+ " `t6`.`__row_number__` <= 10 THEN `t1`.`category` ELSE 'OTHER' END `category`,"
193+
+ " MAX(`t1`.`max(value)`) `max(value)`\n"
194+
+ "FROM (SELECT MAX(`value`) `max(value)`, `SPAN`(`timestamp`, 1, 'w') `timestamp`,"
195+
+ " `category`\n"
196+
+ "FROM `scott`.`time_data`\n"
197+
+ "GROUP BY `category`, `SPAN`(`timestamp`, 1, 'w')) `t1`\n"
198+
+ "LEFT JOIN (SELECT `category`, MAX(`max(value)`) `__grand_total__`, ROW_NUMBER() OVER"
199+
+ " (ORDER BY MAX(`max(value)`) DESC) `__row_number__`\n"
200+
+ "FROM (SELECT MAX(`value`) `max(value)`, `category`\n"
201+
+ "FROM `scott`.`time_data`\n"
202+
+ "GROUP BY `category`, `SPAN`(`timestamp`, 1, 'w')) `t4`\n"
203+
+ "GROUP BY `category`) `t6` ON `t1`.`category` = `t6`.`category`\n"
204+
+ "GROUP BY `t1`.`timestamp`, CASE WHEN `t1`.`category` IS NULL THEN 'NULL' WHEN"
205+
+ " `t6`.`__row_number__` <= 10 THEN `t1`.`category` ELSE 'OTHER' END";
206+
verifyPPLToSparkSQL(root, expectedSparkSql);
207+
}
208+
209+
@Test
210+
public void testChartWithUseOtherTrue() {
211+
String ppl = "source=bank | chart useother=true avg(balance) by gender";
212+
213+
RelNode root = getRelNode(ppl);
214+
String expectedSparkSql =
215+
"SELECT AVG(`balance`) `avg(balance)`, `gender`\n"
216+
+ "FROM `scott`.`bank`\n"
217+
+ "GROUP BY `gender`";
218+
verifyPPLToSparkSQL(root, expectedSparkSql);
219+
}
220+
221+
@Test
222+
public void testChartWithUseOtherFalse() {
223+
String ppl = "source=bank | chart useother=false limit=2 avg(balance) by gender";
224+
225+
RelNode root = getRelNode(ppl);
226+
String expectedSparkSql =
227+
"SELECT AVG(`balance`) `avg(balance)`, `gender`\n"
228+
+ "FROM `scott`.`bank`\n"
229+
+ "GROUP BY `gender`";
230+
verifyPPLToSparkSQL(root, expectedSparkSql);
231+
}
232+
233+
@Test
234+
public void testChartWithOtherStr() {
235+
String ppl = "source=bank | chart limit=1 otherstr='other_values' avg(balance) by gender";
236+
237+
RelNode root = getRelNode(ppl);
238+
String expectedSparkSql =
239+
"SELECT AVG(`balance`) `avg(balance)`, `gender`\n"
240+
+ "FROM `scott`.`bank`\n"
241+
+ "GROUP BY `gender`";
242+
verifyPPLToSparkSQL(root, expectedSparkSql);
243+
}
244+
245+
@Test
246+
public void testChartWithNullStr() {
247+
String ppl = "source=bank | chart nullstr='null_values' avg(balance) by gender";
248+
249+
RelNode root = getRelNode(ppl);
250+
String expectedSparkSql =
251+
"SELECT AVG(`balance`) `avg(balance)`, `gender`\n"
252+
+ "FROM `scott`.`bank`\n"
253+
+ "GROUP BY `gender`";
254+
verifyPPLToSparkSQL(root, expectedSparkSql);
255+
}
256+
257+
@Test
258+
public void testChartWithUseNull() {
259+
String ppl = "source=bank | chart usenull=false avg(balance) by gender";
260+
261+
RelNode root = getRelNode(ppl);
262+
String expectedSparkSql =
263+
"SELECT AVG(`balance`) `avg(balance)`, `gender`\n"
264+
+ "FROM `scott`.`bank`\n"
265+
+ "WHERE `gender` IS NOT NULL\n"
266+
+ "GROUP BY `gender`";
267+
verifyPPLToSparkSQL(root, expectedSparkSql);
268+
}
269+
270+
private UnresolvedPlan parsePPL(String query) {
271+
PPLSyntaxParser parser = new PPLSyntaxParser();
272+
AstBuilder astBuilder = new AstBuilder(query);
273+
return astBuilder.visit(parser.parse(query));
274+
}
275+
276+
@RequiredArgsConstructor
277+
public static class BankTable implements ScannableTable {
278+
private final ImmutableList<Object[]> rows;
279+
280+
protected final RelProtoDataType protoRowType =
281+
factory ->
282+
factory
283+
.builder()
284+
.add("balance", SqlTypeName.INTEGER)
285+
.nullable(true)
286+
.add("gender", SqlTypeName.VARCHAR)
287+
.nullable(true)
288+
.add("age", SqlTypeName.INTEGER)
289+
.nullable(true)
290+
.add("state", SqlTypeName.VARCHAR)
291+
.nullable(true)
292+
.add("timestamp", SqlTypeName.TIMESTAMP)
293+
.nullable(true)
294+
.build();
295+
296+
@Override
297+
public Enumerable<@Nullable Object[]> scan(DataContext root) {
298+
return Linq4j.asEnumerable(rows);
299+
}
300+
301+
@Override
302+
public RelDataType getRowType(RelDataTypeFactory typeFactory) {
303+
return protoRowType.apply(typeFactory);
304+
}
305+
306+
@Override
307+
public Statistic getStatistic() {
308+
return Statistics.of(0d, ImmutableList.of(), RelCollations.createSingleton(0));
309+
}
310+
311+
@Override
312+
public Schema.TableType getJdbcTableType() {
313+
return Schema.TableType.TABLE;
314+
}
315+
316+
@Override
317+
public boolean isRolledUp(String column) {
318+
return false;
319+
}
320+
321+
@Override
322+
public boolean rolledUpColumnValidInsideAgg(
323+
String column,
324+
SqlCall call,
325+
@Nullable SqlNode parent,
326+
@Nullable CalciteConnectionConfig config) {
327+
return false;
328+
}
329+
}
330+
331+
@RequiredArgsConstructor
332+
public static class TimeDataTable implements ScannableTable {
333+
private final ImmutableList<Object[]> rows;
334+
335+
protected final RelProtoDataType protoRowType =
336+
factory ->
337+
factory
338+
.builder()
339+
.add("timestamp", SqlTypeName.TIMESTAMP)
340+
.nullable(true)
341+
.add("category", SqlTypeName.VARCHAR)
342+
.nullable(true)
343+
.add("value", SqlTypeName.INTEGER)
344+
.nullable(true)
345+
.build();
346+
347+
@Override
348+
public Enumerable<@Nullable Object[]> scan(DataContext root) {
349+
return Linq4j.asEnumerable(rows);
350+
}
351+
352+
@Override
353+
public RelDataType getRowType(RelDataTypeFactory typeFactory) {
354+
return protoRowType.apply(typeFactory);
355+
}
356+
357+
@Override
358+
public Statistic getStatistic() {
359+
return Statistics.of(0d, ImmutableList.of(), RelCollations.createSingleton(0));
360+
}
361+
362+
@Override
363+
public Schema.TableType getJdbcTableType() {
364+
return Schema.TableType.TABLE;
365+
}
366+
367+
@Override
368+
public boolean isRolledUp(String column) {
369+
return false;
370+
}
371+
372+
@Override
373+
public boolean rolledUpColumnValidInsideAgg(
374+
String column,
375+
SqlCall call,
376+
@Nullable SqlNode parent,
377+
@Nullable CalciteConnectionConfig config) {
378+
return false;
379+
}
380+
}
381+
}

0 commit comments

Comments
 (0)