Skip to content

Commit

Permalink
Test indexing of a matrix using non-const index (#3982)
Browse files Browse the repository at this point in the history
* Test indexing of a matrix using non-const index

Signed-off-by: sagudev <[email protected]>

* fixup

Signed-off-by: sagudev <[email protected]>

---------

Signed-off-by: sagudev <[email protected]>
Co-authored-by: Corentin Wallez <[email protected]>
  • Loading branch information
sagudev and Kangz authored Oct 24, 2024
1 parent 44754db commit d473d09
Showing 1 changed file with 72 additions and 0 deletions.
72 changes: 72 additions & 0 deletions src/webgpu/shader/execution/expression/access/matrix/index.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ import {
abstractFloat,
f32,
vec,
Value,
} from '../../../../../util/conversion.js';
import { align } from '../../../../../util/math.js';
import { Case } from '../../case.js';
import { allInputSources, basicExpressionBuilder, run } from '../../expression.js';

Expand Down Expand Up @@ -198,3 +200,73 @@ g.test('abstract_float_element')
cases
);
});

g.test('non_const_index')
.specURL('https://www.w3.org/TR/WGSL/#matrix-access-expr')
.desc(`Test indexing of a matrix using non-const index`)
.params(u => u.combine('columns', [2, 3, 4] as const).combine('rows', [2, 3, 4] as const))
.fn(t => {
const cols = t.params.columns;
const rows = t.params.rows;
const values = Array.from(Array(cols * rows).keys());
const wgsl = `
@group(0) @binding(0) var<storage, read_write> output : array<f32, ${cols * rows}>;
@compute @workgroup_size(${cols}, ${rows})
fn main(@builtin(local_invocation_id) invocation_id : vec3<u32>) {
let m = mat${cols}x${rows}f(${values.join(', ')});
output[invocation_id.x*${rows} + invocation_id.y] = m[invocation_id.x][invocation_id.y];
}
`;

const pipeline = t.device.createComputePipeline({
layout: 'auto',
compute: {
module: t.device.createShaderModule({ code: wgsl }),
entryPoint: 'main',
},
});

const bufferSize = (arr: Value[]) => {
let offset = 0;
let alignment = 0;
for (const value of arr) {
alignment = Math.max(alignment, value.type.alignment);
offset = align(offset, value.type.alignment) + value.type.size;
}
return align(offset, alignment);
};

const toArray = (arr: Value[]) => {
const array = new Uint8Array(bufferSize(arr));
let offset = 0;
for (const value of arr) {
offset = align(offset, value.type.alignment);
value.copyTo(array, offset);
offset += value.type.size;
}
return array;
};

const expected = values.map(i => Type['f32'].create(i));

const outputBuffer = t.createBufferTracked({
size: bufferSize(expected),
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
});

const bindGroup = t.device.createBindGroup({
layout: pipeline.getBindGroupLayout(0),
entries: [{ binding: 0, resource: { buffer: outputBuffer } }],
});

const encoder = t.device.createCommandEncoder();
const pass = encoder.beginComputePass();
pass.setPipeline(pipeline);
pass.setBindGroup(0, bindGroup);
pass.dispatchWorkgroups(1);
pass.end();
t.queue.submit([encoder.finish()]);

t.expectGPUBufferValuesEqual(outputBuffer, toArray(expected));
});

0 comments on commit d473d09

Please sign in to comment.