diff --git a/src/webgpu/shader/execution/shader_io/fragment_builtins.spec.ts b/src/webgpu/shader/execution/shader_io/fragment_builtins.spec.ts index 7a6aa8901e28..6431092d478f 100644 --- a/src/webgpu/shader/execution/shader_io/fragment_builtins.spec.ts +++ b/src/webgpu/shader/execution/shader_io/fragment_builtins.spec.ts @@ -1578,23 +1578,6 @@ fn vsMain(@builtin(vertex_index) index : u32) -> @builtin(position) vec4f { const byteLength = bytesPerRow * blocksPerColumn; const uintLength = byteLength / 4; - const buffer = t.makeBufferWithContents( - new Uint32Array([1]), - GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST - ); - - const bg = t.device.createBindGroup({ - layout: pipeline.getBindGroupLayout(0), - entries: [ - { - binding: 0, - resource: { - buffer, - }, - }, - ], - }); - for (let i = 0; i < 2; i++) { const framebuffer = t.createTextureTracked({ size: [width, height], @@ -1617,8 +1600,8 @@ fn vsMain(@builtin(vertex_index) index : u32) -> @builtin(position) vec4f { ], }); pass.setPipeline(pipeline); - pass.setBindGroup(0, bg); - pass.draw(3, 1, i); + // Draw the uperr-left triangle (vertices 0-2) or the lower-right triangle (vertices 3-5) + pass.draw(3, 1, i * 3); pass.end(); t.queue.submit([encoder.finish()]); @@ -1659,15 +1642,11 @@ enable subgroups; const width = ${t.params.size[0]}; const height = ${t.params.size[1]}; -@group(0) @binding(0) var for_layout : u32; - @fragment fn fsMain( @builtin(position) pos : vec4f, @builtin(subgroup_size) sg_size : u32, ) -> @location(0) vec4u { - _ = for_layout; - let ballot = countOneBits(subgroupBallot(true)); let ballotSize = ballot.x + ballot.y + ballot.z + ballot.w; @@ -1699,17 +1678,23 @@ fn fsMain( ); }); +// A non-zero magic number indicating no expectation error, in order to prevent the false no-error +// result from zero-initialization. +const kSubgroupInvocationIdNoError = 17; + /** * Checks subgroup_invocation_id value consistency * * Very little uniformity is expected for subgroup_invocation_id. * This function checks that all ids are less than the subgroup size - * and no id is repeated. + * (not the ballot size, since the subgroup id can be allocated to + * inactivate invocations between active ones) and no id is repeated. * @param data An array of vec4u that contains (per texel): * * subgroup_invocation_id - * * ballot size - * * non-zero ID unique to each subgroup - * * 0 + * * subgroup size + * * ballot active invocation number + * * error flag, should be equal to kSubgroupInvocationIdNoError or shader found + * expection failed otherwise. * @param format The texture format of data * @param width The width of the framebuffer * @param height The height of the framebuffer @@ -1726,31 +1711,44 @@ function checkSubgroupInvocationIdConsistency( const uintsPerRow = bytesPerRow / 4; const uintsPerTexel = (bytesPerBlock ?? 1) / blockWidth / blockHeight / 4; - const mappings = new Map(); for (let row = 0; row < height; row++) { for (let col = 0; col < width; col++) { const offset = uintsPerRow * row + col * uintsPerTexel; const id = data[offset]; - const size = data[offset + 1]; - const repId = data[offset + 2]; - - if (repId === 0) { + const sgSize = data[offset + 1]; + const ballotSize = data[offset + 2]; + const error = data[offset + 3]; + + if (error === 0) { + // Inactive fragment get error `0` instead of noError. Check all output being zero. + if (id !== 0 || sgSize !== 0 || ballotSize !== 0) { + return new Error( + `Unexpected zero error with non-zero outputs for (${row}, ${col}): got output [${id}, ${sgSize}, ${ballotSize}, ${error}]` + ); + } continue; } - if (size < id) { + if (sgSize < id) { return new Error( - `Invocation id '${id}' is greater than subgroup size '${size}' for (${row}, ${col})` + `Invocation id '${id}' is greater than subgroup size '${sgSize}' for (${row}, ${col})` ); } - let v = mappings.get(repId) ?? 0n; - const mask = 1n << BigInt(id); - if ((mask & v) !== 0n) { - return new Error(`Multiple invocations with id '${id}' in subgroup '${repId}'`); + if (sgSize < ballotSize) { + return new Error( + `Ballot size '${ballotSize}' is greater than subgroup size '${sgSize}' for (${row}, ${col})` + ); + } + + if (error !== kSubgroupInvocationIdNoError) { + return new Error( + `Unexpected error value +- icoord: (${row}, ${col}) +- expected: noError (${kSubgroupInvocationIdNoError}) +- got: ${error}` + ); } - v |= mask; - mappings.set(repId, v); } } @@ -1775,7 +1773,10 @@ enable subgroups; const width = ${t.params.size[0]}; const height = ${t.params.size[1]}; -@group(0) @binding(0) var counter : atomic; +const maxSubgroupSize = 128u; +// A non-zero magic number indicating no expectation error, in order to prevent the +// false no-error result from zero-initialization. +const noError = ${kSubgroupInvocationIdNoError}u; @fragment fn fsMain( @@ -1783,14 +1784,40 @@ fn fsMain( @builtin(subgroup_invocation_id) id : u32, @builtin(subgroup_size) sg_size : u32, ) -> @location(0) vec4u { - let ballot = countOneBits(subgroupBallot(true)); - let ballotSize = ballot.x + ballot.y + ballot.z + ballot.w; - // Generate representative id for this subgroup. - var repId = atomicAdd(&counter, 1); - repId = subgroupBroadcast(repId, 0); + var error: u32 = noError; + + // Validate that reported subgroup size is no larger than maxSubgroupSize + if (sg_size > maxSubgroupSize) { + error++; + } + + // Validate that reported subgroup invocation id is smaller than subgroup size + if (id >= sg_size) { + error++; + } + + // Validate that each subgroup id is assigned to at most one active invocation + // in the subgroup + var countAssignedId: u32 = 0u; + for (var i: u32 = 0; i < maxSubgroupSize; i++) { + let ballotIdEqualsI = countOneBits(subgroupBallot(id == i)); + let countInvocationIdEqualsI = ballotIdEqualsI.x + ballotIdEqualsI.y + ballotIdEqualsI.z + ballotIdEqualsI.w; + // Validate an id assigned at most once + error += select(1u, 0u, countInvocationIdEqualsI <= 1); + // Validate id larger than subgroup size will not get balloted + error += select(1u, 0u, (id < sg_size) || (countInvocationIdEqualsI == 0)); + // Sum up the assigned invocation number of each id + countAssignedId += countInvocationIdEqualsI; + } + // Validate that all active invocation get counted during the above loop + let ballotActive = countOneBits(subgroupBallot(true)); + let activeInvocations = ballotActive.x + ballotActive.y + ballotActive.z + ballotActive.w; + if (activeInvocations != countAssignedId) { + error++; + } - return vec4u(id, ballotSize, repId, 0); + return vec4u(id, sg_size, activeInvocations, error); }`; await runSubgroupTest(