Skip to content

Commit

Permalink
subgroupBallot tests in fragment shaders (#4068)
Browse files Browse the repository at this point in the history
* Tests subgroupBallot in fragment shaders
  • Loading branch information
alan-baker authored Nov 29, 2024
1 parent ecb8816 commit 5fd787c
Show file tree
Hide file tree
Showing 2 changed files with 318 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@ local_invocation_index. Tests should avoid assuming there is.

import { makeTestGroup } from '../../../../../../common/framework/test_group.js';
import { keysOf } from '../../../../../../common/util/data_tables.js';
import { iterRange } from '../../../../../../common/util/util.js';
import { iterRange, assert } from '../../../../../../common/util/util.js';
import { kTextureFormatInfo } from '../../../../../format_info.js';
import { GPUTest } from '../../../../../gpu_test.js';
import { align } from '../../../../../util/math.js';

export const g = makeTestGroup(GPUTest);
import { SubgroupTest, kFramebufferSizes, getUintsPerFramebuffer } from './subgroup_util.js';

export const g = makeTestGroup(SubgroupTest);

// 128 is the maximum possible subgroup size.
const kInvocations = 128;
Expand Down Expand Up @@ -344,4 +348,315 @@ fn main(@builtin(subgroup_size) subgroupSize : u32,
await runTest(t, wgsl, testcase.filter, testcase.expect, false);
});

g.test('fragment').unimplemented();
// Filters should always skip the last row and column.
const kFragmentPredicates = {
odd_row: {
cond: `u32(pos.y) % 2 == 1`,
filter: (row: number, col: number, width: number, height: number, id: number, size: number) => {
if (row === height - 1 || col === width - 1) {
return false;
}
return row % 2 === 1;
},
},
even_row: {
cond: `u32(pos.y) % 2 == 0`,
filter: (row: number, col: number, width: number, height: number, id: number, size: number) => {
if (row === height - 1 || col === width - 1) {
return false;
}
return row % 2 === 0;
},
},
odd_col: {
cond: `u32(pos.x) % 2 == 1`,
filter: (row: number, col: number, width: number, height: number, id: number, size: number) => {
if (row === height - 1 || col === width - 1) {
return false;
}
return col % 2 === 1;
},
},
even_col: {
cond: `u32(pos.x) % 2 == 0`,
filter: (row: number, col: number, width: number, height: number, id: number, size: number) => {
if (row === height - 1 || col === width - 1) {
return false;
}
return col % 2 === 0;
},
},
odd_id: {
cond: `id % 2 == 1`,
filter: (row: number, col: number, width: number, height: number, id: number, size: number) => {
if (row === height - 1 || col === width - 1) {
return false;
}
return id % 2 === 1;
},
},
even_id: {
cond: `id % 2 == 0`,
filter: (row: number, col: number, width: number, height: number, id: number, size: number) => {
if (row === height - 1 || col === width - 1) {
return false;
}
return id % 2 === 0;
},
},
upper_half: {
cond: `id > subgroupSize / 2`,
filter: (row: number, col: number, width: number, height: number, id: number, size: number) => {
if (row === height - 1 || col === width - 1) {
return false;
}
return id > Math.floor(size / 2);
},
},
lower_half: {
cond: `id < subgroupSize / 2`,
filter: (row: number, col: number, width: number, height: number, id: number, size: number) => {
if (row === height - 1 || col === width - 1) {
return false;
}
return id < Math.floor(size / 2);
},
},
first_two_or_diagonal: {
cond: `id == 0 || id == 1 || u32(pos.x) == u32(pos.y)`,
filter: (row: number, col: number, width: number, height: number, id: number, size: number) => {
if (row === height - 1 || col === width - 1) {
return false;
}
return id === 0 || id === 1 || row === col;
},
},
};

/**
* Checks the result of subgroupBallot in fragment shaders.
*
* Extra bits are allowed in ballots due to helpers, but results must be consistent
* among invocations known to be good.
* @param ballots Framebuffer of ballot results
* @param metadata Framebuffer of metadata
* * component 0 is subgroup_invocation_id
* * component 1 is subgroup_size
* * component 2 is a unique, generated subgroup id
* @param format The framebuffer format
* @param width The framebuffer width
* @param height The framebuffer height
* @param filter A functor that returns true if the invocation should be included in the ballot
*/
function checkFragmentBallots(
ballots: Uint32Array,
metadata: Uint32Array,
format: GPUTextureFormat,
width: number,
height: number,
filter: (
row: number,
col: number,
width: number,
height: number,
id: number,
size: number
) => boolean
): Error | undefined {
if (width < 3 || height < 3) {
return new Error(
`Insufficient framebuffer size [${width}w x ${height}h]. Minimum is [3w x 3h].`
);
}

const { uintsPerRow, uintsPerTexel } = getUintsPerFramebuffer(format, width, height);

const coordToIndex = (row: number, col: number) => {
return uintsPerRow * row + col * uintsPerTexel;
};

const mapping = new Map<number, bigint>();

// Iteration skips last row and column to avoid helper invocations because it is not
// guaranteed whether or not they participate in the subgroup operation.
for (let row = 0; row < height - 1; row++) {
for (let col = 0; col < width - 1; col++) {
const offset = coordToIndex(row, col);

const id = metadata[offset];
const subgroupSize = metadata[offset + 1];
const subgroupId = metadata[offset + 2];

let ballot = BigInt(ballots[offset]);
ballot |= BigInt(ballots[offset + 1]) << 32n;
ballot |= BigInt(ballots[offset + 2]) << 64n;
ballot |= BigInt(ballots[offset + 3]) << 96n;

const expectBit = filter(row, col, width, height, id, subgroupSize) ? 1n : 0n;
const gotBit = (ballot >> BigInt(id)) & 1n;

if (expectBit !== gotBit) {
return new Error(`Row ${row}, col ${col}: incorrect ballot bit ${id}:
- expected: ${expectBit.toString(10)}
- got: ${gotBit.toString(10)}`);
}

const expected = mapping.get(subgroupId);
if (expected === undefined) {
mapping.set(subgroupId, ballot);
} else {
if (expected !== ballot) {
return new Error(`Row ${row} col ${col}: ballot mismatch:
- expected: ${expected.toString(16)}
- got: ${ballot.toString(16)}`);
}
}
}
}

return undefined;
}

g.test('fragment')
.desc('Tests subgroupBallot in fragment shaders')
.params(u =>
u
.combine('predicate', keysOf(kFragmentPredicates))
.beginSubcases()
.combine('size', kFramebufferSizes)
.combineWithParams([{ format: 'rgba32uint' }] as const)
)
.beforeAllSubcases(t => {
t.selectDeviceOrSkipTestCase('subgroups' as GPUFeatureName);
})
.fn(async t => {
const width = t.params.size[0];
const height = t.params.size[1];
const testcase = kFragmentPredicates[t.params.predicate];

const fsShader = `
enable subgroups;
struct FSOutput {
@location(0) ballot : vec4u,
@location(1) metadata : vec4u,
}
@fragment
fn main(
@builtin(position) pos : vec4f,
@builtin(subgroup_size) subgroupSize : u32,
@builtin(subgroup_invocation_id) id : u32,
) -> FSOutput {
let linear = u32(pos.x) + u32(pos.y) * ${width};
let subgroup_id = subgroupBroadcastFirst(linear + 1);
// Filter out possible helper invocations.
let x_in_range = u32(pos.x) < (${width} - 1);
let y_in_range = u32(pos.y) < (${height} - 1);
let in_range = x_in_range && y_in_range;
let cond = ${testcase.cond};
let ballot = subgroupBallot(in_range && cond);
var out : FSOutput;
out.ballot = ballot;
out.metadata = vec4u(id, subgroupSize, subgroup_id, 0);
return out;
}`;

const vsShader = `
@vertex
fn vsMain(@builtin(vertex_index) index : u32) -> @builtin(position) vec4f {
const vertices = array(
vec2(-2, 4), vec2(-2, -4), vec2(2, 0),
);
return vec4f(vec2f(vertices[index]), 0, 1);
}`;

const pipeline = t.device.createRenderPipeline({
layout: 'auto',
vertex: {
module: t.device.createShaderModule({ code: vsShader }),
},
fragment: {
module: t.device.createShaderModule({ code: fsShader }),
targets: [{ format: t.params.format }, { format: t.params.format }],
},
primitive: {
topology: 'triangle-list',
},
});

const { blockWidth, blockHeight, bytesPerBlock } = kTextureFormatInfo[t.params.format];
assert(bytesPerBlock !== undefined);

const blocksPerRow = width / blockWidth;
const blocksPerColumn = height / blockHeight;
// 256 minimum arises from image copy requirements.
const bytesPerRow = align(blocksPerRow * (bytesPerBlock ?? 1), 256);
const byteLength = bytesPerRow * blocksPerColumn;
const uintLength = byteLength / 4;

const ballotFB = t.createTextureTracked({
size: [width, height],
usage:
GPUTextureUsage.COPY_SRC |
GPUTextureUsage.COPY_DST |
GPUTextureUsage.RENDER_ATTACHMENT |
GPUTextureUsage.TEXTURE_BINDING,
format: t.params.format,
});

const metadataFB = t.createTextureTracked({
size: [width, height],
usage:
GPUTextureUsage.COPY_SRC |
GPUTextureUsage.COPY_DST |
GPUTextureUsage.RENDER_ATTACHMENT |
GPUTextureUsage.TEXTURE_BINDING,
format: t.params.format,
});

const encoder = t.device.createCommandEncoder();
const pass = encoder.beginRenderPass({
colorAttachments: [
{
view: ballotFB.createView(),
loadOp: 'clear',
storeOp: 'store',
},
{
view: metadataFB.createView(),
loadOp: 'clear',
storeOp: 'store',
},
],
});
pass.setPipeline(pipeline);
pass.draw(3);
pass.end();
t.queue.submit([encoder.finish()]);

const ballotBuffer = t.copyWholeTextureToNewBufferSimple(ballotFB, 0);
const ballotReadback = await t.readGPUBufferRangeTyped(ballotBuffer, {
srcByteOffset: 0,
type: Uint32Array,
typedLength: uintLength,
method: 'copy',
});
const ballots: Uint32Array = ballotReadback.data;

const metadataBuffer = t.copyWholeTextureToNewBufferSimple(metadataFB, 0);
const metadataReadback = await t.readGPUBufferRangeTyped(metadataBuffer, {
srcByteOffset: 0,
type: Uint32Array,
typedLength: uintLength,
method: 'copy',
});
const metadata: Uint32Array = metadataReadback.data;

t.expectOK(
checkFragmentBallots(ballots, metadata, t.params.format, width, height, testcase.filter)
);
});
Original file line number Diff line number Diff line change
Expand Up @@ -890,8 +890,6 @@ g.test('fragment')
t.selectDeviceOrSkipTestCase('subgroups' as GPUFeatureName);
})
.fn(async t => {
//t.skipIf(t.params.id !== 2);
//t.skipIf(t.params.op !== 'subgroupShuffleUp');
const fsShader = `
enable subgroups;
Expand Down

0 comments on commit 5fd787c

Please sign in to comment.