16
16
*/
17
17
18
18
import { backend_util , util } from '@tensorflow/tfjs-core' ;
19
+
20
+ import { getGlobalIndexStringWgsl , getMainHeaderStringWgsl } from '../shader_preprocessor_wgsl' ;
19
21
import { computeDispatch } from '../webgpu_util' ;
22
+
20
23
import { mapActivationToShaderProgram } from './activation_util' ;
21
- import { WebGPUProgram } from './webgpu_program' ;
24
+ import { getUseWgsl , WebGPUProgram } from './webgpu_program' ;
22
25
23
26
export class DepthwiseConv2D3x3Program implements WebGPUProgram {
24
27
outputShape : number [ ] ;
@@ -27,12 +30,15 @@ export class DepthwiseConv2D3x3Program implements WebGPUProgram {
27
30
dispatch : [ number , number , number ] ;
28
31
variableNames = [ 'x' , 'W' ] ;
29
32
uniforms = 'ivec2 pad, stride, dilation, inDims;' ;
33
+ uniformsWgsl =
34
+ 'pad : vec2<u32>; stride : vec2<u32>; dilation : vec2<u32>; inDims : vec2<u32>;' ;
30
35
workGroupSize : [ number , number , number ] = [ 4 , 4 , 4 ] ;
31
36
convInfo : backend_util . Conv2DInfo ;
32
37
addBias : boolean ;
33
38
activation : backend_util . Activation ;
34
39
hasPreluActivation : boolean ;
35
40
isVec4 = true ;
41
+ useWgsl : boolean ;
36
42
37
43
constructor (
38
44
convInfo : backend_util . Conv2DInfo , addBias = false ,
@@ -59,6 +65,7 @@ export class DepthwiseConv2D3x3Program implements WebGPUProgram {
59
65
this . hasPreluActivation = hasPreluActivation ;
60
66
61
67
this . shaderKey = `depthwise3x3_${ activation } ` ;
68
+ this . useWgsl = getUseWgsl ( ) ;
62
69
}
63
70
64
71
getUserCode ( ) : string {
@@ -153,4 +160,100 @@ export class DepthwiseConv2D3x3Program implements WebGPUProgram {
153
160
` ;
154
161
return userCode ;
155
162
}
163
+
164
+ getUserCodeWgsl ( ) : string {
165
+ let activationSnippet = '' , applyActivationSnippet = '' ;
166
+ if ( this . activation ) {
167
+ const activationOp = mapActivationToShaderProgram (
168
+ this . activation , this . isVec4 , this . useWgsl ) ;
169
+ if ( this . hasPreluActivation ) {
170
+ activationSnippet =
171
+ `fn activation(a : vec4<f32>, globalId : vec3<u32>, globalIndex : u32) -> vec4<f32> {
172
+ let b = getPreluActivationWeightsAtOutCoordsByGlobalId(globalId, globalIndex);
173
+ ${ activationOp }
174
+ }` ;
175
+ } else {
176
+ activationSnippet = `
177
+ fn activation(a : vec4<f32>, globalId : vec3<u32>, globalIndex : u32) -> vec4<f32> {
178
+ ${ activationOp }
179
+ }
180
+ ` ;
181
+ }
182
+
183
+ applyActivationSnippet =
184
+ `dotProd[i] = activation(dotProd[i], globalId, index);` ;
185
+ }
186
+
187
+ const addBiasSnippet = this . addBias ?
188
+ 'dotProd[i] = dotProd[i] + getBiasAtOutCoordsByCoords(coords);' :
189
+ '' ;
190
+
191
+ const userCode = `
192
+ ${ activationSnippet }
193
+
194
+ ${ getMainHeaderStringWgsl ( this . workGroupSize ) } {
195
+ ${ getGlobalIndexStringWgsl ( this . workGroupSize ) }
196
+ let batch = 0u;
197
+ let r = globalId.x;
198
+ let c = globalId.y * 4u;
199
+ let d2 = globalId.z * 4u;
200
+ let xRCCorner = vec2<i32>(vec2<u32>(r, c) * uniforms.stride - uniforms.pad);
201
+ let d1 = d2;
202
+ let q = 0u;
203
+
204
+ let xRCorner = xRCCorner.x;
205
+ let xCCorner = xRCCorner.y;
206
+
207
+ var wVals : array<vec4<f32>, 9>;
208
+ wVals[0] = getW(0u, 0u, d1, q);
209
+ wVals[1] = getW(0u, 1u, d1, q);
210
+ wVals[2] = getW(0u, 2u, d1, q);
211
+ wVals[3] = getW(1u, 0u, d1, q);
212
+ wVals[4] = getW(1u, 1u, d1, q);
213
+ wVals[5] = getW(1u, 2u, d1, q);
214
+ wVals[6] = getW(2u, 0u, d1, q);
215
+ wVals[7] = getW(2u, 1u, d1, q);
216
+ wVals[8] = getW(2u, 2u, d1, q);
217
+
218
+ var xVals : array<array<vec4<f32>, 6>, 3>;
219
+ for (var wR = 0u; wR < 3u; wR = wR + 1u) {
220
+ let xR = xRCorner + i32(wR * uniforms.dilation[0]);
221
+ for (var wC = 0u; wC < 6u; wC = wC + 1u) {
222
+ let xC = xCCorner + i32(wC * uniforms.dilation[1]);
223
+ if (xR < 0 || xR >= i32(uniforms.inDims[0]) || xC < 0 || xC >= i32(uniforms.inDims[1])) {
224
+ xVals[wR][wC] = vec4<f32>(0.0);
225
+ } else {
226
+ xVals[wR][wC] = getX(batch, u32(xR), u32(xC), d1);
227
+ }
228
+ }
229
+ }
230
+
231
+ var dotProd : array<vec4<f32>, 4>;
232
+ dotProd[0] = vec4<f32>(0.0);
233
+ dotProd[1] = vec4<f32>(0.0);
234
+ dotProd[2] = vec4<f32>(0.0);
235
+ dotProd[3] = vec4<f32>(0.0);
236
+
237
+ for (var wR = 0u; wR < 3u; wR = wR + 1u) {
238
+ for (var wC = 0u; wC < 3u; wC = wC + 1u) {
239
+ let indexW = wR * 3u + wC;
240
+ dotProd[0] = dotProd[0] + xVals[wR][0u + wC] * wVals[indexW];
241
+ dotProd[1] = dotProd[1] + xVals[wR][1u + wC] * wVals[indexW];
242
+ dotProd[2] = dotProd[2] + xVals[wR][2u + wC] * wVals[indexW];
243
+ dotProd[3] = dotProd[3] + xVals[wR][3u + wC] * wVals[indexW];
244
+ }
245
+ }
246
+
247
+ for (var i = 0u; i < 4u; i = i + 1u) {
248
+ let coords = vec4<u32>(batch, r, c + i, d2);
249
+ if (coordsInBounds4D(coords, uniforms.outShape)) {
250
+ ${ addBiasSnippet }
251
+ ${ applyActivationSnippet }
252
+ setOutput(coords[0], coords[1], coords[2], coords[3], dotProd[i]);
253
+ }
254
+ }
255
+ }
256
+ ` ;
257
+ return userCode ;
258
+ }
156
259
}
0 commit comments