Skip to content

Commit

Permalink
feat: DH-18351: Add CumCountWhere() and RollingCountWhere() featu…
Browse files Browse the repository at this point in the history
…res to `UpdateBy` (#6566)

## Groovy Examples
```
table = emptyTable(1000).update("key=randomInt(0,10)", "intCol=randomInt(0,1000)")

// zero-key
t_summary = table.updateBy([
    CumCountWhere("running_gt_500", "intCol > 500"),
    RollingCountWhere(50, "windowed_gt_500", "intCol > 500"),
    ])

// bucketed
t_summary = table.updateBy([
    CumCountWhere("running_gt_500", "intCol > 500"),
    RollingCountWhere(50, "windowed_gt_500", "intCol > 500"),
    ], "key")
```

## Python Examples
```
from deephaven import empty_table
from deephaven.updateby import cum_count_where, rolling_count_where_tick

table = empty_table(1000).update(["key=randomInt(0,10)", "intCol=randomInt(0,1000)"])

# zero-key
t_summary = table.update_by([
    cum_count_where(col="running_gt_500", filters="intCol > 500"),
    rolling_count_where_tick(rev_ticks=50, col="windowed_gt_500", filters="intCol > 500"),
    ])

# bucketed
t_summary_bucketed = table.update_by([
    cum_count_where(col="running_gt_500", filters="intCol > 500"),
    rolling_count_where_tick(rev_ticks=50, col="windowed_gt_500", filters="intCol > 500"),
    ], by="key")
```

## Performance Notes

TL:DR Performance compares very well.

`RollingCountWhere()` has near identical performance to the comparison
benchmarks (can be faster depending on the complexity of the filter.
`CumCountWhere()` also compares well to `Ema()`but can't catch up to
zero-key `CumSum()`, which is is remarkably fast.

Comparing `CumCountWhere` to `CumSum` and `Ema`:
```
120000000
avg of 2

ZeroKey
CumSum	137.36250
Ema	449.5528125
CumCountWhereConstant	475.9980005
CumCountWhereMatch	649.9689995
CumCountWhereRange	654.322250
CumCountWhereMultiple	695.4477915
CumCountWhereMultipleOr	704.900583

Bucketed - 250 buckets
CumSum	2979.1730005
Ema	3024.152458
CumCountWhereConstant	2569.7280835
CumCountWhereMatch	3031.6534795
CumCountWhereRange	3030.5433335
CumCountWhereMultiple	3052.597625
CumCountWhereMultipleOr	3059.911729

Bucketed - 640 buckets
CumSum	3827.299833
Ema	3880.2538125
CumCountWhereConstant	3416.4387715
CumCountWhereMatch	3906.691333
CumCountWhereRange	3902.3064375
CumCountWhereMultiple	3967.1584795
CumCountWhereMultipleOr	3925.0775205
```

Comparing `RollingCountWhere` to `RollingCount` and `RollingSum`:
```
120000000
avg of 2

ZeroKey
RollingCount	1511.7957295
RollingSum	1513.6013545
RollingCountWhereConstant	1403.2817915
RollingCountWhereMatch	1453.9323125
RollingCountWhereRange	1764.2137915
RollingCountWhereMultiple	1576.4896255
RollingCountWhereMultipleOr	1541.5631455

Bucketed - 250 buckets
RollingCount	3468.7696665
RollingSum	3326.047792
RollingCountWhereConstant	2858.677771
RollingCountWhereMatch	3327.958604
RollingCountWhereRange	3347.961083
RollingCountWhereMultiple	3429.413562
RollingCountWhereMultipleOr	3364.244104

Bucketed - 640 buckets
RollingCount	4310.4265835
RollingSum	4286.427479
RollingCountWhereConstant	3869.1892705
RollingCountWhereMatch	4333.8479375
RollingCountWhereRange	4269.3454375
RollingCountWhereMultiple	4290.0618545
RollingCountWhereMultipleOr	4346.8478535
```
  • Loading branch information
lbooker42 authored Jan 21, 2025
1 parent e7f731b commit fac6ad2
Show file tree
Hide file tree
Showing 43 changed files with 7,180 additions and 2,559 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
* {@link UpdateByOperator#initializeRolling(Context, RowSet)} (Context)} for windowed operators</li>
* <li>{@link UpdateByOperator.Context#accumulateCumulative(RowSequence, Chunk[], LongChunk, int)} for cumulative
* operators or
* {@link UpdateByOperator.Context#accumulateRolling(RowSequence, Chunk[], LongChunk, LongChunk, IntChunk, IntChunk, int)}
* {@link UpdateByOperator.Context#accumulateRolling(RowSequence, Chunk[], LongChunk, LongChunk, IntChunk, IntChunk, int, int)}
* for windowed operators</li>
* <li>{@link #finishUpdate(UpdateByOperator.Context)}</li>
* </ol>
Expand Down Expand Up @@ -99,18 +99,48 @@ protected void pop(int count) {
throw new UnsupportedOperationException("pop() must be overriden by rolling operators");
}

public abstract void accumulateCumulative(RowSequence inputKeys,
/**
* For cumulative operators only, this method will be called to pass the input chunk data to the operator and
* produce the output data values.
*
* @param inputKeys the keys for the input data rows (also matches the output keys)
* @param valueChunkArr the input data chunks needed by the operator for internal calculations
* @param tsChunk the timestamp chunk for the input data (if applicable)
* @param len the number of items in the input data chunks
*/
public abstract void accumulateCumulative(
RowSequence inputKeys,
Chunk<? extends Values>[] valueChunkArr,
LongChunk<? extends Values> tsChunk,
int len);

public abstract void accumulateRolling(RowSequence inputKeys,
/**
* For windowed operators only, this method will be called to pass the input chunk data to the operator and
* produce the output data values. It is important to note that the size of the influencer (input) and affected
* (output) chunks are not likely be the same. We pass these sizes explicitly to the operators for the sake of
* the operators (such as {@link io.deephaven.engine.table.impl.updateby.countwhere.CountWhereOperator} with
* zero input columns) where no input chunks are provided but we must still process the exact number of input
* rows.
*
* @param inputKeys the keys for the input data rows (also matches the output keys)
* @param influencerValueChunkArr the input data chunks needed by the operator for internal calculations, these
* values will be pushed and popped into the current window
* @param affectedPosChunk the row positions of the affected rows
* @param influencerPosChunk the row positions of the influencer rows
* @param pushChunk a chunk containing the push instructions for each output row to be calculated
* @param popChunk a chunk containing the pop instructions for each output row to be calculated
* @param affectedCount how many affected (output) rows are being computed
* @param influencerCount how many influencer (input) rows are needed for the computation
*/
public abstract void accumulateRolling(
RowSequence inputKeys,
Chunk<? extends Values>[] influencerValueChunkArr,
LongChunk<OrderedRowKeys> affectedPosChunk,
LongChunk<OrderedRowKeys> influencerPosChunk,
IntChunk<? extends Values> pushChunk,
IntChunk<? extends Values> popChunk,
int len);
int affectedCount,
int influencerCount);

/**
* Write the current value for this row to the output chunk
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,21 @@
import io.deephaven.api.updateby.UpdateByControl;
import io.deephaven.api.updateby.UpdateByOperation;
import io.deephaven.api.updateby.spec.*;
import io.deephaven.base.verify.Require;
import io.deephaven.engine.rowset.RowSetFactory;
import io.deephaven.engine.table.ColumnDefinition;
import io.deephaven.engine.table.ColumnSource;
import io.deephaven.engine.table.Table;
import io.deephaven.engine.table.TableDefinition;
import io.deephaven.engine.table.impl.MatchPair;
import io.deephaven.engine.table.impl.QueryCompilerRequestProcessor;
import io.deephaven.engine.table.impl.QueryTable;
import io.deephaven.engine.table.impl.select.FormulaColumn;
import io.deephaven.engine.table.impl.select.SelectColumn;
import io.deephaven.engine.table.impl.select.WhereFilter;
import io.deephaven.engine.table.impl.sources.NullValueColumnSource;
import io.deephaven.engine.table.impl.sources.ReinterpretUtils;
import io.deephaven.engine.table.impl.updateby.countwhere.CountWhereOperator;
import io.deephaven.engine.table.impl.updateby.delta.*;
import io.deephaven.engine.table.impl.updateby.em.*;
import io.deephaven.engine.table.impl.updateby.emstd.*;
Expand Down Expand Up @@ -45,6 +54,7 @@
import java.time.Instant;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static io.deephaven.util.BooleanUtils.NULL_BOOLEAN_AS_BYTE;
Expand Down Expand Up @@ -414,6 +424,12 @@ public Void visit(CumProdSpec cps) {
return null;
}

@Override
public Void visit(CumCountWhereSpec spec) {
ops.add(makeCountWhereOperator(tableDef, spec));
return null;
}

@Override
public Void visit(@NotNull final DeltaSpec spec) {
Arrays.stream(pairs)
Expand Down Expand Up @@ -537,6 +553,12 @@ public Void visit(@NotNull final RollingCountSpec spec) {
return null;
}

@Override
public Void visit(@NotNull final RollingCountWhereSpec spec) {
ops.add(makeCountWhereOperator(tableDef, spec));
return null;
}

@Override
public Void visit(@NotNull final RollingFormulaSpec spec) {
final boolean isTimeBased = spec.revWindowScale().isTimeBased();
Expand Down Expand Up @@ -1240,6 +1262,130 @@ private UpdateByOperator makeRollingCountOperator(@NotNull final MatchPair pair,
}
}

/**
* This is used for Cum/Rolling CountWhere operators
*/
private UpdateByOperator makeCountWhereOperator(
@NotNull final TableDefinition tableDef,
@NotNull final UpdateBySpec spec) {

Require.eqTrue(spec instanceof CumCountWhereSpec || spec instanceof RollingCountWhereSpec,
"spec instanceof CumCountWhereSpec || spec instanceof RollingCountWhereSpec");

final boolean isCumulative = spec instanceof CumCountWhereSpec;

final WhereFilter[] whereFilters = isCumulative
? WhereFilter.fromInternal(((CumCountWhereSpec) spec).filter())
: WhereFilter.fromInternal(((RollingCountWhereSpec) spec).filter());

final List<String> inputColumnNameList = new ArrayList<>();
final Map<String, Integer> inputColumnMap = new HashMap<>();
final List<int[]> filterInputColumnIndicesList = new ArrayList<>();

// Verify all the columns in the where filters are present in the dummy table and valid for use.
for (final WhereFilter whereFilter : whereFilters) {
whereFilter.init(tableDef);
if (whereFilter.isRefreshing()) {
throw new UnsupportedOperationException("CountWhere does not support refreshing filters");
}

// Compute which input sources this filter will use.
final List<String> filterColumnName = whereFilter.getColumns();
final int inputColumnCount = whereFilter.getColumns().size();
final int[] inputColumnIndices = new int[inputColumnCount];
for (int ii = 0; ii < inputColumnCount; ++ii) {
final String inputColumnName = filterColumnName.get(ii);
final int inputColumnIndex = inputColumnMap.computeIfAbsent(inputColumnName, k -> {
inputColumnNameList.add(inputColumnName);
return inputColumnNameList.size() - 1;
});
inputColumnIndices[ii] = inputColumnIndex;
}
filterInputColumnIndicesList.add(inputColumnIndices);
}

// Gather the input column type info and create a dummy table we can use to initialize filters.
final String[] inputColumnNames = inputColumnNameList.toArray(String[]::new);
final ColumnSource<?>[] originalColumnSources = new ColumnSource[inputColumnNames.length];
final ColumnSource<?>[] reinterpretedColumnSources = new ColumnSource[inputColumnNames.length];

final Map<String, ColumnSource<?>> columnSourceMap = new LinkedHashMap<>();
for (int i = 0; i < inputColumnNames.length; i++) {
final String col = inputColumnNames[i];
final ColumnDefinition<?> def = tableDef.getColumn(col);
// Create a representative column source of the correct type for the filter.
final ColumnSource<?> nullSource =
NullValueColumnSource.getInstance(def.getDataType(), def.getComponentType());
// Create a reinterpreted version of the column source.
final ColumnSource<?> maybeReinterpretedSource = ReinterpretUtils.maybeConvertToPrimitive(nullSource);
if (nullSource != maybeReinterpretedSource) {
originalColumnSources[i] = nullSource;
}
columnSourceMap.put(col, maybeReinterpretedSource);
reinterpretedColumnSources[i] = maybeReinterpretedSource;
}
final Table dummyTable = new QueryTable(RowSetFactory.empty().toTracking(), columnSourceMap);

final CountWhereOperator.CountFilter[] countFilters =
CountWhereOperator.CountFilter.createCountFilters(whereFilters, dummyTable,
filterInputColumnIndicesList);

// If any filter is ConditionFilter or ChunkFilter and uses a reinterpreted column, need to produce
// original-typed chunks.
final boolean originalChunksRequired = Arrays.asList(countFilters).stream()
.anyMatch(filter -> (filter.chunkFilter() != null || filter.conditionFilter() != null)
&& IntStream.of(filter.inputColumnIndices())
.anyMatch(i -> originalColumnSources[i] != null));

// If any filter is a standard WhereFilter or we need to produce original-typed chunks, need a chunk source
// table.
final boolean chunkSourceTableRequired = originalChunksRequired ||
Arrays.asList(countFilters).stream().anyMatch(filter -> filter.whereFilter() != null);

// Create a new column pair with the same name for the left and right columns
final String columnName = isCumulative
? ((CumCountWhereSpec) spec).column().name()
: ((RollingCountWhereSpec) spec).column().name();
final MatchPair pair = new MatchPair(columnName, columnName);

// Create and return the operator.
if (isCumulative) {
return new CountWhereOperator(
pair,
countFilters,
inputColumnNames,
originalColumnSources,
reinterpretedColumnSources,
chunkSourceTableRequired,
originalChunksRequired);
} else {
final RollingCountWhereSpec rs = (RollingCountWhereSpec) spec;

final String[] affectingColumns;
if (rs.revWindowScale().timestampCol() == null) {
affectingColumns = inputColumnNames;
} else {
affectingColumns = ArrayUtils.add(inputColumnNames, rs.revWindowScale().timestampCol());
}

final long prevWindowScaleUnits = rs.revWindowScale().getTimeScaleUnits();
final long fwdWindowScaleUnits = rs.fwdWindowScale().getTimeScaleUnits();

return new CountWhereOperator(
pair,
affectingColumns,
rs.revWindowScale().timestampCol(),
prevWindowScaleUnits,
fwdWindowScaleUnits,
countFilters,
inputColumnNames,
originalColumnSources,
reinterpretedColumnSources,
chunkSourceTableRequired,
originalChunksRequired);
}
}

private UpdateByOperator makeRollingStdOperator(@NotNull final MatchPair pair,
@NotNull final TableDefinition tableDef,
@NotNull final RollingStdSpec rs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ void processWindowBucketOperatorSet(final UpdateByWindowBucketContext context,
influencePosChunk,
ctx.pushChunks[affectedChunkOffset],
ctx.popChunks[affectedChunkOffset],
affectedChunkSize);
affectedChunkSize,
influencerCount);
}

affectedChunkOffset++;
Expand Down
Loading

0 comments on commit fac6ad2

Please sign in to comment.