From afbdb6d4e466d69025d198932154453e2985d759 Mon Sep 17 00:00:00 2001 From: James Price Date: Thu, 30 Sep 2021 17:02:55 -0400 Subject: [PATCH] Add test for compute shader input builtin values (#766) * Add test for compute shader input builtin values Launch a compute shader with a variety of workgroup and dispatch sizes, and test that the values received for each input builtin are correct for every invocation in the grid. Covers parameters, structures, and a combination of both. * Fix formatting issues * tweak refactoring of expectGPUBufferValuesPassCheck calls, add a bit of documentation * Indent WGSL source string Co-authored-by: Kai Ninomiya --- src/webgpu/gpu_test.ts | 2 + .../shader_io/compute_builtins.spec.ts | 278 ++++++++++++++++++ 2 files changed, 280 insertions(+) create mode 100644 src/webgpu/shader/execution/shader_io/compute_builtins.spec.ts diff --git a/src/webgpu/gpu_test.ts b/src/webgpu/gpu_test.ts index 5daeef621e23..2e9a7aaf59fd 100644 --- a/src/webgpu/gpu_test.ts +++ b/src/webgpu/gpu_test.ts @@ -275,6 +275,8 @@ export class GPUTest extends Fixture { /** * Expect a GPUBuffer's contents to pass the provided check. + * + * A library of checks can be found in {@link webgpu/util/check_contents}. */ expectGPUBufferValuesPassCheck( src: GPUBuffer, diff --git a/src/webgpu/shader/execution/shader_io/compute_builtins.spec.ts b/src/webgpu/shader/execution/shader_io/compute_builtins.spec.ts new file mode 100644 index 000000000000..384940e5db7b --- /dev/null +++ b/src/webgpu/shader/execution/shader_io/compute_builtins.spec.ts @@ -0,0 +1,278 @@ +export const description = `Test compute shader builtin variables`; + +import { makeTestGroup } from '../../../../common/framework/test_group.js'; +import { iterRange } from '../../../../common/util/util.js'; +import { GPUTest } from '../../../gpu_test.js'; + +export const g = makeTestGroup(GPUTest); + +// Test that the values for each input builtin are correct. +g.test('inputs') + .desc(`Test compute shader builtin inputs values`) + .params(u => + u + .combine('method', ['param', 'struct', 'mixed'] as const) + .combineWithParams([ + { + groupSize: { x: 1, y: 1, z: 1 }, + numGroups: { x: 1, y: 1, z: 1 }, + }, + { + groupSize: { x: 8, y: 4, z: 2 }, + numGroups: { x: 1, y: 1, z: 1 }, + }, + { + groupSize: { x: 1, y: 1, z: 1 }, + numGroups: { x: 8, y: 4, z: 2 }, + }, + { + groupSize: { x: 3, y: 7, z: 5 }, + numGroups: { x: 13, y: 9, z: 11 }, + }, + ] as const) + .beginSubcases() + ) + .fn(async t => { + const invocationsPerGroup = t.params.groupSize.x * t.params.groupSize.y * t.params.groupSize.z; + const totalInvocations = + invocationsPerGroup * t.params.numGroups.x * t.params.numGroups.y * t.params.numGroups.z; + + // Generate the structures, parameters, and builtin expressions used in the shader. + let params = ''; + let structures = ''; + let local_id = ''; + let local_index = ''; + let global_id = ''; + let group_id = ''; + let num_groups = ''; + switch (t.params.method) { + case 'param': + params = ` + [[builtin(local_invocation_id)]] local_id : vec3, + [[builtin(local_invocation_index)]] local_index : u32, + [[builtin(global_invocation_id)]] global_id : vec3, + [[builtin(workgroup_id)]] group_id : vec3, + [[builtin(num_workgroups)]] num_groups : vec3, + `; + local_id = 'local_id'; + local_index = 'local_index'; + global_id = 'global_id'; + group_id = 'group_id'; + num_groups = 'num_groups'; + break; + case 'struct': + structures = `struct Inputs { + [[builtin(local_invocation_id)]] local_id : vec3; + [[builtin(local_invocation_index)]] local_index : u32; + [[builtin(global_invocation_id)]] global_id : vec3; + [[builtin(workgroup_id)]] group_id : vec3; + [[builtin(num_workgroups)]] num_groups : vec3; + };`; + params = `inputs : Inputs`; + local_id = 'inputs.local_id'; + local_index = 'inputs.local_index'; + global_id = 'inputs.global_id'; + group_id = 'inputs.group_id'; + num_groups = 'inputs.num_groups'; + break; + case 'mixed': + structures = `struct InputsA { + [[builtin(local_invocation_index)]] local_index : u32; + [[builtin(global_invocation_id)]] global_id : vec3; + }; + struct InputsB { + [[builtin(workgroup_id)]] group_id : vec3; + };`; + params = `[[builtin(local_invocation_id)]] local_id : vec3, + inputsA : InputsA, + inputsB : InputsB, + [[builtin(num_workgroups)]] num_groups : vec3,`; + local_id = 'local_id'; + local_index = 'inputsA.local_index'; + global_id = 'inputsA.global_id'; + group_id = 'inputsB.group_id'; + num_groups = 'num_groups'; + break; + } + + // WGSL shader that stores every builtin value to a buffer, for every invocation in the grid. + const wgsl = ` + [[block]] + struct S { + data : array; + }; + [[block]] + struct V { + data : array>; + }; + [[group(0), binding(0)]] var local_id_out : V; + [[group(0), binding(1)]] var local_index_out : S; + [[group(0), binding(2)]] var global_id_out : V; + [[group(0), binding(3)]] var group_id_out : V; + [[group(0), binding(4)]] var num_groups_out : V; + + ${structures} + + let group_width = ${t.params.groupSize.x}u; + let group_height = ${t.params.groupSize.y}u; + let group_depth = ${t.params.groupSize.z}u; + + [[stage(compute), workgroup_size(group_width, group_height, group_depth)]] + fn main( + ${params} + ) { + let group_index = ((${group_id}.z * ${num_groups}.y) + ${group_id}.y) * ${num_groups}.x + ${group_id}.x; + let global_index = group_index * ${invocationsPerGroup}u + ${local_index}; + local_id_out.data[global_index] = ${local_id}; + local_index_out.data[global_index] = ${local_index}; + global_id_out.data[global_index] = ${global_id}; + group_id_out.data[global_index] = ${group_id}; + num_groups_out.data[global_index] = ${num_groups}; + } + `; + + const pipeline = t.device.createComputePipeline({ + compute: { + module: t.device.createShaderModule({ + code: wgsl, + }), + entryPoint: 'main', + }, + }); + + // Helper to create a `size`-byte buffer with binding number `binding`. + function createBuffer(size: number, binding: number) { + const buffer = t.device.createBuffer({ + size, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC, + }); + t.trackForCleanup(buffer); + + bindGroupEntries.push({ + binding, + resource: { + buffer, + }, + }); + + return buffer; + } + + // Create the output buffers. + const bindGroupEntries: GPUBindGroupEntry[] = []; + const localIdBuffer = createBuffer(totalInvocations * 16, 0); + const localIndexBuffer = createBuffer(totalInvocations * 4, 1); + const globalIdBuffer = createBuffer(totalInvocations * 16, 2); + const groupIdBuffer = createBuffer(totalInvocations * 16, 3); + const numGroupsBuffer = createBuffer(totalInvocations * 16, 4); + + const bindGroup = t.device.createBindGroup({ + layout: pipeline.getBindGroupLayout(0), + entries: bindGroupEntries, + }); + + // Run the shader. + const encoder = t.device.createCommandEncoder(); + const pass = encoder.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bindGroup); + pass.dispatch(t.params.numGroups.x, t.params.numGroups.y, t.params.numGroups.z); + pass.endPass(); + t.queue.submit([encoder.finish()]); + + type vec3 = { x: number; y: number; z: number }; + + // Helper to check that the vec3 value at each index of the provided `output` buffer + // matches the expected value for that invocation, as generated by the `getBuiltinValue` + // function. The `name` parameter is the builtin name, used for error messages. + const checkEachIndex = ( + output: Uint32Array, + name: string, + getBuiltinValue: (groupId: vec3, localId: vec3) => vec3 + ) => { + // Loop over workgroups. + for (let gz = 0; gz < t.params.numGroups.z; gz++) { + for (let gy = 0; gy < t.params.numGroups.y; gy++) { + for (let gx = 0; gx < t.params.numGroups.x; gx++) { + // Loop over invocations within a group. + for (let lz = 0; lz < t.params.groupSize.z; lz++) { + for (let ly = 0; ly < t.params.groupSize.y; ly++) { + for (let lx = 0; lx < t.params.groupSize.x; lx++) { + const groupIndex = (gz * t.params.numGroups.y + gy) * t.params.numGroups.x + gx; + const localIndex = (lz * t.params.groupSize.y + ly) * t.params.groupSize.x + lx; + const globalIndex = groupIndex * invocationsPerGroup + localIndex; + const expected = getBuiltinValue( + { x: gx, y: gy, z: gz }, + { x: lx, y: ly, z: lz } + ); + if (output[globalIndex * 4 + 0] !== expected.x) { + return new Error( + `${name}.x failed at group(${gx},${gy},${gz}) local(${lx},${ly},${lz}))\n` + + ` expected: ${expected.x}\n` + + ` got: ${output[globalIndex * 4 + 0]}` + ); + } + if (output[globalIndex * 4 + 1] !== expected.y) { + return new Error( + `${name}.y failed at group(${gx},${gy},${gz}) local(${lx},${ly},${lz}))\n` + + ` expected: ${expected.y}\n` + + ` got: ${output[globalIndex * 4 + 1]}` + ); + } + if (output[globalIndex * 4 + 2] !== expected.z) { + return new Error( + `${name}.z failed at group(${gx},${gy},${gz}) local(${lx},${ly},${lz}))\n` + + ` expected: ${expected.z}\n` + + ` got: ${output[globalIndex * 4 + 2]}` + ); + } + } + } + } + } + } + } + return undefined; + }; + + // Check [[builtin(local_invocation_index)]] values. + t.expectGPUBufferValuesEqual( + localIndexBuffer, + new Uint32Array([...iterRange(totalInvocations, x => x % invocationsPerGroup)]) + ); + + // Check [[builtin(local_invocation_id)]] values. + t.expectGPUBufferValuesPassCheck( + localIdBuffer, + outputData => checkEachIndex(outputData, 'local_invocation_id', (_, localId) => localId), + { type: Uint32Array, typedLength: totalInvocations * 4 } + ); + + // Check [[builtin(global_invocation_id)]] values. + const getGlobalId = (groupId: vec3, localId: vec3) => { + return { + x: groupId.x * t.params.groupSize.x + localId.x, + y: groupId.y * t.params.groupSize.y + localId.y, + z: groupId.z * t.params.groupSize.z + localId.z, + }; + }; + t.expectGPUBufferValuesPassCheck( + globalIdBuffer, + outputData => checkEachIndex(outputData, 'global_invocation_id', getGlobalId), + { type: Uint32Array, typedLength: totalInvocations * 4 } + ); + + // Check [[builtin(workgroup_id)]] values. + t.expectGPUBufferValuesPassCheck( + groupIdBuffer, + outputData => checkEachIndex(outputData, 'workgroup_id', (groupId, _) => groupId), + { type: Uint32Array, typedLength: totalInvocations * 4 } + ); + + // Check [[builtin(num_workgroups)]] values. + t.expectGPUBufferValuesPassCheck( + numGroupsBuffer, + outputData => checkEachIndex(outputData, 'num_workgroups', () => t.params.numGroups), + { type: Uint32Array, typedLength: totalInvocations * 4 } + ); + });