Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compat: Refactor compute_builtins test #3277

Merged
merged 1 commit into from
Jan 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 68 additions & 107 deletions src/webgpu/shader/execution/shader_io/compute_builtins.spec.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
export const description = `Test compute shader builtin variables`;

import { makeTestGroup } from '../../../../common/framework/test_group.js';
import { iterRange } from '../../../../common/util/util.js';
import { GPUTest } from '../../../gpu_test.js';

export const g = makeTestGroup(GPUTest);
Expand Down Expand Up @@ -98,17 +97,14 @@ g.test('inputs')

// WGSL shader that stores every builtin value to a buffer, for every invocation in the grid.
const wgsl = `
struct S {
data : array<u32>
struct Outputs {
local_id: vec3u,
local_index: u32,
global_id: vec3u,
group_id: vec3u,
num_groups: vec3u,
};
struct V {
data : array<vec3<u32>>
};
@group(0) @binding(0) var<storage, read_write> local_id_out : V;
@group(0) @binding(1) var<storage, read_write> local_index_out : S;
@group(0) @binding(2) var<storage, read_write> global_id_out : V;
@group(0) @binding(3) var<storage, read_write> group_id_out : V;
@group(0) @binding(4) var<storage, read_write> num_groups_out : V;
@group(0) @binding(0) var<storage, read_write> outputs : array<Outputs>;

${structures}

Expand All @@ -122,11 +118,13 @@ g.test('inputs')
) {
let group_index = ((${group_id}.z * ${num_groups}.y) + ${group_id}.y) * ${num_groups}.x + ${group_id}.x;
let global_index = group_index * ${invocationsPerGroup}u + ${local_index};
local_id_out.data[global_index] = ${local_id};
local_index_out.data[global_index] = ${local_index};
global_id_out.data[global_index] = ${global_id};
group_id_out.data[global_index] = ${group_id};
num_groups_out.data[global_index] = ${num_groups};
var o: Outputs;
o.local_id = ${local_id};
o.local_index = ${local_index};
o.global_id = ${global_id};
o.group_id = ${group_id};
o.num_groups = ${num_groups};
outputs[global_index] = o;
}
`;

Expand All @@ -140,35 +138,24 @@ g.test('inputs')
},
});

// Helper to create a `size`-byte buffer with binding number `binding`.
function createBuffer(size: number, binding: number) {
const buffer = t.device.createBuffer({
size,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
});
t.trackForCleanup(buffer);

bindGroupEntries.push({
binding,
resource: {
buffer,
},
});

return buffer;
}
// Offsets are in u32 size units
const kLocalIdOffset = 0;
const kLocalIndexOffset = 3;
const kGlobalIdOffset = 4;
const kGroupIdOffset = 8;
const kNumGroupsOffset = 12;
const kOutputElementSize = 16;

// Create the output buffers.
const bindGroupEntries: GPUBindGroupEntry[] = [];
const localIdBuffer = createBuffer(totalInvocations * 16, 0);
const localIndexBuffer = createBuffer(totalInvocations * 4, 1);
const globalIdBuffer = createBuffer(totalInvocations * 16, 2);
const groupIdBuffer = createBuffer(totalInvocations * 16, 3);
const numGroupsBuffer = createBuffer(totalInvocations * 16, 4);
const outputBuffer = t.device.createBuffer({
size: totalInvocations * kOutputElementSize * 4,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
});
t.trackForCleanup(outputBuffer);

const bindGroup = t.device.createBindGroup({
layout: pipeline.getBindGroupLayout(0),
entries: bindGroupEntries,
entries: [{ binding: 0, resource: { buffer: outputBuffer } }],
});

// Run the shader.
Expand Down Expand Up @@ -204,11 +191,7 @@ g.test('inputs')
// Helper to check that the vec3<u32> value at each index of the provided `output` buffer
// matches the expected value for that invocation, as generated by the `getBuiltinValue`
// function. The `name` parameter is the builtin name, used for error messages.
const checkEachIndex = (
output: Uint32Array,
name: string,
getBuiltinValue: (groupId: vec3, localId: vec3) => vec3
) => {
const checkEachIndex = (output: Uint32Array) => {
// Loop over workgroups.
for (let gz = 0; gz < t.params.numGroups.z; gz++) {
for (let gy = 0; gy < t.params.numGroups.y; gy++) {
Expand All @@ -220,30 +203,44 @@ g.test('inputs')
const groupIndex = (gz * t.params.numGroups.y + gy) * t.params.numGroups.x + gx;
const localIndex = (lz * t.params.groupSize.y + ly) * t.params.groupSize.x + lx;
const globalIndex = groupIndex * invocationsPerGroup + localIndex;
const expected = getBuiltinValue(
{ x: gx, y: gy, z: gz },
{ x: lx, y: ly, z: lz }
);
if (output[globalIndex * 4 + 0] !== expected.x) {
return new Error(
`${name}.x failed at group(${gx},${gy},${gz}) local(${lx},${ly},${lz}))\n` +
` expected: ${expected.x}\n` +
` got: ${output[globalIndex * 4 + 0]}`
);
}
if (output[globalIndex * 4 + 1] !== expected.y) {
return new Error(
`${name}.y failed at group(${gx},${gy},${gz}) local(${lx},${ly},${lz}))\n` +
` expected: ${expected.y}\n` +
` got: ${output[globalIndex * 4 + 1]}`
const globalOffset = globalIndex * kOutputElementSize;

const expectEqual = (name: string, expected: number, actual: number) => {
if (actual !== expected) {
return new Error(
`${name} failed at group(${gx},${gy},${gz}) local(${lx},${ly},${lz}))\n` +
` expected: ${expected}\n` +
` got: ${actual}`
);
}
return undefined;
};

const checkVec3Value = (name: string, fieldOffset: number, expected: vec3) => {
const offset = globalOffset + fieldOffset;
return (
expectEqual(`${name}.x`, expected.x, output[offset + 0]) ||
expectEqual(`${name}.y`, expected.y, output[offset + 1]) ||
expectEqual(`${name}.z`, expected.z, output[offset + 2])
);
}
if (output[globalIndex * 4 + 2] !== expected.z) {
return new Error(
`${name}.z failed at group(${gx},${gy},${gz}) local(${lx},${ly},${lz}))\n` +
` expected: ${expected.z}\n` +
` got: ${output[globalIndex * 4 + 2]}`
};

const error =
checkVec3Value('local_id', kLocalIdOffset, { x: lx, y: ly, z: lz }) ||
checkVec3Value('global_id', kGlobalIdOffset, {
x: gx * t.params.groupSize.x + lx,
y: gy * t.params.groupSize.y + ly,
z: gz * t.params.groupSize.z + lz,
}) ||
checkVec3Value('group_id', kGroupIdOffset, { x: gx, y: gy, z: gz }) ||
checkVec3Value('num_groups', kNumGroupsOffset, t.params.numGroups) ||
expectEqual(
'local_index',
localIndex,
output[globalOffset + kLocalIndexOffset]
);
if (error) {
return error;
}
}
}
Expand All @@ -254,44 +251,8 @@ g.test('inputs')
return undefined;
};

// Check @builtin(local_invocation_index) values.
t.expectGPUBufferValuesEqual(
localIndexBuffer,
new Uint32Array([...iterRange(totalInvocations, x => x % invocationsPerGroup)])
);

// Check @builtin(local_invocation_id) values.
t.expectGPUBufferValuesPassCheck(
localIdBuffer,
outputData => checkEachIndex(outputData, 'local_invocation_id', (_, localId) => localId),
{ type: Uint32Array, typedLength: totalInvocations * 4 }
);

// Check @builtin(global_invocation_id) values.
const getGlobalId = (groupId: vec3, localId: vec3) => {
return {
x: groupId.x * t.params.groupSize.x + localId.x,
y: groupId.y * t.params.groupSize.y + localId.y,
z: groupId.z * t.params.groupSize.z + localId.z,
};
};
t.expectGPUBufferValuesPassCheck(
globalIdBuffer,
outputData => checkEachIndex(outputData, 'global_invocation_id', getGlobalId),
{ type: Uint32Array, typedLength: totalInvocations * 4 }
);

// Check @builtin(workgroup_id) values.
t.expectGPUBufferValuesPassCheck(
groupIdBuffer,
outputData => checkEachIndex(outputData, 'workgroup_id', (groupId, _) => groupId),
{ type: Uint32Array, typedLength: totalInvocations * 4 }
);

// Check @builtin(num_workgroups) values.
t.expectGPUBufferValuesPassCheck(
numGroupsBuffer,
outputData => checkEachIndex(outputData, 'num_workgroups', () => t.params.numGroups),
{ type: Uint32Array, typedLength: totalInvocations * 4 }
);
t.expectGPUBufferValuesPassCheck(outputBuffer, outputData => checkEachIndex(outputData), {
type: Uint32Array,
typedLength: outputBuffer.size / 4,
});
});