From d473d09475bffec9569fe5c45834bb6aaad44818 Mon Sep 17 00:00:00 2001 From: Samson <16504129+sagudev@users.noreply.github.com> Date: Thu, 24 Oct 2024 13:06:39 +0200 Subject: [PATCH] Test indexing of a matrix using non-const index (#3982) * Test indexing of a matrix using non-const index Signed-off-by: sagudev <16504129+sagudev@users.noreply.github.com> * fixup Signed-off-by: sagudev <16504129+sagudev@users.noreply.github.com> --------- Signed-off-by: sagudev <16504129+sagudev@users.noreply.github.com> Co-authored-by: Corentin Wallez --- .../expression/access/matrix/index.spec.ts | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/src/webgpu/shader/execution/expression/access/matrix/index.spec.ts b/src/webgpu/shader/execution/expression/access/matrix/index.spec.ts index f6fd05b46fcb..b8872eeab99f 100644 --- a/src/webgpu/shader/execution/expression/access/matrix/index.spec.ts +++ b/src/webgpu/shader/execution/expression/access/matrix/index.spec.ts @@ -11,7 +11,9 @@ import { abstractFloat, f32, vec, + Value, } from '../../../../../util/conversion.js'; +import { align } from '../../../../../util/math.js'; import { Case } from '../../case.js'; import { allInputSources, basicExpressionBuilder, run } from '../../expression.js'; @@ -198,3 +200,73 @@ g.test('abstract_float_element') cases ); }); + +g.test('non_const_index') + .specURL('https://www.w3.org/TR/WGSL/#matrix-access-expr') + .desc(`Test indexing of a matrix using non-const index`) + .params(u => u.combine('columns', [2, 3, 4] as const).combine('rows', [2, 3, 4] as const)) + .fn(t => { + const cols = t.params.columns; + const rows = t.params.rows; + const values = Array.from(Array(cols * rows).keys()); + const wgsl = ` +@group(0) @binding(0) var output : array; + +@compute @workgroup_size(${cols}, ${rows}) +fn main(@builtin(local_invocation_id) invocation_id : vec3) { + let m = mat${cols}x${rows}f(${values.join(', ')}); + output[invocation_id.x*${rows} + invocation_id.y] = m[invocation_id.x][invocation_id.y]; +} +`; + + const pipeline = t.device.createComputePipeline({ + layout: 'auto', + compute: { + module: t.device.createShaderModule({ code: wgsl }), + entryPoint: 'main', + }, + }); + + const bufferSize = (arr: Value[]) => { + let offset = 0; + let alignment = 0; + for (const value of arr) { + alignment = Math.max(alignment, value.type.alignment); + offset = align(offset, value.type.alignment) + value.type.size; + } + return align(offset, alignment); + }; + + const toArray = (arr: Value[]) => { + const array = new Uint8Array(bufferSize(arr)); + let offset = 0; + for (const value of arr) { + offset = align(offset, value.type.alignment); + value.copyTo(array, offset); + offset += value.type.size; + } + return array; + }; + + const expected = values.map(i => Type['f32'].create(i)); + + const outputBuffer = t.createBufferTracked({ + size: bufferSize(expected), + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC, + }); + + const bindGroup = t.device.createBindGroup({ + layout: pipeline.getBindGroupLayout(0), + entries: [{ binding: 0, resource: { buffer: outputBuffer } }], + }); + + const encoder = t.device.createCommandEncoder(); + const pass = encoder.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bindGroup); + pass.dispatchWorkgroups(1); + pass.end(); + t.queue.submit([encoder.finish()]); + + t.expectGPUBufferValuesEqual(outputBuffer, toArray(expected)); + });