Skip to content

Commit

Permalink
Compat: stop using storage buffer when reading texture mix weights (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
greggman authored Nov 28, 2024
1 parent 51b744d commit 68633ee
Showing 1 changed file with 45 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,15 @@ const builtinNeedsMipLevelWeights = (builtin: TextureBuiltin) =>
/**
* Splits in array into multiple arrays where every Nth value goes to a different array
*/
function unzip<T>(array: T[], num: number) {
function unzip<T>(array: T[], num: number, srcStride?: number) {
srcStride = srcStride === undefined ? num : srcStride;
const arrays: T[][] = range(num, () => []);
array.forEach((v, i) => {
arrays[i % num].push(v);
});
const numEntries = Math.ceil(array.length / srcStride);
for (let i = 0; i < numEntries; ++i) {
for (let j = 0; j < num; ++j) {
arrays[j].push(array[i * srcStride + j]);
}
}
return arrays;
}

Expand Down Expand Up @@ -395,11 +399,12 @@ ${graphWeights(32, weights)}
export async function queryMipLevelMixWeightsForDevice(t: GPUTest, stage: ShaderStage) {
const { device } = t;
const kNumWeightTypes = 2;
assert(kNumWeightTypes <= 4);
const module = device.createShaderModule({
code: `
@group(0) @binding(0) var tex: texture_2d<f32>;
@group(0) @binding(1) var smp: sampler;
@group(0) @binding(2) var<storage, read_write> result: array<f32>;
@group(0) @binding(2) var<storage, read_write> result: array<vec4f>;
struct VSOutput {
@builtin(position) pos: vec4f,
Expand All @@ -419,13 +424,6 @@ export async function queryMipLevelMixWeightsForDevice(t: GPUTest, stage: Shader
0);
}
fn recordMixLevels(wNdx: u32, r: vec4f) {
let ndx = wNdx * ${kNumWeightTypes};
for (var i: u32 = 0; i < ${kNumWeightTypes}; i++) {
result[ndx + i] = r[i];
}
}
fn getPosition(vNdx: u32) -> vec4f {
let pos = array(
vec2f(-1, 3),
Expand All @@ -436,26 +434,30 @@ export async function queryMipLevelMixWeightsForDevice(t: GPUTest, stage: Shader
return vec4f(p, 0, 1);
}
// -- for getting fragment stage weights --
@vertex fn vs(@builtin(vertex_index) vNdx: u32, @builtin(instance_index) iNdx: u32) -> VSOutput {
return VSOutput(getPosition(vNdx), iNdx, vec4f(0));
}
@fragment fn fsRecord(v: VSOutput) -> @location(0) vec4f {
recordMixLevels(v.ndx, getMixLevels(v.ndx));
return vec4f(0);
@fragment fn fsRecord(v: VSOutput) -> @location(0) vec4u {
return bitcast<vec4u>(getMixLevels(v.ndx));
}
// -- for getting compute stage weights --
@compute @workgroup_size(1) fn csRecord(@builtin(global_invocation_id) id: vec3u) {
recordMixLevels(id.x, getMixLevels(id.x));
result[id.x] = getMixLevels(id.x);
}
// -- for getting vertex stage weights --
@vertex fn vsRecord(@builtin(vertex_index) vNdx: u32, @builtin(instance_index) iNdx: u32) -> VSOutput {
return VSOutput(getPosition(vNdx), iNdx, getMixLevels(iNdx));
}
@fragment fn fsSaveVs(v: VSOutput) -> @location(0) vec4f {
recordMixLevels(v.ndx, v.result);
return vec4f(0);
@fragment fn fsSaveVs(v: VSOutput) -> @location(0) vec4u {
return bitcast<vec4u>(v.result);
}
`,
});
Expand All @@ -481,18 +483,18 @@ export async function queryMipLevelMixWeightsForDevice(t: GPUTest, stage: Shader
});

const target = t.createTextureTracked({
size: [1, 1],
format: 'rgba8unorm',
usage: GPUTextureUsage.RENDER_ATTACHMENT,
size: [kMipLevelWeightSteps + 1, 1],
format: 'rgba32uint',
usage: GPUTextureUsage.RENDER_ATTACHMENT | GPUTextureUsage.COPY_SRC,
});

const storageBuffer = t.createBufferTracked({
size: 4 * (kMipLevelWeightSteps + 1) * kNumWeightTypes,
size: 4 * 4 * (kMipLevelWeightSteps + 1),
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
});

const resultBuffer = t.createBufferTracked({
size: storageBuffer.size,
size: align(storageBuffer.size, 256), // padded for copyTextureToBuffer
usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
});

Expand All @@ -502,7 +504,7 @@ export async function queryMipLevelMixWeightsForDevice(t: GPUTest, stage: Shader
entries: [
{ binding: 0, resource: texture.createView() },
{ binding: 1, resource: sampler },
{ binding: 2, resource: { buffer: storageBuffer } },
...(stage === 'compute' ? [{ binding: 2, resource: { buffer: storageBuffer } }] : []),
],
});

Expand All @@ -518,13 +520,14 @@ export async function queryMipLevelMixWeightsForDevice(t: GPUTest, stage: Shader
pass.setBindGroup(0, createBindGroup(pipeline));
pass.dispatchWorkgroups(kMipLevelWeightSteps + 1);
pass.end();
encoder.copyBufferToBuffer(storageBuffer, 0, resultBuffer, 0, resultBuffer.size);
break;
}
case 'fragment': {
const pipeline = device.createRenderPipeline({
layout: 'auto',
vertex: { module, entryPoint: 'vs' },
fragment: { module, entryPoint: 'fsRecord', targets: [{ format: 'rgba8unorm' }] },
fragment: { module, entryPoint: 'fsRecord', targets: [{ format: 'rgba32uint' }] },
});
const pass = encoder.beginRenderPass({
colorAttachments: [
Expand All @@ -537,15 +540,19 @@ export async function queryMipLevelMixWeightsForDevice(t: GPUTest, stage: Shader
});
pass.setPipeline(pipeline);
pass.setBindGroup(0, createBindGroup(pipeline));
pass.draw(3, kMipLevelWeightSteps + 1);
for (let x = 0; x <= kMipLevelWeightSteps; ++x) {
pass.setViewport(x, 0, 1, 1, 0, 1);
pass.draw(3, 1, 0, x);
}
pass.end();
encoder.copyTextureToBuffer({ texture: target }, { buffer: resultBuffer }, [target.width]);
break;
}
case 'vertex': {
const pipeline = device.createRenderPipeline({
layout: 'auto',
vertex: { module, entryPoint: 'vsRecord' },
fragment: { module, entryPoint: 'fsSaveVs', targets: [{ format: 'rgba8unorm' }] },
fragment: { module, entryPoint: 'fsSaveVs', targets: [{ format: 'rgba32uint' }] },
});
const pass = encoder.beginRenderPass({
colorAttachments: [
Expand All @@ -558,20 +565,26 @@ export async function queryMipLevelMixWeightsForDevice(t: GPUTest, stage: Shader
});
pass.setPipeline(pipeline);
pass.setBindGroup(0, createBindGroup(pipeline));
pass.draw(3, kMipLevelWeightSteps + 1);
for (let x = 0; x <= kMipLevelWeightSteps; ++x) {
pass.setViewport(x, 0, 1, 1, 0, 1);
pass.draw(3, 1, 0, x);
}
pass.end();
encoder.copyTextureToBuffer({ texture: target }, { buffer: resultBuffer }, [target.width]);
break;
}
}
encoder.copyBufferToBuffer(storageBuffer, 0, resultBuffer, 0, resultBuffer.size);
device.queue.submit([encoder.finish()]);

await resultBuffer.mapAsync(GPUMapMode.READ);
const result = Array.from(new Float32Array(resultBuffer.getMappedRange()));
// need to map a sub-portion since we may have padded the buffer.
const result = Array.from(
new Float32Array(resultBuffer.getMappedRange(0, (kMipLevelWeightSteps + 1) * 16))
);
resultBuffer.unmap();
resultBuffer.destroy();

const [sampleLevelWeights, gradWeights] = unzip(result, kNumWeightTypes);
const [sampleLevelWeights, gradWeights] = unzip(result, kNumWeightTypes, 4);

validateWeights(stage, sampleLevelWeights);
validateWeights(stage, gradWeights);
Expand Down

0 comments on commit 68633ee

Please sign in to comment.