Skip to content

Commit bea721d

Browse files
authored
webgpu: support selu operator (#7118)
1 parent 40160ef commit bea721d

File tree

4 files changed

+44
-2
lines changed

4 files changed

+44
-2
lines changed
+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/**
2+
* @license
3+
* Copyright 2022 Google LLC.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {KernelConfig, Selu} from '@tensorflow/tfjs-core';
19+
20+
import {unaryKernelFunc} from '../kernel_utils/kernel_funcs_utils';
21+
22+
import {UnaryOpType} from '../unary_op_util';
23+
24+
export const selu = unaryKernelFunc({opType: UnaryOpType.SELU});
25+
26+
export const seluConfig: KernelConfig = {
27+
kernelName: Selu,
28+
backendName: 'webgpu',
29+
kernelFunc: selu
30+
};

tfjs-backend-webgpu/src/register_all_kernels.ts

+2
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ import {rsqrtConfig} from './kernels/Rsqrt';
120120
import {scatterNdConfig} from './kernels/ScatterNd';
121121
import {searchSortedConfig} from './kernels/SearchSorted';
122122
import {selectConfig} from './kernels/Select';
123+
import {seluConfig} from './kernels/Selu';
123124
import {sigmoidConfig} from './kernels/Sigmoid';
124125
import {signConfig} from './kernels/Sign';
125126
import {sinConfig} from './kernels/Sin';
@@ -252,6 +253,7 @@ const kernelConfigs: KernelConfig[] = [
252253
scatterNdConfig,
253254
searchSortedConfig,
254255
selectConfig,
256+
seluConfig,
255257
sigmoidConfig,
256258
signConfig,
257259
sinConfig,

tfjs-backend-webgpu/src/setup_test.ts

-2
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ const TEST_FILTERS: TestFilter[] = [
9191
{
9292
startsWith: 'elu ',
9393
excludes: [
94-
'selu', // Not yet implemented.
9594
'derivative', // gradient function not found.
9695
'gradient' // gradient function not found.
9796
]
@@ -265,7 +264,6 @@ const TEST_FILTERS: TestFilter[] = [
265264
'raggedRange ',
266265
'raggedTensorToTensor ',
267266
'method otsu', // round
268-
'selu ',
269267
'sparseFillEmptyRows ',
270268
'sparseReshape ',
271269
'sparseSegmentMean ',

tfjs-backend-webgpu/src/unary_op_util.ts

+12
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ export enum UnaryOpType {
4747
RECIPROCAL,
4848
ROUND,
4949
RSQRT,
50+
SELU,
5051
SIGMOID,
5152
SIGN,
5253
SIN,
@@ -166,6 +167,15 @@ const RELU_VEC4 = `
166167
`;
167168
const ROUND = `return round(a);`;
168169
const RSQRT = `return inverseSqrt(a);`;
170+
// Stable and Attracting Fixed Point (0, 1) for Normalized Weights.
171+
// See: https://arxiv.org/abs/1706.02515
172+
const SELU = `
173+
if (a >= 0.0) {
174+
return ${backend_util.SELU_SCALE} * a;
175+
} else {
176+
return ${backend_util.SELU_SCALEALPHA} * (exp(a) - 1.0);
177+
}
178+
`;
169179
const SIGMOID = `return 1.0 / (1.0 + exp(-1.0 * a));`;
170180
const SIGN = `return sign(a);`;
171181
const SIN = `return sin(a);`;
@@ -258,6 +268,8 @@ export function getUnaryOpString(type: UnaryOpType, useVec4?: boolean): string {
258268
return ROUND;
259269
case UnaryOpType.RSQRT:
260270
return RSQRT;
271+
case UnaryOpType.SELU:
272+
return SELU;
261273
case UnaryOpType.SIGMOID:
262274
return SIGMOID;
263275
case UnaryOpType.SIGN:

0 commit comments

Comments
 (0)