Skip to content

Commit

Permalink
Add test for compute shader input builtin values (#766)
Browse files Browse the repository at this point in the history
* Add test for compute shader input builtin values

Launch a compute shader with a variety of workgroup and dispatch
sizes, and test that the values received for each input builtin are
correct for every invocation in the grid.

Covers parameters, structures, and a combination of both.

* Fix formatting issues

* tweak refactoring of expectGPUBufferValuesPassCheck calls, add a bit of documentation

* Indent WGSL source string

Co-authored-by: Kai Ninomiya <[email protected]>
  • Loading branch information
jrprice and kainino0x authored Sep 30, 2021
1 parent e100dde commit afbdb6d
Show file tree
Hide file tree
Showing 2 changed files with 280 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/webgpu/gpu_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@ export class GPUTest extends Fixture {

/**
* Expect a GPUBuffer's contents to pass the provided check.
*
* A library of checks can be found in {@link webgpu/util/check_contents}.
*/
expectGPUBufferValuesPassCheck<T extends TypedArrayBufferView>(
src: GPUBuffer,
Expand Down
278 changes: 278 additions & 0 deletions src/webgpu/shader/execution/shader_io/compute_builtins.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
export const description = `Test compute shader builtin variables`;

import { makeTestGroup } from '../../../../common/framework/test_group.js';
import { iterRange } from '../../../../common/util/util.js';
import { GPUTest } from '../../../gpu_test.js';

export const g = makeTestGroup(GPUTest);

// Test that the values for each input builtin are correct.
g.test('inputs')
.desc(`Test compute shader builtin inputs values`)
.params(u =>
u
.combine('method', ['param', 'struct', 'mixed'] as const)
.combineWithParams([
{
groupSize: { x: 1, y: 1, z: 1 },
numGroups: { x: 1, y: 1, z: 1 },
},
{
groupSize: { x: 8, y: 4, z: 2 },
numGroups: { x: 1, y: 1, z: 1 },
},
{
groupSize: { x: 1, y: 1, z: 1 },
numGroups: { x: 8, y: 4, z: 2 },
},
{
groupSize: { x: 3, y: 7, z: 5 },
numGroups: { x: 13, y: 9, z: 11 },
},
] as const)
.beginSubcases()
)
.fn(async t => {
const invocationsPerGroup = t.params.groupSize.x * t.params.groupSize.y * t.params.groupSize.z;
const totalInvocations =
invocationsPerGroup * t.params.numGroups.x * t.params.numGroups.y * t.params.numGroups.z;

// Generate the structures, parameters, and builtin expressions used in the shader.
let params = '';
let structures = '';
let local_id = '';
let local_index = '';
let global_id = '';
let group_id = '';
let num_groups = '';
switch (t.params.method) {
case 'param':
params = `
[[builtin(local_invocation_id)]] local_id : vec3<u32>,
[[builtin(local_invocation_index)]] local_index : u32,
[[builtin(global_invocation_id)]] global_id : vec3<u32>,
[[builtin(workgroup_id)]] group_id : vec3<u32>,
[[builtin(num_workgroups)]] num_groups : vec3<u32>,
`;
local_id = 'local_id';
local_index = 'local_index';
global_id = 'global_id';
group_id = 'group_id';
num_groups = 'num_groups';
break;
case 'struct':
structures = `struct Inputs {
[[builtin(local_invocation_id)]] local_id : vec3<u32>;
[[builtin(local_invocation_index)]] local_index : u32;
[[builtin(global_invocation_id)]] global_id : vec3<u32>;
[[builtin(workgroup_id)]] group_id : vec3<u32>;
[[builtin(num_workgroups)]] num_groups : vec3<u32>;
};`;
params = `inputs : Inputs`;
local_id = 'inputs.local_id';
local_index = 'inputs.local_index';
global_id = 'inputs.global_id';
group_id = 'inputs.group_id';
num_groups = 'inputs.num_groups';
break;
case 'mixed':
structures = `struct InputsA {
[[builtin(local_invocation_index)]] local_index : u32;
[[builtin(global_invocation_id)]] global_id : vec3<u32>;
};
struct InputsB {
[[builtin(workgroup_id)]] group_id : vec3<u32>;
};`;
params = `[[builtin(local_invocation_id)]] local_id : vec3<u32>,
inputsA : InputsA,
inputsB : InputsB,
[[builtin(num_workgroups)]] num_groups : vec3<u32>,`;
local_id = 'local_id';
local_index = 'inputsA.local_index';
global_id = 'inputsA.global_id';
group_id = 'inputsB.group_id';
num_groups = 'num_groups';
break;
}

// WGSL shader that stores every builtin value to a buffer, for every invocation in the grid.
const wgsl = `
[[block]]
struct S {
data : array<u32>;
};
[[block]]
struct V {
data : array<vec3<u32>>;
};
[[group(0), binding(0)]] var<storage, write> local_id_out : V;
[[group(0), binding(1)]] var<storage, write> local_index_out : S;
[[group(0), binding(2)]] var<storage, write> global_id_out : V;
[[group(0), binding(3)]] var<storage, write> group_id_out : V;
[[group(0), binding(4)]] var<storage, write> num_groups_out : V;
${structures}
let group_width = ${t.params.groupSize.x}u;
let group_height = ${t.params.groupSize.y}u;
let group_depth = ${t.params.groupSize.z}u;
[[stage(compute), workgroup_size(group_width, group_height, group_depth)]]
fn main(
${params}
) {
let group_index = ((${group_id}.z * ${num_groups}.y) + ${group_id}.y) * ${num_groups}.x + ${group_id}.x;
let global_index = group_index * ${invocationsPerGroup}u + ${local_index};
local_id_out.data[global_index] = ${local_id};
local_index_out.data[global_index] = ${local_index};
global_id_out.data[global_index] = ${global_id};
group_id_out.data[global_index] = ${group_id};
num_groups_out.data[global_index] = ${num_groups};
}
`;

const pipeline = t.device.createComputePipeline({
compute: {
module: t.device.createShaderModule({
code: wgsl,
}),
entryPoint: 'main',
},
});

// Helper to create a `size`-byte buffer with binding number `binding`.
function createBuffer(size: number, binding: number) {
const buffer = t.device.createBuffer({
size,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
});
t.trackForCleanup(buffer);

bindGroupEntries.push({
binding,
resource: {
buffer,
},
});

return buffer;
}

// Create the output buffers.
const bindGroupEntries: GPUBindGroupEntry[] = [];
const localIdBuffer = createBuffer(totalInvocations * 16, 0);
const localIndexBuffer = createBuffer(totalInvocations * 4, 1);
const globalIdBuffer = createBuffer(totalInvocations * 16, 2);
const groupIdBuffer = createBuffer(totalInvocations * 16, 3);
const numGroupsBuffer = createBuffer(totalInvocations * 16, 4);

const bindGroup = t.device.createBindGroup({
layout: pipeline.getBindGroupLayout(0),
entries: bindGroupEntries,
});

// Run the shader.
const encoder = t.device.createCommandEncoder();
const pass = encoder.beginComputePass();
pass.setPipeline(pipeline);
pass.setBindGroup(0, bindGroup);
pass.dispatch(t.params.numGroups.x, t.params.numGroups.y, t.params.numGroups.z);
pass.endPass();
t.queue.submit([encoder.finish()]);

type vec3 = { x: number; y: number; z: number };

// Helper to check that the vec3<u32> value at each index of the provided `output` buffer
// matches the expected value for that invocation, as generated by the `getBuiltinValue`
// function. The `name` parameter is the builtin name, used for error messages.
const checkEachIndex = (
output: Uint32Array,
name: string,
getBuiltinValue: (groupId: vec3, localId: vec3) => vec3
) => {
// Loop over workgroups.
for (let gz = 0; gz < t.params.numGroups.z; gz++) {
for (let gy = 0; gy < t.params.numGroups.y; gy++) {
for (let gx = 0; gx < t.params.numGroups.x; gx++) {
// Loop over invocations within a group.
for (let lz = 0; lz < t.params.groupSize.z; lz++) {
for (let ly = 0; ly < t.params.groupSize.y; ly++) {
for (let lx = 0; lx < t.params.groupSize.x; lx++) {
const groupIndex = (gz * t.params.numGroups.y + gy) * t.params.numGroups.x + gx;
const localIndex = (lz * t.params.groupSize.y + ly) * t.params.groupSize.x + lx;
const globalIndex = groupIndex * invocationsPerGroup + localIndex;
const expected = getBuiltinValue(
{ x: gx, y: gy, z: gz },
{ x: lx, y: ly, z: lz }
);
if (output[globalIndex * 4 + 0] !== expected.x) {
return new Error(
`${name}.x failed at group(${gx},${gy},${gz}) local(${lx},${ly},${lz}))\n` +
` expected: ${expected.x}\n` +
` got: ${output[globalIndex * 4 + 0]}`
);
}
if (output[globalIndex * 4 + 1] !== expected.y) {
return new Error(
`${name}.y failed at group(${gx},${gy},${gz}) local(${lx},${ly},${lz}))\n` +
` expected: ${expected.y}\n` +
` got: ${output[globalIndex * 4 + 1]}`
);
}
if (output[globalIndex * 4 + 2] !== expected.z) {
return new Error(
`${name}.z failed at group(${gx},${gy},${gz}) local(${lx},${ly},${lz}))\n` +
` expected: ${expected.z}\n` +
` got: ${output[globalIndex * 4 + 2]}`
);
}
}
}
}
}
}
}
return undefined;
};

// Check [[builtin(local_invocation_index)]] values.
t.expectGPUBufferValuesEqual(
localIndexBuffer,
new Uint32Array([...iterRange(totalInvocations, x => x % invocationsPerGroup)])
);

// Check [[builtin(local_invocation_id)]] values.
t.expectGPUBufferValuesPassCheck(
localIdBuffer,
outputData => checkEachIndex(outputData, 'local_invocation_id', (_, localId) => localId),
{ type: Uint32Array, typedLength: totalInvocations * 4 }
);

// Check [[builtin(global_invocation_id)]] values.
const getGlobalId = (groupId: vec3, localId: vec3) => {
return {
x: groupId.x * t.params.groupSize.x + localId.x,
y: groupId.y * t.params.groupSize.y + localId.y,
z: groupId.z * t.params.groupSize.z + localId.z,
};
};
t.expectGPUBufferValuesPassCheck(
globalIdBuffer,
outputData => checkEachIndex(outputData, 'global_invocation_id', getGlobalId),
{ type: Uint32Array, typedLength: totalInvocations * 4 }
);

// Check [[builtin(workgroup_id)]] values.
t.expectGPUBufferValuesPassCheck(
groupIdBuffer,
outputData => checkEachIndex(outputData, 'workgroup_id', (groupId, _) => groupId),
{ type: Uint32Array, typedLength: totalInvocations * 4 }
);

// Check [[builtin(num_workgroups)]] values.
t.expectGPUBufferValuesPassCheck(
numGroupsBuffer,
outputData => checkEachIndex(outputData, 'num_workgroups', () => t.params.numGroups),
{ type: Uint32Array, typedLength: totalInvocations * 4 }
);
});

0 comments on commit afbdb6d

Please sign in to comment.