Skip to content

Commit d15c96a

Browse files
authored
webgpu: support linSpace operator (#7119)
1 parent 1073979 commit d15c96a

File tree

4 files changed

+91
-1
lines changed

4 files changed

+91
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/**
2+
* @license
3+
* Copyright 2022 Google LLC.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {KernelConfig, KernelFunc, LinSpace, LinSpaceAttrs, TensorInfo} from '@tensorflow/tfjs-core';
19+
20+
import {WebGPUBackend} from '../backend_webgpu';
21+
import {LinSpaceProgram} from '../lin_space_webgpu';
22+
23+
export function linSpace(args: {backend: WebGPUBackend, attrs: LinSpaceAttrs}):
24+
TensorInfo {
25+
const {backend, attrs} = args;
26+
const {start, stop, num} = attrs;
27+
const step = (stop - start) / (num - 1);
28+
29+
const program = new LinSpaceProgram(num);
30+
const uniformData =
31+
[{type: 'float32', data: [start]}, {type: 'float32', data: [step]}];
32+
return backend.runWebGPUProgram(program, [], 'float32', uniformData);
33+
}
34+
35+
export const linSpaceConfig: KernelConfig = {
36+
kernelName: LinSpace,
37+
backendName: 'webgpu',
38+
kernelFunc: linSpace as unknown as KernelFunc
39+
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/**
2+
* @license
3+
* Copyright 2022 Google LLC.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program';
19+
import {computeDispatch, flatDispatchLayout} from './webgpu_util';
20+
21+
export class LinSpaceProgram implements WebGPUProgram {
22+
variableNames: string[] = [];
23+
outputShape: number[] = [];
24+
shaderKey: string;
25+
dispatchLayout: {x: number[]};
26+
dispatch: [number, number, number];
27+
uniforms = 'start : f32, step : f32,';
28+
workgroupSize: [number, number, number] = [64, 1, 1];
29+
size = true;
30+
31+
constructor(shape: number) {
32+
this.outputShape = [shape];
33+
this.dispatchLayout = flatDispatchLayout(this.outputShape);
34+
this.dispatch = computeDispatch(
35+
this.dispatchLayout, this.outputShape, this.workgroupSize);
36+
37+
this.shaderKey = 'linSpace';
38+
}
39+
40+
getUserCode(): string {
41+
const userCode = `
42+
${main('index')} {
43+
if (index < uniforms.size) {
44+
setOutputAtIndex(index, uniforms.start + f32(index) * uniforms.step);
45+
}
46+
}
47+
`;
48+
return userCode;
49+
}
50+
}

tfjs-backend-webgpu/src/register_all_kernels.ts

+2
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ import {isNaNConfig} from './kernels/IsNaN';
7979
import {leakyReluConfig} from './kernels/LeakyRelu';
8080
import {lessConfig} from './kernels/Less';
8181
import {lessEqualConfig} from './kernels/LessEqual';
82+
import {linSpaceConfig} from './kernels/LinSpace';
8283
import {logConfig} from './kernels/Log';
8384
import {log1pConfig} from './kernels/Log1p';
8485
import {logicalAndConfig} from './kernels/LogicalAnd';
@@ -212,6 +213,7 @@ const kernelConfigs: KernelConfig[] = [
212213
leakyReluConfig,
213214
lessConfig,
214215
lessEqualConfig,
216+
linSpaceConfig,
215217
log1pConfig,
216218
logConfig,
217219
logicalAndConfig,

tfjs-backend-webgpu/src/setup_test.ts

-1
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,6 @@ const TEST_FILTERS: TestFilter[] = [
252252
'diag ',
253253
'dilation2d ',
254254
'encodeWeights ',
255-
'linspace ',
256255
'localResponseNormalization ',
257256
'maxPool3d ',
258257
'maxPool3dBackprop ',

0 commit comments

Comments
 (0)