Skip to content

Commit

Permalink
Add stochastic filtering to AF matrix * matrix cases
Browse files Browse the repository at this point in the history
  • Loading branch information
zoddicus committed Mar 6, 2024
1 parent 2f24495 commit a2b2b9f
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/resources/cache/hashes.json
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
"webgpu/shader/execution/unary/ai_assignment.bin": "c7e6ac33",
"webgpu/shader/execution/binary/ai_arithmetic.bin": "81c11ec2",
"webgpu/shader/execution/unary/ai_arithmetic.bin": "3d27dc97",
"webgpu/shader/execution/binary/af_matrix_matrix_multiplication.bin": "17ffa31c",
"webgpu/shader/execution/binary/af_matrix_matrix_multiplication.bin": "df23d63a",
"webgpu/shader/execution/binary/af_matrix_scalar_multiplication.bin": "718ef50",
"webgpu/shader/execution/binary/af_matrix_vector_multiplication.bin": "495e66cf"
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import { FP } from '../../../../util/floating_point.js';
import { sparseMatrixF64Range } from '../../../../util/math.js';
import { StochasticFilter } from '../../../../util/stochastic_filter.js';
import { makeCaseCache } from '../case_cache.js';

const sf = new StochasticFilter(0);
// Cases: matKxR_matCxK
const mat_mat_cases = ([2, 3, 4] as const)
.flatMap(k =>
([2, 3, 4] as const).flatMap(cols =>
([2, 3, 4] as const).map(rows => ({
[`mat${k}x${rows}_mat${cols}x${k}`]: () => {
return FP.abstract.generateMatrixPairToMatrixCases(
sparseMatrixF64Range(k, rows),
sparseMatrixF64Range(cols, k),
'finite',
FP.abstract.multiplicationMatrixMatrixInterval
return sf.filter(
FP.abstract.generateMatrixPairToMatrixCases(
sparseMatrixF64Range(k, rows),
sparseMatrixF64Range(cols, k),
'finite',
FP.abstract.multiplicationMatrixMatrixInterval
)
);
},
}))
Expand Down
70 changes: 70 additions & 0 deletions src/webgpu/util/stochastic_filter.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import { assert } from '../../common/util/util.js';

import { PRNG } from './prng.js';

/**
* Used to remove entries from a list in a deterministic, but stochastic manner.
* Uses the pseudo-number random number generator from prng.ts for selection.
*/
export class StochasticFilter {
private readonly prng: PRNG;
private ratio: number;

/**
* Constructor
* @param seed init value to pass down to PRNG
* @param ratio is a value between 0.0 and 1.0 indicating the number of
* entries that should be retained by the filter, 0.0 indicates
* none, 1.0 indicates all, defaults to 0.5.
* At least this many entries will be retained, but since arrays
* contain discrete number of entries, an extra element may need
* to be retained to guarantee this.
* For example given 5 entries and ratio 0.5, the result will
* have 3 elements so that 0.5 are retained.
*/
constructor(seed: number, ratio: number = 0.5) {
assert(ratio >= 0.0 && ratio <= 1.0, 'ratio needs to be in the range [0.0, 1.0]');
this.prng = new PRNG(seed);
this.ratio = ratio;
}

/**
* @returns a list of filtered elements, order of the elements is preserved.
* @param input is a list of elements to be filtered
* @param ratio is the number of elements to retain, defaults to this.ratio.
* The calculation for result length uses ceil, so if the input
* is 10 elements long and ratio is set to 0.49, 5 elements will
* be retained, because 10 * 0.49 = 4.9, which ceils to 5.
*/
public filter<T>(input: readonly T[], ratio = this.ratio): T[] {
const target_length = Math.ceil(input.length * ratio);
if (target_length === 0) {
return [];
}

if (target_length === input.length) {
return [...input];
}

return this.shuffle([...input.keys()]) // randomly shuffle list of 0 to input.length - 1 indices
.slice(0, target_length - 1) // Take the first target_length indices
.sort() // Get them back in order
.map(idx => input[idx]); // Copy out the retained indice elements from input
}

/**
* @returns the input, but shuffled.
* Implements Fisher–Yates as described in AoCP.
*/
private shuffle<T>(input: readonly T[]): T[] {
const result = [...input];
let temp: T;
for (let i = result.length - 1; i > 0; i--) {
const j = Math.floor(this.prng.random() * (i + 1));
temp = result[i];
result[i] = result[j];
result[j] = temp;
}
return result;
}
}

0 comments on commit a2b2b9f

Please sign in to comment.