Skip to content

Commit

Permalink
webgpu/shader/exection/builtin: Fix checkExpectation() (#781)
Browse files Browse the repository at this point in the history
The `expected` field of a case is an array of possible allowed values. The logic was checking that the output value matched *all* the values in the array, which would never work.

Fix the subnormal cases of the `abs()` tests. These are allowed to return 0.

Also simplify the formatted string for 0, Infinity, -Infinity.
  • Loading branch information
ben-clayton authored Oct 18, 2021
1 parent 6e1b3a2 commit 7f6e233
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 34 deletions.
4 changes: 2 additions & 2 deletions src/webgpu/shader/execution/builtin/abs.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ Component-wise when T is a vector. (GLSLstd450Fabs)

// Subnormal f32
// TODO(sarahM0): Check if this is needed (or if it has to fail). If yes add other values.
{input: NumberRepr.fromF32Bits(kBit.f32.subnormal.positive.max), expected: [NumberRepr.fromF32Bits(kBit.f32.subnormal.positive.max)] },
{input: NumberRepr.fromF32Bits(kBit.f32.subnormal.positive.min), expected: [NumberRepr.fromF32Bits(kBit.f32.subnormal.positive.min)] },
{input: NumberRepr.fromF32Bits(kBit.f32.subnormal.positive.max), expected: [NumberRepr.fromF32Bits(kBit.f32.subnormal.positive.max), NumberRepr.fromF32(0)] },
{input: NumberRepr.fromF32Bits(kBit.f32.subnormal.positive.min), expected: [NumberRepr.fromF32Bits(kBit.f32.subnormal.positive.min), NumberRepr.fromF32(0)] },

// Infinity f32
{input: NumberRepr.fromF32Bits(kBit.f32.infinity.negative), expected: [NumberRepr.fromF32Bits(kBit.f32.infinity.positive)] },
Expand Down
82 changes: 50 additions & 32 deletions src/webgpu/shader/execution/builtin/builtin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ export function runShaderTest<F extends NumberType>(
struct Data {
values : [[stride(16)]] array<${type}, ${cases.length}>;
};
[[group(0), binding(0)]] var<${storageClass}, ${storageMode}> inputs : Data;
[[group(0), binding(1)]] var<${storageClass}, write> outputs : Data;
[[stage(compute), workgroup_size(1)]]
fn main() {
for(var i = 0; i < ${cases.length}; i = i + 1) {
Expand Down Expand Up @@ -90,44 +90,62 @@ export function runShaderTest<F extends NumberType>(

t.queue.submit([encoder.finish()]);

// Returns the string representation of number.
const formatNum = (num: number | bigint) => {
switch (num) {
case 0:
case Infinity:
case -Infinity:
return num.toString();
default:
return num + ' (0x' + num.toString(16) + ')';
}
};

const checkExpectation = (outputData: typeof inputData) => {
// The list of expectation failures
const errs: string[] = [];

// For each case...
for (let i = 0; i < cases.length; i++) {
const input: string[] = [];
const output: string[] = [];
const expected: string[] = [];
let matched = true;
// String representations of the input, output and expectation values for this case.
const inputValue: string[] = [];
const outputValue: string[] = [];
const expectedValue: string[] = [];
let caseMatched = true;

// For each element in the case...
for (let j = 0; j < arrayLength; j++) {
const idx = i * 4 + j;
const caseExpected: string[] = [];
const expectedIndex = i + j < cases.length ? i + j : i;
for (const e of cases[expectedIndex].expected) {
caseExpected.push(e.value + ' (0x' + e.value.toString(16) + ')');
if (outputData[idx] !== e.value) {
matched = false;
}

input.push(inputData[idx] + ' (0x' + inputData[idx].toString(16) + ')');
output.push(outputData[idx] + ' (0x' + outputData[idx].toString(16) + ')');
expected.push(caseExpected.join(' or '));
}
}
if (matched) {
continue;
inputValue.push(formatNum(inputData[idx]));

// `cases[expectedIndex].expected` is an array of values that are treated as a pass.
// Do any of these expected values match?
const elementMatched = cases[expectedIndex].expected.some(e => e.value === outputData[idx]);
outputValue.push(formatNum(outputData[idx]));

const caseExpected = cases[expectedIndex].expected.map(e => formatNum(e.value));
expectedValue.push(caseExpected.join(' or '));

// If none of the expected values matched, then the case has failed.
caseMatched = caseMatched && elementMatched;
}

if (arrayLength > 1) {
errs.push(
`${builtin}(${type}(${input.join(', ')}))\n` +
` returned: ${type}(${output.join(', ')})\n` +
` expected: ${type}(${expected.join(', ')})`
);
} else {
errs.push(
`${builtin}(${input.join(', ')})\n` +
` returned: ${output.join(', ')}\n` +
` expected: ${expected.join(', ')}`
);
if (!caseMatched) {
if (arrayLength > 1) {
errs.push(
`${builtin}(${type}(${inputValue.join(', ')}))\n` +
` returned: ${type}(${outputValue.join(', ')})\n` +
` expected: ${type}(${expectedValue.join(', ')})`
);
} else {
errs.push(
`${builtin}(${inputValue.join(', ')})\n` +
` returned: ${outputValue.join(', ')}\n` +
` expected: ${expectedValue.join(', ')}`
);
}
}
}

Expand Down

0 comments on commit 7f6e233

Please sign in to comment.