diff --git a/src/resources/cache/hashes.json b/src/resources/cache/hashes.json index f2084a7a6544..dfb6d933dce5 100644 --- a/src/resources/cache/hashes.json +++ b/src/resources/cache/hashes.json @@ -104,7 +104,7 @@ "webgpu/shader/execution/unary/ai_assignment.bin": "c7e6ac33", "webgpu/shader/execution/binary/ai_arithmetic.bin": "81c11ec2", "webgpu/shader/execution/unary/ai_arithmetic.bin": "3d27dc97", - "webgpu/shader/execution/binary/af_matrix_matrix_multiplication.bin": "17ffa31c", + "webgpu/shader/execution/binary/af_matrix_matrix_multiplication.bin": "df23d63a", "webgpu/shader/execution/binary/af_matrix_scalar_multiplication.bin": "718ef50", "webgpu/shader/execution/binary/af_matrix_vector_multiplication.bin": "495e66cf" } \ No newline at end of file diff --git a/src/resources/cache/webgpu/shader/execution/binary/af_matrix_matrix_multiplication.bin b/src/resources/cache/webgpu/shader/execution/binary/af_matrix_matrix_multiplication.bin index 8e97864f29b3..907c1b3caa69 100644 Binary files a/src/resources/cache/webgpu/shader/execution/binary/af_matrix_matrix_multiplication.bin and b/src/resources/cache/webgpu/shader/execution/binary/af_matrix_matrix_multiplication.bin differ diff --git a/src/webgpu/shader/execution/expression/binary/af_matrix_matrix_multiplication.cache.ts b/src/webgpu/shader/execution/expression/binary/af_matrix_matrix_multiplication.cache.ts index da7168bae60a..e719ab3a00f4 100644 --- a/src/webgpu/shader/execution/expression/binary/af_matrix_matrix_multiplication.cache.ts +++ b/src/webgpu/shader/execution/expression/binary/af_matrix_matrix_multiplication.cache.ts @@ -1,18 +1,22 @@ import { FP } from '../../../../util/floating_point.js'; import { sparseMatrixF64Range } from '../../../../util/math.js'; +import { StochasticFilter } from '../../../../util/stochastic_filter.js'; import { makeCaseCache } from '../case_cache.js'; +const sf = new StochasticFilter(0); // Cases: matKxR_matCxK const mat_mat_cases = ([2, 3, 4] as const) .flatMap(k => ([2, 3, 4] as const).flatMap(cols => ([2, 3, 4] as const).map(rows => ({ [`mat${k}x${rows}_mat${cols}x${k}`]: () => { - return FP.abstract.generateMatrixPairToMatrixCases( - sparseMatrixF64Range(k, rows), - sparseMatrixF64Range(cols, k), - 'finite', - FP.abstract.multiplicationMatrixMatrixInterval + return sf.filter( + FP.abstract.generateMatrixPairToMatrixCases( + sparseMatrixF64Range(k, rows), + sparseMatrixF64Range(cols, k), + 'finite', + FP.abstract.multiplicationMatrixMatrixInterval + ) ); }, })) diff --git a/src/webgpu/util/stochastic_filter.ts b/src/webgpu/util/stochastic_filter.ts new file mode 100644 index 000000000000..c1f48b8f9719 --- /dev/null +++ b/src/webgpu/util/stochastic_filter.ts @@ -0,0 +1,70 @@ +import { assert } from '../../common/util/util.js'; + +import { PRNG } from './prng.js'; + +/** + * Used to remove entries from a list in a deterministic, but stochastic manner. + * Uses the pseudo-number random number generator from prng.ts for selection. + */ +export class StochasticFilter { + private readonly prng: PRNG; + private ratio: number; + + /** + * Constructor + * @param seed init value to pass down to PRNG + * @param ratio is a value between 0.0 and 1.0 indicating the number of + * entries that should be retained by the filter, 0.0 indicates + * none, 1.0 indicates all, defaults to 0.5. + * At least this many entries will be retained, but since arrays + * contain discrete number of entries, an extra element may need + * to be retained to guarantee this. + * For example given 5 entries and ratio 0.5, the result will + * have 3 elements so that 0.5 are retained. + */ + constructor(seed: number, ratio: number = 0.5) { + assert(ratio >= 0.0 && ratio <= 1.0, 'ratio needs to be in the range [0.0, 1.0]'); + this.prng = new PRNG(seed); + this.ratio = ratio; + } + + /** + * @returns a list of filtered elements, order of the elements is preserved. + * @param input is a list of elements to be filtered + * @param ratio is the number of elements to retain, defaults to this.ratio. + * The calculation for result length uses ceil, so if the input + * is 10 elements long and ratio is set to 0.49, 5 elements will + * be retained, because 10 * 0.49 = 4.9, which ceils to 5. + */ + public filter(input: readonly T[], ratio = this.ratio): T[] { + const target_length = Math.ceil(input.length * ratio); + if (target_length === 0) { + return []; + } + + if (target_length === input.length) { + return [...input]; + } + + return this.shuffle([...input.keys()]) // randomly shuffle list of 0 to input.length - 1 indices + .slice(0, target_length - 1) // Take the first target_length indices + .sort() // Get them back in order + .map(idx => input[idx]); // Copy out the retained indice elements from input + } + + /** + * @returns the input, but shuffled. + * Implements Fisher–Yates as described in AoCP. + */ + private shuffle(input: readonly T[]): T[] { + const result = [...input]; + let temp: T; + for (let i = result.length - 1; i > 0; i--) { + const j = Math.floor(this.prng.random() * (i + 1)); + temp = result[i]; + result[i] = result[j]; + result[j] = temp; + } + return result; + } +}