diff --git a/src/webgpu/shader/execution/builtin/abs.spec.ts b/src/webgpu/shader/execution/builtin/abs.spec.ts index 1e2e77f7f7fc..5c7f77ed03cb 100644 --- a/src/webgpu/shader/execution/builtin/abs.spec.ts +++ b/src/webgpu/shader/execution/builtin/abs.spec.ts @@ -203,8 +203,8 @@ Component-wise when T is a vector. (GLSLstd450Fabs) // Subnormal f32 // TODO(sarahM0): Check if this is needed (or if it has to fail). If yes add other values. - {input: NumberRepr.fromF32Bits(kBit.f32.subnormal.positive.max), expected: [NumberRepr.fromF32Bits(kBit.f32.subnormal.positive.max)] }, - {input: NumberRepr.fromF32Bits(kBit.f32.subnormal.positive.min), expected: [NumberRepr.fromF32Bits(kBit.f32.subnormal.positive.min)] }, + {input: NumberRepr.fromF32Bits(kBit.f32.subnormal.positive.max), expected: [NumberRepr.fromF32Bits(kBit.f32.subnormal.positive.max), NumberRepr.fromF32(0)] }, + {input: NumberRepr.fromF32Bits(kBit.f32.subnormal.positive.min), expected: [NumberRepr.fromF32Bits(kBit.f32.subnormal.positive.min), NumberRepr.fromF32(0)] }, // Infinity f32 {input: NumberRepr.fromF32Bits(kBit.f32.infinity.negative), expected: [NumberRepr.fromF32Bits(kBit.f32.infinity.positive)] }, diff --git a/src/webgpu/shader/execution/builtin/builtin.ts b/src/webgpu/shader/execution/builtin/builtin.ts index 151ad6c23e4b..12678d1afb49 100644 --- a/src/webgpu/shader/execution/builtin/builtin.ts +++ b/src/webgpu/shader/execution/builtin/builtin.ts @@ -25,10 +25,10 @@ export function runShaderTest( struct Data { values : [[stride(16)]] array<${type}, ${cases.length}>; }; - + [[group(0), binding(0)]] var<${storageClass}, ${storageMode}> inputs : Data; [[group(0), binding(1)]] var<${storageClass}, write> outputs : Data; - + [[stage(compute), workgroup_size(1)]] fn main() { for(var i = 0; i < ${cases.length}; i = i + 1) { @@ -90,44 +90,62 @@ export function runShaderTest( t.queue.submit([encoder.finish()]); + // Returns the string representation of number. + const formatNum = (num: number | bigint) => { + switch (num) { + case 0: + case Infinity: + case -Infinity: + return num.toString(); + default: + return num + ' (0x' + num.toString(16) + ')'; + } + }; + const checkExpectation = (outputData: typeof inputData) => { + // The list of expectation failures const errs: string[] = []; + + // For each case... for (let i = 0; i < cases.length; i++) { - const input: string[] = []; - const output: string[] = []; - const expected: string[] = []; - let matched = true; + // String representations of the input, output and expectation values for this case. + const inputValue: string[] = []; + const outputValue: string[] = []; + const expectedValue: string[] = []; + let caseMatched = true; + + // For each element in the case... for (let j = 0; j < arrayLength; j++) { const idx = i * 4 + j; - const caseExpected: string[] = []; const expectedIndex = i + j < cases.length ? i + j : i; - for (const e of cases[expectedIndex].expected) { - caseExpected.push(e.value + ' (0x' + e.value.toString(16) + ')'); - if (outputData[idx] !== e.value) { - matched = false; - } - - input.push(inputData[idx] + ' (0x' + inputData[idx].toString(16) + ')'); - output.push(outputData[idx] + ' (0x' + outputData[idx].toString(16) + ')'); - expected.push(caseExpected.join(' or ')); - } - } - if (matched) { - continue; + inputValue.push(formatNum(inputData[idx])); + + // `cases[expectedIndex].expected` is an array of values that are treated as a pass. + // Do any of these expected values match? + const elementMatched = cases[expectedIndex].expected.some(e => e.value === outputData[idx]); + outputValue.push(formatNum(outputData[idx])); + + const caseExpected = cases[expectedIndex].expected.map(e => formatNum(e.value)); + expectedValue.push(caseExpected.join(' or ')); + + // If none of the expected values matched, then the case has failed. + caseMatched = caseMatched && elementMatched; } - if (arrayLength > 1) { - errs.push( - `${builtin}(${type}(${input.join(', ')}))\n` + - ` returned: ${type}(${output.join(', ')})\n` + - ` expected: ${type}(${expected.join(', ')})` - ); - } else { - errs.push( - `${builtin}(${input.join(', ')})\n` + - ` returned: ${output.join(', ')}\n` + - ` expected: ${expected.join(', ')}` - ); + if (!caseMatched) { + if (arrayLength > 1) { + errs.push( + `${builtin}(${type}(${inputValue.join(', ')}))\n` + + ` returned: ${type}(${outputValue.join(', ')})\n` + + ` expected: ${type}(${expectedValue.join(', ')})` + ); + } else { + errs.push( + `${builtin}(${inputValue.join(', ')})\n` + + ` returned: ${outputValue.join(', ')}\n` + + ` expected: ${expectedValue.join(', ')}` + ); + } } }