Skip to content

Commit 281d8f8

Browse files
authored
[webgpu] More preparation for element-wise binary op restructuring (#7666)
The current code isn't great in that the vec4 shaders have diverged from the scalar ones more than necessary. Here is the common preparation work, so that following refactoring can be done on a per-op basis.
1 parent 04d4a86 commit 281d8f8

File tree

1 file changed

+30
-13
lines changed

1 file changed

+30
-13
lines changed

tfjs-backend-webgpu/src/binary_op_util.ts

+30-13
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,6 @@ export enum BinaryOpType {
4141
SUB
4242
}
4343

44-
const CHECK_NAN_SNIPPET = `
45-
resultTemp = select(resultTemp, valueForNaN, isNaN | isnan(a) | isnan(b));`;
46-
47-
const CHECK_NAN_SNIPPET_VEC4 = `
48-
resultTemp = select(
49-
resultTemp, vec4<f32>(valueForNaN),
50-
vec4<bool>(isNaN) | isnanVec4(a) | isnanVec4(b));
51-
`;
52-
5344
const ADD = 'return a + b;';
5445
const ATAN2 = 'var resultTemp = atan2(a, b);';
5546
// (Ar + Ai)(Br + Bi) =
@@ -183,9 +174,10 @@ const SUB = 'return a - b;';
183174

184175
export function getBinaryOpString(
185176
type: BinaryOpType, useVec4?: boolean): string {
177+
let doOpSnippet: string;
178+
186179
// Ops with NaN check
187180
do {
188-
let doOpSnippet: string;
189181
switch (type) {
190182
case BinaryOpType.ATAN2:
191183
doOpSnippet = ATAN2;
@@ -208,13 +200,34 @@ export function getBinaryOpString(
208200
default:
209201
continue;
210202
}
203+
204+
let isNaN: string;
205+
let dTypeN: string;
206+
let boolN: string;
207+
if (useVec4) {
208+
isNaN = 'isnanVec4';
209+
dTypeN = 'vec4<f32>';
210+
boolN = 'vec4<bool>';
211+
} else {
212+
isNaN = 'isnan';
213+
dTypeN = 'f32';
214+
boolN = 'bool';
215+
}
216+
211217
return `
218+
let aIsNaN = ${isNaN}(a);
219+
let aPostLegalization = select(a, ${dTypeN}(42), aIsNaN);
220+
let bIsNaN = ${isNaN}(b);
221+
let bPostLegalization = select(b, ${dTypeN}(42), bIsNaN);
212222
let isNaN = false;
213223
let valueForNaN = uniforms.NAN;
214224
{
225+
let a = aPostLegalization;
226+
let b = bPostLegalization;
215227
${doOpSnippet}
216-
${useVec4 ? CHECK_NAN_SNIPPET_VEC4 : CHECK_NAN_SNIPPET}
217-
return resultTemp;
228+
return select(
229+
resultTemp, ${dTypeN}(valueForNaN),
230+
${boolN}(isNaN) | aIsNaN | bIsNaN);
218231
}
219232
`;
220233
} while (false);
@@ -256,6 +269,10 @@ export function getBinaryOpString(
256269
case BinaryOpType.SUB:
257270
return SUB;
258271
default:
259-
throw new Error(`BinaryType ${type} is not implemented!`);
272+
// throw new Error(`BinaryType ${type} is not implemented!`);
260273
}
274+
return `
275+
${doOpSnippet}
276+
return resultTemp;
277+
`;
261278
}

0 commit comments

Comments
 (0)