Skip to content

Commit

Permalink
Refactor derivative tests for 0 storage buffers.
Browse files Browse the repository at this point in the history
  • Loading branch information
greggman committed Dec 13, 2024
1 parent 726f4dd commit 3c9d1cc
Showing 1 changed file with 87 additions and 61 deletions.
148 changes: 87 additions & 61 deletions src/webgpu/shader/execution/expression/call/builtin/derivatives.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { GPUTest } from '../../../../../gpu_test.js';
import { Type, Value } from '../../../../../util/conversion.js';
import { align } from '../../../../../util/math.js';
import { Case } from '../../case.js';
import { toComparator } from '../../expectation.js';
import { packScalarsToVector } from '../../expression.js';
Expand Down Expand Up @@ -29,16 +30,17 @@ export function runDerivativeTest(

////////////////////////////////////////////////////////////////
// The two input values for a given case are distributed to two different invocations in a quad.
// We will populate a storage buffer with these input values laid out sequentially:
// We will populate a uniform buffer with these input values laid out sequentially:
// [ case_0_input_1, case_0_input_0, case_1_input_1, case_1_input_0, ...]
//
// The render pipeline will be launched several times over a viewport size of (2, 2). Each draw
// call will execute a single quad (four fragment invocation), which will exercise two test cases.
// Each of these draw calls will use a different instance index, which is forwarded to the
// fragment shader. Each invocation will determine its index into the storage buffer using its
// fragment position and the instance index for that draw call.
// The render pipeline will be launched once per pixel per pair of cases over
// a viewport size of (2, 2) with the viewport set to cover 1 pixel.
// Each 2x2 set of calls will will exercise two test cases. Each of these
// draw calls will use a different instance index, which is forwarded to the
// fragment shader. Each invocation returns the result which is stored in
// a rgba32uint texture.
//
// Consider two draw calls that test 4 cases (c_0, c_1, c_2, c_3).
// Consider draw calls that test 4 cases (c_0, c_1, c_2, c_3).
//
// For derivatives along the 'x' direction, the mapping from fragment position to case input is:
// Quad 0: | c_0_i_1 | c_0_i_0 | Quad 1: | c_2_i_1 | c_2_i_0 |
Expand All @@ -54,13 +56,23 @@ export function runDerivativeTest(
const dir = builtin[3];

// Determine the WGSL type to use in the shader, and the stride in bytes between values.
let valueStride = 4;
let wgslType = 'f32';
const valueStride = 16;
let conversionFromInput = 'input.x';
let conversionToOutput = `vec4f(v)`;
if (vectorize) {
wgslType = `vec${vectorize}f`;
valueStride = vectorize * 4;
if (vectorize === 3) {
valueStride = 16;
switch (vectorize) {
case 2:
conversionFromInput = 'input.xy';
conversionToOutput = 'vec4f(v, 0, 0)';
break;
case 3:
conversionFromInput = 'input.xyz';
conversionToOutput = 'vec4f(v, 0)';
break;
case 4:
conversionFromInput = 'input';
conversionToOutput = 'v';
break;
}
}

Expand All @@ -84,17 +96,17 @@ fn vert(@builtin(vertex_index) vertex_idx: u32,
return CaseInfo(vec4(kVertices[vertex_idx], 0, 1), instance_idx);
}
@group(0) @binding(0) var<storage, read> inputs : array<${wgslType}>;
@group(0) @binding(1) var<storage, read_write> outputs : array<${wgslType}>;
@group(0) @binding(0) var<uniform> inputs : array<vec4f, ${cases.length * 2}>;
@fragment
fn frag(info : CaseInfo) {
fn frag(info : CaseInfo) -> @location(0) vec4u {
let case_idx = u32(info.position.${dir === 'x' ? 'y' : 'x'});
let inv_idx = u32(info.position.${dir});
let index = info.quad_idx*4 + case_idx*2 + inv_idx;
let input = inputs[index];
${non_uniform_discard ? 'if inv_idx == 0 { discard; }' : ''}
outputs[index] = ${builtin}(input);
let v = ${builtin}(${conversionFromInput});
return bitcast<vec4u>(${conversionToOutput});
}
`;

Expand All @@ -103,22 +115,18 @@ fn frag(info : CaseInfo) {
const pipeline = t.device.createRenderPipeline({
layout: 'auto',
vertex: { module },
fragment: { module, targets: [{ format: 'rgba8unorm', writeMask: 0 }] },
fragment: { module, targets: [{ format: 'rgba32uint' }] },
});

// Create storage buffers to hold the inputs and outputs.
const bufferSize = cases.length * 2 * valueStride;
const inputBuffer = t.createBufferTracked({
size: bufferSize,
usage: GPUBufferUsage.STORAGE,
usage: GPUBufferUsage.UNIFORM,
mappedAtCreation: true,
});
const outputBuffer = t.createBufferTracked({
size: bufferSize,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
});

// Populate the input storage buffer with case input values.
// Populate the input uniform buffer with case input values.
const valuesData = new Uint8Array(inputBuffer.getMappedRange());
for (let i = 0; i < cases.length; i++) {
const inputs = cases[i].input as ReadonlyArray<Value>;
Expand All @@ -129,61 +137,79 @@ fn frag(info : CaseInfo) {

// Create a bind group for the storage buffers.
const group = t.device.createBindGroup({
entries: [
{ binding: 0, resource: { buffer: inputBuffer } },
{ binding: 1, resource: { buffer: outputBuffer } },
],
entries: [{ binding: 0, resource: { buffer: inputBuffer } }],
layout: pipeline.getBindGroupLayout(0),
});

// Create a texture to use as a color attachment.
// We only need this for launching the desired number of fragment invocations.
const colorAttachment = t.createTextureTracked({
size: { width: 2, height: 2 },
format: 'rgba8unorm',
usage: GPUTextureUsage.RENDER_ATTACHMENT,
format: 'rgba32uint',
usage: GPUTextureUsage.RENDER_ATTACHMENT | GPUTextureUsage.COPY_SRC,
});
const bytesPerRow = align(valueStride * colorAttachment.width, 256);

// Submit the render pass to the device.
const results = [];
const encoder = t.device.createCommandEncoder();
const pass = encoder.beginRenderPass({
colorAttachments: [
{
view: colorAttachment.createView(),
loadOp: 'clear',
storeOp: 'discard',
},
],
});
pass.setPipeline(pipeline);
pass.setBindGroup(0, group);
for (let quad = 0; quad < cases.length / 2; quad++) {
pass.draw(3, 1, undefined, quad);
const pass = encoder.beginRenderPass({
colorAttachments: [
{
view: colorAttachment.createView(),
loadOp: 'clear',
storeOp: 'store',
},
],
});
pass.setPipeline(pipeline);
pass.setBindGroup(0, group);
for (let y = 0; y < colorAttachment.height; ++y) {
for (let x = 0; x < colorAttachment.width; ++x) {
pass.setViewport(x, y, 1, 1, 0, 1);
pass.draw(3, 1, 0, quad);
}
}
pass.end();
const outputBuffer = t.createBufferTracked({
size: bytesPerRow * colorAttachment.height,
usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC,
});
results.push(outputBuffer);
encoder.copyTextureToBuffer(
{ texture: colorAttachment },
{ buffer: outputBuffer, bytesPerRow },
[colorAttachment.width, colorAttachment.height]
);
}
pass.end();

t.queue.submit([encoder.finish()]);

// Check the outputs match the expected results.
t.expectGPUBufferValuesPassCheck(
outputBuffer,
(outputData: Uint8Array) => {
for (let i = 0; i < cases.length; i++) {
const c = cases[i];

// Both invocations involved in the derivative should get the same result.
for (let d = 0; d < 2; d++) {
if (non_uniform_discard && d === 0) {
results.forEach((outputBuffer, quadNdx) => {
t.expectGPUBufferValuesPassCheck(
outputBuffer,
(outputData: Uint8Array) => {
for (let i = 0; i < 4; ++i) {
const tx = i % 2;
const ty = (i / 2) | 0;
const [inputNdx, caseNdx] = dir === 'x' ? [tx, ty] : [ty, tx];
const c = cases[quadNdx * 2 + caseNdx];

// Both invocations involved in the derivative should get the same result.
if (non_uniform_discard && inputNdx === 0) {
continue;
}

const index = (i * 2 + d) * valueStride;
const index = ty * bytesPerRow + tx * valueStride;
const result = type.read(outputData, index);
const cmp = toComparator(c.expected).compare(result);
if (!cmp.matched) {
// If this is a coarse derivative, the implementation is also allowed to calculate only
// one of the two derivatives and return that result to all of the invocations.
if (!builtin.endsWith('Fine')) {
const c0 = cases[i % 2 === 0 ? i + 1 : i - 1];
const c0 = cases[inputNdx];
const cmp0 = toComparator(c0.expected).compare(result);
if (!cmp0.matched) {
return new Error(`
Expand All @@ -204,12 +230,12 @@ fn frag(info : CaseInfo) {
}
}
}
return undefined;
},
{
type: Uint8Array,
typedLength: outputBuffer.size,
}
return undefined;
},
{
type: Uint8Array,
typedLength: bufferSize,
}
);
);
});
}

0 comments on commit 3c9d1cc

Please sign in to comment.