From 927fbfa73c63f89dd33f96e2ac1ceca476558d68 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Fri, 31 Jan 2025 15:32:42 -0800 Subject: [PATCH 01/36] WMA Signed-off-by: Andy Kwok --- .../opensearch/sql/ast/tree/Trendline.java | 3 +- .../planner/physical/TrendlineOperator.java | 56 ++++++++++++++++++- ppl/src/main/antlr/OpenSearchPPLLexer.g4 | 1 + ppl/src/main/antlr/OpenSearchPPLParser.g4 | 1 + 4 files changed, 57 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java b/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java index aa4fcc200d..3f9f9e2fbc 100644 --- a/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java @@ -66,6 +66,7 @@ public R accept(AbstractNodeVisitor nodeVisitor, C context) { } public enum TrendlineType { - SMA + SMA, + WMA } } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index 7bf10964cf..fdc53f96db 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -11,6 +11,7 @@ import com.google.common.collect.ImmutableMap.Builder; import java.time.Instant; import java.time.LocalTime; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -21,6 +22,7 @@ import lombok.ToString; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.data.model.ExprDoubleValue; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -106,9 +108,10 @@ private Map consumeInputTuple(ExprValue inputValue) { private static TrendlineAccumulator createAccumulator( Pair computation) { - // Add a switch statement based on computation type to choose the accumulator when more - // types of computations are supported. - return new SimpleMovingAverageAccumulator(computation.getKey(), computation.getValue()); + return switch (computation.getKey().getComputationType()) { + case SMA -> new SimpleMovingAverageAccumulator(computation.getKey(), computation.getValue()); + case WMA -> new WeightedMovingAverageAccumulator(computation.getKey(), computation.getValue()); + }; } /** Maintains stateful information for calculating the trendline. */ @@ -187,6 +190,53 @@ public ExprValue calculate() { } } + private static class WeightedMovingAverageAccumulator implements TrendlineAccumulator { + private final LiteralExpression dataPointsNeeded; + private final ArrayList receivedValues; +// private final ArithmeticEvaluator evaluator; + + public WeightedMovingAverageAccumulator( + Trendline.TrendlineComputation computation, ExprCoreType type) { + dataPointsNeeded = DSL.literal(computation.getNumberOfDataPoints().doubleValue()); + receivedValues = new ArrayList<>(computation.getNumberOfDataPoints()+1); +// evaluator = TrendlineAccumulator.getEvaluator(type); + } + + @Override + public void accumulate(ExprValue value) { +// if (dataPointsNeeded.valueOf().integerValue() == 1) { +// receivedValues.add(value); +// return; +// } + + receivedValues.add(value); + + if (receivedValues.size() > dataPointsNeeded.valueOf().integerValue()) { + receivedValues.removeFirst(); + } + + } + + @Override + public ExprValue calculate() { + if (receivedValues.size() < dataPointsNeeded.valueOf().integerValue()) { + return null; + } else if (dataPointsNeeded.valueOf().integerValue() == 1) { + return receivedValues.getFirst(); + } + return computeWma(receivedValues); + } + } + + private static ExprValue computeWma(ArrayList receivedValues) { + double sum = 0D; + for (int i=0 ; i dataPoints); diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index c484f34a2a..badfed9a73 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -60,6 +60,7 @@ NUM: 'NUM'; // TRENDLINE KEYWORDS SMA: 'SMA'; +WMA: 'WMA'; // ARGUMENT KEYWORDS KEEPEMPTY: 'KEEPEMPTY'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index acae54b7d9..3e7e88717f 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -156,6 +156,7 @@ trendlineClause trendlineType : SMA + | WMA ; kmeansCommand From 3836f31f9bdcbf6301d4f2917d0f9f955c4e5a11 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Fri, 31 Jan 2025 16:40:33 -0800 Subject: [PATCH 02/36] Update switch Signed-off-by: Andy Kwok --- .../planner/physical/TrendlineOperator.java | 63 ++++++++++++++----- 1 file changed, 46 insertions(+), 17 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index fdc53f96db..7d8fe5a739 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -23,6 +23,7 @@ import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.data.model.ExprDoubleValue; +import org.opensearch.sql.data.model.ExprLongValue; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -193,28 +194,22 @@ public ExprValue calculate() { private static class WeightedMovingAverageAccumulator implements TrendlineAccumulator { private final LiteralExpression dataPointsNeeded; private final ArrayList receivedValues; -// private final ArithmeticEvaluator evaluator; + private final ExprCoreType type; + public WeightedMovingAverageAccumulator( Trendline.TrendlineComputation computation, ExprCoreType type) { - dataPointsNeeded = DSL.literal(computation.getNumberOfDataPoints().doubleValue()); - receivedValues = new ArrayList<>(computation.getNumberOfDataPoints()+1); -// evaluator = TrendlineAccumulator.getEvaluator(type); + this.dataPointsNeeded = DSL.literal(computation.getNumberOfDataPoints().doubleValue()); + this.receivedValues = new ArrayList<>(computation.getNumberOfDataPoints()+1); + this.type = type; } @Override public void accumulate(ExprValue value) { -// if (dataPointsNeeded.valueOf().integerValue() == 1) { -// receivedValues.add(value); -// return; -// } - receivedValues.add(value); - if (receivedValues.size() > dataPointsNeeded.valueOf().integerValue()) { receivedValues.removeFirst(); } - } @Override @@ -226,17 +221,51 @@ public ExprValue calculate() { } return computeWma(receivedValues); } - } - private static ExprValue computeWma(ArrayList receivedValues) { - double sum = 0D; - for (int i=0 ; i receivedValues) { + if (type == ExprCoreType.DOUBLE) { + return new ExprDoubleValue(calculateWmaInDouble(receivedValues)); + + } else if (type == ExprCoreType.DATE) { + return ExprValueUtils.dateValue( + ExprValueUtils.timestampValue(Instant.ofEpochMilli( + calculateWmaInLong(receivedValues))).dateValue()); + + } else if ( type == ExprCoreType.TIME) { + return ExprValueUtils.timeValue( + LocalTime.MIN.plus(calculateWmaInLong(receivedValues), MILLIS)); + + } else if (type == ExprCoreType.TIMESTAMP) { + return ExprValueUtils.timestampValue(Instant.ofEpochMilli( + calculateWmaInLong(receivedValues))); + } + return null; + } + + private double calculateWmaInDouble (ArrayList receivedValues) { + double sum = 0D; + for (int i=0 ; i receivedValues) { + long sum = 0L; + for (int i=0 ; i dataPoints); From 0fcd688f74c846606d019bd7d13ff43e3b698283 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Mon, 3 Feb 2025 11:59:15 -0800 Subject: [PATCH 03/36] Unit-test Signed-off-by: Andy Kwok --- .../planner/physical/TrendlineOperator.java | 28 +- .../physical/TrendlineOperatorTest.java | 289 ++++++++++++++++++ 2 files changed, 309 insertions(+), 8 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index 7d8fe5a739..45d3f05443 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -234,34 +234,46 @@ private ExprValue computeWma(ArrayList receivedValues) { } else if (type == ExprCoreType.DATE) { return ExprValueUtils.dateValue( ExprValueUtils.timestampValue(Instant.ofEpochMilli( - calculateWmaInLong(receivedValues))).dateValue()); + calculateWmaInTs(receivedValues))).dateValue()); } else if ( type == ExprCoreType.TIME) { return ExprValueUtils.timeValue( - LocalTime.MIN.plus(calculateWmaInLong(receivedValues), MILLIS)); + LocalTime.MIN.plus(calculateWmaInTime(receivedValues), MILLIS)); } else if (type == ExprCoreType.TIMESTAMP) { return ExprValueUtils.timestampValue(Instant.ofEpochMilli( - calculateWmaInLong(receivedValues))); + calculateWmaInTs(receivedValues))); } return null; } private double calculateWmaInDouble (ArrayList receivedValues) { double sum = 0D; + int totalWeight = (receivedValues.size()*(receivedValues.size()+1)) / 2; for (int i=0 ; i receivedValues) { + private long calculateWmaInTs (ArrayList receivedValues) { long sum = 0L; + int totalWeight = (receivedValues.size()*(receivedValues.size()+1)) / 2; for (int i=0 ; i receivedValues) { + long sum = 0L; + int totalWeight = (receivedValues.size()*(receivedValues.size()+1)) / 2; + for (int i=0 ; i Date: Mon, 3 Feb 2025 12:26:12 -0800 Subject: [PATCH 04/36] Integ-test Signed-off-by: Andy Kwok --- .../sql/ppl/TrendlineCommandIT.java | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java index 38baa0f01f..32ceb03305 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java @@ -75,4 +75,65 @@ public void testTrendlineWithSort() throws IOException { TEST_INDEX_BANK)); verifyDataRows(result, rows(new Object[] {null}), rows(44313.0), rows(39882.5)); } + + @Test + public void testTrendlineWma() throws IOException { + final JSONObject result = + executeQuery( + String.format( + "source=%s | where balance > 39000 | sort balance | trendline wma(2, balance) as" + + " balance_trend | fields balance_trend", + TEST_INDEX_BANK)); + System.out.println("Result (Base): " + result.toString()); + verifyDataRows(result, rows(new Object[] {null}), rows(45570.666666666664), rows(40101.666666666664)); + } + + @Test + public void testTrendlineMultipleFieldsWma() throws IOException { + final JSONObject result = + executeQuery( + String.format( + "source=%s | where balance > 39000 | sort balance | trendline wma(2, balance) as" + + " balance_trend wma(2, account_number) as account_number_trend | fields" + + " balance_trend, account_number_trend", + TEST_INDEX_BANK)); + verifyDataRows(result, rows(null, null), + rows(40101.666666666664,16.999999999999996), + rows(45570.666666666664,29.666666666666664)); + } + + @Test + public void testTrendlineOverwritesExistingFieldWma() throws IOException { + final JSONObject result = + executeQuery( + String.format( + "source=%s | where balance > 39000 | sort balance | trendline wma(2, balance) as" + + " age | fields age", + TEST_INDEX_BANK)); + System.out.println("Result (Overwrite) : " + result.toString()); + verifyDataRows(result, rows(new Object[] {null}), rows(40101.666666666664), rows(45570.666666666664)); + } + + @Test + public void testTrendlineNoAliasWma() throws IOException { + final JSONObject result = + executeQuery( + String.format( + "source=%s | where balance > 39000 | sort balance | trendline wma(2, balance) |" + + " fields balance_trendline", + TEST_INDEX_BANK)); + verifyDataRows(result, rows(new Object[] {null}), rows(40101.666666666664), rows(45570.666666666664)); + } + + @Test + public void testTrendlineWithSortWma() throws IOException { + final JSONObject result = + executeQuery( + String.format( + "source=%s | where balance > 39000 | trendline sort balance wma(2, balance) |" + + " fields balance_trendline", + TEST_INDEX_BANK)); + System.out.println("Result (WithSortWma): " + result.toString()); + verifyDataRows(result, rows(new Object[] {null}), rows(40101.666666666664), rows(45570.666666666664)); + } } From ab583708e49817d8290ee57a189ee5e47b1ad943 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Mon, 3 Feb 2025 12:51:17 -0800 Subject: [PATCH 05/36] Doc test Signed-off-by: Andy Kwok --- docs/user/ppl/cmd/trendline.rst | 80 +++++++++++++++++++++++++++++++-- 1 file changed, 76 insertions(+), 4 deletions(-) diff --git a/docs/user/ppl/cmd/trendline.rst b/docs/user/ppl/cmd/trendline.rst index e6df0d7a2c..178f52e144 100644 --- a/docs/user/ppl/cmd/trendline.rst +++ b/docs/user/ppl/cmd/trendline.rst @@ -13,7 +13,7 @@ Description ============ | Using ``trendline`` command to calculate moving averages of fields. -Syntax +Syntax - SMA (Simple Moving Average) ============ `TRENDLINE [sort <[+|-] sort-field>] SMA(number-of-datapoints, field) [AS alias] [SMA(number-of-datapoints, field) [AS alias]]...` @@ -23,8 +23,6 @@ Syntax * field: mandatory. The name of the field the moving average should be calculated for. * alias: optional. The name of the resulting column containing the moving average (defaults to the field name with "_trendline"). -At the moment only the Simple Moving Average (SMA) type is supported. - It is calculated like f[i]: The value of field 'f' in the i-th data-point @@ -70,7 +68,7 @@ PPL query:: | 15.5 | 30.5 | +------+-----------+ -Example 4: Calculate the moving average on one field without specifying an alias. +Example 3: Calculate the moving average on one field without specifying an alias. ================================================================================= The example shows how to calculate the moving average on one field. @@ -88,3 +86,77 @@ PPL query:: | 15.5 | +--------------------------+ +Syntax - WMA (Weighted Moving Average) +============ +`TRENDLINE [sort <[+|-] sort-field>] WMA(number-of-datapoints, field) [AS alias] [WMA(number-of-datapoints, field) [AS alias]]...` + +* [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first. +* sort-field: mandatory when sorting is used. The field used to sort. +* number-of-datapoints: mandatory. The number of datapoints to calculate the moving average (must be greater than zero). +* field: mandatory. The name of the field the moving average should be calculated for. +* alias: optional. The name of the resulting column containing the moving average (defaults to the field name with "_trendline"). + +It is calculated like + + f[i]: The value of field 'f' in the i-th data point + n: The number of data points in the moving window (period) + t: The current time index + w[i]: The weight assigned to the i-th data point, typically increasing for more recent points + + WMA(t) = ( Σ from i=t−n+1 to t of (w[i] * f[i]) ) / ( Σ from i=t−n+1 to t of w[i] ) + +Example 1: Calculate the weighted moving average on one field. +===================================================== + +The example shows how to calculate the weighted moving average on one field. + +PPL query:: + + os> source=accounts | trendline wma(2, account_number) as an | fields an; + fetched rows / total rows = 4/4 + +--------------------+ + | an | + |--------------------| + | null | + | 4.333333333333333 | + | 10.666666666666666 | + | 16.333333333333332 | + +--------------------+ + +Example 2: Calculate the weighted moving average on multiple fields. +=========================================================== + +The example shows how to calculate the weighted moving average on multiple fields. + +PPL query:: + + os> source=accounts | trendline wma(2, account_number) as an sma(2, age) as age_trend | fields an, age_trend ; + fetched rows / total rows = 4/4 + +--------------------+-----------+ + | an | age_trend | + |--------------------+-----------| + | null | null | + | 4.333333333333333 | 34.0 | + | 10.666666666666666 | 32.0 | + | 16.333333333333332 | 30.5 | + +--------------------+-----------+ + + +Example 3: Calculate the weighted moving average on one field without specifying an alias. +================================================================================= + +The example shows how to calculate the weighted moving average on one field. + +PPL query:: + + os> source=accounts | trendline wma(2, account_number) | fields account_number_trendline; + fetched rows / total rows = 4/4 + +--------------------------+ + | account_number_trendline | + |--------------------------| + | null | + | 4.333333333333333 | + | 10.666666666666666 | + | 16.333333333333332 | + +--------------------------+ + From 51d8395db250a1eec4d33a5b3cda586ff8af6684 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Mon, 3 Feb 2025 13:24:32 -0800 Subject: [PATCH 06/36] Spotless Signed-off-by: Andy Kwok --- .../planner/physical/TrendlineOperator.java | 65 ++-- .../physical/TrendlineOperatorTest.java | 330 +++++++++--------- .../sql/ppl/TrendlineCommandIT.java | 72 ++-- 3 files changed, 239 insertions(+), 228 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index 45d3f05443..24a22fa880 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -17,13 +17,13 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.function.Function; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.data.model.ExprDoubleValue; -import org.opensearch.sql.data.model.ExprLongValue; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -109,10 +109,11 @@ private Map consumeInputTuple(ExprValue inputValue) { private static TrendlineAccumulator createAccumulator( Pair computation) { - return switch (computation.getKey().getComputationType()) { - case SMA -> new SimpleMovingAverageAccumulator(computation.getKey(), computation.getValue()); - case WMA -> new WeightedMovingAverageAccumulator(computation.getKey(), computation.getValue()); - }; + return switch (computation.getKey().getComputationType()) { + case SMA -> new SimpleMovingAverageAccumulator(computation.getKey(), computation.getValue()); + case WMA -> new WeightedMovingAverageAccumulator( + computation.getKey(), computation.getValue()); + }; } /** Maintains stateful information for calculating the trendline. */ @@ -196,11 +197,10 @@ private static class WeightedMovingAverageAccumulator implements TrendlineAccumu private final ArrayList receivedValues; private final ExprCoreType type; - public WeightedMovingAverageAccumulator( - Trendline.TrendlineComputation computation, ExprCoreType type) { + Trendline.TrendlineComputation computation, ExprCoreType type) { this.dataPointsNeeded = DSL.literal(computation.getNumberOfDataPoints().doubleValue()); - this.receivedValues = new ArrayList<>(computation.getNumberOfDataPoints()+1); + this.receivedValues = new ArrayList<>(computation.getNumberOfDataPoints() + 1); this.type = type; } @@ -224,6 +224,7 @@ public ExprValue calculate() { /** * Compute WMA values from provided dataset with in ascending order. + * * @param receivedValues the dataset for WMA calculation, sorted in ascending order. * @return ExprValue which represent he result onf WMA. */ @@ -232,52 +233,44 @@ private ExprValue computeWma(ArrayList receivedValues) { return new ExprDoubleValue(calculateWmaInDouble(receivedValues)); } else if (type == ExprCoreType.DATE) { - return ExprValueUtils.dateValue( - ExprValueUtils.timestampValue(Instant.ofEpochMilli( - calculateWmaInTs(receivedValues))).dateValue()); - - } else if ( type == ExprCoreType.TIME) { + return ExprValueUtils.timestampValue( + Instant.ofEpochMilli( + calculateWmaInLong(receivedValues, i -> i.timestampValue().toEpochMilli()))); + } else if (type == ExprCoreType.TIME) { return ExprValueUtils.timeValue( - LocalTime.MIN.plus(calculateWmaInTime(receivedValues), MILLIS)); + LocalTime.MIN.plus( + calculateWmaInLong( + receivedValues, i -> MILLIS.between(LocalTime.MIN, i.timeValue())), + MILLIS)); } else if (type == ExprCoreType.TIMESTAMP) { - return ExprValueUtils.timestampValue(Instant.ofEpochMilli( - calculateWmaInTs(receivedValues))); + return ExprValueUtils.timestampValue( + Instant.ofEpochMilli( + calculateWmaInLong(receivedValues, i -> i.timestampValue().toEpochMilli()))); } - return null; + return null; } - private double calculateWmaInDouble (ArrayList receivedValues) { + private double calculateWmaInDouble(ArrayList receivedValues) { double sum = 0D; - int totalWeight = (receivedValues.size()*(receivedValues.size()+1)) / 2; - for (int i=0 ; i receivedValues) { + private long calculateWmaInLong( + ArrayList receivedValues, Function convertFunc) { long sum = 0L; - int totalWeight = (receivedValues.size()*(receivedValues.size()+1)) / 2; - for (int i=0 ; i receivedValues) { - long sum = 0L; - int totalWeight = (receivedValues.size()*(receivedValues.size()+1)) / 2; - for (int i=0 ; i dataPoints); diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index 1e0c3ae0d5..723e6db03c 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -401,49 +401,49 @@ public void calculates_simple_moving_average_timestamp() { public void calculates_weighted_moving_average_one_field_one_sample() { when(inputPlan.hasNext()).thenReturn(true, false); when(inputPlan.next()) - .thenReturn(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10))); + .thenReturn(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10))); var plan = - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(1, AstDSL.field("distance"), "distance_alias", WMA), - ExprCoreType.DOUBLE))); + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(1, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.DOUBLE))); plan.open(); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100)), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100)), + plan.next()); } @Test public void calculates_weighted_moving_average_one_field_two_samples() { when(inputPlan.hasNext()).thenReturn(true, true, false); when(inputPlan.next()) - .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); var plan = - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), - ExprCoreType.DOUBLE))); + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.DOUBLE))); plan.open(); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), + plan.next()); assertFalse(plan.hasNext()); } @@ -451,33 +451,33 @@ public void calculates_weighted_moving_average_one_field_two_samples() { public void calculates_weighted_moving_average_one_field_two_samples_three_rows() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) - .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); var plan = - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), - ExprCoreType.DOUBLE))); + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.DOUBLE))); plan.open(); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), + plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997)), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997)), + plan.next()); assertFalse(plan.hasNext()); } @@ -485,70 +485,83 @@ public void calculates_weighted_moving_average_one_field_two_samples_three_rows( public void calculates_weighted_moving_average_multiple_computations() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) - .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20))); + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20))); var plan = - new TrendlineOperator( - inputPlan, - Arrays.asList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), - ExprCoreType.DOUBLE), - Pair.of( - AstDSL.computation(2, AstDSL.field("time"), "time_alias", WMA), - ExprCoreType.DOUBLE))); + new TrendlineOperator( + inputPlan, + Arrays.asList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.DOUBLE), + Pair.of( + AstDSL.computation(2, AstDSL.field("time"), "time_alias", WMA), + ExprCoreType.DOUBLE))); plan.open(); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "distance", 200, "time", 20, "distance_alias", 166.66666666666663, "time_alias", 16.666666666666664)), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of( + "distance", + 200, + "time", + 20, + "distance_alias", + 166.66666666666663, + "time_alias", + 16.666666666666664)), + plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "distance", 200, "time", 20, "distance_alias", 199.99999999999997, "time_alias", 20.0)), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of( + "distance", + 200, + "time", + 20, + "distance_alias", + 199.99999999999997, + "time_alias", + 20.0)), + plan.next()); assertFalse(plan.hasNext()); } - @Test public void calculates_weighted_moving_average_one_field_two_samples_three_rows_null_value() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) - .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 300, "time", 10))); + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 300, "time", 10))); var plan = - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), - ExprCoreType.DOUBLE))); + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.DOUBLE))); plan.open(); assertTrue(plan.hasNext()); assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("time", 10)), plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), plan.next()); + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 300, "time", 10, "distance_alias", 266.66666666666663)), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 300, "time", 10, "distance_alias", 266.66666666666663)), + plan.next()); assertFalse(plan.hasNext()); } @@ -556,46 +569,46 @@ public void calculates_weighted_moving_average_one_field_two_samples_three_rows_ public void calculates_weighted_moving_average_date() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) - .thenReturn( - ExprValueUtils.tupleValue( - ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), - ExprValueUtils.tupleValue( - ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)))), - ExprValueUtils.tupleValue( - ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12))))); + .thenReturn( + ExprValueUtils.tupleValue( + ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), + ExprValueUtils.tupleValue( + ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)))), + ExprValueUtils.tupleValue( + ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12))))); var plan = - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("date"), "date_alias", WMA), - ExprCoreType.DATE))); + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("date"), "date_alias", WMA), + ExprCoreType.DATE))); plan.open(); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), + plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "date", - ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)), - "date_alias", - ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(4)))), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of( + "date", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)), + "date_alias", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(4)))), + plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "date", - ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12)), - "date_alias", - ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(10)))), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of( + "date", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12)), + "date_alias", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(10)))), + plan.next()); assertFalse(plan.hasNext()); } @@ -603,37 +616,37 @@ public void calculates_weighted_moving_average_date() { public void calculates_weighted_moving_average_time() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) - .thenReturn( - ExprValueUtils.tupleValue( - ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN))), - ExprValueUtils.tupleValue( - ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(6)))), - ExprValueUtils.tupleValue( - ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(12))))); + .thenReturn( + ExprValueUtils.tupleValue( + ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN))), + ExprValueUtils.tupleValue( + ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(6)))), + ExprValueUtils.tupleValue( + ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(12))))); var plan = - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("time"), "time_alias", WMA), - ExprCoreType.TIME))); + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("time"), "time_alias", WMA), + ExprCoreType.TIME))); plan.open(); assertTrue(plan.hasNext()); assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("time", LocalTime.MIN)), plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "time", LocalTime.MIN.plusHours(6), "time_alias", LocalTime.MIN.plusHours(4))), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of( + "time", LocalTime.MIN.plusHours(6), "time_alias", LocalTime.MIN.plusHours(4))), + plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "time", LocalTime.MIN.plusHours(12), "time_alias", LocalTime.MIN.plusHours(10))), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of( + "time", LocalTime.MIN.plusHours(12), "time_alias", LocalTime.MIN.plusHours(10))), + plan.next()); assertFalse(plan.hasNext()); } @@ -641,47 +654,46 @@ public void calculates_weighted_moving_average_time() { public void calculates_weighted_moving_average_timestamp() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) - .thenReturn( - ExprValueUtils.tupleValue( - ImmutableMap.of("timestamp", ExprValueUtils.timestampValue(Instant.EPOCH))), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "timestamp", ExprValueUtils.timestampValue(Instant.EPOCH.plusMillis(1000)))), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "timestamp", ExprValueUtils.timestampValue(Instant.EPOCH.plusMillis(1500))))); + .thenReturn( + ExprValueUtils.tupleValue( + ImmutableMap.of("timestamp", ExprValueUtils.timestampValue(Instant.EPOCH))), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "timestamp", ExprValueUtils.timestampValue(Instant.EPOCH.plusMillis(1000)))), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "timestamp", ExprValueUtils.timestampValue(Instant.EPOCH.plusMillis(1500))))); var plan = - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("timestamp"), "timestamp_alias", WMA), - ExprCoreType.TIMESTAMP))); + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("timestamp"), "timestamp_alias", WMA), + ExprCoreType.TIMESTAMP))); plan.open(); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("timestamp", Instant.EPOCH)), plan.next()); + ExprValueUtils.tupleValue(ImmutableMap.of("timestamp", Instant.EPOCH)), plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "timestamp", - Instant.EPOCH.plusMillis(1000), - "timestamp_alias", - Instant.EPOCH.plusMillis(666))), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of( + "timestamp", + Instant.EPOCH.plusMillis(1000), + "timestamp_alias", + Instant.EPOCH.plusMillis(666))), + plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "timestamp", - Instant.EPOCH.plusMillis(1500), - "timestamp_alias", - Instant.EPOCH.plusMillis(1333))), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of( + "timestamp", + Instant.EPOCH.plusMillis(1500), + "timestamp_alias", + Instant.EPOCH.plusMillis(1333))), + plan.next()); assertFalse(plan.hasNext()); } - } diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java index 32ceb03305..1db3472b54 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java @@ -79,61 +79,67 @@ public void testTrendlineWithSort() throws IOException { @Test public void testTrendlineWma() throws IOException { final JSONObject result = - executeQuery( - String.format( - "source=%s | where balance > 39000 | sort balance | trendline wma(2, balance) as" - + " balance_trend | fields balance_trend", - TEST_INDEX_BANK)); + executeQuery( + String.format( + "source=%s | where balance > 39000 | sort balance | trendline wma(2, balance) as" + + " balance_trend | fields balance_trend", + TEST_INDEX_BANK)); System.out.println("Result (Base): " + result.toString()); - verifyDataRows(result, rows(new Object[] {null}), rows(45570.666666666664), rows(40101.666666666664)); + verifyDataRows( + result, rows(new Object[] {null}), rows(45570.666666666664), rows(40101.666666666664)); } @Test public void testTrendlineMultipleFieldsWma() throws IOException { final JSONObject result = - executeQuery( - String.format( - "source=%s | where balance > 39000 | sort balance | trendline wma(2, balance) as" - + " balance_trend wma(2, account_number) as account_number_trend | fields" - + " balance_trend, account_number_trend", - TEST_INDEX_BANK)); - verifyDataRows(result, rows(null, null), - rows(40101.666666666664,16.999999999999996), - rows(45570.666666666664,29.666666666666664)); + executeQuery( + String.format( + "source=%s | where balance > 39000 | sort balance | trendline wma(2, balance) as" + + " balance_trend wma(2, account_number) as account_number_trend | fields" + + " balance_trend, account_number_trend", + TEST_INDEX_BANK)); + verifyDataRows( + result, + rows(null, null), + rows(40101.666666666664, 16.999999999999996), + rows(45570.666666666664, 29.666666666666664)); } @Test public void testTrendlineOverwritesExistingFieldWma() throws IOException { final JSONObject result = - executeQuery( - String.format( - "source=%s | where balance > 39000 | sort balance | trendline wma(2, balance) as" - + " age | fields age", - TEST_INDEX_BANK)); + executeQuery( + String.format( + "source=%s | where balance > 39000 | sort balance | trendline wma(2, balance) as" + + " age | fields age", + TEST_INDEX_BANK)); System.out.println("Result (Overwrite) : " + result.toString()); - verifyDataRows(result, rows(new Object[] {null}), rows(40101.666666666664), rows(45570.666666666664)); + verifyDataRows( + result, rows(new Object[] {null}), rows(40101.666666666664), rows(45570.666666666664)); } @Test public void testTrendlineNoAliasWma() throws IOException { final JSONObject result = - executeQuery( - String.format( - "source=%s | where balance > 39000 | sort balance | trendline wma(2, balance) |" - + " fields balance_trendline", - TEST_INDEX_BANK)); - verifyDataRows(result, rows(new Object[] {null}), rows(40101.666666666664), rows(45570.666666666664)); + executeQuery( + String.format( + "source=%s | where balance > 39000 | sort balance | trendline wma(2, balance) |" + + " fields balance_trendline", + TEST_INDEX_BANK)); + verifyDataRows( + result, rows(new Object[] {null}), rows(40101.666666666664), rows(45570.666666666664)); } @Test public void testTrendlineWithSortWma() throws IOException { final JSONObject result = - executeQuery( - String.format( - "source=%s | where balance > 39000 | trendline sort balance wma(2, balance) |" - + " fields balance_trendline", - TEST_INDEX_BANK)); + executeQuery( + String.format( + "source=%s | where balance > 39000 | trendline sort balance wma(2, balance) |" + + " fields balance_trendline", + TEST_INDEX_BANK)); System.out.println("Result (WithSortWma): " + result.toString()); - verifyDataRows(result, rows(new Object[] {null}), rows(40101.666666666664), rows(45570.666666666664)); + verifyDataRows( + result, rows(new Object[] {null}), rows(40101.666666666664), rows(45570.666666666664)); } } From e4c25b3cf75b58e370d460305e335d222ae03317 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Mon, 3 Feb 2025 15:04:37 -0800 Subject: [PATCH 07/36] Update test cases Signed-off-by: Andy Kwok --- .../planner/physical/TrendlineOperator.java | 103 ++++++++++-------- .../physical/TrendlineOperatorTest.java | 15 +++ 2 files changed, 75 insertions(+), 43 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index 24a22fa880..6043721170 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -195,13 +195,23 @@ public ExprValue calculate() { private static class WeightedMovingAverageAccumulator implements TrendlineAccumulator { private final LiteralExpression dataPointsNeeded; private final ArrayList receivedValues; - private final ExprCoreType type; + private final WmaTrendlineEvaluator evaluator; public WeightedMovingAverageAccumulator( Trendline.TrendlineComputation computation, ExprCoreType type) { this.dataPointsNeeded = DSL.literal(computation.getNumberOfDataPoints().doubleValue()); - this.receivedValues = new ArrayList<>(computation.getNumberOfDataPoints() + 1); - this.type = type; + this.receivedValues = new ArrayList<>(computation.getNumberOfDataPoints()); + this.evaluator = getEvaluator(type); + } + + static WmaTrendlineEvaluator getEvaluator(ExprCoreType type) { + return switch (type) { + case DOUBLE -> NumericWmaEvaluator.INSTANCE; + case DATE, TIMESTAMP -> TimeStampWmaEvaluator.INSTANCE; + case TIME -> TimeWmaEvaluator.INSTANCE; + default -> throw new IllegalArgumentException( + String.format("Invalid type %s used for weighted moving average.", type.typeName())); + }; } @Override @@ -219,58 +229,65 @@ public ExprValue calculate() { } else if (dataPointsNeeded.valueOf().integerValue() == 1) { return receivedValues.getFirst(); } - return computeWma(receivedValues); + return evaluator.evaluate(receivedValues); } - /** - * Compute WMA values from provided dataset with in ascending order. - * - * @param receivedValues the dataset for WMA calculation, sorted in ascending order. - * @return ExprValue which represent he result onf WMA. - */ - private ExprValue computeWma(ArrayList receivedValues) { - if (type == ExprCoreType.DOUBLE) { - return new ExprDoubleValue(calculateWmaInDouble(receivedValues)); - - } else if (type == ExprCoreType.DATE) { - return ExprValueUtils.timestampValue( - Instant.ofEpochMilli( - calculateWmaInLong(receivedValues, i -> i.timestampValue().toEpochMilli()))); - } else if (type == ExprCoreType.TIME) { - return ExprValueUtils.timeValue( - LocalTime.MIN.plus( - calculateWmaInLong( - receivedValues, i -> MILLIS.between(LocalTime.MIN, i.timeValue())), - MILLIS)); - - } else if (type == ExprCoreType.TIMESTAMP) { - return ExprValueUtils.timestampValue( - Instant.ofEpochMilli( - calculateWmaInLong(receivedValues, i -> i.timestampValue().toEpochMilli()))); + private static class NumericWmaEvaluator implements WmaTrendlineEvaluator { + + private static final NumericWmaEvaluator INSTANCE = new NumericWmaEvaluator(); + + @Override + public ExprValue evaluate(ArrayList receivedValues) { + double sum = 0D; + int totalWeight = (receivedValues.size() * (receivedValues.size() + 1)) / 2; + for (int i = 0; i < receivedValues.size(); i++) { + sum += receivedValues.get(i).doubleValue() * ((i + 1D) / totalWeight); + } + return new ExprDoubleValue(sum); } - return null; } - private double calculateWmaInDouble(ArrayList receivedValues) { - double sum = 0D; - int totalWeight = (receivedValues.size() * (receivedValues.size() + 1)) / 2; - for (int i = 0; i < receivedValues.size(); i++) { - sum += receivedValues.get(i).doubleValue() * ((i + 1D) / totalWeight); + private static class TimeStampWmaEvaluator implements WmaTrendlineEvaluator { + + private static final TimeStampWmaEvaluator INSTANCE = new TimeStampWmaEvaluator(); + + @Override + public ExprValue evaluate(ArrayList receivedValues) { + long sum = 0L; + int totalWeight = (receivedValues.size() * (receivedValues.size() + 1)) / 2; + for (int i = 0; i < receivedValues.size(); i++) { + sum += (long) (receivedValues.get(i).timestampValue().toEpochMilli() * ((i + 1D) / totalWeight)); + } + + return ExprValueUtils.timestampValue(Instant.ofEpochMilli((sum))); } - return sum; } - private long calculateWmaInLong( - ArrayList receivedValues, Function convertFunc) { - long sum = 0L; - int totalWeight = (receivedValues.size() * (receivedValues.size() + 1)) / 2; - for (int i = 0; i < receivedValues.size(); i++) { - sum += (long) (convertFunc.apply(receivedValues.get(i)) * ((i + 1D) / totalWeight)); + private static class TimeWmaEvaluator implements WmaTrendlineEvaluator { + + private static final TimeWmaEvaluator INSTANCE = new TimeWmaEvaluator(); + + @Override + public ExprValue evaluate(ArrayList receivedValues) { + long sum = 0L; + int totalWeight = (receivedValues.size() * (receivedValues.size() + 1)) / 2; + for (int i = 0; i < receivedValues.size(); i++) { + sum += (long) (MILLIS.between(LocalTime.MIN, receivedValues.get(i).timeValue()) * ((i + 1D) / totalWeight)); + } + return ExprValueUtils.timeValue( + LocalTime.MIN.plus(sum, MILLIS)); } - return sum; + } + + + private interface WmaTrendlineEvaluator { + ExprValue evaluate(ArrayList receivedValues); } } + + + private interface ArithmeticEvaluator { Expression calculateFirstTotal(List dataPoints); diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index 723e6db03c..eec534bcfd 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -696,4 +696,19 @@ public void calculates_weighted_moving_average_timestamp() { plan.next()); assertFalse(plan.hasNext()); } + + @Test + public void use_illegal_core_type_wma() { + assertThrows( + IllegalArgumentException.class, + () -> { + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.ARRAY))); + }); + } + } From f6c82aeb56127e79e6a82e7db530015317048f4d Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Mon, 3 Feb 2025 15:14:54 -0800 Subject: [PATCH 08/36] Update test coverage Signed-off-by: Andy Kwok --- .../planner/physical/TrendlineOperator.java | 32 +++++++++---------- .../physical/TrendlineOperatorTest.java | 19 ++++++----- 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index 6043721170..43e63c3aac 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -17,7 +17,6 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.function.Function; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; @@ -205,13 +204,13 @@ public WeightedMovingAverageAccumulator( } static WmaTrendlineEvaluator getEvaluator(ExprCoreType type) { - return switch (type) { - case DOUBLE -> NumericWmaEvaluator.INSTANCE; - case DATE, TIMESTAMP -> TimeStampWmaEvaluator.INSTANCE; - case TIME -> TimeWmaEvaluator.INSTANCE; - default -> throw new IllegalArgumentException( - String.format("Invalid type %s used for weighted moving average.", type.typeName())); - }; + return switch (type) { + case DOUBLE -> NumericWmaEvaluator.INSTANCE; + case DATE, TIMESTAMP -> TimeStampWmaEvaluator.INSTANCE; + case TIME -> TimeWmaEvaluator.INSTANCE; + default -> throw new IllegalArgumentException( + String.format("Invalid type %s used for weighted moving average.", type.typeName())); + }; } @Override @@ -256,7 +255,10 @@ public ExprValue evaluate(ArrayList receivedValues) { long sum = 0L; int totalWeight = (receivedValues.size() * (receivedValues.size() + 1)) / 2; for (int i = 0; i < receivedValues.size(); i++) { - sum += (long) (receivedValues.get(i).timestampValue().toEpochMilli() * ((i + 1D) / totalWeight)); + sum += + (long) + (receivedValues.get(i).timestampValue().toEpochMilli() + * ((i + 1D) / totalWeight)); } return ExprValueUtils.timestampValue(Instant.ofEpochMilli((sum))); @@ -272,22 +274,20 @@ public ExprValue evaluate(ArrayList receivedValues) { long sum = 0L; int totalWeight = (receivedValues.size() * (receivedValues.size() + 1)) / 2; for (int i = 0; i < receivedValues.size(); i++) { - sum += (long) (MILLIS.between(LocalTime.MIN, receivedValues.get(i).timeValue()) * ((i + 1D) / totalWeight)); + sum += + (long) + (MILLIS.between(LocalTime.MIN, receivedValues.get(i).timeValue()) + * ((i + 1D) / totalWeight)); } - return ExprValueUtils.timeValue( - LocalTime.MIN.plus(sum, MILLIS)); + return ExprValueUtils.timeValue(LocalTime.MIN.plus(sum, MILLIS)); } } - private interface WmaTrendlineEvaluator { ExprValue evaluate(ArrayList receivedValues); } } - - - private interface ArithmeticEvaluator { Expression calculateFirstTotal(List dataPoints); diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index eec534bcfd..dac2d95d15 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -700,15 +700,14 @@ public void calculates_weighted_moving_average_timestamp() { @Test public void use_illegal_core_type_wma() { assertThrows( - IllegalArgumentException.class, - () -> { - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), - ExprCoreType.ARRAY))); - }); + IllegalArgumentException.class, + () -> { + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.ARRAY))); + }); } - } From d0cf8987cfe825dffd250a70567b6bbf327ea9a6 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Thu, 6 Feb 2025 11:46:05 -0800 Subject: [PATCH 09/36] Update docs/user/ppl/cmd/trendline.rst Co-authored-by: Taylor Curran Signed-off-by: Andy Kwok --- docs/user/ppl/cmd/trendline.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user/ppl/cmd/trendline.rst b/docs/user/ppl/cmd/trendline.rst index 178f52e144..e4126c108b 100644 --- a/docs/user/ppl/cmd/trendline.rst +++ b/docs/user/ppl/cmd/trendline.rst @@ -149,7 +149,7 @@ The example shows how to calculate the weighted moving average on one field. PPL query:: - os> source=accounts | trendline wma(2, account_number) | fields account_number_trendline; + os> source=accounts | trendline wma(2, account_number) | fields account_number_trendline; fetched rows / total rows = 4/4 +--------------------------+ | account_number_trendline | From ce94ca9856020b40a41ed8cb889c36a884d4bdb0 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Thu, 6 Feb 2025 11:46:21 -0800 Subject: [PATCH 10/36] Update docs/user/ppl/cmd/trendline.rst Co-authored-by: Taylor Curran Signed-off-by: Andy Kwok --- docs/user/ppl/cmd/trendline.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user/ppl/cmd/trendline.rst b/docs/user/ppl/cmd/trendline.rst index e4126c108b..e7a54283dd 100644 --- a/docs/user/ppl/cmd/trendline.rst +++ b/docs/user/ppl/cmd/trendline.rst @@ -145,7 +145,7 @@ PPL query:: Example 3: Calculate the weighted moving average on one field without specifying an alias. ================================================================================= -The example shows how to calculate the weighted moving average on one field. +The example shows how to calculate the weighted moving average on one field without specifying an alias. PPL query:: From f32e5ad2d906b482255f006fae1f78a895ecaf0d Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Thu, 6 Feb 2025 12:05:45 -0800 Subject: [PATCH 11/36] Remove debug Signed-off-by: Andy Kwok --- .../test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java index 1db3472b54..4c5c0b0153 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java @@ -84,7 +84,6 @@ public void testTrendlineWma() throws IOException { "source=%s | where balance > 39000 | sort balance | trendline wma(2, balance) as" + " balance_trend | fields balance_trend", TEST_INDEX_BANK)); - System.out.println("Result (Base): " + result.toString()); verifyDataRows( result, rows(new Object[] {null}), rows(45570.666666666664), rows(40101.666666666664)); } @@ -113,7 +112,6 @@ public void testTrendlineOverwritesExistingFieldWma() throws IOException { "source=%s | where balance > 39000 | sort balance | trendline wma(2, balance) as" + " age | fields age", TEST_INDEX_BANK)); - System.out.println("Result (Overwrite) : " + result.toString()); verifyDataRows( result, rows(new Object[] {null}), rows(40101.666666666664), rows(45570.666666666664)); } @@ -138,7 +136,6 @@ public void testTrendlineWithSortWma() throws IOException { "source=%s | where balance > 39000 | trendline sort balance wma(2, balance) |" + " fields balance_trendline", TEST_INDEX_BANK)); - System.out.println("Result (WithSortWma): " + result.toString()); verifyDataRows( result, rows(new Object[] {null}), rows(40101.666666666664), rows(45570.666666666664)); } From 4d95529bbafa16ae64692e8136b9eb2ac0688007 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Thu, 6 Feb 2025 15:33:09 -0800 Subject: [PATCH 12/36] Address code comments Signed-off-by: Andy Kwok --- .../planner/physical/TrendlineOperator.java | 42 ++++++++++++------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index 43e63c3aac..4639462faf 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -12,11 +12,13 @@ import java.time.Instant; import java.time.LocalTime; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Queue; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; @@ -116,10 +118,20 @@ private static TrendlineAccumulator createAccumulator( } /** Maintains stateful information for calculating the trendline. */ - private interface TrendlineAccumulator { - void accumulate(ExprValue value); + private abstract static class TrendlineAccumulator> { - ExprValue calculate(); + protected final LiteralExpression dataPointsNeeded; + + protected final C receivedValues; + + private TrendlineAccumulator(LiteralExpression dataPointsNeeded, C receivedValues) { + this.dataPointsNeeded = dataPointsNeeded; + this.receivedValues = receivedValues; + } + + abstract void accumulate(ExprValue value); + + abstract ExprValue calculate(); static ArithmeticEvaluator getEvaluator(ExprCoreType type) { switch (type) { @@ -137,16 +149,16 @@ static ArithmeticEvaluator getEvaluator(ExprCoreType type) { } } - private static class SimpleMovingAverageAccumulator implements TrendlineAccumulator { - private final LiteralExpression dataPointsNeeded; - private final EvictingQueue receivedValues; + private static class SimpleMovingAverageAccumulator + extends TrendlineAccumulator> { private final ArithmeticEvaluator evaluator; private Expression runningTotal = null; public SimpleMovingAverageAccumulator( Trendline.TrendlineComputation computation, ExprCoreType type) { - dataPointsNeeded = DSL.literal(computation.getNumberOfDataPoints().doubleValue()); - receivedValues = EvictingQueue.create(computation.getNumberOfDataPoints()); + super( + DSL.literal(computation.getNumberOfDataPoints().doubleValue()), + EvictingQueue.create(computation.getNumberOfDataPoints())); evaluator = TrendlineAccumulator.getEvaluator(type); } @@ -191,19 +203,19 @@ public ExprValue calculate() { } } - private static class WeightedMovingAverageAccumulator implements TrendlineAccumulator { - private final LiteralExpression dataPointsNeeded; - private final ArrayList receivedValues; + private static class WeightedMovingAverageAccumulator + extends TrendlineAccumulator> { private final WmaTrendlineEvaluator evaluator; public WeightedMovingAverageAccumulator( Trendline.TrendlineComputation computation, ExprCoreType type) { - this.dataPointsNeeded = DSL.literal(computation.getNumberOfDataPoints().doubleValue()); - this.receivedValues = new ArrayList<>(computation.getNumberOfDataPoints()); - this.evaluator = getEvaluator(type); + super( + DSL.literal(computation.getNumberOfDataPoints()), + new ArrayList<>(computation.getNumberOfDataPoints())); + this.evaluator = getWmaEvaluator(type); } - static WmaTrendlineEvaluator getEvaluator(ExprCoreType type) { + static WmaTrendlineEvaluator getWmaEvaluator(ExprCoreType type) { return switch (type) { case DOUBLE -> NumericWmaEvaluator.INSTANCE; case DATE, TIMESTAMP -> TimeStampWmaEvaluator.INSTANCE; From df3c6661cff629acf507d310d605017004641872 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Fri, 7 Feb 2025 16:07:55 -0800 Subject: [PATCH 13/36] Add support to all numeric types Signed-off-by: Andy Kwok --- .../planner/physical/TrendlineOperator.java | 4 +- .../physical/TrendlineOperatorTest.java | 241 ++++++++++++++++++ 2 files changed, 243 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index 4639462faf..7ca7a8ba80 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -135,7 +135,7 @@ private TrendlineAccumulator(LiteralExpression dataPointsNeeded, C receivedValue static ArithmeticEvaluator getEvaluator(ExprCoreType type) { switch (type) { - case DOUBLE: + case INTEGER, SHORT, LONG, FLOAT, DOUBLE: return NumericArithmeticEvaluator.INSTANCE; case DATE: return DateArithmeticEvaluator.INSTANCE; @@ -217,7 +217,7 @@ public WeightedMovingAverageAccumulator( static WmaTrendlineEvaluator getWmaEvaluator(ExprCoreType type) { return switch (type) { - case DOUBLE -> NumericWmaEvaluator.INSTANCE; + case INTEGER, SHORT, LONG, FLOAT, DOUBLE -> NumericWmaEvaluator.INSTANCE; case DATE, TIMESTAMP -> TimeStampWmaEvaluator.INSTANCE; case TIME -> TimeWmaEvaluator.INSTANCE; default -> throw new IllegalArgumentException( diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index dac2d95d15..fc4302ae9a 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -120,6 +120,109 @@ public void calculates_simple_moving_average_one_field_two_samples_three_rows() assertFalse(plan.hasNext()); } + @Test + public void calculates_simple_moving_average_data_type_support_short() { + when(inputPlan.hasNext()).thenReturn(true, true, true, false); + when(inputPlan.next()) + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), + ExprCoreType.SHORT))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), + plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)), + plan.next()); + assertFalse(plan.hasNext()); + } + + @Test + public void calculates_simple_moving_average_data_type_support_long() { + when(inputPlan.hasNext()).thenReturn(true, true, true, false); + when(inputPlan.next()) + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), + ExprCoreType.SHORT))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), + plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)), + plan.next()); + assertFalse(plan.hasNext()); + } + + @Test + public void calculates_simple_moving_average_data_type_support_float() { + when(inputPlan.hasNext()).thenReturn(true, true, true, false); + when(inputPlan.next()) + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), + ExprCoreType.FLOAT))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), + plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)), + plan.next()); + assertFalse(plan.hasNext()); + } + + @Test public void calculates_simple_moving_average_multiple_computations() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); @@ -481,6 +584,144 @@ public void calculates_weighted_moving_average_one_field_two_samples_three_rows( assertFalse(plan.hasNext()); } + @Test + public void calculates_weighted_moving_average_data_type_support_short() { + when(inputPlan.hasNext()).thenReturn(true, true, true, false); + when(inputPlan.next()) + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.SHORT))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), + plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997)), + plan.next()); + assertFalse(plan.hasNext()); + } + + @Test + public void calculates_weighted_moving_average_data_type_support_integer() { + when(inputPlan.hasNext()).thenReturn(true, true, true, false); + when(inputPlan.next()) + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.INTEGER))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), + plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997)), + plan.next()); + assertFalse(plan.hasNext()); + } + + + @Test + public void calculates_weighted_moving_average_data_type_support_long() { + when(inputPlan.hasNext()).thenReturn(true, true, true, false); + when(inputPlan.next()) + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.LONG))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), + plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997)), + plan.next()); + assertFalse(plan.hasNext()); + } + + + @Test + public void calculates_weighted_moving_average_data_type_support_float() { + when(inputPlan.hasNext()).thenReturn(true, true, true, false); + when(inputPlan.next()) + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.FLOAT))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), + plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997)), + plan.next()); + assertFalse(plan.hasNext()); + } + @Test public void calculates_weighted_moving_average_multiple_computations() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); From 3ed8fa88c109ef73616c69e7825ca1ca59b9fff2 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Fri, 7 Feb 2025 16:49:32 -0800 Subject: [PATCH 14/36] Update generic Signed-off-by: Andy Kwok --- .../planner/physical/TrendlineOperator.java | 343 +++++++++--------- .../physical/TrendlineOperatorTest.java | 241 ++++++------ 2 files changed, 289 insertions(+), 295 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index 7ca7a8ba80..2883f20673 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -11,11 +11,10 @@ import com.google.common.collect.ImmutableMap.Builder; import java.time.Instant; import java.time.LocalTime; -import java.util.ArrayList; -import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Queue; @@ -118,13 +117,14 @@ private static TrendlineAccumulator createAccumulator( } /** Maintains stateful information for calculating the trendline. */ - private abstract static class TrendlineAccumulator> { + private abstract static class TrendlineAccumulator { protected final LiteralExpression dataPointsNeeded; - protected final C receivedValues; + protected final Queue receivedValues; - private TrendlineAccumulator(LiteralExpression dataPointsNeeded, C receivedValues) { + private TrendlineAccumulator( + LiteralExpression dataPointsNeeded, Queue receivedValues) { this.dataPointsNeeded = dataPointsNeeded; this.receivedValues = receivedValues; } @@ -132,25 +132,9 @@ private TrendlineAccumulator(LiteralExpression dataPointsNeeded, C receivedValue abstract void accumulate(ExprValue value); abstract ExprValue calculate(); - - static ArithmeticEvaluator getEvaluator(ExprCoreType type) { - switch (type) { - case INTEGER, SHORT, LONG, FLOAT, DOUBLE: - return NumericArithmeticEvaluator.INSTANCE; - case DATE: - return DateArithmeticEvaluator.INSTANCE; - case TIME: - return TimeArithmeticEvaluator.INSTANCE; - case TIMESTAMP: - return TimestampArithmeticEvaluator.INSTANCE; - } - throw new IllegalArgumentException( - String.format("Invalid type %s used for moving average.", type.typeName())); - } } - private static class SimpleMovingAverageAccumulator - extends TrendlineAccumulator> { + private static class SimpleMovingAverageAccumulator extends TrendlineAccumulator { private final ArithmeticEvaluator evaluator; private Expression runningTotal = null; @@ -159,7 +143,7 @@ public SimpleMovingAverageAccumulator( super( DSL.literal(computation.getNumberOfDataPoints().doubleValue()), EvictingQueue.create(computation.getNumberOfDataPoints())); - evaluator = TrendlineAccumulator.getEvaluator(type); + evaluator = getEvaluator(type); } @Override @@ -201,17 +185,157 @@ public ExprValue calculate() { } return evaluator.evaluate(runningTotal, dataPointsNeeded); } + + static ArithmeticEvaluator getEvaluator(ExprCoreType type) { + return switch (type) { + case INTEGER, SHORT, LONG, FLOAT, DOUBLE -> NumericArithmeticEvaluator.INSTANCE; + case DATE -> DateArithmeticEvaluator.INSTANCE; + case TIME -> TimeArithmeticEvaluator.INSTANCE; + case TIMESTAMP -> TimestampArithmeticEvaluator.INSTANCE; + default -> throw new IllegalArgumentException( + String.format("Invalid type %s used for moving average.", type.typeName())); + }; + } + + private interface ArithmeticEvaluator { + Expression calculateFirstTotal(List dataPoints); + + Expression add(Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue); + + ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints); + } + + private static class NumericArithmeticEvaluator implements ArithmeticEvaluator { + private static final NumericArithmeticEvaluator INSTANCE = new NumericArithmeticEvaluator(); + + private NumericArithmeticEvaluator() {} + + @Override + public Expression calculateFirstTotal(List dataPoints) { + Expression total = DSL.literal(0.0D); + for (ExprValue dataPoint : dataPoints) { + total = DSL.add(total, DSL.literal(dataPoint.doubleValue())); + } + return DSL.literal(total.valueOf().doubleValue()); + } + + @Override + public Expression add( + Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) { + return DSL.literal( + DSL.add( + runningTotal, + DSL.subtract(DSL.literal(incomingValue), DSL.literal(evictedValue))) + .valueOf() + .doubleValue()); + } + + @Override + public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) { + return DSL.divide(runningTotal, numberOfDataPoints).valueOf(); + } + } + + private static class DateArithmeticEvaluator implements ArithmeticEvaluator { + private static final DateArithmeticEvaluator INSTANCE = new DateArithmeticEvaluator(); + + private DateArithmeticEvaluator() {} + + @Override + public Expression calculateFirstTotal(List dataPoints) { + return TimestampArithmeticEvaluator.INSTANCE.calculateFirstTotal(dataPoints); + } + + @Override + public Expression add( + Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) { + return TimestampArithmeticEvaluator.INSTANCE.add(runningTotal, incomingValue, evictedValue); + } + + @Override + public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) { + final ExprValue timestampResult = + TimestampArithmeticEvaluator.INSTANCE.evaluate(runningTotal, numberOfDataPoints); + return ExprValueUtils.dateValue(timestampResult.dateValue()); + } + } + + private static class TimeArithmeticEvaluator implements ArithmeticEvaluator { + private static final TimeArithmeticEvaluator INSTANCE = new TimeArithmeticEvaluator(); + + private TimeArithmeticEvaluator() {} + + @Override + public Expression calculateFirstTotal(List dataPoints) { + Expression total = DSL.literal(0); + for (ExprValue dataPoint : dataPoints) { + total = DSL.add(total, DSL.literal(MILLIS.between(LocalTime.MIN, dataPoint.timeValue()))); + } + return DSL.literal(total.valueOf().longValue()); + } + + @Override + public Expression add( + Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) { + return DSL.literal( + DSL.add( + runningTotal, + DSL.subtract( + DSL.literal(MILLIS.between(LocalTime.MIN, incomingValue.timeValue())), + DSL.literal(MILLIS.between(LocalTime.MIN, evictedValue.timeValue())))) + .valueOf()); + } + + @Override + public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) { + return ExprValueUtils.timeValue( + LocalTime.MIN.plus( + DSL.divide(runningTotal, numberOfDataPoints).valueOf().longValue(), MILLIS)); + } + } + + private static class TimestampArithmeticEvaluator implements ArithmeticEvaluator { + private static final TimestampArithmeticEvaluator INSTANCE = + new TimestampArithmeticEvaluator(); + + private TimestampArithmeticEvaluator() {} + + @Override + public Expression calculateFirstTotal(List dataPoints) { + Expression total = DSL.literal(0); + for (ExprValue dataPoint : dataPoints) { + total = DSL.add(total, DSL.literal(dataPoint.timestampValue().toEpochMilli())); + } + return DSL.literal(total.valueOf().longValue()); + } + + @Override + public Expression add( + Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) { + return DSL.literal( + DSL.add( + runningTotal, + DSL.subtract( + DSL.literal(incomingValue.timestampValue().toEpochMilli()), + DSL.literal(evictedValue.timestampValue().toEpochMilli()))) + .valueOf()); + } + + @Override + public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) { + return ExprValueUtils.timestampValue( + Instant.ofEpochMilli( + DSL.divide(runningTotal, numberOfDataPoints).valueOf().longValue())); + } + } } - private static class WeightedMovingAverageAccumulator - extends TrendlineAccumulator> { + private static class WeightedMovingAverageAccumulator extends TrendlineAccumulator { private final WmaTrendlineEvaluator evaluator; public WeightedMovingAverageAccumulator( Trendline.TrendlineComputation computation, ExprCoreType type) { - super( - DSL.literal(computation.getNumberOfDataPoints()), - new ArrayList<>(computation.getNumberOfDataPoints())); + super(DSL.literal(computation.getNumberOfDataPoints()), new LinkedList<>()); this.evaluator = getWmaEvaluator(type); } @@ -229,7 +353,7 @@ static WmaTrendlineEvaluator getWmaEvaluator(ExprCoreType type) { public void accumulate(ExprValue value) { receivedValues.add(value); if (receivedValues.size() > dataPointsNeeded.valueOf().integerValue()) { - receivedValues.removeFirst(); + receivedValues.remove(); } } @@ -238,7 +362,7 @@ public ExprValue calculate() { if (receivedValues.size() < dataPointsNeeded.valueOf().integerValue()) { return null; } else if (dataPointsNeeded.valueOf().integerValue() == 1) { - return receivedValues.getFirst(); + return receivedValues.peek(); } return evaluator.evaluate(receivedValues); } @@ -248,11 +372,13 @@ private static class NumericWmaEvaluator implements WmaTrendlineEvaluator { private static final NumericWmaEvaluator INSTANCE = new NumericWmaEvaluator(); @Override - public ExprValue evaluate(ArrayList receivedValues) { + public ExprValue evaluate(Queue receivedValues) { double sum = 0D; int totalWeight = (receivedValues.size() * (receivedValues.size() + 1)) / 2; - for (int i = 0; i < receivedValues.size(); i++) { - sum += receivedValues.get(i).doubleValue() * ((i + 1D) / totalWeight); + int count = 0; + for (ExprValue next : receivedValues) { + sum += next.doubleValue() * ((count + 1D) / totalWeight); + count++; } return new ExprDoubleValue(sum); } @@ -263,16 +389,14 @@ private static class TimeStampWmaEvaluator implements WmaTrendlineEvaluator { private static final TimeStampWmaEvaluator INSTANCE = new TimeStampWmaEvaluator(); @Override - public ExprValue evaluate(ArrayList receivedValues) { + public ExprValue evaluate(Queue receivedValues) { long sum = 0L; int totalWeight = (receivedValues.size() * (receivedValues.size() + 1)) / 2; - for (int i = 0; i < receivedValues.size(); i++) { - sum += - (long) - (receivedValues.get(i).timestampValue().toEpochMilli() - * ((i + 1D) / totalWeight)); + int count = 0; + for (ExprValue next : receivedValues) { + sum += (long) (next.timestampValue().toEpochMilli() * ((count + 1D) / totalWeight)); + count++; } - return ExprValueUtils.timestampValue(Instant.ofEpochMilli((sum))); } } @@ -282,149 +406,22 @@ private static class TimeWmaEvaluator implements WmaTrendlineEvaluator { private static final TimeWmaEvaluator INSTANCE = new TimeWmaEvaluator(); @Override - public ExprValue evaluate(ArrayList receivedValues) { + public ExprValue evaluate(Queue receivedValues) { long sum = 0L; int totalWeight = (receivedValues.size() * (receivedValues.size() + 1)) / 2; - for (int i = 0; i < receivedValues.size(); i++) { + int count = 0; + for (ExprValue next : receivedValues) { sum += (long) - (MILLIS.between(LocalTime.MIN, receivedValues.get(i).timeValue()) - * ((i + 1D) / totalWeight)); + (MILLIS.between(LocalTime.MIN, next.timeValue()) * ((count + 1D) / totalWeight)); + count++; } return ExprValueUtils.timeValue(LocalTime.MIN.plus(sum, MILLIS)); } } private interface WmaTrendlineEvaluator { - ExprValue evaluate(ArrayList receivedValues); - } - } - - private interface ArithmeticEvaluator { - Expression calculateFirstTotal(List dataPoints); - - Expression add(Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue); - - ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints); - } - - private static class NumericArithmeticEvaluator implements ArithmeticEvaluator { - private static final NumericArithmeticEvaluator INSTANCE = new NumericArithmeticEvaluator(); - - private NumericArithmeticEvaluator() {} - - @Override - public Expression calculateFirstTotal(List dataPoints) { - Expression total = DSL.literal(0.0D); - for (ExprValue dataPoint : dataPoints) { - total = DSL.add(total, DSL.literal(dataPoint.doubleValue())); - } - return DSL.literal(total.valueOf().doubleValue()); - } - - @Override - public Expression add( - Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) { - return DSL.literal( - DSL.add(runningTotal, DSL.subtract(DSL.literal(incomingValue), DSL.literal(evictedValue))) - .valueOf() - .doubleValue()); - } - - @Override - public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) { - return DSL.divide(runningTotal, numberOfDataPoints).valueOf(); - } - } - - private static class DateArithmeticEvaluator implements ArithmeticEvaluator { - private static final DateArithmeticEvaluator INSTANCE = new DateArithmeticEvaluator(); - - private DateArithmeticEvaluator() {} - - @Override - public Expression calculateFirstTotal(List dataPoints) { - return TimestampArithmeticEvaluator.INSTANCE.calculateFirstTotal(dataPoints); - } - - @Override - public Expression add( - Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) { - return TimestampArithmeticEvaluator.INSTANCE.add(runningTotal, incomingValue, evictedValue); - } - - @Override - public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) { - final ExprValue timestampResult = - TimestampArithmeticEvaluator.INSTANCE.evaluate(runningTotal, numberOfDataPoints); - return ExprValueUtils.dateValue(timestampResult.dateValue()); - } - } - - private static class TimeArithmeticEvaluator implements ArithmeticEvaluator { - private static final TimeArithmeticEvaluator INSTANCE = new TimeArithmeticEvaluator(); - - private TimeArithmeticEvaluator() {} - - @Override - public Expression calculateFirstTotal(List dataPoints) { - Expression total = DSL.literal(0); - for (ExprValue dataPoint : dataPoints) { - total = DSL.add(total, DSL.literal(MILLIS.between(LocalTime.MIN, dataPoint.timeValue()))); - } - return DSL.literal(total.valueOf().longValue()); - } - - @Override - public Expression add( - Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) { - return DSL.literal( - DSL.add( - runningTotal, - DSL.subtract( - DSL.literal(MILLIS.between(LocalTime.MIN, incomingValue.timeValue())), - DSL.literal(MILLIS.between(LocalTime.MIN, evictedValue.timeValue())))) - .valueOf()); - } - - @Override - public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) { - return ExprValueUtils.timeValue( - LocalTime.MIN.plus( - DSL.divide(runningTotal, numberOfDataPoints).valueOf().longValue(), MILLIS)); - } - } - - private static class TimestampArithmeticEvaluator implements ArithmeticEvaluator { - private static final TimestampArithmeticEvaluator INSTANCE = new TimestampArithmeticEvaluator(); - - private TimestampArithmeticEvaluator() {} - - @Override - public Expression calculateFirstTotal(List dataPoints) { - Expression total = DSL.literal(0); - for (ExprValue dataPoint : dataPoints) { - total = DSL.add(total, DSL.literal(dataPoint.timestampValue().toEpochMilli())); - } - return DSL.literal(total.valueOf().longValue()); - } - - @Override - public Expression add( - Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) { - return DSL.literal( - DSL.add( - runningTotal, - DSL.subtract( - DSL.literal(incomingValue.timestampValue().toEpochMilli()), - DSL.literal(evictedValue.timestampValue().toEpochMilli()))) - .valueOf()); - } - - @Override - public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) { - return ExprValueUtils.timestampValue( - Instant.ofEpochMilli(DSL.divide(runningTotal, numberOfDataPoints).valueOf().longValue())); + ExprValue evaluate(Queue receivedValues); } } } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index fc4302ae9a..7e1c0087e3 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -124,33 +124,33 @@ public void calculates_simple_moving_average_one_field_two_samples_three_rows() public void calculates_simple_moving_average_data_type_support_short() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) - .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); var plan = - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), - ExprCoreType.SHORT))); + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), + ExprCoreType.SHORT))); plan.open(); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), + plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)), + plan.next()); assertFalse(plan.hasNext()); } @@ -158,33 +158,33 @@ public void calculates_simple_moving_average_data_type_support_short() { public void calculates_simple_moving_average_data_type_support_long() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) - .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); var plan = - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), - ExprCoreType.SHORT))); + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), + ExprCoreType.SHORT))); plan.open(); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), + plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)), + plan.next()); assertFalse(plan.hasNext()); } @@ -192,37 +192,36 @@ public void calculates_simple_moving_average_data_type_support_long() { public void calculates_simple_moving_average_data_type_support_float() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) - .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); var plan = - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), - ExprCoreType.FLOAT))); + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), + ExprCoreType.FLOAT))); plan.open(); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), + plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)), + plan.next()); assertFalse(plan.hasNext()); } - @Test public void calculates_simple_moving_average_multiple_computations() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); @@ -588,33 +587,33 @@ public void calculates_weighted_moving_average_one_field_two_samples_three_rows( public void calculates_weighted_moving_average_data_type_support_short() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) - .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); var plan = - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), - ExprCoreType.SHORT))); + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.SHORT))); plan.open(); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), + plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997)), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997)), + plan.next()); assertFalse(plan.hasNext()); } @@ -622,103 +621,101 @@ public void calculates_weighted_moving_average_data_type_support_short() { public void calculates_weighted_moving_average_data_type_support_integer() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) - .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); var plan = - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), - ExprCoreType.INTEGER))); + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.INTEGER))); plan.open(); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), + plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997)), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997)), + plan.next()); assertFalse(plan.hasNext()); } - @Test public void calculates_weighted_moving_average_data_type_support_long() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) - .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); var plan = - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), - ExprCoreType.LONG))); + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.LONG))); plan.open(); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), + plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997)), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997)), + plan.next()); assertFalse(plan.hasNext()); } - @Test public void calculates_weighted_moving_average_data_type_support_float() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) - .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); var plan = - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), - ExprCoreType.FLOAT))); + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.FLOAT))); plan.open(); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), + plan.next()); assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997)), - plan.next()); + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997)), + plan.next()); assertFalse(plan.hasNext()); } From 7fc4075cb02beace22a8d635cf890cf4939ebc36 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Mon, 10 Feb 2025 10:39:43 -0800 Subject: [PATCH 15/36] Replace evaluator with functionalInterface Signed-off-by: Andy Kwok --- .../planner/physical/TrendlineOperator.java | 96 ++++++++----------- 1 file changed, 39 insertions(+), 57 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index 2883f20673..aa23bec1e9 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -18,6 +18,9 @@ import java.util.List; import java.util.Map; import java.util.Queue; +import java.util.function.BiFunction; +import java.util.function.Function; + import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; @@ -331,21 +334,24 @@ public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDat } private static class WeightedMovingAverageAccumulator extends TrendlineAccumulator { - private final WmaTrendlineEvaluator evaluator; + + private final BiFunction, Integer, ExprValue> evaluator; + private final int totalWeight; public WeightedMovingAverageAccumulator( Trendline.TrendlineComputation computation, ExprCoreType type) { super(DSL.literal(computation.getNumberOfDataPoints()), new LinkedList<>()); + this.totalWeight = (computation.getNumberOfDataPoints() * (computation.getNumberOfDataPoints() + 1)) / 2; this.evaluator = getWmaEvaluator(type); } - static WmaTrendlineEvaluator getWmaEvaluator(ExprCoreType type) { + static BiFunction, Integer, ExprValue> getWmaEvaluator(ExprCoreType type) { return switch (type) { - case INTEGER, SHORT, LONG, FLOAT, DOUBLE -> NumericWmaEvaluator.INSTANCE; - case DATE, TIMESTAMP -> TimeStampWmaEvaluator.INSTANCE; - case TIME -> TimeWmaEvaluator.INSTANCE; + case INTEGER, SHORT, LONG, FLOAT, DOUBLE -> WMA_NUMERIC_EVALUATOR; + case DATE, TIMESTAMP -> WMA_TIMESTAMP_EVALUATOR; + case TIME -> WMA_TIME_EVALUATOR; default -> throw new IllegalArgumentException( - String.format("Invalid type %s used for weighted moving average.", type.typeName())); + String.format("Invalid type %s used for weighted moving average.", type.typeName())); }; } @@ -364,64 +370,40 @@ public ExprValue calculate() { } else if (dataPointsNeeded.valueOf().integerValue() == 1) { return receivedValues.peek(); } - return evaluator.evaluate(receivedValues); + return evaluator.apply(receivedValues, totalWeight); } - private static class NumericWmaEvaluator implements WmaTrendlineEvaluator { - - private static final NumericWmaEvaluator INSTANCE = new NumericWmaEvaluator(); - - @Override - public ExprValue evaluate(Queue receivedValues) { - double sum = 0D; - int totalWeight = (receivedValues.size() * (receivedValues.size() + 1)) / 2; - int count = 0; - for (ExprValue next : receivedValues) { - sum += next.doubleValue() * ((count + 1D) / totalWeight); - count++; - } - return new ExprDoubleValue(sum); + public static final BiFunction, Integer, ExprValue> WMA_NUMERIC_EVALUATOR = (receivedValues, totalWeight) -> { + double sum = 0D; + int count = 0; + for (ExprValue next : receivedValues) { + sum += next.doubleValue() * ((count + 1D) / totalWeight); + count++; } - } - - private static class TimeStampWmaEvaluator implements WmaTrendlineEvaluator { - - private static final TimeStampWmaEvaluator INSTANCE = new TimeStampWmaEvaluator(); + return new ExprDoubleValue(sum); + }; - @Override - public ExprValue evaluate(Queue receivedValues) { - long sum = 0L; - int totalWeight = (receivedValues.size() * (receivedValues.size() + 1)) / 2; - int count = 0; - for (ExprValue next : receivedValues) { - sum += (long) (next.timestampValue().toEpochMilli() * ((count + 1D) / totalWeight)); - count++; - } - return ExprValueUtils.timestampValue(Instant.ofEpochMilli((sum))); + public static final BiFunction, Integer, ExprValue> WMA_TIMESTAMP_EVALUATOR = (receivedValues, totalWeight) -> { + long sum = 0L; + int count = 0; + for (ExprValue next : receivedValues) { + sum += (long) (next.timestampValue().toEpochMilli() * ((count + 1D) / totalWeight)); + count++; } - } - - private static class TimeWmaEvaluator implements WmaTrendlineEvaluator { - - private static final TimeWmaEvaluator INSTANCE = new TimeWmaEvaluator(); + return ExprValueUtils.timestampValue(Instant.ofEpochMilli((sum))); + }; - @Override - public ExprValue evaluate(Queue receivedValues) { - long sum = 0L; - int totalWeight = (receivedValues.size() * (receivedValues.size() + 1)) / 2; - int count = 0; - for (ExprValue next : receivedValues) { - sum += - (long) - (MILLIS.between(LocalTime.MIN, next.timeValue()) * ((count + 1D) / totalWeight)); - count++; - } - return ExprValueUtils.timeValue(LocalTime.MIN.plus(sum, MILLIS)); + public static final BiFunction, Integer, ExprValue> WMA_TIME_EVALUATOR = (receivedValues, totalWeight) -> { + long sum = 0L; + int count = 0; + for (ExprValue next : receivedValues) { + sum += + (long) + (MILLIS.between(LocalTime.MIN, next.timeValue()) * ((count + 1D) / totalWeight)); + count++; } - } + return ExprValueUtils.timeValue(LocalTime.MIN.plus(sum, MILLIS)); + }; - private interface WmaTrendlineEvaluator { - ExprValue evaluate(Queue receivedValues); - } } } From 3b7902f873a2fc2bccab5576d5dfb3e6d4bd3ba8 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Mon, 10 Feb 2025 10:56:03 -0800 Subject: [PATCH 16/36] Apply suggestions from code review Co-authored-by: Taylor Curran Signed-off-by: Andy Kwok --- docs/user/ppl/cmd/trendline.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/user/ppl/cmd/trendline.rst b/docs/user/ppl/cmd/trendline.rst index e7a54283dd..7e94c8c1d4 100644 --- a/docs/user/ppl/cmd/trendline.rst +++ b/docs/user/ppl/cmd/trendline.rst @@ -13,9 +13,9 @@ Description ============ | Using ``trendline`` command to calculate moving averages of fields. -Syntax - SMA (Simple Moving Average) +Syntax ============ -`TRENDLINE [sort <[+|-] sort-field>] SMA(number-of-datapoints, field) [AS alias] [SMA(number-of-datapoints, field) [AS alias]]...` +`TRENDLINE [sort <[+|-] sort-field>] (number-of-datapoints, field) [AS alias] [(number-of-datapoints, field) [AS alias]]...` * [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first. * sort-field: mandatory when sorting is used. The field used to sort. @@ -71,7 +71,7 @@ PPL query:: Example 3: Calculate the moving average on one field without specifying an alias. ================================================================================= -The example shows how to calculate the moving average on one field. +The example shows how to calculate the moving average on one field without specifying an alias. PPL query:: From 47f9cec3609b80fcddf657de2980ae03169b3c00 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Mon, 10 Feb 2025 11:13:12 -0800 Subject: [PATCH 17/36] Update wma doc Signed-off-by: Andy Kwok --- docs/user/ppl/cmd/trendline.rst | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/docs/user/ppl/cmd/trendline.rst b/docs/user/ppl/cmd/trendline.rst index 7e94c8c1d4..7cda521bc6 100644 --- a/docs/user/ppl/cmd/trendline.rst +++ b/docs/user/ppl/cmd/trendline.rst @@ -19,11 +19,12 @@ Syntax * [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first. * sort-field: mandatory when sorting is used. The field used to sort. +* trendline-type: mandatory. The type of algorithm being used for the calculation, only SMA || WMA are supported at the moment. * number-of-datapoints: mandatory. The number of datapoints to calculate the moving average (must be greater than zero). * field: mandatory. The name of the field the moving average should be calculated for. * alias: optional. The name of the resulting column containing the moving average (defaults to the field name with "_trendline"). -It is calculated like +In the case of Simple Moving Average - SMA, result will be calculated as per the below formula. f[i]: The value of field 'f' in the i-th data-point n: The number of data-points in the moving window (period) @@ -86,22 +87,14 @@ PPL query:: | 15.5 | +--------------------------+ -Syntax - WMA (Weighted Moving Average) -============ -`TRENDLINE [sort <[+|-] sort-field>] WMA(number-of-datapoints, field) [AS alias] [WMA(number-of-datapoints, field) [AS alias]]...` -* [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first. -* sort-field: mandatory when sorting is used. The field used to sort. -* number-of-datapoints: mandatory. The number of datapoints to calculate the moving average (must be greater than zero). -* field: mandatory. The name of the field the moving average should be calculated for. -* alias: optional. The name of the resulting column containing the moving average (defaults to the field name with "_trendline"). -It is calculated like +In the case of Weighted Moving Average - WMA, result will be calculated as per the below formula. f[i]: The value of field 'f' in the i-th data point n: The number of data points in the moving window (period) t: The current time index - w[i]: The weight assigned to the i-th data point, typically increasing for more recent points + w[i]: The weight of the i-th data point, increasing by one per step to prioritize recent points. WMA(t) = ( Σ from i=t−n+1 to t of (w[i] * f[i]) ) / ( Σ from i=t−n+1 to t of w[i] ) From b8cf496d87d583b5f02b4b25035a3bb08627b138 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Mon, 10 Feb 2025 11:34:38 -0800 Subject: [PATCH 18/36] Fix style Signed-off-by: Andy Kwok --- .../planner/physical/TrendlineOperator.java | 72 ++++++++++--------- 1 file changed, 37 insertions(+), 35 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index aa23bec1e9..302dd69115 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -19,8 +19,6 @@ import java.util.Map; import java.util.Queue; import java.util.function.BiFunction; -import java.util.function.Function; - import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; @@ -341,7 +339,8 @@ private static class WeightedMovingAverageAccumulator extends TrendlineAccumulat public WeightedMovingAverageAccumulator( Trendline.TrendlineComputation computation, ExprCoreType type) { super(DSL.literal(computation.getNumberOfDataPoints()), new LinkedList<>()); - this.totalWeight = (computation.getNumberOfDataPoints() * (computation.getNumberOfDataPoints() + 1)) / 2; + this.totalWeight = + (computation.getNumberOfDataPoints() * (computation.getNumberOfDataPoints() + 1)) / 2; this.evaluator = getWmaEvaluator(type); } @@ -351,7 +350,7 @@ static BiFunction, Integer, ExprValue> getWmaEvaluator(ExprCore case DATE, TIMESTAMP -> WMA_TIMESTAMP_EVALUATOR; case TIME -> WMA_TIME_EVALUATOR; default -> throw new IllegalArgumentException( - String.format("Invalid type %s used for weighted moving average.", type.typeName())); + String.format("Invalid type %s used for weighted moving average.", type.typeName())); }; } @@ -373,37 +372,40 @@ public ExprValue calculate() { return evaluator.apply(receivedValues, totalWeight); } - public static final BiFunction, Integer, ExprValue> WMA_NUMERIC_EVALUATOR = (receivedValues, totalWeight) -> { - double sum = 0D; - int count = 0; - for (ExprValue next : receivedValues) { - sum += next.doubleValue() * ((count + 1D) / totalWeight); - count++; - } - return new ExprDoubleValue(sum); - }; - - public static final BiFunction, Integer, ExprValue> WMA_TIMESTAMP_EVALUATOR = (receivedValues, totalWeight) -> { - long sum = 0L; - int count = 0; - for (ExprValue next : receivedValues) { - sum += (long) (next.timestampValue().toEpochMilli() * ((count + 1D) / totalWeight)); - count++; - } - return ExprValueUtils.timestampValue(Instant.ofEpochMilli((sum))); - }; - - public static final BiFunction, Integer, ExprValue> WMA_TIME_EVALUATOR = (receivedValues, totalWeight) -> { - long sum = 0L; - int count = 0; - for (ExprValue next : receivedValues) { - sum += + public static final BiFunction, Integer, ExprValue> WMA_NUMERIC_EVALUATOR = + (receivedValues, totalWeight) -> { + double sum = 0D; + int count = 0; + for (ExprValue next : receivedValues) { + sum += next.doubleValue() * ((count + 1D) / totalWeight); + count++; + } + return new ExprDoubleValue(sum); + }; + + public static final BiFunction, Integer, ExprValue> WMA_TIMESTAMP_EVALUATOR = + (receivedValues, totalWeight) -> { + long sum = 0L; + int count = 0; + for (ExprValue next : receivedValues) { + sum += (long) (next.timestampValue().toEpochMilli() * ((count + 1D) / totalWeight)); + count++; + } + return ExprValueUtils.timestampValue(Instant.ofEpochMilli((sum))); + }; + + public static final BiFunction, Integer, ExprValue> WMA_TIME_EVALUATOR = + (receivedValues, totalWeight) -> { + long sum = 0L; + int count = 0; + for (ExprValue next : receivedValues) { + sum += (long) - (MILLIS.between(LocalTime.MIN, next.timeValue()) * ((count + 1D) / totalWeight)); - count++; - } - return ExprValueUtils.timeValue(LocalTime.MIN.plus(sum, MILLIS)); - }; - + (MILLIS.between(LocalTime.MIN, next.timeValue()) + * ((count + 1D) / totalWeight)); + count++; + } + return ExprValueUtils.timeValue(LocalTime.MIN.plus(sum, MILLIS)); + }; } } From 4a56b571ebe0a4af2e61d43a06760dfbcedc285b Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Tue, 11 Feb 2025 10:44:24 -0800 Subject: [PATCH 19/36] Update docs/user/ppl/cmd/trendline.rst Co-authored-by: Taylor Curran Signed-off-by: Andy Kwok --- docs/user/ppl/cmd/trendline.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user/ppl/cmd/trendline.rst b/docs/user/ppl/cmd/trendline.rst index 7cda521bc6..79c832a86c 100644 --- a/docs/user/ppl/cmd/trendline.rst +++ b/docs/user/ppl/cmd/trendline.rst @@ -19,7 +19,7 @@ Syntax * [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first. * sort-field: mandatory when sorting is used. The field used to sort. -* trendline-type: mandatory. The type of algorithm being used for the calculation, only SMA || WMA are supported at the moment. +* trendline-type: mandatory. The type of algorithm being used for the calculation, only SMA (simple moving average) or WMA (weighted moving average) are supported at the moment. * number-of-datapoints: mandatory. The number of datapoints to calculate the moving average (must be greater than zero). * field: mandatory. The name of the field the moving average should be calculated for. * alias: optional. The name of the resulting column containing the moving average (defaults to the field name with "_trendline"). From 0b50637cfaf9d4ab8b788bb052802ac6970f0d42 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Tue, 11 Feb 2025 15:52:45 -0800 Subject: [PATCH 20/36] Fix code commentse Signed-off-by: Andy Kwok --- .../planner/physical/TrendlineOperator.java | 95 ++++++++++--------- .../physical/TrendlineOperatorTest.java | 2 +- 2 files changed, 52 insertions(+), 45 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index 302dd69115..99002e12dd 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -6,6 +6,7 @@ package org.opensearch.sql.planner.physical; import static java.time.temporal.ChronoUnit.MILLIS; +import static java.util.stream.Collectors.*; import com.google.common.collect.EvictingQueue; import com.google.common.collect.ImmutableMap.Builder; @@ -14,11 +15,17 @@ import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Queue; import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.DoubleStream; +import java.util.stream.IntStream; + import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; @@ -125,9 +132,9 @@ private abstract static class TrendlineAccumulator { protected final Queue receivedValues; private TrendlineAccumulator( - LiteralExpression dataPointsNeeded, Queue receivedValues) { - this.dataPointsNeeded = dataPointsNeeded; - this.receivedValues = receivedValues; + Trendline.TrendlineComputation config) { + this.dataPointsNeeded = DSL.literal(config.getNumberOfDataPoints().doubleValue()); + this.receivedValues = EvictingQueue.create(config.getNumberOfDataPoints()); } abstract void accumulate(ExprValue value); @@ -141,9 +148,7 @@ private static class SimpleMovingAverageAccumulator extends TrendlineAccumulator public SimpleMovingAverageAccumulator( Trendline.TrendlineComputation computation, ExprCoreType type) { - super( - DSL.literal(computation.getNumberOfDataPoints().doubleValue()), - EvictingQueue.create(computation.getNumberOfDataPoints())); + super(computation); evaluator = getEvaluator(type); } @@ -333,18 +338,20 @@ public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDat private static class WeightedMovingAverageAccumulator extends TrendlineAccumulator { - private final BiFunction, Integer, ExprValue> evaluator; - private final int totalWeight; + private final Function, ExprValue> evaluator; + private final List weights; public WeightedMovingAverageAccumulator( Trendline.TrendlineComputation computation, ExprCoreType type) { - super(DSL.literal(computation.getNumberOfDataPoints()), new LinkedList<>()); - this.totalWeight = - (computation.getNumberOfDataPoints() * (computation.getNumberOfDataPoints() + 1)) / 2; + super(computation); + int dataPoints = computation.getNumberOfDataPoints(); this.evaluator = getWmaEvaluator(type); + this.weights = IntStream.rangeClosed(1, dataPoints) + .mapToObj(i -> i / ((dataPoints * (dataPoints + 1)) / 2d)) + .collect(toList()); } - static BiFunction, Integer, ExprValue> getWmaEvaluator(ExprCoreType type) { + Function, ExprValue> getWmaEvaluator(ExprCoreType type) { return switch (type) { case INTEGER, SHORT, LONG, FLOAT, DOUBLE -> WMA_NUMERIC_EVALUATOR; case DATE, TIMESTAMP -> WMA_TIMESTAMP_EVALUATOR; @@ -369,43 +376,43 @@ public ExprValue calculate() { } else if (dataPointsNeeded.valueOf().integerValue() == 1) { return receivedValues.peek(); } - return evaluator.apply(receivedValues, totalWeight); + return evaluator.apply(receivedValues); } - public static final BiFunction, Integer, ExprValue> WMA_NUMERIC_EVALUATOR = - (receivedValues, totalWeight) -> { - double sum = 0D; - int count = 0; - for (ExprValue next : receivedValues) { - sum += next.doubleValue() * ((count + 1D) / totalWeight); - count++; - } - return new ExprDoubleValue(sum); - }; + public final Function, ExprValue> WMA_NUMERIC_EVALUATOR = + (receivedValues) -> + new ExprDoubleValue(calculateWmaInDouble(receivedValues, ExprValue::doubleValue));; + + public final Function, ExprValue> WMA_TIMESTAMP_EVALUATOR = + (receivedValues) -> { + Long wmaResult = Math.round(calculateWmaInDouble(receivedValues, + i -> (double) (i.timestampValue().toEpochMilli()))); + return ExprValueUtils.timestampValue(Instant.ofEpochMilli((wmaResult))); - public static final BiFunction, Integer, ExprValue> WMA_TIMESTAMP_EVALUATOR = - (receivedValues, totalWeight) -> { - long sum = 0L; - int count = 0; - for (ExprValue next : receivedValues) { - sum += (long) (next.timestampValue().toEpochMilli() * ((count + 1D) / totalWeight)); - count++; - } - return ExprValueUtils.timestampValue(Instant.ofEpochMilli((sum))); }; - public static final BiFunction, Integer, ExprValue> WMA_TIME_EVALUATOR = - (receivedValues, totalWeight) -> { - long sum = 0L; - int count = 0; - for (ExprValue next : receivedValues) { - sum += - (long) - (MILLIS.between(LocalTime.MIN, next.timeValue()) - * ((count + 1D) / totalWeight)); - count++; - } - return ExprValueUtils.timeValue(LocalTime.MIN.plus(sum, MILLIS)); + public final Function, ExprValue> WMA_TIME_EVALUATOR = + (receivedValues) -> { + Long wmaResult = Math.round(calculateWmaInDouble(receivedValues, + i -> (double) (MILLIS.between(LocalTime.MIN, i.timeValue())))); + return ExprValueUtils.timeValue(LocalTime.MIN.plus(wmaResult, MILLIS)); }; + + /** + * Responsible to iterate the internal buffer, perform necessary calculation, + * and the up-to-date wma result in Double + * @param receivedValues internal buffer which stores all value in range. + * @param exprToDouble transformation function to convert incoming values to double for calcaution. + * @return wma result in Double form. + */ + private Double calculateWmaInDouble(Queue receivedValues, Function exprToDouble) { + double sum = 0D; + Iterator weightIter = weights.iterator(); + for (ExprValue next : receivedValues) { + sum += exprToDouble.apply(next) * (weightIter.next()); + } + return sum; + } + } } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index 7e1c0087e3..16292f40eb 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -921,7 +921,7 @@ public void calculates_weighted_moving_average_timestamp() { "timestamp", Instant.EPOCH.plusMillis(1000), "timestamp_alias", - Instant.EPOCH.plusMillis(666))), + Instant.EPOCH.plusMillis(667))), plan.next()); assertTrue(plan.hasNext()); assertEquals( From b8e39837de62b524f4e0d7563184c48af8b392f5 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Tue, 11 Feb 2025 15:58:25 -0800 Subject: [PATCH 21/36] update default name Signed-off-by: Andy Kwok --- .../planner/physical/TrendlineOperator.java | 52 ++++++++++--------- .../sql/ppl/parser/AstExpressionBuilder.java | 7 ++- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index 99002e12dd..82fa9698f2 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -16,16 +16,11 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; -import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Queue; -import java.util.function.BiFunction; import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.DoubleStream; import java.util.stream.IntStream; - import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; @@ -131,8 +126,7 @@ private abstract static class TrendlineAccumulator { protected final Queue receivedValues; - private TrendlineAccumulator( - Trendline.TrendlineComputation config) { + private TrendlineAccumulator(Trendline.TrendlineComputation config) { this.dataPointsNeeded = DSL.literal(config.getNumberOfDataPoints().doubleValue()); this.receivedValues = EvictingQueue.create(config.getNumberOfDataPoints()); } @@ -346,7 +340,8 @@ public WeightedMovingAverageAccumulator( super(computation); int dataPoints = computation.getNumberOfDataPoints(); this.evaluator = getWmaEvaluator(type); - this.weights = IntStream.rangeClosed(1, dataPoints) + this.weights = + IntStream.rangeClosed(1, dataPoints) .mapToObj(i -> i / ((dataPoints * (dataPoints + 1)) / 2d)) .collect(toList()); } @@ -381,38 +376,45 @@ public ExprValue calculate() { public final Function, ExprValue> WMA_NUMERIC_EVALUATOR = (receivedValues) -> - new ExprDoubleValue(calculateWmaInDouble(receivedValues, ExprValue::doubleValue));; + new ExprDoubleValue(calculateWmaInDouble(receivedValues, ExprValue::doubleValue)); + ; public final Function, ExprValue> WMA_TIMESTAMP_EVALUATOR = (receivedValues) -> { - Long wmaResult = Math.round(calculateWmaInDouble(receivedValues, - i -> (double) (i.timestampValue().toEpochMilli()))); + Long wmaResult = + Math.round( + calculateWmaInDouble( + receivedValues, i -> (double) (i.timestampValue().toEpochMilli()))); return ExprValueUtils.timestampValue(Instant.ofEpochMilli((wmaResult))); - }; public final Function, ExprValue> WMA_TIME_EVALUATOR = (receivedValues) -> { - Long wmaResult = Math.round(calculateWmaInDouble(receivedValues, + Long wmaResult = + Math.round( + calculateWmaInDouble( + receivedValues, i -> (double) (MILLIS.between(LocalTime.MIN, i.timeValue())))); - return ExprValueUtils.timeValue(LocalTime.MIN.plus(wmaResult, MILLIS)); + return ExprValueUtils.timeValue(LocalTime.MIN.plus(wmaResult, MILLIS)); }; /** - * Responsible to iterate the internal buffer, perform necessary calculation, - * and the up-to-date wma result in Double + * Responsible to iterate the internal buffer, perform necessary calculation, and the up-to-date + * wma result in Double + * * @param receivedValues internal buffer which stores all value in range. - * @param exprToDouble transformation function to convert incoming values to double for calcaution. + * @param exprToDouble transformation function to convert incoming values to double for + * calcaution. * @return wma result in Double form. */ - private Double calculateWmaInDouble(Queue receivedValues, Function exprToDouble) { - double sum = 0D; - Iterator weightIter = weights.iterator(); - for (ExprValue next : receivedValues) { - sum += exprToDouble.apply(next) * (weightIter.next()); - } - return sum; + private Double calculateWmaInDouble( + Queue receivedValues, Function exprToDouble) { + double sum = 0D; + Iterator weightIter = weights.iterator(); + for (ExprValue next : receivedValues) { + sum += exprToDouble.apply(next) * (weightIter.next()); + } + return sum; } - } } diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 5a7522683a..0424b148b6 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -89,13 +89,12 @@ public Trendline.TrendlineComputation visitTrendlineClause( } final Field dataField = (Field) this.visitFieldExpression(ctx.field); + final Trendline.TrendlineType computationType = + Trendline.TrendlineType.valueOf(ctx.trendlineType().getText().toUpperCase(Locale.ROOT)); final String alias = ctx.alias != null ? ctx.alias.getText() - : dataField.getChild().get(0).toString() + "_trendline"; - - final Trendline.TrendlineType computationType = - Trendline.TrendlineType.valueOf(ctx.trendlineType().getText().toUpperCase(Locale.ROOT)); + : dataField.getChild().get(0).toString() + "_" + computationType.name() + "_trendline"; return new Trendline.TrendlineComputation( numberOfDataPoints, dataField, alias, computationType); } From 8dc96c96696707a740f92d869612b6e703a9ffc4 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Tue, 11 Feb 2025 16:00:05 -0800 Subject: [PATCH 22/36] Doc Signed-off-by: Andy Kwok --- docs/user/ppl/cmd/trendline.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user/ppl/cmd/trendline.rst b/docs/user/ppl/cmd/trendline.rst index 79c832a86c..301523539d 100644 --- a/docs/user/ppl/cmd/trendline.rst +++ b/docs/user/ppl/cmd/trendline.rst @@ -22,7 +22,7 @@ Syntax * trendline-type: mandatory. The type of algorithm being used for the calculation, only SMA (simple moving average) or WMA (weighted moving average) are supported at the moment. * number-of-datapoints: mandatory. The number of datapoints to calculate the moving average (must be greater than zero). * field: mandatory. The name of the field the moving average should be calculated for. -* alias: optional. The name of the resulting column containing the moving average (defaults to the field name with "_trendline"). +* alias: optional. The name of the resulting column containing the moving average (defaults to the field name with "__trendline"). In the case of Simple Moving Average - SMA, result will be calculated as per the below formula. From c197ecdffe50d8e081be11314f99853f92f2a771 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Wed, 12 Feb 2025 09:09:37 -0800 Subject: [PATCH 23/36] Doc Signed-off-by: Andy Kwok --- docs/user/ppl/cmd/trendline.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/user/ppl/cmd/trendline.rst b/docs/user/ppl/cmd/trendline.rst index 301523539d..147e11f82b 100644 --- a/docs/user/ppl/cmd/trendline.rst +++ b/docs/user/ppl/cmd/trendline.rst @@ -32,7 +32,7 @@ In the case of Simple Moving Average - SMA, result will be calculated as per the SMA(t) = (1/n) * Σ(f[i]), where i = t-n+1 to t -Example 1: Calculate the moving average on one field. +Example 1: Calculate the simple moving average on one field. ===================================================== The example shows how to calculate the moving average on one field. @@ -51,7 +51,7 @@ PPL query:: +------+ -Example 2: Calculate the moving average on multiple fields. +Example 2: Calculate the simple moving average on multiple fields. =========================================================== The example shows how to calculate the moving average on multiple fields. @@ -69,7 +69,7 @@ PPL query:: | 15.5 | 30.5 | +------+-----------+ -Example 3: Calculate the moving average on one field without specifying an alias. +Example 3: Calculate the simple moving average on one field without specifying an alias. ================================================================================= The example shows how to calculate the moving average on one field without specifying an alias. From 5ea47aec753d38ece9f8ba536db1f46ba7c0d911 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Wed, 12 Feb 2025 10:14:57 -0800 Subject: [PATCH 24/36] Refactor test-cases Signed-off-by: Andy Kwok --- .../physical/TrendlineOperatorTest.java | 799 ++++++++---------- 1 file changed, 344 insertions(+), 455 deletions(-) diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index 16292f40eb..7277318e3d 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -5,6 +5,8 @@ package org.opensearch.sql.planner.physical; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -12,6 +14,7 @@ import static org.mockito.Mockito.when; import static org.opensearch.sql.ast.tree.Trendline.TrendlineType.SMA; import static org.opensearch.sql.ast.tree.Trendline.TrendlineType.WMA; +import static org.opensearch.sql.data.model.ExprValueUtils.tupleValue; import com.google.common.collect.ImmutableMap; import java.time.Instant; @@ -19,6 +22,8 @@ import java.time.LocalTime; import java.util.Arrays; import java.util.Collections; +import java.util.List; + import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; @@ -28,19 +33,20 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.data.model.ExprNullValue; +import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.data.type.ExprCoreType; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @ExtendWith(MockitoExtension.class) -public class TrendlineOperatorTest { +public class TrendlineOperatorTest extends PhysicalPlanTestBase { @Mock private PhysicalPlan inputPlan; @Test public void calculates_simple_moving_average_one_field_one_sample() { when(inputPlan.hasNext()).thenReturn(true, false); when(inputPlan.next()) - .thenReturn(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10))); + .thenReturn(tupleValue(ImmutableMap.of("distance", 100, "time", 10))); var plan = new TrendlineOperator( @@ -50,12 +56,10 @@ public void calculates_simple_moving_average_one_field_one_sample() { AstDSL.computation(1, AstDSL.field("distance"), "distance_alias", SMA), ExprCoreType.DOUBLE))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100)), - plan.next()); + List result = execute(plan); + assertEquals(1, result.size()); + assertThat(result, containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100)))); } @Test @@ -63,8 +67,8 @@ public void calculates_simple_moving_average_one_field_two_samples() { when(inputPlan.hasNext()).thenReturn(true, true, false); when(inputPlan.next()) .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10))); var plan = new TrendlineOperator( @@ -74,16 +78,12 @@ public void calculates_simple_moving_average_one_field_two_samples() { AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), ExprCoreType.DOUBLE))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), - plan.next()); - assertFalse(plan.hasNext()); + List result = execute(plan); + assertEquals(2, result.size()); + assertThat(result, containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0) + ))); } @Test @@ -91,9 +91,9 @@ public void calculates_simple_moving_average_one_field_two_samples_three_rows() when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10))); var plan = new TrendlineOperator( @@ -103,21 +103,14 @@ public void calculates_simple_moving_average_one_field_two_samples_three_rows() AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), ExprCoreType.DOUBLE))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)), - plan.next()); - assertFalse(plan.hasNext()); + + List result = execute(plan); + assertEquals(3, result.size()); + assertThat(result, containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)) + )); } @Test @@ -125,9 +118,9 @@ public void calculates_simple_moving_average_data_type_support_short() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10))); var plan = new TrendlineOperator( @@ -137,21 +130,14 @@ public void calculates_simple_moving_average_data_type_support_short() { AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), ExprCoreType.SHORT))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)), - plan.next()); - assertFalse(plan.hasNext()); + + List result = execute(plan); + assertEquals(3, result.size()); + assertThat(result, containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0))) + ); } @Test @@ -159,9 +145,9 @@ public void calculates_simple_moving_average_data_type_support_long() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10))); var plan = new TrendlineOperator( @@ -171,21 +157,15 @@ public void calculates_simple_moving_average_data_type_support_long() { AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), ExprCoreType.SHORT))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)), - plan.next()); - assertFalse(plan.hasNext()); + List result = execute(plan); + assertEquals(3, result.size()); + assertThat(result, containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0))) + ); + + } @Test @@ -193,9 +173,9 @@ public void calculates_simple_moving_average_data_type_support_float() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10))); var plan = new TrendlineOperator( @@ -205,21 +185,13 @@ public void calculates_simple_moving_average_data_type_support_float() { AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), ExprCoreType.FLOAT))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)), - plan.next()); - assertFalse(plan.hasNext()); + List result = execute(plan); + assertEquals(3, result.size()); + assertThat(result, containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0))) + ); } @Test @@ -227,9 +199,9 @@ public void calculates_simple_moving_average_multiple_computations() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20))); + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 20)), + tupleValue(ImmutableMap.of("distance", 200, "time", 20))); var plan = new TrendlineOperator( @@ -242,23 +214,13 @@ public void calculates_simple_moving_average_multiple_computations() { AstDSL.computation(2, AstDSL.field("time"), "time_alias", SMA), ExprCoreType.DOUBLE))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "distance", 200, "time", 20, "distance_alias", 150.0, "time_alias", 15.0)), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "distance", 200, "time", 20, "distance_alias", 200.0, "time_alias", 20.0)), - plan.next()); - assertFalse(plan.hasNext()); + List result = execute(plan); + assertEquals(3, result.size()); + assertThat(result, containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 20, "distance_alias", 150.0, "time_alias", 15.0)), + tupleValue(ImmutableMap.of("distance", 200, "time", 20, "distance_alias", 200.0, "time_alias", 20.0))) + ); } @Test @@ -266,9 +228,9 @@ public void alias_overwrites_input_field() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10))); var plan = new TrendlineOperator( @@ -278,16 +240,14 @@ public void alias_overwrites_input_field() { AstDSL.computation(2, AstDSL.field("distance"), "time", SMA), ExprCoreType.DOUBLE))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 150.0)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 200.0)), plan.next()); - assertFalse(plan.hasNext()); + List result = execute(plan); + assertEquals(3, result.size()); + assertThat(result, containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100)), + tupleValue(ImmutableMap.of("distance", 200, "time", 150.0)), + tupleValue(ImmutableMap.of("distance", 200, "time", 200.0))) + ); + } @Test @@ -295,9 +255,9 @@ public void calculates_simple_moving_average_one_field_two_samples_three_rows_nu when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 300, "time", 10))); + tupleValue(ImmutableMap.of("time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 300, "time", 10))); var plan = new TrendlineOperator( @@ -307,18 +267,15 @@ public void calculates_simple_moving_average_one_field_two_samples_three_rows_nu AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), ExprCoreType.DOUBLE))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("time", 10)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 300, "time", 10, "distance_alias", 250.0)), - plan.next()); - assertFalse(plan.hasNext()); + List result = execute(plan); + assertEquals(3, result.size()); + assertThat(result, containsInAnyOrder( + tupleValue(ImmutableMap.of("time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 300, "time", 10, "distance_alias", 250.0))) + ); + + } @Test @@ -326,9 +283,9 @@ public void use_null_value() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", ExprNullValue.of(), "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10))); + tupleValue(ImmutableMap.of("time", 10)), + tupleValue(ImmutableMap.of("distance", ExprNullValue.of(), "time", 10)), + tupleValue(ImmutableMap.of("distance", 100, "time", 10))); var plan = new TrendlineOperator( @@ -338,19 +295,15 @@ public void use_null_value() { AstDSL.computation(1, AstDSL.field("distance"), "distance_alias", SMA), ExprCoreType.DOUBLE))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("time", 10)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", ExprNullValue.of(), "time", 10)), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100)), - plan.next()); - assertFalse(plan.hasNext()); + List result = execute(plan); + assertEquals(3, result.size()); + assertThat(result, containsInAnyOrder( + tupleValue(ImmutableMap.of("time", 10)), + tupleValue(ImmutableMap.of("distance", ExprNullValue.of(), "time", 10)), + tupleValue(ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100))) + ); + + } @Test @@ -372,11 +325,11 @@ public void calculates_simple_moving_average_date() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) .thenReturn( - ExprValueUtils.tupleValue( + tupleValue( ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), - ExprValueUtils.tupleValue( + tupleValue( ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)))), - ExprValueUtils.tupleValue( + tupleValue( ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12))))); var plan = @@ -387,31 +340,21 @@ public void calculates_simple_moving_average_date() { AstDSL.computation(2, AstDSL.field("date"), "date_alias", SMA), ExprCoreType.DATE))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "date", - ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)), - "date_alias", - ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(3)))), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "date", - ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12)), - "date_alias", - ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(9)))), - plan.next()); - assertFalse(plan.hasNext()); + List result = execute(plan); + assertEquals(3, result.size()); + assertThat(result, containsInAnyOrder( + tupleValue(ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), + tupleValue(ImmutableMap.of( + "date", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)), + "date_alias", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(3)))), + tupleValue(ImmutableMap.of( + "date", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12)), + "date_alias", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(9))))) + ); } @Test @@ -419,11 +362,11 @@ public void calculates_simple_moving_average_time() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) .thenReturn( - ExprValueUtils.tupleValue( + tupleValue( ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN))), - ExprValueUtils.tupleValue( + tupleValue( ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(6)))), - ExprValueUtils.tupleValue( + tupleValue( ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(12))))); var plan = @@ -434,22 +377,17 @@ public void calculates_simple_moving_average_time() { AstDSL.computation(2, AstDSL.field("time"), "time_alias", SMA), ExprCoreType.TIME))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("time", LocalTime.MIN)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "time", LocalTime.MIN.plusHours(6), "time_alias", LocalTime.MIN.plusHours(3))), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "time", LocalTime.MIN.plusHours(12), "time_alias", LocalTime.MIN.plusHours(9))), - plan.next()); - assertFalse(plan.hasNext()); + List result = execute(plan); + assertEquals(3, result.size()); + assertThat(result, containsInAnyOrder( + tupleValue(ImmutableMap.of("time", LocalTime.MIN)), + tupleValue(ImmutableMap.of( + "time", LocalTime.MIN.plusHours(6), "time_alias", LocalTime.MIN.plusHours(3))), + tupleValue(ImmutableMap.of( + "time", LocalTime.MIN.plusHours(12), "time_alias", LocalTime.MIN.plusHours(9)))) + ); + + } @Test @@ -457,12 +395,12 @@ public void calculates_simple_moving_average_timestamp() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) .thenReturn( - ExprValueUtils.tupleValue( + tupleValue( ImmutableMap.of("timestamp", ExprValueUtils.timestampValue(Instant.EPOCH))), - ExprValueUtils.tupleValue( + tupleValue( ImmutableMap.of( "timestamp", ExprValueUtils.timestampValue(Instant.EPOCH.plusMillis(1000)))), - ExprValueUtils.tupleValue( + tupleValue( ImmutableMap.of( "timestamp", ExprValueUtils.timestampValue(Instant.EPOCH.plusMillis(1500))))); @@ -474,36 +412,30 @@ public void calculates_simple_moving_average_timestamp() { AstDSL.computation(2, AstDSL.field("timestamp"), "timestamp_alias", SMA), ExprCoreType.TIMESTAMP))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("timestamp", Instant.EPOCH)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "timestamp", - Instant.EPOCH.plusMillis(1000), - "timestamp_alias", - Instant.EPOCH.plusMillis(500))), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "timestamp", - Instant.EPOCH.plusMillis(1500), - "timestamp_alias", - Instant.EPOCH.plusMillis(1250))), - plan.next()); - assertFalse(plan.hasNext()); + List result = execute(plan); + assertEquals(3, result.size()); + assertThat(result, containsInAnyOrder( + tupleValue(ImmutableMap.of("timestamp", Instant.EPOCH)), + tupleValue(ImmutableMap.of( + "timestamp", + Instant.EPOCH.plusMillis(1000), + "timestamp_alias", + Instant.EPOCH.plusMillis(500))), + tupleValue(ImmutableMap.of( + "timestamp", + Instant.EPOCH.plusMillis(1500), + "timestamp_alias", + Instant.EPOCH.plusMillis(1250)))) + ); + + } @Test public void calculates_weighted_moving_average_one_field_one_sample() { when(inputPlan.hasNext()).thenReturn(true, false); when(inputPlan.next()) - .thenReturn(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10))); + .thenReturn(tupleValue(ImmutableMap.of("distance", 100, "time", 10))); var plan = new TrendlineOperator( @@ -514,9 +446,10 @@ public void calculates_weighted_moving_average_one_field_one_sample() { ExprCoreType.DOUBLE))); plan.open(); + assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue( + tupleValue( ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100)), plan.next()); } @@ -526,8 +459,8 @@ public void calculates_weighted_moving_average_one_field_two_samples() { when(inputPlan.hasNext()).thenReturn(true, true, false); when(inputPlan.next()) .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10))); var plan = new TrendlineOperator( @@ -537,16 +470,16 @@ public void calculates_weighted_moving_average_one_field_two_samples() { AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), ExprCoreType.DOUBLE))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), - plan.next()); - assertFalse(plan.hasNext()); + + List result = execute(plan); + assertEquals(2, result.size()); + assertThat(result, containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663))) + ); + + } @Test @@ -554,9 +487,9 @@ public void calculates_weighted_moving_average_one_field_two_samples_three_rows( when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10))); var plan = new TrendlineOperator( @@ -566,21 +499,16 @@ public void calculates_weighted_moving_average_one_field_two_samples_three_rows( AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), ExprCoreType.DOUBLE))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997)), - plan.next()); - assertFalse(plan.hasNext()); + List result = execute(plan); + assertEquals(3, result.size()); + assertThat(result, containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), + tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997))) + ); + } @Test @@ -588,9 +516,9 @@ public void calculates_weighted_moving_average_data_type_support_short() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10))); var plan = new TrendlineOperator( @@ -600,21 +528,17 @@ public void calculates_weighted_moving_average_data_type_support_short() { AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), ExprCoreType.SHORT))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997)), - plan.next()); - assertFalse(plan.hasNext()); + List result = execute(plan); + assertEquals(3, result.size()); + assertThat(result, containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), + tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997))) + ); + + } @Test @@ -622,9 +546,9 @@ public void calculates_weighted_moving_average_data_type_support_integer() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10))); var plan = new TrendlineOperator( @@ -634,21 +558,16 @@ public void calculates_weighted_moving_average_data_type_support_integer() { AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), ExprCoreType.INTEGER))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997)), - plan.next()); - assertFalse(plan.hasNext()); + List result = execute(plan); + assertEquals(3, result.size()); + assertThat(result, containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), + tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997))) + ); + } @Test @@ -656,9 +575,9 @@ public void calculates_weighted_moving_average_data_type_support_long() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10))); var plan = new TrendlineOperator( @@ -668,21 +587,18 @@ public void calculates_weighted_moving_average_data_type_support_long() { AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), ExprCoreType.LONG))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997)), - plan.next()); - assertFalse(plan.hasNext()); + + List result = execute(plan); + assertEquals(3, result.size()); + assertThat(result, containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), + tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997))) + ); + + } @Test @@ -690,9 +606,9 @@ public void calculates_weighted_moving_average_data_type_support_float() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10))); var plan = new TrendlineOperator( @@ -702,21 +618,16 @@ public void calculates_weighted_moving_average_data_type_support_float() { AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), ExprCoreType.FLOAT))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997)), - plan.next()); - assertFalse(plan.hasNext()); + List result = execute(plan); + assertEquals(3, result.size()); + assertThat(result, containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), + tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997))) + ); + } @Test @@ -724,9 +635,9 @@ public void calculates_weighted_moving_average_multiple_computations() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20))); + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 20)), + tupleValue(ImmutableMap.of("distance", 200, "time", 20))); var plan = new TrendlineOperator( @@ -739,37 +650,32 @@ public void calculates_weighted_moving_average_multiple_computations() { AstDSL.computation(2, AstDSL.field("time"), "time_alias", WMA), ExprCoreType.DOUBLE))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "distance", - 200, - "time", - 20, - "distance_alias", - 166.66666666666663, - "time_alias", - 16.666666666666664)), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "distance", - 200, - "time", - 20, - "distance_alias", - 199.99999999999997, - "time_alias", - 20.0)), - plan.next()); - assertFalse(plan.hasNext()); + List result = execute(plan); + assertEquals(3, result.size()); + assertThat(result, containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue( + ImmutableMap.of( + "distance", + 200, + "time", + 20, + "distance_alias", + 166.66666666666663, + "time_alias", + 16.666666666666664)), + tupleValue( + ImmutableMap.of( + "distance", + 200, + "time", + 20, + "distance_alias", + 199.99999999999997, + "time_alias", + 20.0))) + ); + } @Test @@ -777,9 +683,9 @@ public void calculates_weighted_moving_average_one_field_two_samples_three_rows_ when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 300, "time", 10))); + tupleValue(ImmutableMap.of("time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 300, "time", 10))); var plan = new TrendlineOperator( @@ -789,18 +695,17 @@ public void calculates_weighted_moving_average_one_field_two_samples_three_rows_ AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), ExprCoreType.DOUBLE))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("time", 10)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 300, "time", 10, "distance_alias", 266.66666666666663)), - plan.next()); - assertFalse(plan.hasNext()); + + List result = execute(plan); + assertEquals(3, result.size()); + assertThat(result, containsInAnyOrder( + tupleValue(ImmutableMap.of("time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue( + ImmutableMap.of("distance", 300, "time", 10, "distance_alias", 266.66666666666663))) + ); + + } @Test @@ -808,11 +713,11 @@ public void calculates_weighted_moving_average_date() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) .thenReturn( - ExprValueUtils.tupleValue( + tupleValue( ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), - ExprValueUtils.tupleValue( + tupleValue( ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)))), - ExprValueUtils.tupleValue( + tupleValue( ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12))))); var plan = @@ -823,31 +728,25 @@ public void calculates_weighted_moving_average_date() { AstDSL.computation(2, AstDSL.field("date"), "date_alias", WMA), ExprCoreType.DATE))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "date", - ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)), - "date_alias", - ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(4)))), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "date", - ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12)), - "date_alias", - ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(10)))), - plan.next()); - assertFalse(plan.hasNext()); + + List result = execute(plan); + assertEquals(3, result.size()); + assertThat(result, containsInAnyOrder( + tupleValue(ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), + tupleValue(ImmutableMap.of( + "date", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)), + "date_alias", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(4)))), + tupleValue( + ImmutableMap.of( + "date", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12)), + "date_alias", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(10))))) + ); + + } @Test @@ -855,11 +754,11 @@ public void calculates_weighted_moving_average_time() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) .thenReturn( - ExprValueUtils.tupleValue( + tupleValue( ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN))), - ExprValueUtils.tupleValue( + tupleValue( ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(6)))), - ExprValueUtils.tupleValue( + tupleValue( ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(12))))); var plan = @@ -870,22 +769,18 @@ public void calculates_weighted_moving_average_time() { AstDSL.computation(2, AstDSL.field("time"), "time_alias", WMA), ExprCoreType.TIME))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("time", LocalTime.MIN)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "time", LocalTime.MIN.plusHours(6), "time_alias", LocalTime.MIN.plusHours(4))), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "time", LocalTime.MIN.plusHours(12), "time_alias", LocalTime.MIN.plusHours(10))), - plan.next()); - assertFalse(plan.hasNext()); + + List result = execute(plan); + assertEquals(3, result.size()); + assertThat(result, containsInAnyOrder( + tupleValue(ImmutableMap.of("time", LocalTime.MIN)), + tupleValue(ImmutableMap.of( + "time", LocalTime.MIN.plusHours(6), "time_alias", LocalTime.MIN.plusHours(4))), + tupleValue( + ImmutableMap.of( + "time", LocalTime.MIN.plusHours(12), "time_alias", LocalTime.MIN.plusHours(10)))) + ); + } @Test @@ -893,12 +788,12 @@ public void calculates_weighted_moving_average_timestamp() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) .thenReturn( - ExprValueUtils.tupleValue( + tupleValue( ImmutableMap.of("timestamp", ExprValueUtils.timestampValue(Instant.EPOCH))), - ExprValueUtils.tupleValue( + tupleValue( ImmutableMap.of( "timestamp", ExprValueUtils.timestampValue(Instant.EPOCH.plusMillis(1000)))), - ExprValueUtils.tupleValue( + tupleValue( ImmutableMap.of( "timestamp", ExprValueUtils.timestampValue(Instant.EPOCH.plusMillis(1500))))); @@ -910,29 +805,23 @@ public void calculates_weighted_moving_average_timestamp() { AstDSL.computation(2, AstDSL.field("timestamp"), "timestamp_alias", WMA), ExprCoreType.TIMESTAMP))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("timestamp", Instant.EPOCH)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "timestamp", - Instant.EPOCH.plusMillis(1000), - "timestamp_alias", - Instant.EPOCH.plusMillis(667))), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "timestamp", - Instant.EPOCH.plusMillis(1500), - "timestamp_alias", - Instant.EPOCH.plusMillis(1333))), - plan.next()); - assertFalse(plan.hasNext()); + List result = execute(plan); + assertEquals(3, result.size()); + assertThat(result, containsInAnyOrder( + tupleValue(ImmutableMap.of("timestamp", Instant.EPOCH)), + tupleValue(ImmutableMap.of( + "timestamp", + Instant.EPOCH.plusMillis(1000), + "timestamp_alias", + Instant.EPOCH.plusMillis(667))), + tupleValue( + ImmutableMap.of( + "timestamp", + Instant.EPOCH.plusMillis(1500), + "timestamp_alias", + Instant.EPOCH.plusMillis(1333)))) + ); + } @Test From 8d6ac311bc46b2546f71a01e99f2b98fb71b5f32 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Wed, 12 Feb 2025 10:48:19 -0800 Subject: [PATCH 25/36] DataPoints test-cases Signed-off-by: Andy Kwok --- .../planner/physical/TrendlineOperator.java | 9 ++++- .../physical/TrendlineOperatorTest.java | 38 +++++++++++++++---- 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index 82fa9698f2..9374561e7b 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -127,8 +127,13 @@ private abstract static class TrendlineAccumulator { protected final Queue receivedValues; private TrendlineAccumulator(Trendline.TrendlineComputation config) { - this.dataPointsNeeded = DSL.literal(config.getNumberOfDataPoints().doubleValue()); - this.receivedValues = EvictingQueue.create(config.getNumberOfDataPoints()); + Integer numberOfDataPoints = config.getNumberOfDataPoints(); + if (numberOfDataPoints <=0) { + throw new IllegalArgumentException( + String.format("Invalid dataPoints [%d] value.", numberOfDataPoints)); + }; + this.dataPointsNeeded = DSL.literal(numberOfDataPoints.doubleValue()); + this.receivedValues = EvictingQueue.create(numberOfDataPoints); } abstract void accumulate(ExprValue value); diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index 7277318e3d..999cebbb32 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -828,13 +828,35 @@ public void calculates_weighted_moving_average_timestamp() { public void use_illegal_core_type_wma() { assertThrows( IllegalArgumentException.class, - () -> { - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), - ExprCoreType.ARRAY))); - }); + () -> new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.ARRAY)))); + } + + @Test + public void use_invalid_dataPoints_zero() { + assertThrows( + IllegalArgumentException.class, + () -> new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(0, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.INTEGER)))); + } + + @Test + public void use_invalid_dataPoints_negative() { + assertThrows( + IllegalArgumentException.class, + () -> new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(-100, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.INTEGER)))); } } From 9ce1d99573264d4ef89e785d8e6e52195e0ec11c Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Wed, 12 Feb 2025 13:01:14 -0800 Subject: [PATCH 26/36] Update test-cases Signed-off-by: Andy Kwok --- .../planner/physical/TrendlineOperator.java | 7 +- .../physical/TrendlineOperatorTest.java | 549 +++++++++--------- 2 files changed, 271 insertions(+), 285 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index 9374561e7b..8e3e70e74f 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -128,10 +128,11 @@ private abstract static class TrendlineAccumulator { private TrendlineAccumulator(Trendline.TrendlineComputation config) { Integer numberOfDataPoints = config.getNumberOfDataPoints(); - if (numberOfDataPoints <=0) { + if (numberOfDataPoints <= 0) { throw new IllegalArgumentException( - String.format("Invalid dataPoints [%d] value.", numberOfDataPoints)); - }; + String.format("Invalid dataPoints [%d] value.", numberOfDataPoints)); + } + ; this.dataPointsNeeded = DSL.literal(numberOfDataPoints.doubleValue()); this.receivedValues = EvictingQueue.create(numberOfDataPoints); } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index 999cebbb32..6acdb14bfa 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -8,7 +8,6 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.when; @@ -20,9 +19,12 @@ import java.time.Instant; import java.time.LocalDate; import java.time.LocalTime; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.Iterator; import java.util.List; +import java.util.Map; import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.DisplayNameGeneration; @@ -44,10 +46,8 @@ public class TrendlineOperatorTest extends PhysicalPlanTestBase { @Test public void calculates_simple_moving_average_one_field_one_sample() { - when(inputPlan.hasNext()).thenReturn(true, false); - when(inputPlan.next()) - .thenReturn(tupleValue(ImmutableMap.of("distance", 100, "time", 10))); - + mockPlanWithData(List.of( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)))); var plan = new TrendlineOperator( inputPlan, @@ -58,17 +58,24 @@ public void calculates_simple_moving_average_one_field_one_sample() { List result = execute(plan); assertEquals(1, result.size()); - assertThat(result, containsInAnyOrder( - tupleValue(ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100)))); + assertThat( + result, + containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100)))); } @Test public void calculates_simple_moving_average_one_field_two_samples() { - when(inputPlan.hasNext()).thenReturn(true, true, false); - when(inputPlan.next()) - .thenReturn( +// when(inputPlan.hasNext()).thenReturn(true, true, false); +// when(inputPlan.next()) +// .thenReturn( +// tupleValue(ImmutableMap.of("distance", 100, "time", 10)), +// tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + + mockPlanWithData(List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + tupleValue(ImmutableMap.of("distance", 200, "time", 10)) + )); var plan = new TrendlineOperator( @@ -80,20 +87,19 @@ public void calculates_simple_moving_average_one_field_two_samples() { List result = execute(plan); assertEquals(2, result.size()); - assertThat(result, containsInAnyOrder( - tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0) - ))); + assertThat( + result, + containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)))); } @Test public void calculates_simple_moving_average_one_field_two_samples_three_rows() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( + mockPlanWithData(List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); var plan = new TrendlineOperator( @@ -103,24 +109,22 @@ public void calculates_simple_moving_average_one_field_two_samples_three_rows() AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), ExprCoreType.DOUBLE))); - List result = execute(plan); assertEquals(3, result.size()); - assertThat(result, containsInAnyOrder( + assertThat( + result, + containsInAnyOrder( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)) - )); + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)))); } @Test public void calculates_simple_moving_average_data_type_support_short() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( + mockPlanWithData(List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); var plan = new TrendlineOperator( @@ -130,24 +134,22 @@ public void calculates_simple_moving_average_data_type_support_short() { AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), ExprCoreType.SHORT))); - List result = execute(plan); assertEquals(3, result.size()); - assertThat(result, containsInAnyOrder( + assertThat( + result, + containsInAnyOrder( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0))) - ); + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)))); } @Test public void calculates_simple_moving_average_data_type_support_long() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( + mockPlanWithData(List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); var plan = new TrendlineOperator( @@ -159,23 +161,20 @@ public void calculates_simple_moving_average_data_type_support_long() { List result = execute(plan); assertEquals(3, result.size()); - assertThat(result, containsInAnyOrder( + assertThat( + result, + containsInAnyOrder( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0))) - ); - - + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)))); } @Test public void calculates_simple_moving_average_data_type_support_float() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( + mockPlanWithData(List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); var plan = new TrendlineOperator( @@ -187,21 +186,20 @@ public void calculates_simple_moving_average_data_type_support_float() { List result = execute(plan); assertEquals(3, result.size()); - assertThat(result, containsInAnyOrder( + assertThat( + result, + containsInAnyOrder( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0))) - ); + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)))); } @Test public void calculates_simple_moving_average_multiple_computations() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( + mockPlanWithData(List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 20)), - tupleValue(ImmutableMap.of("distance", 200, "time", 20))); + tupleValue(ImmutableMap.of("distance", 200, "time", 20)))); var plan = new TrendlineOperator( @@ -216,21 +214,24 @@ public void calculates_simple_moving_average_multiple_computations() { List result = execute(plan); assertEquals(3, result.size()); - assertThat(result, containsInAnyOrder( + assertThat( + result, + containsInAnyOrder( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 20, "distance_alias", 150.0, "time_alias", 15.0)), - tupleValue(ImmutableMap.of("distance", 200, "time", 20, "distance_alias", 200.0, "time_alias", 20.0))) - ); + tupleValue( + ImmutableMap.of( + "distance", 200, "time", 20, "distance_alias", 150.0, "time_alias", 15.0)), + tupleValue( + ImmutableMap.of( + "distance", 200, "time", 20, "distance_alias", 200.0, "time_alias", 20.0)))); } @Test public void alias_overwrites_input_field() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( + mockPlanWithData(List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); var plan = new TrendlineOperator( @@ -242,22 +243,20 @@ public void alias_overwrites_input_field() { List result = execute(plan); assertEquals(3, result.size()); - assertThat(result, containsInAnyOrder( + assertThat( + result, + containsInAnyOrder( tupleValue(ImmutableMap.of("distance", 100)), tupleValue(ImmutableMap.of("distance", 200, "time", 150.0)), - tupleValue(ImmutableMap.of("distance", 200, "time", 200.0))) - ); - + tupleValue(ImmutableMap.of("distance", 200, "time", 200.0)))); } @Test public void calculates_simple_moving_average_one_field_two_samples_three_rows_null_value() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( + mockPlanWithData(List.of( tupleValue(ImmutableMap.of("time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - tupleValue(ImmutableMap.of("distance", 300, "time", 10))); + tupleValue(ImmutableMap.of("distance", 300, "time", 10)))); var plan = new TrendlineOperator( @@ -269,23 +268,20 @@ public void calculates_simple_moving_average_one_field_two_samples_three_rows_nu List result = execute(plan); assertEquals(3, result.size()); - assertThat(result, containsInAnyOrder( + assertThat( + result, + containsInAnyOrder( tupleValue(ImmutableMap.of("time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - tupleValue(ImmutableMap.of("distance", 300, "time", 10, "distance_alias", 250.0))) - ); - - + tupleValue(ImmutableMap.of("distance", 300, "time", 10, "distance_alias", 250.0)))); } @Test public void use_null_value() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( + mockPlanWithData(List.of( tupleValue(ImmutableMap.of("time", 10)), tupleValue(ImmutableMap.of("distance", ExprNullValue.of(), "time", 10)), - tupleValue(ImmutableMap.of("distance", 100, "time", 10))); + tupleValue(ImmutableMap.of("distance", 100, "time", 10)))); var plan = new TrendlineOperator( @@ -297,13 +293,12 @@ public void use_null_value() { List result = execute(plan); assertEquals(3, result.size()); - assertThat(result, containsInAnyOrder( + assertThat( + result, + containsInAnyOrder( tupleValue(ImmutableMap.of("time", 10)), tupleValue(ImmutableMap.of("distance", ExprNullValue.of(), "time", 10)), - tupleValue(ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100))) - ); - - + tupleValue(ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100)))); } @Test @@ -322,15 +317,12 @@ public void use_illegal_core_type() { @Test public void calculates_simple_moving_average_date() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( - tupleValue( - ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), + mockPlanWithData(List.of( + tupleValue(ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), tupleValue( ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)))), tupleValue( - ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12))))); + ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12)))))); var plan = new TrendlineOperator( @@ -342,32 +334,32 @@ public void calculates_simple_moving_average_date() { List result = execute(plan); assertEquals(3, result.size()); - assertThat(result, containsInAnyOrder( + assertThat( + result, + containsInAnyOrder( tupleValue(ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), - tupleValue(ImmutableMap.of( + tupleValue( + ImmutableMap.of( "date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)), "date_alias", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(3)))), - tupleValue(ImmutableMap.of( + tupleValue( + ImmutableMap.of( "date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12)), "date_alias", - ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(9))))) - ); + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(9)))))); } @Test public void calculates_simple_moving_average_time() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( - tupleValue( - ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN))), + mockPlanWithData(List.of( + tupleValue(ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN))), tupleValue( ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(6)))), tupleValue( - ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(12))))); + ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(12)))))); var plan = new TrendlineOperator( @@ -379,30 +371,31 @@ public void calculates_simple_moving_average_time() { List result = execute(plan); assertEquals(3, result.size()); - assertThat(result, containsInAnyOrder( + assertThat( + result, + containsInAnyOrder( tupleValue(ImmutableMap.of("time", LocalTime.MIN)), - tupleValue(ImmutableMap.of( + tupleValue( + ImmutableMap.of( "time", LocalTime.MIN.plusHours(6), "time_alias", LocalTime.MIN.plusHours(3))), - tupleValue(ImmutableMap.of( - "time", LocalTime.MIN.plusHours(12), "time_alias", LocalTime.MIN.plusHours(9)))) - ); - - + tupleValue( + ImmutableMap.of( + "time", + LocalTime.MIN.plusHours(12), + "time_alias", + LocalTime.MIN.plusHours(9))))); } @Test public void calculates_simple_moving_average_timestamp() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( - tupleValue( - ImmutableMap.of("timestamp", ExprValueUtils.timestampValue(Instant.EPOCH))), + mockPlanWithData(List.of( + tupleValue(ImmutableMap.of("timestamp", ExprValueUtils.timestampValue(Instant.EPOCH))), tupleValue( ImmutableMap.of( "timestamp", ExprValueUtils.timestampValue(Instant.EPOCH.plusMillis(1000)))), tupleValue( ImmutableMap.of( - "timestamp", ExprValueUtils.timestampValue(Instant.EPOCH.plusMillis(1500))))); + "timestamp", ExprValueUtils.timestampValue(Instant.EPOCH.plusMillis(1500)))))); var plan = new TrendlineOperator( @@ -414,28 +407,28 @@ public void calculates_simple_moving_average_timestamp() { List result = execute(plan); assertEquals(3, result.size()); - assertThat(result, containsInAnyOrder( + assertThat( + result, + containsInAnyOrder( tupleValue(ImmutableMap.of("timestamp", Instant.EPOCH)), - tupleValue(ImmutableMap.of( + tupleValue( + ImmutableMap.of( "timestamp", Instant.EPOCH.plusMillis(1000), "timestamp_alias", Instant.EPOCH.plusMillis(500))), - tupleValue(ImmutableMap.of( + tupleValue( + ImmutableMap.of( "timestamp", Instant.EPOCH.plusMillis(1500), "timestamp_alias", - Instant.EPOCH.plusMillis(1250)))) - ); - - + Instant.EPOCH.plusMillis(1250))))); } @Test public void calculates_weighted_moving_average_one_field_one_sample() { - when(inputPlan.hasNext()).thenReturn(true, false); - when(inputPlan.next()) - .thenReturn(tupleValue(ImmutableMap.of("distance", 100, "time", 10))); + mockPlanWithData(List.of( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)))); var plan = new TrendlineOperator( @@ -449,18 +442,15 @@ public void calculates_weighted_moving_average_one_field_one_sample() { assertTrue(plan.hasNext()); assertEquals( - tupleValue( - ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100)), + tupleValue(ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100)), plan.next()); } @Test public void calculates_weighted_moving_average_one_field_two_samples() { - when(inputPlan.hasNext()).thenReturn(true, true, false); - when(inputPlan.next()) - .thenReturn( + mockPlanWithData(List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); var plan = new TrendlineOperator( @@ -470,26 +460,23 @@ public void calculates_weighted_moving_average_one_field_two_samples() { AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), ExprCoreType.DOUBLE))); - List result = execute(plan); assertEquals(2, result.size()); - assertThat(result, containsInAnyOrder( + assertThat( + result, + containsInAnyOrder( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663))) - ); - - + ImmutableMap.of( + "distance", 200, "time", 10, "distance_alias", 166.66666666666663)))); } @Test public void calculates_weighted_moving_average_one_field_two_samples_three_rows() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( + mockPlanWithData(List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); var plan = new TrendlineOperator( @@ -501,24 +488,23 @@ public void calculates_weighted_moving_average_one_field_two_samples_three_rows( List result = execute(plan); assertEquals(3, result.size()); - assertThat(result, containsInAnyOrder( + assertThat( + result, + containsInAnyOrder( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997))) - ); - + ImmutableMap.of( + "distance", 200, "time", 10, "distance_alias", 199.99999999999997)))); } @Test public void calculates_weighted_moving_average_data_type_support_short() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( + mockPlanWithData(List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); var plan = new TrendlineOperator( @@ -530,25 +516,23 @@ public void calculates_weighted_moving_average_data_type_support_short() { List result = execute(plan); assertEquals(3, result.size()); - assertThat(result, containsInAnyOrder( + assertThat( + result, + containsInAnyOrder( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997))) - ); - - + ImmutableMap.of( + "distance", 200, "time", 10, "distance_alias", 199.99999999999997)))); } @Test public void calculates_weighted_moving_average_data_type_support_integer() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( + mockPlanWithData(List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); var plan = new TrendlineOperator( @@ -560,24 +544,23 @@ public void calculates_weighted_moving_average_data_type_support_integer() { List result = execute(plan); assertEquals(3, result.size()); - assertThat(result, containsInAnyOrder( + assertThat( + result, + containsInAnyOrder( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997))) - ); - + ImmutableMap.of( + "distance", 200, "time", 10, "distance_alias", 199.99999999999997)))); } @Test public void calculates_weighted_moving_average_data_type_support_long() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( + mockPlanWithData(List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); var plan = new TrendlineOperator( @@ -587,28 +570,25 @@ public void calculates_weighted_moving_average_data_type_support_long() { AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), ExprCoreType.LONG))); - List result = execute(plan); assertEquals(3, result.size()); - assertThat(result, containsInAnyOrder( + assertThat( + result, + containsInAnyOrder( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997))) - ); - - + ImmutableMap.of( + "distance", 200, "time", 10, "distance_alias", 199.99999999999997)))); } @Test public void calculates_weighted_moving_average_data_type_support_float() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( + mockPlanWithData(List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); var plan = new TrendlineOperator( @@ -620,24 +600,23 @@ public void calculates_weighted_moving_average_data_type_support_float() { List result = execute(plan); assertEquals(3, result.size()); - assertThat(result, containsInAnyOrder( + assertThat( + result, + containsInAnyOrder( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997))) - ); - + ImmutableMap.of( + "distance", 200, "time", 10, "distance_alias", 199.99999999999997)))); } @Test public void calculates_weighted_moving_average_multiple_computations() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( + mockPlanWithData(List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 20)), - tupleValue(ImmutableMap.of("distance", 200, "time", 20))); + tupleValue(ImmutableMap.of("distance", 200, "time", 20)))); var plan = new TrendlineOperator( @@ -652,40 +631,38 @@ public void calculates_weighted_moving_average_multiple_computations() { List result = execute(plan); assertEquals(3, result.size()); - assertThat(result, containsInAnyOrder( + assertThat( + result, + containsInAnyOrder( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue( - ImmutableMap.of( - "distance", - 200, - "time", - 20, - "distance_alias", - 166.66666666666663, - "time_alias", - 16.666666666666664)), - tupleValue( - ImmutableMap.of( - "distance", - 200, - "time", - 20, - "distance_alias", - 199.99999999999997, - "time_alias", - 20.0))) - ); - + ImmutableMap.of( + "distance", + 200, + "time", + 20, + "distance_alias", + 166.66666666666663, + "time_alias", + 16.666666666666664)), + tupleValue( + ImmutableMap.of( + "distance", + 200, + "time", + 20, + "distance_alias", + 199.99999999999997, + "time_alias", + 20.0)))); } @Test public void calculates_weighted_moving_average_one_field_two_samples_three_rows_null_value() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( + mockPlanWithData(List.of( tupleValue(ImmutableMap.of("time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - tupleValue(ImmutableMap.of("distance", 300, "time", 10))); + tupleValue(ImmutableMap.of("distance", 300, "time", 10)))); var plan = new TrendlineOperator( @@ -695,30 +672,26 @@ public void calculates_weighted_moving_average_one_field_two_samples_three_rows_ AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), ExprCoreType.DOUBLE))); - List result = execute(plan); assertEquals(3, result.size()); - assertThat(result, containsInAnyOrder( + assertThat( + result, + containsInAnyOrder( tupleValue(ImmutableMap.of("time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), tupleValue( - ImmutableMap.of("distance", 300, "time", 10, "distance_alias", 266.66666666666663))) - ); - - + ImmutableMap.of( + "distance", 300, "time", 10, "distance_alias", 266.66666666666663)))); } @Test public void calculates_weighted_moving_average_date() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( - tupleValue( - ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), + mockPlanWithData(List.of( + tupleValue(ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), tupleValue( ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)))), tupleValue( - ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12))))); + ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12)))))); var plan = new TrendlineOperator( @@ -728,38 +701,34 @@ public void calculates_weighted_moving_average_date() { AstDSL.computation(2, AstDSL.field("date"), "date_alias", WMA), ExprCoreType.DATE))); - List result = execute(plan); assertEquals(3, result.size()); - assertThat(result, containsInAnyOrder( + assertThat( + result, + containsInAnyOrder( tupleValue(ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), - tupleValue(ImmutableMap.of( + tupleValue( + ImmutableMap.of( "date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)), "date_alias", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(4)))), tupleValue( - ImmutableMap.of( - "date", - ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12)), - "date_alias", - ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(10))))) - ); - - + ImmutableMap.of( + "date", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12)), + "date_alias", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(10)))))); } @Test public void calculates_weighted_moving_average_time() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( - tupleValue( - ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN))), + mockPlanWithData(List.of( + tupleValue(ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN))), tupleValue( ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(6)))), tupleValue( - ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(12))))); + ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(12)))))); var plan = new TrendlineOperator( @@ -769,33 +738,33 @@ public void calculates_weighted_moving_average_time() { AstDSL.computation(2, AstDSL.field("time"), "time_alias", WMA), ExprCoreType.TIME))); - List result = execute(plan); assertEquals(3, result.size()); - assertThat(result, containsInAnyOrder( + assertThat( + result, + containsInAnyOrder( tupleValue(ImmutableMap.of("time", LocalTime.MIN)), - tupleValue(ImmutableMap.of( + tupleValue( + ImmutableMap.of( "time", LocalTime.MIN.plusHours(6), "time_alias", LocalTime.MIN.plusHours(4))), tupleValue( - ImmutableMap.of( - "time", LocalTime.MIN.plusHours(12), "time_alias", LocalTime.MIN.plusHours(10)))) - ); - + ImmutableMap.of( + "time", + LocalTime.MIN.plusHours(12), + "time_alias", + LocalTime.MIN.plusHours(10))))); } @Test public void calculates_weighted_moving_average_timestamp() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( - tupleValue( - ImmutableMap.of("timestamp", ExprValueUtils.timestampValue(Instant.EPOCH))), + mockPlanWithData(List.of( + tupleValue(ImmutableMap.of("timestamp", ExprValueUtils.timestampValue(Instant.EPOCH))), tupleValue( ImmutableMap.of( "timestamp", ExprValueUtils.timestampValue(Instant.EPOCH.plusMillis(1000)))), tupleValue( ImmutableMap.of( - "timestamp", ExprValueUtils.timestampValue(Instant.EPOCH.plusMillis(1500))))); + "timestamp", ExprValueUtils.timestampValue(Instant.EPOCH.plusMillis(1500)))))); var plan = new TrendlineOperator( @@ -807,56 +776,72 @@ public void calculates_weighted_moving_average_timestamp() { List result = execute(plan); assertEquals(3, result.size()); - assertThat(result, containsInAnyOrder( + assertThat( + result, + containsInAnyOrder( tupleValue(ImmutableMap.of("timestamp", Instant.EPOCH)), - tupleValue(ImmutableMap.of( + tupleValue( + ImmutableMap.of( "timestamp", Instant.EPOCH.plusMillis(1000), "timestamp_alias", Instant.EPOCH.plusMillis(667))), tupleValue( - ImmutableMap.of( - "timestamp", - Instant.EPOCH.plusMillis(1500), - "timestamp_alias", - Instant.EPOCH.plusMillis(1333)))) - ); - + ImmutableMap.of( + "timestamp", + Instant.EPOCH.plusMillis(1500), + "timestamp_alias", + Instant.EPOCH.plusMillis(1333))))); } @Test public void use_illegal_core_type_wma() { assertThrows( IllegalArgumentException.class, - () -> new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), - ExprCoreType.ARRAY)))); + () -> + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.ARRAY)))); } @Test public void use_invalid_dataPoints_zero() { assertThrows( - IllegalArgumentException.class, - () -> new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(0, AstDSL.field("distance"), "distance_alias", WMA), - ExprCoreType.INTEGER)))); + IllegalArgumentException.class, + () -> + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(0, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.INTEGER)))); } @Test public void use_invalid_dataPoints_negative() { assertThrows( - IllegalArgumentException.class, - () -> new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(-100, AstDSL.field("distance"), "distance_alias", WMA), - ExprCoreType.INTEGER)))); + IllegalArgumentException.class, + () -> + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(-100, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.INTEGER)))); + } + + private void mockPlanWithData(List inputs ) { + List hasNextElements = new ArrayList<>(Collections.nCopies(inputs.size(), true)); + hasNextElements.add(false); + + Iterator hasNextIterator = hasNextElements.iterator(); + when(inputPlan.hasNext()) + .thenAnswer(i -> hasNextIterator.hasNext() ? hasNextIterator.next() : null); + Iterator iterator = inputs.iterator(); + when(inputPlan.next()) + .thenAnswer(i -> iterator.hasNext() ? iterator.next() : null); } } From bdb0f2867e5bce775c90a446dc56f2b7c822bd02 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Wed, 12 Feb 2025 13:01:53 -0800 Subject: [PATCH 27/36] Update test-cases Signed-off-by: Andy Kwok --- .../sql/planner/physical/TrendlineOperatorTest.java | 6 ------ 1 file changed, 6 deletions(-) diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index 6acdb14bfa..3552ba0d6e 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -66,12 +66,6 @@ public void calculates_simple_moving_average_one_field_one_sample() { @Test public void calculates_simple_moving_average_one_field_two_samples() { -// when(inputPlan.hasNext()).thenReturn(true, true, false); -// when(inputPlan.next()) -// .thenReturn( -// tupleValue(ImmutableMap.of("distance", 100, "time", 10)), -// tupleValue(ImmutableMap.of("distance", 200, "time", 10))); - mockPlanWithData(List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)) From fa685bdb70a6b0518fcd0395424f3a1bb9ca462b Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Wed, 12 Feb 2025 13:02:22 -0800 Subject: [PATCH 28/36] Update test-cases Signed-off-by: Andy Kwok --- .../physical/TrendlineOperatorTest.java | 87 +++++++++++-------- 1 file changed, 52 insertions(+), 35 deletions(-) diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index 3552ba0d6e..a0c5e54c0f 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -24,8 +24,6 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; -import java.util.Map; - import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; @@ -46,8 +44,7 @@ public class TrendlineOperatorTest extends PhysicalPlanTestBase { @Test public void calculates_simple_moving_average_one_field_one_sample() { - mockPlanWithData(List.of( - tupleValue(ImmutableMap.of("distance", 100, "time", 10)))); + mockPlanWithData(List.of(tupleValue(ImmutableMap.of("distance", 100, "time", 10)))); var plan = new TrendlineOperator( inputPlan, @@ -66,10 +63,10 @@ public void calculates_simple_moving_average_one_field_one_sample() { @Test public void calculates_simple_moving_average_one_field_two_samples() { - mockPlanWithData(List.of( + mockPlanWithData( + List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10)) - )); + tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); var plan = new TrendlineOperator( @@ -90,7 +87,8 @@ public void calculates_simple_moving_average_one_field_two_samples() { @Test public void calculates_simple_moving_average_one_field_two_samples_three_rows() { - mockPlanWithData(List.of( + mockPlanWithData( + List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); @@ -115,7 +113,8 @@ public void calculates_simple_moving_average_one_field_two_samples_three_rows() @Test public void calculates_simple_moving_average_data_type_support_short() { - mockPlanWithData(List.of( + mockPlanWithData( + List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); @@ -140,7 +139,8 @@ public void calculates_simple_moving_average_data_type_support_short() { @Test public void calculates_simple_moving_average_data_type_support_long() { - mockPlanWithData(List.of( + mockPlanWithData( + List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); @@ -165,7 +165,8 @@ public void calculates_simple_moving_average_data_type_support_long() { @Test public void calculates_simple_moving_average_data_type_support_float() { - mockPlanWithData(List.of( + mockPlanWithData( + List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); @@ -190,7 +191,8 @@ public void calculates_simple_moving_average_data_type_support_float() { @Test public void calculates_simple_moving_average_multiple_computations() { - mockPlanWithData(List.of( + mockPlanWithData( + List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 20)), tupleValue(ImmutableMap.of("distance", 200, "time", 20)))); @@ -222,7 +224,8 @@ public void calculates_simple_moving_average_multiple_computations() { @Test public void alias_overwrites_input_field() { - mockPlanWithData(List.of( + mockPlanWithData( + List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); @@ -247,7 +250,8 @@ public void alias_overwrites_input_field() { @Test public void calculates_simple_moving_average_one_field_two_samples_three_rows_null_value() { - mockPlanWithData(List.of( + mockPlanWithData( + List.of( tupleValue(ImmutableMap.of("time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), tupleValue(ImmutableMap.of("distance", 300, "time", 10)))); @@ -272,7 +276,8 @@ public void calculates_simple_moving_average_one_field_two_samples_three_rows_nu @Test public void use_null_value() { - mockPlanWithData(List.of( + mockPlanWithData( + List.of( tupleValue(ImmutableMap.of("time", 10)), tupleValue(ImmutableMap.of("distance", ExprNullValue.of(), "time", 10)), tupleValue(ImmutableMap.of("distance", 100, "time", 10)))); @@ -311,7 +316,8 @@ public void use_illegal_core_type() { @Test public void calculates_simple_moving_average_date() { - mockPlanWithData(List.of( + mockPlanWithData( + List.of( tupleValue(ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), tupleValue( ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)))), @@ -348,7 +354,8 @@ public void calculates_simple_moving_average_date() { @Test public void calculates_simple_moving_average_time() { - mockPlanWithData(List.of( + mockPlanWithData( + List.of( tupleValue(ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN))), tupleValue( ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(6)))), @@ -382,7 +389,8 @@ public void calculates_simple_moving_average_time() { @Test public void calculates_simple_moving_average_timestamp() { - mockPlanWithData(List.of( + mockPlanWithData( + List.of( tupleValue(ImmutableMap.of("timestamp", ExprValueUtils.timestampValue(Instant.EPOCH))), tupleValue( ImmutableMap.of( @@ -421,8 +429,7 @@ public void calculates_simple_moving_average_timestamp() { @Test public void calculates_weighted_moving_average_one_field_one_sample() { - mockPlanWithData(List.of( - tupleValue(ImmutableMap.of("distance", 100, "time", 10)))); + mockPlanWithData(List.of(tupleValue(ImmutableMap.of("distance", 100, "time", 10)))); var plan = new TrendlineOperator( @@ -442,7 +449,8 @@ public void calculates_weighted_moving_average_one_field_one_sample() { @Test public void calculates_weighted_moving_average_one_field_two_samples() { - mockPlanWithData(List.of( + mockPlanWithData( + List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); @@ -467,7 +475,8 @@ public void calculates_weighted_moving_average_one_field_two_samples() { @Test public void calculates_weighted_moving_average_one_field_two_samples_three_rows() { - mockPlanWithData(List.of( + mockPlanWithData( + List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); @@ -495,7 +504,8 @@ public void calculates_weighted_moving_average_one_field_two_samples_three_rows( @Test public void calculates_weighted_moving_average_data_type_support_short() { - mockPlanWithData(List.of( + mockPlanWithData( + List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); @@ -523,7 +533,8 @@ public void calculates_weighted_moving_average_data_type_support_short() { @Test public void calculates_weighted_moving_average_data_type_support_integer() { - mockPlanWithData(List.of( + mockPlanWithData( + List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); @@ -551,7 +562,8 @@ public void calculates_weighted_moving_average_data_type_support_integer() { @Test public void calculates_weighted_moving_average_data_type_support_long() { - mockPlanWithData(List.of( + mockPlanWithData( + List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); @@ -579,7 +591,8 @@ public void calculates_weighted_moving_average_data_type_support_long() { @Test public void calculates_weighted_moving_average_data_type_support_float() { - mockPlanWithData(List.of( + mockPlanWithData( + List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); @@ -607,7 +620,8 @@ public void calculates_weighted_moving_average_data_type_support_float() { @Test public void calculates_weighted_moving_average_multiple_computations() { - mockPlanWithData(List.of( + mockPlanWithData( + List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 20)), tupleValue(ImmutableMap.of("distance", 200, "time", 20)))); @@ -653,7 +667,8 @@ public void calculates_weighted_moving_average_multiple_computations() { @Test public void calculates_weighted_moving_average_one_field_two_samples_three_rows_null_value() { - mockPlanWithData(List.of( + mockPlanWithData( + List.of( tupleValue(ImmutableMap.of("time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), tupleValue(ImmutableMap.of("distance", 300, "time", 10)))); @@ -680,7 +695,8 @@ public void calculates_weighted_moving_average_one_field_two_samples_three_rows_ @Test public void calculates_weighted_moving_average_date() { - mockPlanWithData(List.of( + mockPlanWithData( + List.of( tupleValue(ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), tupleValue( ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)))), @@ -717,7 +733,8 @@ public void calculates_weighted_moving_average_date() { @Test public void calculates_weighted_moving_average_time() { - mockPlanWithData(List.of( + mockPlanWithData( + List.of( tupleValue(ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN))), tupleValue( ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(6)))), @@ -751,7 +768,8 @@ public void calculates_weighted_moving_average_time() { @Test public void calculates_weighted_moving_average_timestamp() { - mockPlanWithData(List.of( + mockPlanWithData( + List.of( tupleValue(ImmutableMap.of("timestamp", ExprValueUtils.timestampValue(Instant.EPOCH))), tupleValue( ImmutableMap.of( @@ -827,15 +845,14 @@ public void use_invalid_dataPoints_negative() { ExprCoreType.INTEGER)))); } - private void mockPlanWithData(List inputs ) { + private void mockPlanWithData(List inputs) { List hasNextElements = new ArrayList<>(Collections.nCopies(inputs.size(), true)); hasNextElements.add(false); Iterator hasNextIterator = hasNextElements.iterator(); when(inputPlan.hasNext()) - .thenAnswer(i -> hasNextIterator.hasNext() ? hasNextIterator.next() : null); + .thenAnswer(i -> hasNextIterator.hasNext() ? hasNextIterator.next() : null); Iterator iterator = inputs.iterator(); - when(inputPlan.next()) - .thenAnswer(i -> iterator.hasNext() ? iterator.next() : null); + when(inputPlan.next()).thenAnswer(i -> iterator.hasNext() ? iterator.next() : null); } } From b87db9f84dc695f1538c01a52a0e6444b462212d Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Thu, 13 Feb 2025 12:11:36 -0800 Subject: [PATCH 29/36] Address code comments Signed-off-by: Andy Kwok --- .../sql/planner/physical/TrendlineOperator.java | 12 ++++++------ .../sql/planner/physical/TrendlineOperatorTest.java | 9 +++++---- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index 8e3e70e74f..cc1bcf6290 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -6,7 +6,7 @@ package org.opensearch.sql.planner.physical; import static java.time.temporal.ChronoUnit.MILLIS; -import static java.util.stream.Collectors.*; +import static java.util.stream.Collectors.toList; import com.google.common.collect.EvictingQueue; import com.google.common.collect.ImmutableMap.Builder; @@ -31,6 +31,7 @@ import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.LiteralExpression; @@ -120,7 +121,7 @@ private static TrendlineAccumulator createAccumulator( } /** Maintains stateful information for calculating the trendline. */ - private abstract static class TrendlineAccumulator { + protected abstract static class TrendlineAccumulator { protected final LiteralExpression dataPointsNeeded; @@ -129,10 +130,9 @@ private abstract static class TrendlineAccumulator { private TrendlineAccumulator(Trendline.TrendlineComputation config) { Integer numberOfDataPoints = config.getNumberOfDataPoints(); if (numberOfDataPoints <= 0) { - throw new IllegalArgumentException( + throw new SemanticCheckException( String.format("Invalid dataPoints [%d] value.", numberOfDataPoints)); } - ; this.dataPointsNeeded = DSL.literal(numberOfDataPoints.doubleValue()); this.receivedValues = EvictingQueue.create(numberOfDataPoints); } @@ -198,7 +198,7 @@ static ArithmeticEvaluator getEvaluator(ExprCoreType type) { case DATE -> DateArithmeticEvaluator.INSTANCE; case TIME -> TimeArithmeticEvaluator.INSTANCE; case TIMESTAMP -> TimestampArithmeticEvaluator.INSTANCE; - default -> throw new IllegalArgumentException( + default -> throw new SemanticCheckException( String.format("Invalid type %s used for moving average.", type.typeName())); }; } @@ -357,7 +357,7 @@ Function, ExprValue> getWmaEvaluator(ExprCoreType type) { case INTEGER, SHORT, LONG, FLOAT, DOUBLE -> WMA_NUMERIC_EVALUATOR; case DATE, TIMESTAMP -> WMA_TIMESTAMP_EVALUATOR; case TIME -> WMA_TIME_EVALUATOR; - default -> throw new IllegalArgumentException( + default -> throw new SemanticCheckException( String.format("Invalid type %s used for weighted moving average.", type.typeName())); }; } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index a0c5e54c0f..7dcdbf535b 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -36,6 +36,7 @@ import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.exception.SemanticCheckException; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @ExtendWith(MockitoExtension.class) @@ -303,7 +304,7 @@ public void use_null_value() { @Test public void use_illegal_core_type() { assertThrows( - IllegalArgumentException.class, + SemanticCheckException.class, () -> { new TrendlineOperator( inputPlan, @@ -809,7 +810,7 @@ public void calculates_weighted_moving_average_timestamp() { @Test public void use_illegal_core_type_wma() { assertThrows( - IllegalArgumentException.class, + SemanticCheckException.class, () -> new TrendlineOperator( inputPlan, @@ -822,7 +823,7 @@ public void use_illegal_core_type_wma() { @Test public void use_invalid_dataPoints_zero() { assertThrows( - IllegalArgumentException.class, + SemanticCheckException.class, () -> new TrendlineOperator( inputPlan, @@ -835,7 +836,7 @@ public void use_invalid_dataPoints_zero() { @Test public void use_invalid_dataPoints_negative() { assertThrows( - IllegalArgumentException.class, + SemanticCheckException.class, () -> new TrendlineOperator( inputPlan, From 39bf1130924cbe867378e499db8e8181d878c326 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Thu, 13 Feb 2025 13:28:43 -0800 Subject: [PATCH 30/36] Update test-cases Signed-off-by: Andy Kwok --- .../physical/TrendlineOperatorTest.java | 323 +++++------------- 1 file changed, 88 insertions(+), 235 deletions(-) diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index 7dcdbf535b..5ab6b93dcc 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -24,14 +24,20 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; +import java.util.stream.Stream; import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.ast.dsl.AstDSL; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.data.model.ExprNullValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -43,6 +49,47 @@ public class TrendlineOperatorTest extends PhysicalPlanTestBase { @Mock private PhysicalPlan inputPlan; + static Stream supportedDataTypes() { + return Stream.of(SMA, WMA) + .flatMap( + trendlineType -> + Stream.of( + Arguments.of(trendlineType, ExprCoreType.SHORT), + Arguments.of(trendlineType, ExprCoreType.INTEGER), + Arguments.of(trendlineType, ExprCoreType.LONG), + Arguments.of(trendlineType, ExprCoreType.FLOAT), + Arguments.of(trendlineType, ExprCoreType.DOUBLE))); + } + + static Stream invalidArguments() { + return Stream.of(SMA, WMA) + .flatMap( + trendlineType -> + Stream.of( + // WMA + Arguments.of( + 2, + AstDSL.field("distance"), + "distance_alias", + trendlineType, + ExprCoreType.ARRAY, + "DateType - Array"), + Arguments.of( + -100, + AstDSL.field("distance"), + "distance_alias", + trendlineType, + ExprCoreType.INTEGER, + "DataPoints - Negative"), + Arguments.of( + 0, + AstDSL.field("distance"), + "distance_alias", + trendlineType, + ExprCoreType.INTEGER, + "DataPoints - zero"))); + } + @Test public void calculates_simple_moving_average_one_field_one_sample() { mockPlanWithData(List.of(tupleValue(ImmutableMap.of("distance", 100, "time", 10)))); @@ -112,84 +159,6 @@ public void calculates_simple_moving_average_one_field_two_samples_three_rows() tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)))); } - @Test - public void calculates_simple_moving_average_data_type_support_short() { - mockPlanWithData( - List.of( - tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); - - var plan = - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), - ExprCoreType.SHORT))); - - List result = execute(plan); - assertEquals(3, result.size()); - assertThat( - result, - containsInAnyOrder( - tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)))); - } - - @Test - public void calculates_simple_moving_average_data_type_support_long() { - mockPlanWithData( - List.of( - tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); - - var plan = - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), - ExprCoreType.SHORT))); - - List result = execute(plan); - assertEquals(3, result.size()); - assertThat( - result, - containsInAnyOrder( - tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)))); - } - - @Test - public void calculates_simple_moving_average_data_type_support_float() { - mockPlanWithData( - List.of( - tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); - - var plan = - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), - ExprCoreType.FLOAT))); - - List result = execute(plan); - assertEquals(3, result.size()); - assertThat( - result, - containsInAnyOrder( - tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)))); - } - @Test public void calculates_simple_moving_average_multiple_computations() { mockPlanWithData( @@ -301,20 +270,6 @@ public void use_null_value() { tupleValue(ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100)))); } - @Test - public void use_illegal_core_type() { - assertThrows( - SemanticCheckException.class, - () -> { - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), - ExprCoreType.ARRAY))); - }); - } - @Test public void calculates_simple_moving_average_date() { mockPlanWithData( @@ -503,122 +458,6 @@ public void calculates_weighted_moving_average_one_field_two_samples_three_rows( "distance", 200, "time", 10, "distance_alias", 199.99999999999997)))); } - @Test - public void calculates_weighted_moving_average_data_type_support_short() { - mockPlanWithData( - List.of( - tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); - - var plan = - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), - ExprCoreType.SHORT))); - - List result = execute(plan); - assertEquals(3, result.size()); - assertThat( - result, - containsInAnyOrder( - tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), - tupleValue( - ImmutableMap.of( - "distance", 200, "time", 10, "distance_alias", 199.99999999999997)))); - } - - @Test - public void calculates_weighted_moving_average_data_type_support_integer() { - mockPlanWithData( - List.of( - tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); - - var plan = - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), - ExprCoreType.INTEGER))); - - List result = execute(plan); - assertEquals(3, result.size()); - assertThat( - result, - containsInAnyOrder( - tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), - tupleValue( - ImmutableMap.of( - "distance", 200, "time", 10, "distance_alias", 199.99999999999997)))); - } - - @Test - public void calculates_weighted_moving_average_data_type_support_long() { - mockPlanWithData( - List.of( - tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); - - var plan = - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), - ExprCoreType.LONG))); - - List result = execute(plan); - assertEquals(3, result.size()); - assertThat( - result, - containsInAnyOrder( - tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), - tupleValue( - ImmutableMap.of( - "distance", 200, "time", 10, "distance_alias", 199.99999999999997)))); - } - - @Test - public void calculates_weighted_moving_average_data_type_support_float() { - mockPlanWithData( - List.of( - tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); - - var plan = - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), - ExprCoreType.FLOAT))); - - List result = execute(plan); - assertEquals(3, result.size()); - assertThat( - result, - containsInAnyOrder( - tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), - tupleValue( - ImmutableMap.of( - "distance", 200, "time", 10, "distance_alias", 199.99999999999997)))); - } - @Test public void calculates_weighted_moving_average_multiple_computations() { mockPlanWithData( @@ -807,34 +646,48 @@ public void calculates_weighted_moving_average_timestamp() { Instant.EPOCH.plusMillis(1333))))); } - @Test - public void use_illegal_core_type_wma() { - assertThrows( - SemanticCheckException.class, - () -> - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), - ExprCoreType.ARRAY)))); - } + @ParameterizedTest + @MethodSource("supportedDataTypes") + public void trendLine_dataType_support( + Trendline.TrendlineType trendlineType, ExprCoreType supportedType) { + mockPlanWithData( + List.of( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); - @Test - public void use_invalid_dataPoints_zero() { - assertThrows( - SemanticCheckException.class, - () -> - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(0, AstDSL.field("distance"), "distance_alias", WMA), - ExprCoreType.INTEGER)))); + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), + supportedType))); + + List result = execute(plan); + assertEquals(3, result.size()); + assertThat( + String.format( + "Assertion error on TrendLine-WMA dataType support: %s", supportedType.typeName()), + result, + containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), + tupleValue( + ImmutableMap.of( + "distance", 200, "time", 10, "distance_alias", 199.99999999999997)))); } - @Test - public void use_invalid_dataPoints_negative() { + @ParameterizedTest + @MethodSource("invalidArguments") + public void use_invalid_configuration( + Integer dataPoints, + Field field, + String alias, + Trendline.TrendlineType trendlineType, + ExprCoreType dataType, + String errorMessage) { assertThrows( SemanticCheckException.class, () -> @@ -842,8 +695,8 @@ public void use_invalid_dataPoints_negative() { inputPlan, Collections.singletonList( Pair.of( - AstDSL.computation(-100, AstDSL.field("distance"), "distance_alias", WMA), - ExprCoreType.INTEGER)))); + AstDSL.computation(dataPoints, field, alias, trendlineType), dataType))), + "Unsupported arguments: " + errorMessage); } private void mockPlanWithData(List inputs) { From ba9bfb6f4ad357a53b256e56df6fb5ca3d8bc85c Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Thu, 13 Feb 2025 15:46:52 -0800 Subject: [PATCH 31/36] Update test-cases Signed-off-by: Andy Kwok --- .../planner/physical/TrendlineOperator.java | 2 +- .../physical/TrendlineOperatorTest.java | 159 +++++++++--------- 2 files changed, 77 insertions(+), 84 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index cc1bcf6290..a4191004a6 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -133,7 +133,7 @@ private TrendlineAccumulator(Trendline.TrendlineComputation config) { throw new SemanticCheckException( String.format("Invalid dataPoints [%d] value.", numberOfDataPoints)); } - this.dataPointsNeeded = DSL.literal(numberOfDataPoints.doubleValue()); + this.dataPointsNeeded = DSL.literal(numberOfDataPoints); this.receivedValues = EvictingQueue.create(numberOfDataPoints); } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index 5ab6b93dcc..7c500e74ef 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -50,15 +50,12 @@ public class TrendlineOperatorTest extends PhysicalPlanTestBase { @Mock private PhysicalPlan inputPlan; static Stream supportedDataTypes() { - return Stream.of(SMA, WMA) - .flatMap( - trendlineType -> - Stream.of( - Arguments.of(trendlineType, ExprCoreType.SHORT), - Arguments.of(trendlineType, ExprCoreType.INTEGER), - Arguments.of(trendlineType, ExprCoreType.LONG), - Arguments.of(trendlineType, ExprCoreType.FLOAT), - Arguments.of(trendlineType, ExprCoreType.DOUBLE))); + return Stream.of( + Arguments.of(ExprCoreType.SHORT), + Arguments.of(ExprCoreType.INTEGER), + Arguments.of(ExprCoreType.LONG), + Arguments.of(ExprCoreType.FLOAT), + Arguments.of(ExprCoreType.DOUBLE)); } static Stream invalidArguments() { @@ -404,37 +401,12 @@ public void calculates_weighted_moving_average_one_field_one_sample() { } @Test - public void calculates_weighted_moving_average_one_field_two_samples() { - mockPlanWithData( - List.of( - tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); - - var plan = - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), - ExprCoreType.DOUBLE))); - - List result = execute(plan); - assertEquals(2, result.size()); - assertThat( - result, - containsInAnyOrder( - tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - tupleValue( - ImmutableMap.of( - "distance", 200, "time", 10, "distance_alias", 166.66666666666663)))); - } - - @Test - public void calculates_weighted_moving_average_one_field_two_samples_three_rows() { + public void calculates_weighted_moving_average_one_field_four_samples_four_rows() { mockPlanWithData( List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); var plan = @@ -442,20 +414,19 @@ public void calculates_weighted_moving_average_one_field_two_samples_three_rows( inputPlan, Collections.singletonList( Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), + AstDSL.computation(4, AstDSL.field("distance"), "distance_alias", WMA), ExprCoreType.DOUBLE))); List result = execute(plan); - assertEquals(3, result.size()); + assertEquals(4, result.size()); assertThat( result, containsInAnyOrder( - tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), - tupleValue( - ImmutableMap.of( - "distance", 200, "time", 10, "distance_alias", 199.99999999999997)))); + tupleValue( ImmutableMap.of("distance", 100, "time", 10)), + tupleValue( ImmutableMap.of("distance", 200, "time", 10)), + tupleValue( ImmutableMap.of("distance", 200, "time", 10)), + tupleValue( ImmutableMap.of("distance", 200, "time", 10, + "distance_alias", 190)))); } @Test @@ -464,6 +435,7 @@ public void calculates_weighted_moving_average_multiple_computations() { List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 20)), + tupleValue(ImmutableMap.of("distance", 200, "time", 20)), tupleValue(ImmutableMap.of("distance", 200, "time", 20)))); var plan = @@ -471,46 +443,32 @@ public void calculates_weighted_moving_average_multiple_computations() { inputPlan, Arrays.asList( Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), + AstDSL.computation(4, AstDSL.field("distance"), "distance_alias", WMA), ExprCoreType.DOUBLE), Pair.of( - AstDSL.computation(2, AstDSL.field("time"), "time_alias", WMA), + AstDSL.computation(4, AstDSL.field("time"), "time_alias", WMA), ExprCoreType.DOUBLE))); List result = execute(plan); - assertEquals(3, result.size()); + assertEquals(4, result.size()); assertThat( result, containsInAnyOrder( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - tupleValue( - ImmutableMap.of( - "distance", - 200, - "time", - 20, - "distance_alias", - 166.66666666666663, - "time_alias", - 16.666666666666664)), - tupleValue( - ImmutableMap.of( - "distance", - 200, - "time", - 20, - "distance_alias", - 199.99999999999997, - "time_alias", - 20.0)))); + tupleValue(ImmutableMap.of("distance", 200, "time", 20)), + tupleValue(ImmutableMap.of("distance", 200, "time", 20)), + tupleValue(ImmutableMap.of("distance", 200, "time", 20, + "distance_alias", 190, "time_alias", 19.0)))); } @Test - public void calculates_weighted_moving_average_one_field_two_samples_three_rows_null_value() { + public void calculates_weighted_moving_average_null_value() { mockPlanWithData( List.of( tupleValue(ImmutableMap.of("time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 300, "time", 10)), + tupleValue(ImmutableMap.of("distance", 300, "time", 10)), tupleValue(ImmutableMap.of("distance", 300, "time", 10)))); var plan = @@ -518,19 +476,20 @@ public void calculates_weighted_moving_average_one_field_two_samples_three_rows_ inputPlan, Collections.singletonList( Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), + AstDSL.computation(4, AstDSL.field("distance"), "distance_alias", WMA), ExprCoreType.DOUBLE))); List result = execute(plan); - assertEquals(3, result.size()); + assertEquals(5, result.size()); assertThat( result, containsInAnyOrder( tupleValue(ImmutableMap.of("time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - tupleValue( - ImmutableMap.of( - "distance", 300, "time", 10, "distance_alias", 266.66666666666663)))); + tupleValue(ImmutableMap.of("distance", 300, "time", 10)), + tupleValue(ImmutableMap.of("distance", 300, "time", 10)), + tupleValue(ImmutableMap.of("distance", 300, "time", 10, + "distance_alias", 290)))); } @Test @@ -646,14 +605,48 @@ public void calculates_weighted_moving_average_timestamp() { Instant.EPOCH.plusMillis(1333))))); } + @ParameterizedTest @MethodSource("supportedDataTypes") - public void trendLine_dataType_support( - Trendline.TrendlineType trendlineType, ExprCoreType supportedType) { + public void trendLine_dataType_support_sma(ExprCoreType supportedType) { + mockPlanWithData( + List.of( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(4, AstDSL.field("distance"), "distance_alias", SMA), + supportedType))); + + List result = execute(plan); + System.out.println(result); + assertEquals(4, result.size()); + assertThat( + String.format( + "Assertion error on TrendLine-WMA dataType support: %s", supportedType.typeName()), + result, + containsInAnyOrder( + tupleValue( ImmutableMap.of("distance", 100, "time", 10)), + tupleValue( ImmutableMap.of("distance", 200, "time", 10)), + tupleValue( ImmutableMap.of("distance", 200, "time", 10)), + tupleValue( ImmutableMap.of("distance", 200, "time", 10, + "distance_alias", 175)))); + } + + @ParameterizedTest + @MethodSource("supportedDataTypes") + public void trendLine_dataType_support_wma(ExprCoreType supportedType) { mockPlanWithData( List.of( tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); var plan = @@ -661,22 +654,22 @@ public void trendLine_dataType_support( inputPlan, Collections.singletonList( Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA), + AstDSL.computation(4, AstDSL.field("distance"), "distance_alias", WMA), supportedType))); List result = execute(plan); - assertEquals(3, result.size()); + System.out.println(result); + assertEquals(4, result.size()); assertThat( String.format( "Assertion error on TrendLine-WMA dataType support: %s", supportedType.typeName()), result, containsInAnyOrder( - tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)), - tupleValue( - ImmutableMap.of( - "distance", 200, "time", 10, "distance_alias", 199.99999999999997)))); + tupleValue( ImmutableMap.of("distance", 100, "time", 10)), + tupleValue( ImmutableMap.of("distance", 200, "time", 10)), + tupleValue( ImmutableMap.of("distance", 200, "time", 10)), + tupleValue( ImmutableMap.of("distance", 200, "time", 10, + "distance_alias", 190)))); } @ParameterizedTest From 9910b22025dbe59ea3397e92a2a149ff9b6b1d56 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Thu, 13 Feb 2025 16:45:15 -0800 Subject: [PATCH 32/36] Update IT magic number Signed-off-by: Andy Kwok --- .../physical/TrendlineOperatorTest.java | 76 +++++++++---------- .../sql/ppl/TrendlineCommandIT.java | 76 ++++++++++++++----- .../sql/ppl/parser/AstExpressionBuilder.java | 5 +- 3 files changed, 99 insertions(+), 58 deletions(-) diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index 7c500e74ef..4e4010c1a3 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -51,11 +51,11 @@ public class TrendlineOperatorTest extends PhysicalPlanTestBase { static Stream supportedDataTypes() { return Stream.of( - Arguments.of(ExprCoreType.SHORT), - Arguments.of(ExprCoreType.INTEGER), - Arguments.of(ExprCoreType.LONG), - Arguments.of(ExprCoreType.FLOAT), - Arguments.of(ExprCoreType.DOUBLE)); + Arguments.of(ExprCoreType.SHORT), + Arguments.of(ExprCoreType.INTEGER), + Arguments.of(ExprCoreType.LONG), + Arguments.of(ExprCoreType.FLOAT), + Arguments.of(ExprCoreType.DOUBLE)); } static Stream invalidArguments() { @@ -422,11 +422,10 @@ public void calculates_weighted_moving_average_one_field_four_samples_four_rows( assertThat( result, containsInAnyOrder( - tupleValue( ImmutableMap.of("distance", 100, "time", 10)), - tupleValue( ImmutableMap.of("distance", 200, "time", 10)), - tupleValue( ImmutableMap.of("distance", 200, "time", 10)), - tupleValue( ImmutableMap.of("distance", 200, "time", 10, - "distance_alias", 190)))); + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 190)))); } @Test @@ -457,8 +456,9 @@ public void calculates_weighted_moving_average_multiple_computations() { tupleValue(ImmutableMap.of("distance", 100, "time", 10)), tupleValue(ImmutableMap.of("distance", 200, "time", 20)), tupleValue(ImmutableMap.of("distance", 200, "time", 20)), - tupleValue(ImmutableMap.of("distance", 200, "time", 20, - "distance_alias", 190, "time_alias", 19.0)))); + tupleValue( + ImmutableMap.of( + "distance", 200, "time", 20, "distance_alias", 190, "time_alias", 19.0)))); } @Test @@ -488,8 +488,7 @@ public void calculates_weighted_moving_average_null_value() { tupleValue(ImmutableMap.of("distance", 200, "time", 10)), tupleValue(ImmutableMap.of("distance", 300, "time", 10)), tupleValue(ImmutableMap.of("distance", 300, "time", 10)), - tupleValue(ImmutableMap.of("distance", 300, "time", 10, - "distance_alias", 290)))); + tupleValue(ImmutableMap.of("distance", 300, "time", 10, "distance_alias", 290)))); } @Test @@ -605,38 +604,36 @@ public void calculates_weighted_moving_average_timestamp() { Instant.EPOCH.plusMillis(1333))))); } - @ParameterizedTest @MethodSource("supportedDataTypes") public void trendLine_dataType_support_sma(ExprCoreType supportedType) { mockPlanWithData( - List.of( - tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); + List.of( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); var plan = - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(4, AstDSL.field("distance"), "distance_alias", SMA), - supportedType))); + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(4, AstDSL.field("distance"), "distance_alias", SMA), + supportedType))); List result = execute(plan); System.out.println(result); assertEquals(4, result.size()); assertThat( - String.format( - "Assertion error on TrendLine-WMA dataType support: %s", supportedType.typeName()), - result, - containsInAnyOrder( - tupleValue( ImmutableMap.of("distance", 100, "time", 10)), - tupleValue( ImmutableMap.of("distance", 200, "time", 10)), - tupleValue( ImmutableMap.of("distance", 200, "time", 10)), - tupleValue( ImmutableMap.of("distance", 200, "time", 10, - "distance_alias", 175)))); + String.format( + "Assertion error on TrendLine-WMA dataType support: %s", supportedType.typeName()), + result, + containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 175)))); } @ParameterizedTest @@ -665,11 +662,10 @@ public void trendLine_dataType_support_wma(ExprCoreType supportedType) { "Assertion error on TrendLine-WMA dataType support: %s", supportedType.typeName()), result, containsInAnyOrder( - tupleValue( ImmutableMap.of("distance", 100, "time", 10)), - tupleValue( ImmutableMap.of("distance", 200, "time", 10)), - tupleValue( ImmutableMap.of("distance", 200, "time", 10)), - tupleValue( ImmutableMap.of("distance", 200, "time", 10, - "distance_alias", 190)))); + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 190)))); } @ParameterizedTest diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java index 4c5c0b0153..0618c9cb77 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java @@ -7,7 +7,9 @@ import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK; import static org.opensearch.sql.util.MatcherUtils.rows; +import static org.opensearch.sql.util.MatcherUtils.schema; import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; +import static org.opensearch.sql.util.MatcherUtils.verifySchema; import java.io.IOException; import org.json.JSONObject; @@ -60,7 +62,7 @@ public void testTrendlineNoAlias() throws IOException { executeQuery( String.format( "source=%s | where balance > 39000 | sort balance | trendline sma(2, balance) |" - + " fields balance_trendline", + + " fields balance_sma_trendline", TEST_INDEX_BANK)); verifyDataRows(result, rows(new Object[] {null}), rows(44313.0), rows(39882.5)); } @@ -71,7 +73,7 @@ public void testTrendlineWithSort() throws IOException { executeQuery( String.format( "source=%s | where balance > 39000 | trendline sort balance sma(2, balance) |" - + " fields balance_trendline", + + " fields balance_sma_trendline", TEST_INDEX_BANK)); verifyDataRows(result, rows(new Object[] {null}), rows(44313.0), rows(39882.5)); } @@ -81,11 +83,15 @@ public void testTrendlineWma() throws IOException { final JSONObject result = executeQuery( String.format( - "source=%s | where balance > 39000 | sort balance | trendline wma(2, balance) as" + "source=%s | sort balance | head 4 | trendline wma(4, balance) as" + " balance_trend | fields balance_trend", TEST_INDEX_BANK)); verifyDataRows( - result, rows(new Object[] {null}), rows(45570.666666666664), rows(40101.666666666664)); + result, + rows(new Object[] {null}), + rows(new Object[] {null}), + rows(new Object[] {null}), + rows(19615.8)); } @Test @@ -93,15 +99,17 @@ public void testTrendlineMultipleFieldsWma() throws IOException { final JSONObject result = executeQuery( String.format( - "source=%s | where balance > 39000 | sort balance | trendline wma(2, balance) as" - + " balance_trend wma(2, account_number) as account_number_trend | fields" + "source=%s | sort balance | head 5 | trendline wma(4, balance) as" + + " balance_trend wma(5, account_number) as account_number_trend | fields" + " balance_trend, account_number_trend", TEST_INDEX_BANK)); verifyDataRows( result, rows(null, null), - rows(40101.666666666664, 16.999999999999996), - rows(45570.666666666664, 29.666666666666664)); + rows(null, null), + rows(null, null), + rows(19615.8, null), + rows(29393.6, 9.8)); } @Test @@ -109,23 +117,35 @@ public void testTrendlineOverwritesExistingFieldWma() throws IOException { final JSONObject result = executeQuery( String.format( - "source=%s | where balance > 39000 | sort balance | trendline wma(2, balance) as" + "source=%s | sort balance | head 6 | trendline wma(4, balance) as" + " age | fields age", TEST_INDEX_BANK)); verifyDataRows( - result, rows(new Object[] {null}), rows(40101.666666666664), rows(45570.666666666664)); + result, + rows(new Object[] {null}), + rows(new Object[] {null}), + rows(new Object[] {null}), + rows(19615.8), + rows(29393.6), + rows(36192.9)); } @Test - public void testTrendlineNoAliasWma() throws IOException { + public void testTrendlineNoAliasWmaDefaultName() throws IOException { final JSONObject result = executeQuery( String.format( - "source=%s | where balance > 39000 | sort balance | trendline wma(2, balance) |" - + " fields balance_trendline", + "source=%s | sort balance | head 5 | trendline wma(4, balance) |" + + " fields balance_wma_trendline", TEST_INDEX_BANK)); + verifySchema(result, schema("balance_wma_trendline", "double")); verifyDataRows( - result, rows(new Object[] {null}), rows(40101.666666666664), rows(45570.666666666664)); + result, + rows(new Object[] {null}), + rows(new Object[] {null}), + rows(new Object[] {null}), + rows(19615.8), + rows(29393.6)); } @Test @@ -133,10 +153,32 @@ public void testTrendlineWithSortWma() throws IOException { final JSONObject result = executeQuery( String.format( - "source=%s | where balance > 39000 | trendline sort balance wma(2, balance) |" - + " fields balance_trendline", + "source=%s | sort balance | head 5 | trendline sort balance wma(4, balance) |" + + " fields balance_wma_trendline", TEST_INDEX_BANK)); verifyDataRows( - result, rows(new Object[] {null}), rows(40101.666666666664), rows(45570.666666666664)); + result, + rows(new Object[] {null}), + rows(new Object[] {null}), + rows(new Object[] {null}), + rows(19615.8), + rows(29393.6)); } + // + // @Test + // public void testTrendlineWithDefaultNameWma() throws IOException { + // final JSONObject result = + // executeQuery( + // String.format( + // "source=%s | where balance > 39000 | trendline sort balance wma(2, + // balance) |" + // + " fields balance_wma_trendline", + // TEST_INDEX_BANK)); + // verifySchema( + // result, + // schema("balance_wma_trendline", "double")); + // verifyDataRows( + // result, rows(new Object[] {null}), rows(40101.666666666664), + // rows(45570.666666666664)); + // } } diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 0424b148b6..e5b43bc834 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -94,7 +94,10 @@ public Trendline.TrendlineComputation visitTrendlineClause( final String alias = ctx.alias != null ? ctx.alias.getText() - : dataField.getChild().get(0).toString() + "_" + computationType.name() + "_trendline"; + : dataField.getChild().getFirst().toString() + + "_" + + computationType.name().toLowerCase() + + "_trendline"; return new Trendline.TrendlineComputation( numberOfDataPoints, dataField, alias, computationType); } From 4be0400f60dc7fd1c3666aab08db71d501a7222d Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Thu, 13 Feb 2025 16:51:25 -0800 Subject: [PATCH 33/36] Update IT test Signed-off-by: Andy Kwok --- .../sql/ppl/TrendlineCommandIT.java | 21 +++---------------- 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java index 0618c9cb77..a555ad6525 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java @@ -57,13 +57,14 @@ public void testTrendlineOverwritesExistingField() throws IOException { } @Test - public void testTrendlineNoAlias() throws IOException { + public void testTrendlineNoAliasDefaultName() throws IOException { final JSONObject result = executeQuery( String.format( "source=%s | where balance > 39000 | sort balance | trendline sma(2, balance) |" + " fields balance_sma_trendline", TEST_INDEX_BANK)); + verifySchema(result, schema("balance_sma_trendline", "double")); verifyDataRows(result, rows(new Object[] {null}), rows(44313.0), rows(39882.5)); } @@ -164,21 +165,5 @@ public void testTrendlineWithSortWma() throws IOException { rows(19615.8), rows(29393.6)); } - // - // @Test - // public void testTrendlineWithDefaultNameWma() throws IOException { - // final JSONObject result = - // executeQuery( - // String.format( - // "source=%s | where balance > 39000 | trendline sort balance wma(2, - // balance) |" - // + " fields balance_wma_trendline", - // TEST_INDEX_BANK)); - // verifySchema( - // result, - // schema("balance_wma_trendline", "double")); - // verifyDataRows( - // result, rows(new Object[] {null}), rows(40101.666666666664), - // rows(45570.666666666664)); - // } + } From 8348950e0c19e1f3b99bf1de6b85f9e5909ebbc4 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Thu, 13 Feb 2025 17:02:23 -0800 Subject: [PATCH 34/36] Update test coverage Signed-off-by: Andy Kwok --- .../planner/physical/TrendlineOperator.java | 3 -- .../physical/TrendlineOperatorTest.java | 30 +++++++++++++++++++ .../sql/ppl/TrendlineCommandIT.java | 1 - 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index a4191004a6..69a5dc1038 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -365,9 +365,6 @@ Function, ExprValue> getWmaEvaluator(ExprCoreType type) { @Override public void accumulate(ExprValue value) { receivedValues.add(value); - if (receivedValues.size() > dataPointsNeeded.valueOf().integerValue()) { - receivedValues.remove(); - } } @Override diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index 4e4010c1a3..fcccec1ce2 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -428,6 +428,36 @@ public void calculates_weighted_moving_average_one_field_four_samples_four_rows( tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 190)))); } + @Test + public void calculates_weighted_moving_average_one_field_five_samples_four_rows() { + mockPlanWithData( + List.of( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(4, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.DOUBLE))); + + List result = execute(plan); + assertEquals(5, result.size()); + assertThat( + result, + containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 190)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200)))); + } + @Test public void calculates_weighted_moving_average_multiple_computations() { mockPlanWithData( diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java index a555ad6525..02860fe1e7 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java @@ -165,5 +165,4 @@ public void testTrendlineWithSortWma() throws IOException { rows(19615.8), rows(29393.6)); } - } From 60377429d87eade054cc342569254bfdf20edc5d Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Thu, 13 Feb 2025 17:13:09 -0800 Subject: [PATCH 35/36] Update unit test Signed-off-by: Andy Kwok --- .../physical/TrendlineOperatorTest.java | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index fcccec1ce2..05583e2014 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -58,6 +58,21 @@ static Stream supportedDataTypes() { Arguments.of(ExprCoreType.DOUBLE)); } + static Stream unSupportedDataTypes() { + return Stream.of(SMA, WMA) + .flatMap( + trendlineType -> + Stream.of( + Arguments.of(trendlineType, ExprCoreType.UNDEFINED), + Arguments.of(trendlineType, ExprCoreType.BYTE), + Arguments.of(trendlineType, ExprCoreType.STRING), + Arguments.of(trendlineType, ExprCoreType.BOOLEAN), + Arguments.of(trendlineType, ExprCoreType.INTERVAL), + Arguments.of(trendlineType, ExprCoreType.IP), + Arguments.of(trendlineType, ExprCoreType.STRUCT), + Arguments.of(trendlineType, ExprCoreType.ARRAY))); + } + static Stream invalidArguments() { return Stream.of(SMA, WMA) .flatMap( @@ -698,6 +713,21 @@ public void trendLine_dataType_support_wma(ExprCoreType supportedType) { tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 190)))); } + @ParameterizedTest + @MethodSource("unSupportedDataTypes") + public void trendLine_unsupported_dataType( + Trendline.TrendlineType trendlineType, ExprCoreType dataType) { + assertThrows( + SemanticCheckException.class, + () -> + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), + "distance_alias", trendlineType), dataType)))); + } + @ParameterizedTest @MethodSource("invalidArguments") public void use_invalid_configuration( @@ -718,6 +748,8 @@ public void use_invalid_configuration( "Unsupported arguments: " + errorMessage); } + + private void mockPlanWithData(List inputs) { List hasNextElements = new ArrayList<>(Collections.nCopies(inputs.size(), true)); hasNextElements.add(false); From e590fb97586a03d94ea99e351ea3b2f769ff1cc9 Mon Sep 17 00:00:00 2001 From: Andy Kwok Date: Thu, 13 Feb 2025 20:51:46 -0800 Subject: [PATCH 36/36] Code style Signed-off-by: Andy Kwok --- .../physical/TrendlineOperatorTest.java | 43 +++++++++---------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index 05583e2014..e6a6d1e045 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -60,17 +60,17 @@ static Stream supportedDataTypes() { static Stream unSupportedDataTypes() { return Stream.of(SMA, WMA) - .flatMap( - trendlineType -> - Stream.of( - Arguments.of(trendlineType, ExprCoreType.UNDEFINED), - Arguments.of(trendlineType, ExprCoreType.BYTE), - Arguments.of(trendlineType, ExprCoreType.STRING), - Arguments.of(trendlineType, ExprCoreType.BOOLEAN), - Arguments.of(trendlineType, ExprCoreType.INTERVAL), - Arguments.of(trendlineType, ExprCoreType.IP), - Arguments.of(trendlineType, ExprCoreType.STRUCT), - Arguments.of(trendlineType, ExprCoreType.ARRAY))); + .flatMap( + trendlineType -> + Stream.of( + Arguments.of(trendlineType, ExprCoreType.UNDEFINED), + Arguments.of(trendlineType, ExprCoreType.BYTE), + Arguments.of(trendlineType, ExprCoreType.STRING), + Arguments.of(trendlineType, ExprCoreType.BOOLEAN), + Arguments.of(trendlineType, ExprCoreType.INTERVAL), + Arguments.of(trendlineType, ExprCoreType.IP), + Arguments.of(trendlineType, ExprCoreType.STRUCT), + Arguments.of(trendlineType, ExprCoreType.ARRAY))); } static Stream invalidArguments() { @@ -716,16 +716,17 @@ public void trendLine_dataType_support_wma(ExprCoreType supportedType) { @ParameterizedTest @MethodSource("unSupportedDataTypes") public void trendLine_unsupported_dataType( - Trendline.TrendlineType trendlineType, ExprCoreType dataType) { + Trendline.TrendlineType trendlineType, ExprCoreType dataType) { assertThrows( - SemanticCheckException.class, - () -> - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), - "distance_alias", trendlineType), dataType)))); + SemanticCheckException.class, + () -> + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation( + 2, AstDSL.field("distance"), "distance_alias", trendlineType), + dataType)))); } @ParameterizedTest @@ -748,8 +749,6 @@ public void use_invalid_configuration( "Unsupported arguments: " + errorMessage); } - - private void mockPlanWithData(List inputs) { List hasNextElements = new ArrayList<>(Collections.nCopies(inputs.size(), true)); hasNextElements.add(false);