Skip to content

Commit 98f6c26

Browse files
axingingqjia7
andauthored
[webgpu] Add depthwise3x3/argminmax/reduce/resize_nearest_neighbor WGSL support (#5535)
* [webgpu] Add argminmax/depthwise3x3/reduce/resize WGSL support * Fix u32 usage * Fix glsl compile error Co-authored-by: Jiajia Qin <[email protected]>
1 parent 6ceb1c9 commit 98f6c26

6 files changed

+402
-7
lines changed

tfjs-backend-webgpu/src/kernels/argminmax_webgpu.ts

+135-1
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
import {backend_util, util} from '@tensorflow/tfjs-core';
1919

2020
import {getCoordsDataType} from '../shader_preprocessor';
21+
import {getCoordsDataTypeWgsl, getGlobalIndexStringWgsl, getMainHeaderStringWgsl} from '../shader_preprocessor_wgsl';
2122
import {computeDispatch} from '../webgpu_util';
2223

23-
import {WebGPUProgram} from './webgpu_program';
24+
import {getUseWgsl, WebGPUProgram} from './webgpu_program';
2425

2526
export class ArgMinMaxProgram implements WebGPUProgram {
2627
outputShape: number[];
@@ -30,9 +31,11 @@ export class ArgMinMaxProgram implements WebGPUProgram {
3031
workGroupSize: [number, number, number];
3132
variableNames = ['x'];
3233
uniforms = 'int axis;';
34+
uniformsWgsl = 'axis : u32;';
3335
inputShape: number[];
3436
reductionFactor: number;
3537
op: string;
38+
useWgsl: boolean;
3639

3740
constructor(inputShape: number[], axis: number, reduceType: 'min'|'max') {
3841
const axes = [axis];
@@ -67,6 +70,7 @@ export class ArgMinMaxProgram implements WebGPUProgram {
6770

6871
this.inputShape = inputShape;
6972
this.shaderKey = `argMinMax${this.op}`;
73+
this.useWgsl = getUseWgsl();
7074
}
7175

7276
getUserCode(): string {
@@ -192,4 +196,134 @@ export class ArgMinMaxProgram implements WebGPUProgram {
192196
`;
193197
return userCode;
194198
}
199+
200+
getUserCodeWgsl(): string {
201+
// When this.workGroupSize[0] > 1, each thread reduces Length /
202+
// this.workGroupSize[0] values. Thes results are stored in shared memory
203+
// and iteratively reduced.
204+
const reduceInSharedMemory = this.workGroupSize[0] > 1;
205+
const sharedMemorySnippet = `
206+
var<workgroup> xBestIndices : array<u32, ${this.workGroupSize[0]}>;
207+
var<workgroup> xBestValues : array<f32, ${this.workGroupSize[0]}>;
208+
`;
209+
210+
const sharedMemoryReduceSnippet = `
211+
xBestIndices[localId.x] = bestIndex;
212+
xBestValues[localId.x] = bestValue;
213+
214+
for(var currentSize = WorkGroupSize; currentSize > 1u; currentSize = DIV_CEIL(currentSize, ${
215+
this.reductionFactor}u)) {
216+
workgroupBarrier();
217+
218+
for (var w = 0u; w < ${this.reductionFactor}u; w = w + 1u) {
219+
let i = localId.x * ${this.reductionFactor}u + w;
220+
if (i < currentSize) {
221+
let candidateIndex = xBestIndices[i];
222+
let candidate = xBestValues[i];
223+
if(candidate ${this.op} bestValue && !isNanCustom(candidate)) {
224+
bestValue = candidate;
225+
bestIndex = candidateIndex;
226+
}
227+
}
228+
}
229+
230+
xBestIndices[localId.x] = bestIndex;
231+
xBestValues[localId.x] = bestValue;
232+
}
233+
234+
if (localId.x == 0u) {
235+
setOutputFlatI32(flatOutputIndex, i32(bestIndex));
236+
}
237+
`;
238+
239+
const outputCoordsType = getCoordsDataTypeWgsl(this.outputShape.length);
240+
241+
const indexOutputCoords = (outputCoords: string, index: string) => {
242+
if (this.outputShape.length === 1) {
243+
return outputCoords;
244+
} else {
245+
return `${outputCoords}[${index}]`;
246+
}
247+
};
248+
249+
const indexInputShape = (index: string) => {
250+
if (this.inputShape.length === 1) {
251+
return 'uniforms.xShape';
252+
} else {
253+
return `uniforms.xShape[${index}]`;
254+
}
255+
};
256+
257+
const userCode = `
258+
fn DIV_CEIL(a : u32, b : u32) -> u32 {
259+
return ((a - 1u) / b + 1u);
260+
}
261+
262+
let WorkGroupSize = ${this.workGroupSize[0]}u;
263+
264+
${reduceInSharedMemory ? sharedMemorySnippet : ''}
265+
266+
// In order to get a flattened index into the input tensor, we need to
267+
// add back the index along the reduced dimension to |outputCoords|.
268+
// This function outputs the offset to the first value along
269+
// |axis| and the stride to get the next value of the input along |axis|.
270+
fn getInputCoordInfo(globalId : vec3<u32>, globalIndex : u32) -> vec2<u32>{
271+
let outputCoords : ${
272+
outputCoordsType} = getOutputCoords(globalId, globalIndex);
273+
var i = ${this.outputShape.length - 1}u;
274+
275+
var stride = 1u;
276+
var inputStride = 1u;
277+
var offset = 0u;
278+
279+
for (var r = 1u; r <= ${this.inputShape.length}u; r = r + 1u) {
280+
let length = ${indexInputShape(`${this.inputShape.length}u - r`)};
281+
if (${this.inputShape.length}u - r == uniforms.axis) {
282+
inputStride = stride;
283+
} else {
284+
offset = offset + ${
285+
indexOutputCoords('outputCoords', 'i')} * stride;
286+
i = i - 1u;
287+
}
288+
stride = stride * length;
289+
}
290+
291+
return vec2<u32>(offset, inputStride);
292+
}
293+
294+
fn getInputIndex(coordInfo : vec2<u32>, index : u32) -> u32{
295+
return coordInfo[0] + coordInfo[1] * index;
296+
}
297+
298+
${getMainHeaderStringWgsl(this.workGroupSize)} {
299+
${getGlobalIndexStringWgsl(this.workGroupSize)}
300+
let coordInfo = getInputCoordInfo(globalId, index);
301+
302+
var bestIndex = 0u;
303+
var bestValue = x.numbers[getInputIndex(coordInfo, bestIndex)];
304+
305+
let Length = ${indexInputShape('uniforms.axis')};
306+
let WorkPerThread = DIV_CEIL(Length, WorkGroupSize);
307+
308+
for (var w = 0u; w < WorkPerThread; w = w + 1u) {
309+
let i = globalId.x * WorkPerThread + w;
310+
if (i < Length) {
311+
let candidate = x.numbers[getInputIndex(coordInfo, i)];
312+
if (candidate ${
313+
this.op} bestValue && !isNanCustom(f32(candidate))) {
314+
bestValue = candidate;
315+
bestIndex = i;
316+
}
317+
}
318+
}
319+
320+
let flatOutputIndex = globalId.y;
321+
${
322+
reduceInSharedMemory ?
323+
sharedMemoryReduceSnippet :
324+
'setOutputFlatI32(flatOutputIndex, i32(bestIndex));'}
325+
}
326+
`;
327+
return userCode;
328+
}
195329
}

tfjs-backend-webgpu/src/kernels/depthwise_conv2d_3x3_webgpu.ts

+104-1
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@
1616
*/
1717

1818
import {backend_util, util} from '@tensorflow/tfjs-core';
19+
20+
import {getGlobalIndexStringWgsl, getMainHeaderStringWgsl} from '../shader_preprocessor_wgsl';
1921
import {computeDispatch} from '../webgpu_util';
22+
2023
import {mapActivationToShaderProgram} from './activation_util';
21-
import {WebGPUProgram} from './webgpu_program';
24+
import {getUseWgsl, WebGPUProgram} from './webgpu_program';
2225

2326
export class DepthwiseConv2D3x3Program implements WebGPUProgram {
2427
outputShape: number[];
@@ -27,12 +30,15 @@ export class DepthwiseConv2D3x3Program implements WebGPUProgram {
2730
dispatch: [number, number, number];
2831
variableNames = ['x', 'W'];
2932
uniforms = 'ivec2 pad, stride, dilation, inDims;';
33+
uniformsWgsl =
34+
'pad : vec2<u32>; stride : vec2<u32>; dilation : vec2<u32>; inDims : vec2<u32>;';
3035
workGroupSize: [number, number, number] = [4, 4, 4];
3136
convInfo: backend_util.Conv2DInfo;
3237
addBias: boolean;
3338
activation: backend_util.Activation;
3439
hasPreluActivation: boolean;
3540
isVec4 = true;
41+
useWgsl: boolean;
3642

3743
constructor(
3844
convInfo: backend_util.Conv2DInfo, addBias = false,
@@ -59,6 +65,7 @@ export class DepthwiseConv2D3x3Program implements WebGPUProgram {
5965
this.hasPreluActivation = hasPreluActivation;
6066

6167
this.shaderKey = `depthwise3x3_${activation}`;
68+
this.useWgsl = getUseWgsl();
6269
}
6370

6471
getUserCode(): string {
@@ -153,4 +160,100 @@ export class DepthwiseConv2D3x3Program implements WebGPUProgram {
153160
`;
154161
return userCode;
155162
}
163+
164+
getUserCodeWgsl(): string {
165+
let activationSnippet = '', applyActivationSnippet = '';
166+
if (this.activation) {
167+
const activationOp = mapActivationToShaderProgram(
168+
this.activation, this.isVec4, this.useWgsl);
169+
if (this.hasPreluActivation) {
170+
activationSnippet =
171+
`fn activation(a : vec4<f32>, globalId : vec3<u32>, globalIndex : u32) -> vec4<f32> {
172+
let b = getPreluActivationWeightsAtOutCoordsByGlobalId(globalId, globalIndex);
173+
${activationOp}
174+
}`;
175+
} else {
176+
activationSnippet = `
177+
fn activation(a : vec4<f32>, globalId : vec3<u32>, globalIndex : u32) -> vec4<f32> {
178+
${activationOp}
179+
}
180+
`;
181+
}
182+
183+
applyActivationSnippet =
184+
`dotProd[i] = activation(dotProd[i], globalId, index);`;
185+
}
186+
187+
const addBiasSnippet = this.addBias ?
188+
'dotProd[i] = dotProd[i] + getBiasAtOutCoordsByCoords(coords);' :
189+
'';
190+
191+
const userCode = `
192+
${activationSnippet}
193+
194+
${getMainHeaderStringWgsl(this.workGroupSize)} {
195+
${getGlobalIndexStringWgsl(this.workGroupSize)}
196+
let batch = 0u;
197+
let r = globalId.x;
198+
let c = globalId.y * 4u;
199+
let d2 = globalId.z * 4u;
200+
let xRCCorner = vec2<i32>(vec2<u32>(r, c) * uniforms.stride - uniforms.pad);
201+
let d1 = d2;
202+
let q = 0u;
203+
204+
let xRCorner = xRCCorner.x;
205+
let xCCorner = xRCCorner.y;
206+
207+
var wVals : array<vec4<f32>, 9>;
208+
wVals[0] = getW(0u, 0u, d1, q);
209+
wVals[1] = getW(0u, 1u, d1, q);
210+
wVals[2] = getW(0u, 2u, d1, q);
211+
wVals[3] = getW(1u, 0u, d1, q);
212+
wVals[4] = getW(1u, 1u, d1, q);
213+
wVals[5] = getW(1u, 2u, d1, q);
214+
wVals[6] = getW(2u, 0u, d1, q);
215+
wVals[7] = getW(2u, 1u, d1, q);
216+
wVals[8] = getW(2u, 2u, d1, q);
217+
218+
var xVals : array<array<vec4<f32>, 6>, 3>;
219+
for (var wR = 0u; wR < 3u; wR = wR + 1u) {
220+
let xR = xRCorner + i32(wR * uniforms.dilation[0]);
221+
for (var wC = 0u; wC < 6u; wC = wC + 1u) {
222+
let xC = xCCorner + i32(wC * uniforms.dilation[1]);
223+
if (xR < 0 || xR >= i32(uniforms.inDims[0]) || xC < 0 || xC >= i32(uniforms.inDims[1])) {
224+
xVals[wR][wC] = vec4<f32>(0.0);
225+
} else {
226+
xVals[wR][wC] = getX(batch, u32(xR), u32(xC), d1);
227+
}
228+
}
229+
}
230+
231+
var dotProd : array<vec4<f32>, 4>;
232+
dotProd[0] = vec4<f32>(0.0);
233+
dotProd[1] = vec4<f32>(0.0);
234+
dotProd[2] = vec4<f32>(0.0);
235+
dotProd[3] = vec4<f32>(0.0);
236+
237+
for (var wR = 0u; wR < 3u; wR = wR + 1u) {
238+
for (var wC = 0u; wC < 3u; wC = wC + 1u) {
239+
let indexW = wR * 3u + wC;
240+
dotProd[0] = dotProd[0] + xVals[wR][0u + wC] * wVals[indexW];
241+
dotProd[1] = dotProd[1] + xVals[wR][1u + wC] * wVals[indexW];
242+
dotProd[2] = dotProd[2] + xVals[wR][2u + wC] * wVals[indexW];
243+
dotProd[3] = dotProd[3] + xVals[wR][3u + wC] * wVals[indexW];
244+
}
245+
}
246+
247+
for (var i = 0u; i < 4u; i = i + 1u) {
248+
let coords = vec4<u32>(batch, r, c + i, d2);
249+
if (coordsInBounds4D(coords, uniforms.outShape)) {
250+
${addBiasSnippet}
251+
${applyActivationSnippet}
252+
setOutput(coords[0], coords[1], coords[2], coords[3], dotProd[i]);
253+
}
254+
}
255+
}
256+
`;
257+
return userCode;
258+
}
156259
}

tfjs-backend-webgpu/src/kernels/depthwise_conv2d_webgpu.ts

+1-2
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,7 @@ export class DepthwiseConv2DProgram implements WebGPUProgram {
230230
231231
// Extract if checking out of for loop for performance.
232232
if (inputRowStart >= 0 && inputColStart >= 0 &&
233-
inputRowEnd < i32(uniforms.inDims[0]) && inputColEnd < i32(uniforms.inDims[1]))
234-
{
233+
inputRowEnd < i32(uniforms.inDims[0]) && inputColEnd < i32(uniforms.inDims[1])) {
235234
// Here using a constant value |this.convInfo.filterHeight| instead
236235
// of uniform value is in order to loop unrolling.
237236
for (var wR = 0u; wR < ${

0 commit comments

Comments
 (0)