Skip to content

Commit

Permalink
WIP: Make test params readonly
Browse files Browse the repository at this point in the history
This should hopefully categorically prevent bugs like the one fixed in
#3096
  • Loading branch information
kainino0x committed Oct 25, 2023
1 parent e5f120e commit ae43985
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 29 deletions.
4 changes: 2 additions & 2 deletions src/common/framework/params_builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ export function builderIterateCasesWithSubcases(
*/
export class CaseParamsBuilder<CaseP extends {}>
extends ParamsBuilderBase<CaseP, {}>
implements Iterable<CaseP>, ParamsBuilder {
implements Iterable<Readonly<CaseP>>, ParamsBuilder {
*iterateCasesWithSubcases(caseFilter: TestParams | null): CaseSubcaseIterable<CaseP, {}> {
for (const caseP of this.cases(caseFilter)) {
if (caseFilter) {
Expand All @@ -159,7 +159,7 @@ export class CaseParamsBuilder<CaseP extends {}>
}
}

[Symbol.iterator](): Iterator<CaseP> {
[Symbol.iterator](): Iterator<Readonly<CaseP>> {
return this.cases(null);
}

Expand Down
3 changes: 2 additions & 1 deletion src/common/internal/test_group.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import {
stringifyPublicParamsUniquely,
} from '../internal/query/stringify_params.js';
import { validQueryPart } from '../internal/query/validQueryPart.js';
import { DeepReadonly } from '../util/types.js';
import { assert, unreachable } from '../util/util.js';

import { logToWebsocket } from './websocket_logger.js';
Expand Down Expand Up @@ -216,7 +217,7 @@ interface TestBuilderWithParams<F extends Fixture, CaseP extends {}, SubcaseP ex
* Set the test function.
* @param fn the test function.
*/
fn(fn: TestFn<F, Merged<CaseP, SubcaseP>>): void;
fn(fn: TestFn<F, DeepReadonly<Merged<CaseP, SubcaseP>>>): void;
/**
* Mark the test as unimplemented.
*/
Expand Down
14 changes: 14 additions & 0 deletions src/common/util/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@ export type TypeEqual<X, Y> = (<T>() => T extends X ? 1 : 2) extends <T>() => T
/* eslint-disable-next-line @typescript-eslint/no-unused-vars */
export function assertTypeTrue<T extends true>() {}

export type DeepReadonly<T> = T extends (infer R)[]
? DeepReadonlyArray<R>
: T extends Function
? T
: T extends object
? DeepReadonlyObject<T>
: T;

type DeepReadonlyArray<T> = ReadonlyArray<DeepReadonly<T>>;

type DeepReadonlyObject<T> = {
readonly [P in keyof T]: DeepReadonly<T[P]>;
};

/**
* Computes the intersection of a set of types, given the union of those types.
*
Expand Down
6 changes: 3 additions & 3 deletions src/webgpu/util/floating_point.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ export type FPKind = 'f32' | 'f16' | 'abstract';
* two elements, the first is the lower bound of the interval and the second is
* the upper bound.
*/
export type IntervalBounds = [number] | [number, number];
export type IntervalBounds = readonly [number] | readonly [number, number];

/** Represents a closed interval of floating point numbers */
export class FPInterval {
Expand Down Expand Up @@ -224,7 +224,7 @@ export type FPVector =
| [FPInterval, FPInterval, FPInterval, FPInterval];

/** Shorthand for an Array of Arrays that contains a column-major matrix */
type Array2D<T> = T[][];
type Array2D<T> = ReadonlyArray<ReadonlyArray<T>>;

/**
* Representation of a matCxR of floating point intervals as an array of arrays
Expand Down Expand Up @@ -808,7 +808,7 @@ export abstract class FPTraits {
`Matrix span is not defined for Matrices of differing dimensions`
);

const result: Array2D<FPInterval> = [...Array(num_cols)].map(_ => [...Array(num_rows)]);
const result: FPInterval[][] = [...Array(num_cols)].map(_ => [...Array(num_rows)]);
for (let i = 0; i < num_cols; i++) {
for (let j = 0; j < num_rows; j++) {
result[i][j] = this.spanIntervals(...ms.map(m => m[i][j]));
Expand Down
49 changes: 26 additions & 23 deletions src/webgpu/util/math.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ import {
reinterpretU16AsF16,
} from './reinterpret.js';

type ROArrayArray<T> = ReadonlyArray<ReadonlyArray<T>>;
type ROArrayArrayArray<T> = ReadonlyArray<ReadonlyArray<ReadonlyArray<T>>>;

/**
* A multiple of 8 guaranteed to be way too large to allocate (just under 8 pebibytes).
* This is a "safe" integer (ULP <= 1.0) very close to MAX_SAFE_INTEGER.
Expand Down Expand Up @@ -1172,7 +1175,7 @@ const kVectorI32Values = {
* vector to get a spread of testing over the entire range. This reduces the
* number of cases being run substantially, but maintains coverage.
*/
export function vectorI32Range(dim: number): number[][] {
export function vectorI32Range(dim: number): ROArrayArray<number> {
assert(dim === 2 || dim === 3 || dim === 4, 'vectorI32Range only accepts dimensions 2, 3, and 4');
return kVectorI32Values[dim];
}
Expand Down Expand Up @@ -1250,7 +1253,7 @@ const kVectorU32Values = {
* vector to get a spread of testing over the entire range. This reduces the
* number of cases being run substantially, but maintains coverage.
*/
export function vectorU32Range(dim: number): number[][] {
export function vectorU32Range(dim: number): ROArrayArray<number> {
assert(dim === 2 || dim === 3 || dim === 4, 'vectorU32Range only accepts dimensions 2, 3, and 4');
return kVectorU32Values[dim];
}
Expand Down Expand Up @@ -1342,7 +1345,7 @@ const kVectorF32Values = {
* vector to get a spread of testing over the entire range. This reduces the
* number of cases being run substantially, but maintains coverage.
*/
export function vectorF32Range(dim: number): number[][] {
export function vectorF32Range(dim: number): ROArrayArray<number> {
assert(dim === 2 || dim === 3 || dim === 4, 'vectorF32Range only accepts dimensions 2, 3, and 4');
return kVectorF32Values[dim];
}
Expand Down Expand Up @@ -1371,7 +1374,7 @@ const kSparseVectorF32Values = {
* All of the interesting floats from sparseF32 are guaranteed to be tested, but
* not in every position.
*/
export function sparseVectorF32Range(dim: number): number[][] {
export function sparseVectorF32Range(dim: number): ROArrayArray<number> {
assert(
dim === 2 || dim === 3 || dim === 4,
'sparseVectorF32Range only accepts dimensions 2, 3, and 4'
Expand Down Expand Up @@ -1490,7 +1493,7 @@ const kSparseMatrixF32Values = {
* All of the interesting floats from sparseF32 are guaranteed to be tested, but
* not in every position.
*/
export function sparseMatrixF32Range(c: number, r: number): number[][][] {
export function sparseMatrixF32Range(c: number, r: number): ROArrayArrayArray<number> {
assert(
c === 2 || c === 3 || c === 4,
'sparseMatrixF32Range only accepts column counts of 2, 3, and 4'
Expand Down Expand Up @@ -1578,7 +1581,7 @@ const kVectorF16Values = {
* vector to get a spread of testing over the entire range. This reduces the
* number of cases being run substantially, but maintains coverage.
*/
export function vectorF16Range(dim: number): number[][] {
export function vectorF16Range(dim: number): ROArrayArray<number> {
assert(dim === 2 || dim === 3 || dim === 4, 'vectorF16Range only accepts dimensions 2, 3, and 4');
return kVectorF16Values[dim];
}
Expand Down Expand Up @@ -1607,7 +1610,7 @@ const kSparseVectorF16Values = {
* All of the interesting floats from sparseF16 are guaranteed to be tested, but
* not in every position.
*/
export function sparseVectorF16Range(dim: number): number[][] {
export function sparseVectorF16Range(dim: number): ROArrayArray<number> {
assert(
dim === 2 || dim === 3 || dim === 4,
'sparseVectorF16Range only accepts dimensions 2, 3, and 4'
Expand Down Expand Up @@ -1726,7 +1729,7 @@ const kSparseMatrixF16Values = {
* All of the interesting floats from sparseF16 are guaranteed to be tested, but
* not in every position.
*/
export function sparseMatrixF16Range(c: number, r: number): number[][][] {
export function sparseMatrixF16Range(c: number, r: number): ROArrayArray<number>[] {
assert(
c === 2 || c === 3 || c === 4,
'sparseMatrixF16Range only accepts column counts of 2, 3, and 4'
Expand Down Expand Up @@ -1814,7 +1817,7 @@ const kVectorF64Values = {
* vector to get a spread of testing over the entire range. This reduces the
* number of cases being run substantially, but maintains coverage.
*/
export function vectorF64Range(dim: number): number[][] {
export function vectorF64Range(dim: number): ROArrayArray<number> {
assert(dim === 2 || dim === 3 || dim === 4, 'vectorF64Range only accepts dimensions 2, 3, and 4');
return kVectorF64Values[dim];
}
Expand Down Expand Up @@ -1843,7 +1846,7 @@ const kSparseVectorF64Values = {
* All the interesting floats from sparseF64 are guaranteed to be tested, but
* not in every position.
*/
export function sparseVectorF64Range(dim: number): number[][] {
export function sparseVectorF64Range(dim: number): ROArrayArray<number> {
assert(
dim === 2 || dim === 3 || dim === 4,
'sparseVectorF64Range only accepts dimensions 2, 3, and 4'
Expand Down Expand Up @@ -1962,7 +1965,7 @@ const kSparseMatrixF64Values = {
* All the interesting floats from sparseF64 are guaranteed to be tested, but
* not in every position.
*/
export function sparseMatrixF64Range(c: number, r: number): number[][][] {
export function sparseMatrixF64Range(c: number, r: number): ROArrayArray<number>[] {
assert(
c === 2 || c === 3 || c === 4,
'sparseMatrixF64Range only accepts column counts of 2, 3, and 4'
Expand Down Expand Up @@ -2074,8 +2077,8 @@ export function lcm(a: number, b: number): number {
* @param intermediate arrays of values representing the partial result of
* cartesianProduct
*/
function cartesianProductImpl<T>(elements: T[], intermediate: T[][]): T[][] {
const result: T[][] = [];
function cartesianProductImpl<T>(elements: T[], intermediate: ROArrayArray<T>): ROArrayArray<T> {
const result: ROArrayArray<T> = [];
elements.forEach((e: T) => {
if (intermediate.length > 0) {
intermediate.forEach((i: T[]) => {
Expand All @@ -2098,8 +2101,8 @@ function cartesianProductImpl<T>(elements: T[], intermediate: T[][]): T[][] {
*
* @param inputs arrays of numbers to calculate cartesian product over
*/
export function cartesianProduct<T>(...inputs: T[][]): T[][] {
let result: T[][] = [];
export function cartesianProduct<T>(...inputs: ROArrayArray<T>): ROArrayArray<T> {
let result: ROArrayArray<T> = [];
inputs.forEach((i: T[]) => {
result = cartesianProductImpl<T>(i, result);
});
Expand All @@ -2122,7 +2125,7 @@ export function cartesianProduct<T>(...inputs: T[][]): T[][] {
*
* @param input the array to get permutations of
*/
export function calculatePermutations<T>(input: T[]): T[][] {
export function calculatePermutations<T>(input: T[]): ROArrayArray<T> {
if (input.length === 0) {
return [];
}
Expand All @@ -2135,7 +2138,7 @@ export function calculatePermutations<T>(input: T[]): T[][] {
return [input, [input[1], input[0]]];
}

const result: T[][] = [];
const result: ROArrayArray<T> = [];
input.forEach((head, idx) => {
const tail = input.slice(0, idx).concat(input.slice(idx + 1));
const permutations = calculatePermutations(tail);
Expand All @@ -2155,7 +2158,7 @@ export function calculatePermutations<T>(input: T[]): T[][] {
*
* @param m Matrix to convert
*/
export function flatten2DArray<T>(m: T[][]): T[] {
export function flatten2DArray<T>(m: ROArrayArray<T>): T[] {
const c = m.length;
const r = m[0].length;
assert(
Expand All @@ -2177,13 +2180,13 @@ export function flatten2DArray<T>(m: T[][]): T[] {
* @param c number of elements in the array containing arrays
* @param r number of elements in the arrays that are contained
*/
export function unflatten2DArray<T>(n: T[], c: number, r: number): T[][] {
export function unflatten2DArray<T>(n: T[], c: number, r: number): ROArrayArray<T> {
assert(
c > 0 && Number.isInteger(c) && r > 0 && Number.isInteger(r),
`columns (${c}) and rows (${r}) need to be positive integers`
);
assert(n.length === c * r, `m.length(${n.length}) should equal c * r (${c * r})`);
const result: T[][] = [...Array(c)].map(_ => [...Array(r)]);
const result: ROArrayArray<T> = [...Array(c)].map(_ => [...Array(r)]);
for (let i = 0; i < c; i++) {
for (let j = 0; j < r; j++) {
result[i][j] = n[j + i * r];
Expand All @@ -2200,14 +2203,14 @@ export function unflatten2DArray<T>(n: T[], c: number, r: number): T[][] {
* @param op operation that converts an element of type T to one of type S
* @returns a matrix with elements of type S that are calculated by applying op element by element
*/
export function map2DArray<T, S>(m: T[][], op: (input: T) => S): S[][] {
export function map2DArray<T, S>(m: ROArrayArray<T>, op: (input: T) => S): ROArrayArray<S> {
const c = m.length;
const r = m[0].length;
assert(
m.every(c => c.length === r),
`Unexpectedly received jagged array to map`
);
const result: S[][] = [...Array(c)].map(_ => [...Array(r)]);
const result: ROArrayArray<S> = [...Array(c)].map(_ => [...Array(r)]);
for (let i = 0; i < c; i++) {
for (let j = 0; j < r; j++) {
result[i][j] = op(m[i][j]);
Expand All @@ -2223,7 +2226,7 @@ export function map2DArray<T, S>(m: T[][], op: (input: T) => S): S[][] {
* @param op operation that performs a test on an element
* @returns a boolean indicating if the test passed for every element
*/
export function every2DArray<T>(m: T[][], op: (input: T) => boolean): boolean {
export function every2DArray<T>(m: ROArrayArray<T>, op: (input: T) => boolean): boolean {
const r = m[0].length;
assert(
m.every(c => c.length === r),
Expand Down

0 comments on commit ae43985

Please sign in to comment.