Skip to content

Commit d563832

Browse files
committed
CTS for subgroup_id language feature
* Tests for subgroup_id and num_subgroups builtins * builtin validation * uniformity validation * execution tests
1 parent 9504252 commit d563832

File tree

6 files changed

+404
-3
lines changed

6 files changed

+404
-3
lines changed

src/webgpu/capability_info.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,7 @@ export const kKnownWGSLLanguageFeatures = [
936936
'packed_4x8_integer_dot_product',
937937
'unrestricted_pointer_parameters',
938938
'pointer_composite_access',
939+
'subgroup_id',
939940
] as const;
940941

941942
export type WGSLLanguageFeature = (typeof kKnownWGSLLanguageFeatures)[number];

src/webgpu/listing_meta.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1913,6 +1913,8 @@
19131913
"webgpu:shader,execution,robust_access:linear_memory:*": { "subcaseMS": 5.293 },
19141914
"webgpu:shader,execution,robust_access_vertex:vertex_buffer_access:*": { "subcaseMS": 6.487 },
19151915
"webgpu:shader,execution,shader_io,compute_builtins:inputs:*": { "subcaseMS": 19.342 },
1916+
"webgpu:shader,execution,shader_io,compute_builtins:num_subgroups:*": { "subcaseMS": 139.178 },
1917+
"webgpu:shader,execution,shader_io,compute_builtins:subgroup_id:*": { "subcaseMS": 430.747 },
19161918
"webgpu:shader,execution,shader_io,compute_builtins:subgroup_invocation_id:*": { "subcaseMS": 217.700 },
19171919
"webgpu:shader,execution,shader_io,compute_builtins:subgroup_size:*": { "subcaseMS": 644.206 },
19181920
"webgpu:shader,execution,shader_io,fragment_builtins:inputs,front_facing:*": { "subcaseMS": 1.001 },

src/webgpu/shader/execution/shader_io/compute_builtins.spec.ts

Lines changed: 364 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,3 +753,367 @@ fn main(@builtin(subgroup_size) size : u32,
753753
)
754754
);
755755
});
756+
757+
const skipValue = 0xffff0000;
758+
759+
/**
760+
* Checks subgroup_id consistency
761+
*
762+
* @param outputData An array of vec4u
763+
* * 0: comparison of subgroup_id among subgroup
764+
* * 1: comparison of subgroup_id < num_subgroups
765+
* * 2: subgroup_id (for first member) or skipValue
766+
* * 3: unused
767+
* @param wgSize Invocations in the workgroup
768+
* @param numWGs Number of workgroups
769+
*/
770+
function checkSubgroupIdConsistency(
771+
outputData: Uint32Array,
772+
wgSize: number,
773+
numWGs: number
774+
): Error | undefined {
775+
for (let wg = 0; wg < numWGs; wg++) {
776+
// Max wgSize is 256 and min subgroup size is 4
777+
const seen = new Array(Math.ceil(wgSize / 4));
778+
seen.fill(0);
779+
for (let inv = 0; inv < wgSize; inv++) {
780+
const gid = wg * wgSize + inv;
781+
const outputIdx = gid * 4;
782+
const compare = outputData[outputIdx];
783+
const in_range = outputData[outputIdx + 1];
784+
const sid = outputData[outputIdx + 2];
785+
786+
if (compare !== 1) {
787+
return new Error(
788+
`Invocation ${gid}: not all invocations in subgroup have same subgroup_id: ${compare}`
789+
);
790+
}
791+
if (in_range !== 1) {
792+
return new Error(
793+
`Invocation ${gid}: subgroup_id out of range of num_subgroups: ${in_range}`
794+
);
795+
}
796+
797+
if (sid !== skipValue) {
798+
if (seen[sid] !== 0) {
799+
return new Error(`Invocation ${gid}: subgroup_id reused among different subgroups`);
800+
}
801+
seen[sid] = 1;
802+
}
803+
}
804+
805+
const firstZero = seen.findIndex(ele => ele === 0);
806+
const lastOne = seen.findLastIndex(ele => ele === 1);
807+
if (firstZero !== -1 && firstZero < lastOne) {
808+
return new Error(`Subgroup id values are not densely packed: missing ${firstZero}`);
809+
}
810+
}
811+
812+
return undefined;
813+
}
814+
815+
g.test('subgroup_id')
816+
.desc(
817+
'Tests subgroup_id values. No mapping between local_invocation_index and subgroup_id can be relied upon.'
818+
)
819+
.params(u =>
820+
u
821+
.combine('sizes', kWGSizes)
822+
.beginSubcases()
823+
.combine('numWGs', [1, 2] as const)
824+
.combine('lid', [
825+
[0, 1, 2],
826+
[0, 2, 1],
827+
[1, 0, 2],
828+
[1, 2, 0],
829+
[2, 0, 1],
830+
[2, 1, 0],
831+
] as const)
832+
)
833+
.fn(async t => {
834+
t.skipIfDeviceDoesNotHaveFeature('subgroups' as GPUFeatureName);
835+
t.skipIfLanguageFeatureNotSupported('subgroup_id');
836+
const wgx = t.params.sizes[0];
837+
const wgy = t.params.sizes[1];
838+
const wgz = t.params.sizes[2];
839+
const lid = t.params.lid;
840+
const wgThreads = wgx * wgy * wgz;
841+
842+
// Compatibility mode has lower workgroup limits.
843+
const {
844+
maxComputeInvocationsPerWorkgroup,
845+
maxComputeWorkgroupSizeX,
846+
maxComputeWorkgroupSizeY,
847+
maxComputeWorkgroupSizeZ,
848+
} = t.device.limits;
849+
t.skipIf(
850+
maxComputeInvocationsPerWorkgroup < wgThreads ||
851+
maxComputeWorkgroupSizeX < wgx ||
852+
maxComputeWorkgroupSizeY < wgy ||
853+
maxComputeWorkgroupSizeZ < wgz,
854+
'Workgroup size too large'
855+
);
856+
857+
const wgsl = `
858+
enable subgroups;
859+
requires subgroup_id;
860+
861+
const stride = ${wgThreads};
862+
863+
${genLID(lid[0], lid[1], lid[2], t.params.sizes)}
864+
865+
@group(0) @binding(0)
866+
var<storage, read_write> output : array<vec4u>;
867+
868+
@compute @workgroup_size(${wgx}, ${wgy}, ${wgz})
869+
fn main(@builtin(local_invocation_id) local_id : vec3u,
870+
@builtin(workgroup_id) wgid : vec3u,
871+
@builtin(subgroup_id) sid : u32,
872+
@builtin(num_subgroups) num_subgroups : u32) {
873+
// Remapped local id.
874+
let lid = getLID(local_id);
875+
876+
let gid = lid + stride * wgid.x;
877+
878+
// Is the subgroup_id equivalent for all members?
879+
let broadcast_id = subgroupBroadcastFirst(sid);
880+
let compare = subgroupAll(broadcast_id == sid);
881+
882+
// Is subgroup_id in the range of num_subgroups?
883+
let in_range = sid < num_subgroups;
884+
885+
var out_sid = ${skipValue}u;
886+
if subgroupElect() {
887+
out_sid = sid;
888+
}
889+
890+
output[gid] = vec4u(
891+
select(0u, 1u, compare),
892+
select(0u, 1u, in_range),
893+
out_sid,
894+
0);
895+
}
896+
`;
897+
898+
const numInvocations = wgThreads * t.params.numWGs;
899+
const numUints = 4 * numInvocations;
900+
const placeholderValue = 999;
901+
const outputBuffer = t.makeBufferWithContents(
902+
new Uint32Array([...iterRange(numUints, x => placeholderValue)]),
903+
GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST
904+
);
905+
t.trackForCleanup(outputBuffer);
906+
907+
const pipeline = t.device.createComputePipeline({
908+
layout: 'auto',
909+
compute: {
910+
module: t.device.createShaderModule({
911+
code: wgsl,
912+
}),
913+
entryPoint: 'main',
914+
},
915+
});
916+
const bg = t.device.createBindGroup({
917+
layout: pipeline.getBindGroupLayout(0),
918+
entries: [
919+
{
920+
binding: 0,
921+
resource: {
922+
buffer: outputBuffer,
923+
},
924+
},
925+
],
926+
});
927+
928+
const encoder = t.device.createCommandEncoder();
929+
const pass = encoder.beginComputePass();
930+
pass.setPipeline(pipeline);
931+
pass.setBindGroup(0, bg);
932+
pass.dispatchWorkgroups(t.params.numWGs, 1, 1);
933+
pass.end();
934+
t.queue.submit([encoder.finish()]);
935+
936+
const outputReadback = await t.readGPUBufferRangeTyped(outputBuffer, {
937+
srcByteOffset: 0,
938+
type: Uint32Array,
939+
typedLength: numUints,
940+
method: 'copy',
941+
});
942+
const outputData: Uint32Array = outputReadback.data;
943+
944+
t.expectOK(checkSubgroupIdConsistency(outputData, wgThreads, t.params.numWGs));
945+
});
946+
947+
/**
948+
* Checks num_subgroups consistency
949+
*
950+
* @param countData An array with numWGs elements containing the counted number of subgroups
951+
* @param outputData An array numWGs * wgSize elements containing the value of num_subgroups
952+
* @param wgSize Number of invocations in the workgroup
953+
* @param numWGs Number of workgroups
954+
*/
955+
function checkNumSubgroupsConsistency(
956+
countData: Uint32Array,
957+
outputData: Uint32Array,
958+
wgSize: number,
959+
numWGs: number
960+
): Error | undefined {
961+
for (let wg = 0; wg < numWGs; wg++) {
962+
const count = countData[wg];
963+
const slice = outputData.slice(wg * wgSize, (wg + 1) * wgSize);
964+
const index = slice.findIndex(ele => ele !== count);
965+
if (index !== -1) {
966+
return new Error(`Workgroup ${wg}: inconsistent num_subgroups:
967+
- expected: ${count}
968+
- got: ${slice[index]}`);
969+
}
970+
}
971+
972+
return undefined;
973+
}
974+
975+
g.test('num_subgroups')
976+
.desc('Tests num_subgroups values.')
977+
.params(u =>
978+
u
979+
.combine('sizes', kWGSizes)
980+
.beginSubcases()
981+
.combine('numWGs', [1, 2] as const)
982+
.combine('lid', [
983+
[0, 1, 2],
984+
[0, 2, 1],
985+
[1, 0, 2],
986+
[1, 2, 0],
987+
[2, 0, 1],
988+
[2, 1, 0],
989+
] as const)
990+
)
991+
.fn(async t => {
992+
t.skipIfDeviceDoesNotHaveFeature('subgroups' as GPUFeatureName);
993+
t.skipIfLanguageFeatureNotSupported('subgroup_id');
994+
const wgx = t.params.sizes[0];
995+
const wgy = t.params.sizes[1];
996+
const wgz = t.params.sizes[2];
997+
const lid = t.params.lid;
998+
const wgThreads = wgx * wgy * wgz;
999+
1000+
// Compatibility mode has lower workgroup limits.
1001+
const {
1002+
maxComputeInvocationsPerWorkgroup,
1003+
maxComputeWorkgroupSizeX,
1004+
maxComputeWorkgroupSizeY,
1005+
maxComputeWorkgroupSizeZ,
1006+
} = t.device.limits;
1007+
t.skipIf(
1008+
maxComputeInvocationsPerWorkgroup < wgThreads ||
1009+
maxComputeWorkgroupSizeX < wgx ||
1010+
maxComputeWorkgroupSizeY < wgy ||
1011+
maxComputeWorkgroupSizeZ < wgz,
1012+
'Workgroup size too large'
1013+
);
1014+
1015+
const wgsl = `
1016+
enable subgroups;
1017+
requires subgroup_id;
1018+
1019+
const stride = ${wgThreads};
1020+
1021+
${genLID(lid[0], lid[1], lid[2], t.params.sizes)}
1022+
1023+
@group(0) @binding(0)
1024+
var<storage, read_write> numSubgroups : array<u32>;
1025+
1026+
@group(0) @binding(1)
1027+
var<storage, read_write> output : array<u32>;
1028+
1029+
var<workgroup> count : atomic<u32>;
1030+
1031+
@compute @workgroup_size(${wgx}, ${wgy}, ${wgz})
1032+
fn main(@builtin(local_invocation_id) local_id : vec3u,
1033+
@builtin(workgroup_id) wgid : vec3u,
1034+
@builtin(subgroup_id) sid : u32,
1035+
@builtin(num_subgroups) num_subgroups : u32) {
1036+
// Remapped local id.
1037+
let lid = getLID(local_id);
1038+
1039+
let gid = lid + stride * wgid.x;
1040+
1041+
if subgroupElect() {
1042+
atomicAdd(&count, 1);
1043+
}
1044+
1045+
workgroupBarrier();
1046+
1047+
if lid == 0 {
1048+
numSubgroups[wgid.x] = atomicLoad(&count);
1049+
}
1050+
1051+
output[gid] = num_subgroups;
1052+
}
1053+
`;
1054+
1055+
const numInvocations = wgThreads * t.params.numWGs;
1056+
const placeholderValue = 999;
1057+
const countBuffer = t.makeBufferWithContents(
1058+
new Uint32Array([...iterRange(t.params.numWGs, x => placeholderValue)]),
1059+
GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST
1060+
);
1061+
t.trackForCleanup(countBuffer);
1062+
const outputBuffer = t.makeBufferWithContents(
1063+
new Uint32Array([...iterRange(numInvocations * 4, x => placeholderValue)]),
1064+
GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST
1065+
);
1066+
t.trackForCleanup(outputBuffer);
1067+
1068+
const pipeline = t.device.createComputePipeline({
1069+
layout: 'auto',
1070+
compute: {
1071+
module: t.device.createShaderModule({
1072+
code: wgsl,
1073+
}),
1074+
entryPoint: 'main',
1075+
},
1076+
});
1077+
const bg = t.device.createBindGroup({
1078+
layout: pipeline.getBindGroupLayout(0),
1079+
entries: [
1080+
{
1081+
binding: 0,
1082+
resource: {
1083+
buffer: countBuffer,
1084+
},
1085+
},
1086+
{
1087+
binding: 1,
1088+
resource: {
1089+
buffer: outputBuffer,
1090+
},
1091+
},
1092+
],
1093+
});
1094+
1095+
const encoder = t.device.createCommandEncoder();
1096+
const pass = encoder.beginComputePass();
1097+
pass.setPipeline(pipeline);
1098+
pass.setBindGroup(0, bg);
1099+
pass.dispatchWorkgroups(t.params.numWGs, 1, 1);
1100+
pass.end();
1101+
t.queue.submit([encoder.finish()]);
1102+
1103+
const countReadback = await t.readGPUBufferRangeTyped(countBuffer, {
1104+
srcByteOffset: 0,
1105+
type: Uint32Array,
1106+
typedLength: t.params.numWGs,
1107+
method: 'copy',
1108+
});
1109+
const countData: Uint32Array = countReadback.data;
1110+
const outputReadback = await t.readGPUBufferRangeTyped(outputBuffer, {
1111+
srcByteOffset: 0,
1112+
type: Uint32Array,
1113+
typedLength: numInvocations,
1114+
method: 'copy',
1115+
});
1116+
const outputData: Uint32Array = outputReadback.data;
1117+
1118+
t.expectOK(checkNumSubgroupsConsistency(countData, outputData, wgThreads, t.params.numWGs));
1119+
});

0 commit comments

Comments
 (0)