From ab90807fcdb16e809f91bb8d468870c03df82dc1 Mon Sep 17 00:00:00 2001 From: Gregg Tavares Date: Tue, 16 Jan 2024 14:14:42 -0800 Subject: [PATCH] Compat: Refactor compute_builtins test This test used 5 storage buffers which is more than the 4 min max of compat mode. Refactored to use 1 buffer. --- .../shader_io/compute_builtins.spec.ts | 175 +++++++----------- 1 file changed, 68 insertions(+), 107 deletions(-) diff --git a/src/webgpu/shader/execution/shader_io/compute_builtins.spec.ts b/src/webgpu/shader/execution/shader_io/compute_builtins.spec.ts index fcf3159c642c..a40b42633283 100644 --- a/src/webgpu/shader/execution/shader_io/compute_builtins.spec.ts +++ b/src/webgpu/shader/execution/shader_io/compute_builtins.spec.ts @@ -1,7 +1,6 @@ export const description = `Test compute shader builtin variables`; import { makeTestGroup } from '../../../../common/framework/test_group.js'; -import { iterRange } from '../../../../common/util/util.js'; import { GPUTest } from '../../../gpu_test.js'; export const g = makeTestGroup(GPUTest); @@ -98,17 +97,14 @@ g.test('inputs') // WGSL shader that stores every builtin value to a buffer, for every invocation in the grid. const wgsl = ` - struct S { - data : array + struct Outputs { + local_id: vec3u, + local_index: u32, + global_id: vec3u, + group_id: vec3u, + num_groups: vec3u, }; - struct V { - data : array> - }; - @group(0) @binding(0) var local_id_out : V; - @group(0) @binding(1) var local_index_out : S; - @group(0) @binding(2) var global_id_out : V; - @group(0) @binding(3) var group_id_out : V; - @group(0) @binding(4) var num_groups_out : V; + @group(0) @binding(0) var outputs : array; ${structures} @@ -122,11 +118,13 @@ g.test('inputs') ) { let group_index = ((${group_id}.z * ${num_groups}.y) + ${group_id}.y) * ${num_groups}.x + ${group_id}.x; let global_index = group_index * ${invocationsPerGroup}u + ${local_index}; - local_id_out.data[global_index] = ${local_id}; - local_index_out.data[global_index] = ${local_index}; - global_id_out.data[global_index] = ${global_id}; - group_id_out.data[global_index] = ${group_id}; - num_groups_out.data[global_index] = ${num_groups}; + var o: Outputs; + o.local_id = ${local_id}; + o.local_index = ${local_index}; + o.global_id = ${global_id}; + o.group_id = ${group_id}; + o.num_groups = ${num_groups}; + outputs[global_index] = o; } `; @@ -140,35 +138,24 @@ g.test('inputs') }, }); - // Helper to create a `size`-byte buffer with binding number `binding`. - function createBuffer(size: number, binding: number) { - const buffer = t.device.createBuffer({ - size, - usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC, - }); - t.trackForCleanup(buffer); - - bindGroupEntries.push({ - binding, - resource: { - buffer, - }, - }); - - return buffer; - } + // Offsets are in u32 size units + const kLocalIdOffset = 0; + const kLocalIndexOffset = 3; + const kGlobalIdOffset = 4; + const kGroupIdOffset = 8; + const kNumGroupsOffset = 12; + const kOutputElementSize = 16; // Create the output buffers. - const bindGroupEntries: GPUBindGroupEntry[] = []; - const localIdBuffer = createBuffer(totalInvocations * 16, 0); - const localIndexBuffer = createBuffer(totalInvocations * 4, 1); - const globalIdBuffer = createBuffer(totalInvocations * 16, 2); - const groupIdBuffer = createBuffer(totalInvocations * 16, 3); - const numGroupsBuffer = createBuffer(totalInvocations * 16, 4); + const outputBuffer = t.device.createBuffer({ + size: totalInvocations * kOutputElementSize * 4, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC, + }); + t.trackForCleanup(outputBuffer); const bindGroup = t.device.createBindGroup({ layout: pipeline.getBindGroupLayout(0), - entries: bindGroupEntries, + entries: [{ binding: 0, resource: { buffer: outputBuffer } }], }); // Run the shader. @@ -204,11 +191,7 @@ g.test('inputs') // Helper to check that the vec3 value at each index of the provided `output` buffer // matches the expected value for that invocation, as generated by the `getBuiltinValue` // function. The `name` parameter is the builtin name, used for error messages. - const checkEachIndex = ( - output: Uint32Array, - name: string, - getBuiltinValue: (groupId: vec3, localId: vec3) => vec3 - ) => { + const checkEachIndex = (output: Uint32Array) => { // Loop over workgroups. for (let gz = 0; gz < t.params.numGroups.z; gz++) { for (let gy = 0; gy < t.params.numGroups.y; gy++) { @@ -220,30 +203,44 @@ g.test('inputs') const groupIndex = (gz * t.params.numGroups.y + gy) * t.params.numGroups.x + gx; const localIndex = (lz * t.params.groupSize.y + ly) * t.params.groupSize.x + lx; const globalIndex = groupIndex * invocationsPerGroup + localIndex; - const expected = getBuiltinValue( - { x: gx, y: gy, z: gz }, - { x: lx, y: ly, z: lz } - ); - if (output[globalIndex * 4 + 0] !== expected.x) { - return new Error( - `${name}.x failed at group(${gx},${gy},${gz}) local(${lx},${ly},${lz}))\n` + - ` expected: ${expected.x}\n` + - ` got: ${output[globalIndex * 4 + 0]}` - ); - } - if (output[globalIndex * 4 + 1] !== expected.y) { - return new Error( - `${name}.y failed at group(${gx},${gy},${gz}) local(${lx},${ly},${lz}))\n` + - ` expected: ${expected.y}\n` + - ` got: ${output[globalIndex * 4 + 1]}` + const globalOffset = globalIndex * kOutputElementSize; + + const expectEqual = (name: string, expected: number, actual: number) => { + if (actual !== expected) { + return new Error( + `${name} failed at group(${gx},${gy},${gz}) local(${lx},${ly},${lz}))\n` + + ` expected: ${expected}\n` + + ` got: ${actual}` + ); + } + return undefined; + }; + + const checkVec3Value = (name: string, fieldOffset: number, expected: vec3) => { + const offset = globalOffset + fieldOffset; + return ( + expectEqual(`${name}.x`, expected.x, output[offset + 0]) || + expectEqual(`${name}.y`, expected.y, output[offset + 1]) || + expectEqual(`${name}.z`, expected.z, output[offset + 2]) ); - } - if (output[globalIndex * 4 + 2] !== expected.z) { - return new Error( - `${name}.z failed at group(${gx},${gy},${gz}) local(${lx},${ly},${lz}))\n` + - ` expected: ${expected.z}\n` + - ` got: ${output[globalIndex * 4 + 2]}` + }; + + const error = + checkVec3Value('local_id', kLocalIdOffset, { x: lx, y: ly, z: lz }) || + checkVec3Value('global_id', kGlobalIdOffset, { + x: gx * t.params.groupSize.x + lx, + y: gy * t.params.groupSize.y + ly, + z: gz * t.params.groupSize.z + lz, + }) || + checkVec3Value('group_id', kGroupIdOffset, { x: gx, y: gy, z: gz }) || + checkVec3Value('num_groups', kNumGroupsOffset, t.params.numGroups) || + expectEqual( + 'local_index', + localIndex, + output[globalOffset + kLocalIndexOffset] ); + if (error) { + return error; } } } @@ -254,44 +251,8 @@ g.test('inputs') return undefined; }; - // Check @builtin(local_invocation_index) values. - t.expectGPUBufferValuesEqual( - localIndexBuffer, - new Uint32Array([...iterRange(totalInvocations, x => x % invocationsPerGroup)]) - ); - - // Check @builtin(local_invocation_id) values. - t.expectGPUBufferValuesPassCheck( - localIdBuffer, - outputData => checkEachIndex(outputData, 'local_invocation_id', (_, localId) => localId), - { type: Uint32Array, typedLength: totalInvocations * 4 } - ); - - // Check @builtin(global_invocation_id) values. - const getGlobalId = (groupId: vec3, localId: vec3) => { - return { - x: groupId.x * t.params.groupSize.x + localId.x, - y: groupId.y * t.params.groupSize.y + localId.y, - z: groupId.z * t.params.groupSize.z + localId.z, - }; - }; - t.expectGPUBufferValuesPassCheck( - globalIdBuffer, - outputData => checkEachIndex(outputData, 'global_invocation_id', getGlobalId), - { type: Uint32Array, typedLength: totalInvocations * 4 } - ); - - // Check @builtin(workgroup_id) values. - t.expectGPUBufferValuesPassCheck( - groupIdBuffer, - outputData => checkEachIndex(outputData, 'workgroup_id', (groupId, _) => groupId), - { type: Uint32Array, typedLength: totalInvocations * 4 } - ); - - // Check @builtin(num_workgroups) values. - t.expectGPUBufferValuesPassCheck( - numGroupsBuffer, - outputData => checkEachIndex(outputData, 'num_workgroups', () => t.params.numGroups), - { type: Uint32Array, typedLength: totalInvocations * 4 } - ); + t.expectGPUBufferValuesPassCheck(outputBuffer, outputData => checkEachIndex(outputData), { + type: Uint32Array, + typedLength: outputBuffer.size / 4, + }); });