Skip to content

Commit

Permalink
Add AggWAvg support for rollup() (#5121)
Browse files Browse the repository at this point in the history
* Initial commit of Rollup wAvg

* Addressed PR comments.
  • Loading branch information
lbooker42 authored and devinrsmith committed Feb 12, 2024
1 parent c7dc934 commit f3ebf39
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 459 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,7 @@
import static io.deephaven.engine.table.ChunkSource.WithPrev.ZERO_LENGTH_CHUNK_SOURCE_WITH_PREV_ARRAY;
import static io.deephaven.engine.table.Table.AGGREGATION_ROW_LOOKUP_ATTRIBUTE;
import static io.deephaven.engine.table.impl.by.IterativeChunkedAggregationOperator.ZERO_LENGTH_ITERATIVE_CHUNKED_AGGREGATION_OPERATOR_ARRAY;
import static io.deephaven.engine.table.impl.by.RollupConstants.ROLLUP_COLUMN_SUFFIX;
import static io.deephaven.engine.table.impl.by.RollupConstants.ROLLUP_DISTINCT_SSM_COLUMN_ID;
import static io.deephaven.engine.table.impl.by.RollupConstants.ROLLUP_NAN_COUNT_COLUMN_ID;
import static io.deephaven.engine.table.impl.by.RollupConstants.ROLLUP_NI_COUNT_COLUMN_ID;
import static io.deephaven.engine.table.impl.by.RollupConstants.ROLLUP_NONNULL_COUNT_COLUMN_ID;
import static io.deephaven.engine.table.impl.by.RollupConstants.ROLLUP_PI_COUNT_COLUMN_ID;
import static io.deephaven.engine.table.impl.by.RollupConstants.ROLLUP_RUNNING_SUM2_COLUMN_ID;
import static io.deephaven.engine.table.impl.by.RollupConstants.ROLLUP_RUNNING_SUM_COLUMN_ID;
import static io.deephaven.engine.table.impl.by.RollupConstants.ROW_REDIRECTION_PREFIX;
import static io.deephaven.engine.table.impl.by.RollupConstants.*;
import static io.deephaven.util.QueryConstants.*;
import static io.deephaven.util.type.TypeUtils.getBoxedType;
import static io.deephaven.util.type.TypeUtils.isNumeric;
Expand Down Expand Up @@ -584,7 +576,10 @@ final void descendingSortedFirstOrLastUnsupported(@NotNull final SortColumn sort
isFirst ? "SortedFirst" : "SortedLast", sortColumn));
}

final void addWeightedAvgOrSumOperator(@NotNull final String weightName, final boolean isSum) {
final void addWeightedAvgOrSumOperator(
@NotNull final String weightName,
final boolean isSum,
final boolean exposeInternal) {
final ColumnSource<?> weightSource = table.getColumnSource(weightName);
final boolean weightSourceIsFloatingPoint;
if (isInteger(weightSource.getChunkType())) {
Expand Down Expand Up @@ -653,7 +648,7 @@ final void addWeightedAvgOrSumOperator(@NotNull final String weightName, final b
}
} else {
resultOperator = new ChunkedWeightedAverageOperator(
r.source.getChunkType(), doubleWeightOperator, r.pair.output().name());
r.source.getChunkType(), doubleWeightOperator, r.pair.output().name(), exposeInternal);
}
addOperator(resultOperator, r.source, r.pair.input().name(), weightName);
});
Expand Down Expand Up @@ -824,12 +819,12 @@ public void visit(@NotNull final AggSpecUnique unique) {

@Override
public void visit(@NotNull final AggSpecWAvg wAvg) {
addWeightedAvgOrSumOperator(wAvg.weight().name(), false);
addWeightedAvgOrSumOperator(wAvg.weight().name(), false, false);
}

@Override
public void visit(@NotNull final AggSpecWSum wSum) {
addWeightedAvgOrSumOperator(wSum.weight().name(), true);
addWeightedAvgOrSumOperator(wSum.weight().name(), true, false);
}

@Override
Expand Down Expand Up @@ -904,13 +899,6 @@ default void visit(@NotNull final AggSpecPercentile pct) {
default void visit(@NotNull final AggSpecTDigest tDigest) {
rollupUnsupported("TDigest");
}

@Override
@FinalDefault
default void visit(@NotNull final AggSpecWAvg wAvg) {
// TODO(deephaven-core#3350): AggWAvg support for rollup()
rollupUnsupported("WAvg", 3350);
}
}

private static void rollupUnsupported(@NotNull final String operationName) {
Expand Down Expand Up @@ -1042,7 +1030,14 @@ public void visit(@NotNull final AggSpecUnique unique) {

@Override
public void visit(@NotNull final AggSpecWSum wSum) {
addWeightedAvgOrSumOperator(wSum.weight().name(), true);
// Weighted sum does not need to expose internal columns to re-aggregate.
addWeightedAvgOrSumOperator(wSum.weight().name(), true, false);
}

@Override
public void visit(@NotNull final AggSpecWAvg wAvg) {
// Weighted average needs access internal columns to re-aggregate.
addWeightedAvgOrSumOperator(wAvg.weight().name(), false, true);
}

@Override
Expand Down Expand Up @@ -1190,6 +1185,12 @@ public void visit(@NotNull final AggSpecWSum wSum) {
reaggregateAsSum();
}

@Override
public void visit(@NotNull final AggSpecWAvg wAvg) {
reaggregateWAvgOperator();
}


@Override
public void visit(@NotNull final AggSpecVar var) {
reaggregateStdOrVarOperators(false);
Expand Down Expand Up @@ -1288,6 +1289,28 @@ private void reaggregateAvgOperator() {
}
}

private void reaggregateWAvgOperator() {
for (final Pair pair : resultPairs) {
final String resultName = pair.output().name();

// Make a recording operator for the sum of weights column
final String sumOfWeightsName = resultName + ROLLUP_SUM_WEIGHTS_COLUMN_ID + ROLLUP_COLUMN_SUFFIX;
final ColumnSource<?> sumOfWeightsSource = table.getColumnSource(sumOfWeightsName);

final DoubleWeightRecordingInternalOperator doubleWeightOperator =
new DoubleWeightRecordingInternalOperator(sumOfWeightsSource.getChunkType());
addOperator(doubleWeightOperator, sumOfWeightsSource, resultName, sumOfWeightsName);

final ColumnSource<?> weightedAveragesSource = table.getColumnSource(resultName);

// The sum of weights column is directly usable as the weights for the WAvg re-aggregation.
final IterativeChunkedAggregationOperator resultOperator = new ChunkedWeightedAverageOperator(
weightedAveragesSource.getChunkType(), doubleWeightOperator, resultName, true);

addOperator(resultOperator, weightedAveragesSource, resultName, sumOfWeightsName);
}
}

private void reaggregateStdOrVarOperators(final boolean isStd) {
for (final Pair pair : resultPairs) {
final String resultName = pair.output().name();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
import org.apache.commons.lang3.mutable.MutableInt;

import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;

import static io.deephaven.engine.table.impl.by.RollupConstants.*;

class ChunkedWeightedAverageOperator implements IterativeChunkedAggregationOperator {
private final ChunkType chunkType;
private final DoubleWeightRecordingInternalOperator weightOperator;
private final String resultName;
private final ChunkType chunkType;
private final boolean exposeInternalColumns;

private long tableSize;
private final LongArraySource normalCount;
Expand All @@ -32,11 +36,15 @@ class ChunkedWeightedAverageOperator implements IterativeChunkedAggregationOpera
private final DoubleArraySource weightedSum;
private final DoubleArraySource resultColumn;

ChunkedWeightedAverageOperator(ChunkType chunkType, DoubleWeightRecordingInternalOperator weightOperator,
String name) {
ChunkedWeightedAverageOperator(
ChunkType chunkType,
DoubleWeightRecordingInternalOperator weightOperator,
String name,
boolean exposeInternalColumns) {
this.chunkType = chunkType;
this.weightOperator = weightOperator;
this.resultName = name;
this.exposeInternalColumns = exposeInternalColumns;

tableSize = 0;
normalCount = new LongArraySource();
Expand Down Expand Up @@ -416,12 +424,22 @@ public void ensureCapacity(long tableSize) {

@Override
public Map<String, ? extends ColumnSource<?>> getResultColumns() {
return Collections.singletonMap(resultName, resultColumn);
if (exposeInternalColumns) {
final Map<String, ColumnSource<?>> results = new LinkedHashMap<>(2);
results.put(resultName, resultColumn);
results.put(resultName + ROLLUP_SUM_WEIGHTS_COLUMN_ID + ROLLUP_COLUMN_SUFFIX, sumOfWeights);
return results;
} else {
return Collections.singletonMap(resultName, resultColumn);
}
}

@Override
public void startTrackingPrevValues() {
resultColumn.startTrackingPrevValues();
if (exposeInternalColumns) {
sumOfWeights.startTrackingPrevValues();
}
}

private class Context implements BucketedContext, SingletonContext {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ private RollupConstants() {}
*/
static final String ROLLUP_RUNNING_SUM_COLUMN_ID = "_RS_";

/**
* Middle column name component (between source column name and {@link #ROLLUP_COLUMN_SUFFIX suffix}) for sum of
* weights columns used in rollup wavg aggregations.
*/
static final String ROLLUP_SUM_WEIGHTS_COLUMN_ID = "_RSW_";

/**
* Middle column name component (between source column name and {@link #ROLLUP_COLUMN_SUFFIX suffix}) for running
* sum of squares columns used in rollup aggregations.
Expand Down
Loading

0 comments on commit f3ebf39

Please sign in to comment.