Skip to content

Commit cb9a98b

Browse files
authored
[e2e] Enable webgpu intergration test (#7543)
1 parent c3f04be commit cb9a98b

13 files changed

+109
-126
lines changed

e2e/integration_tests/backends_test.ts

+39-33
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import '@tensorflow/tfjs-backend-cpu';
1919
import '@tensorflow/tfjs-backend-webgl';
20+
import '@tensorflow/tfjs-backend-webgpu';
2021

2122
import * as tfc from '@tensorflow/tfjs-core';
2223
// tslint:disable-next-line: no-imports-from-dist
@@ -27,27 +28,28 @@ import {SMOKE} from './constants';
2728
/**
2829
* This file tests backend switching scenario.
2930
*/
30-
31+
// TODO: Support backend switching between wasm and cpu.
32+
// https://github.com/tensorflow/tfjs/issues/7623
3133
describeWithFlags(
3234
`${SMOKE} backend switching`, {
33-
predicate: testEnv => testEnv.backendName === 'webgl' &&
34-
tfc.findBackend('webgl') !== null && tfc.findBackend('cpu') !== null
35+
predicate: testEnv =>
36+
testEnv.backendName !== 'cpu' && testEnv.backendName !== 'wasm'
3537
},
3638

37-
() => {
38-
it(`from webgl to cpu.`, async () => {
39-
await tfc.setBackend('webgl');
39+
(env) => {
40+
it(`from ${env.name} to cpu.`, async () => {
41+
await tfc.setBackend(env.name);
4042

41-
const webglBefore = tfc.engine().backend.numDataIds();
43+
const backendBefore = tfc.engine().backend.numDataIds();
4244

4345
const input = tfc.tensor2d([1, 1, 1, 1], [2, 2], 'float32');
44-
// input is stored in webgl backend.
46+
// input is stored in backend.
4547

4648
const inputReshaped = tfc.reshape(input, [2, 2]);
4749

48-
const webglAfter = tfc.engine().backend.numDataIds();
50+
const backendAfter = tfc.engine().backend.numDataIds();
4951

50-
expect(webglAfter).toEqual(webglBefore + 1);
52+
expect(backendAfter).toEqual(backendBefore + 1);
5153

5254
await tfc.setBackend('cpu');
5355

@@ -56,8 +58,9 @@ describeWithFlags(
5658
const inputReshaped2 = tfc.reshape(inputReshaped, [2, 2]);
5759
// input moved to cpu.
5860

59-
// Because input is moved to cpu, data should be deleted from webgl
60-
expect(tfc.findBackend('webgl').numDataIds()).toEqual(webglAfter - 1);
61+
// Because input is moved to cpu, data should be deleted from backend.
62+
expect(tfc.findBackend(env.name).numDataIds())
63+
.toEqual(backendAfter - 1);
6164

6265
const cpuAfter = tfc.engine().backend.numDataIds();
6366

@@ -77,7 +80,7 @@ describeWithFlags(
7780
expect(after).toBe(cpuBefore);
7881
});
7982

80-
it(`from cpu to webgl.`, async () => {
83+
it(`from cpu to ${env.name}.`, async () => {
8184
await tfc.setBackend('cpu');
8285

8386
const cpuBefore = tfc.engine().backend.numDataIds();
@@ -91,46 +94,47 @@ describeWithFlags(
9194

9295
expect(cpuAfter).toEqual(cpuBefore + 1);
9396

94-
await tfc.setBackend('webgl');
97+
await tfc.setBackend(env.name);
9598

96-
const webglBefore = tfc.engine().backend.numDataIds();
99+
const backendBefore = tfc.engine().backend.numDataIds();
97100

98101
const inputReshaped2 = tfc.reshape(inputReshaped, [2, 2]);
99-
// input moved to webgl.
102+
// input moved to webgl or webgpu.
100103

101-
// Because input is moved to webgl, data should be deleted from cpu
104+
// Because input is moved to backend, data should be deleted
105+
// from cpu.
102106
expect(tfc.findBackend('cpu').numDataIds()).toEqual(cpuAfter - 1);
103107

104-
const webglAfter = tfc.engine().backend.numDataIds();
108+
const backendAfter = tfc.engine().backend.numDataIds();
105109

106-
expect(webglAfter).toEqual(webglBefore + 1);
110+
expect(backendAfter).toEqual(backendBefore + 1);
107111

108112
input.dispose();
109113

110-
expect(tfc.engine().backend.numDataIds()).toEqual(webglAfter);
114+
expect(tfc.engine().backend.numDataIds()).toEqual(backendAfter);
111115

112116
inputReshaped.dispose();
113117

114-
expect(tfc.engine().backend.numDataIds()).toEqual(webglAfter);
118+
expect(tfc.engine().backend.numDataIds()).toEqual(backendAfter);
115119

116120
inputReshaped2.dispose();
117121

118122
const after = tfc.engine().backend.numDataIds();
119123

120-
expect(after).toBe(webglBefore);
124+
expect(after).toBe(backendBefore);
121125
});
122126

123127
it('can execute op with data from mixed backends', async () => {
124128
const numTensors = tfc.memory().numTensors;
125-
const webglNumDataIds = tfc.findBackend('webgl').numDataIds();
129+
const backendNumDataIds = tfc.findBackend(env.name).numDataIds();
126130
const cpuNumDataIds = tfc.findBackend('cpu').numDataIds();
127131

128132
await tfc.setBackend('cpu');
129133
// This scalar lives in cpu.
130134
const a = tfc.scalar(5);
131135

132-
await tfc.setBackend('webgl');
133-
// This scalar lives in webgl.
136+
await tfc.setBackend(env.name);
137+
// This scalar lives in webgl or webgpu.
134138
const b = tfc.scalar(3);
135139

136140
// Verify that ops can execute with mixed backend data.
@@ -141,32 +145,34 @@ describeWithFlags(
141145
tfc.test_util.expectArraysClose(await result.data(), [8]);
142146
expect(tfc.findBackend('cpu').numDataIds()).toBe(cpuNumDataIds + 3);
143147

144-
await tfc.setBackend('webgl');
148+
await tfc.setBackend(env.name);
145149
tfc.test_util.expectArraysClose(await tfc.add(a, b).data(), [8]);
146-
expect(tfc.findBackend('webgl').numDataIds()).toBe(webglNumDataIds + 3);
150+
expect(tfc.findBackend(env.name).numDataIds())
151+
.toBe(backendNumDataIds + 3);
147152

148153
tfc.engine().endScope();
149154

150155
expect(tfc.memory().numTensors).toBe(numTensors + 2);
151-
expect(tfc.findBackend('webgl').numDataIds()).toBe(webglNumDataIds + 2);
156+
expect(tfc.findBackend(env.name).numDataIds())
157+
.toBe(backendNumDataIds + 2);
152158
expect(tfc.findBackend('cpu').numDataIds()).toBe(cpuNumDataIds);
153159

154160
tfc.dispose([a, b]);
155161

156162
expect(tfc.memory().numTensors).toBe(numTensors);
157-
expect(tfc.findBackend('webgl').numDataIds()).toBe(webglNumDataIds);
163+
expect(tfc.findBackend(env.name).numDataIds()).toBe(backendNumDataIds);
158164
expect(tfc.findBackend('cpu').numDataIds()).toBe(cpuNumDataIds);
159165
});
160166

161167
// tslint:disable-next-line: ban
162-
xit('can move complex tensor from cpu to webgl.', async () => {
168+
xit(`can move complex tensor from cpu to ${env.name}.`, async () => {
163169
await tfc.setBackend('cpu');
164170

165171
const real1 = tfc.tensor1d([1]);
166172
const imag1 = tfc.tensor1d([2]);
167173
const complex1 = tfc.complex(real1, imag1);
168174

169-
await tfc.setBackend('webgl');
175+
await tfc.setBackend(env.name);
170176

171177
const real2 = tfc.tensor1d([3]);
172178
const imag2 = tfc.tensor1d([4]);
@@ -178,8 +184,8 @@ describeWithFlags(
178184
});
179185

180186
// tslint:disable-next-line: ban
181-
xit('can move complex tensor from webgl to cpu.', async () => {
182-
await tfc.setBackend('webgl');
187+
xit(`can move complex tensor from ${env.name} to cpu.`, async () => {
188+
await tfc.setBackend(env.name);
183189

184190
const real1 = tfc.tensor1d([1]);
185191
const imag1 = tfc.tensor1d([2]);

e2e/integration_tests/constants.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ export const GOLDEN = '#GOLDEN';
2727
export const TAGS = [SMOKE, REGRESSION, GOLDEN];
2828

2929
/** Testing backends. */
30-
export const BACKENDS = ['cpu', 'webgl'];
30+
export const BACKENDS = ['cpu', 'webgl', 'webgpu'];
3131

3232
/** Testing models for CUJ: create -> save -> predict. */
3333
export const LAYERS_MODELS = [

e2e/integration_tests/convert_predict.ts

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
*/
2727
import '@tensorflow/tfjs-backend-cpu';
2828
import '@tensorflow/tfjs-backend-webgl';
29+
import '@tensorflow/tfjs-backend-webgpu';
2930

3031
import * as tfconverter from '@tensorflow/tfjs-converter';
3132
import * as tfc from '@tensorflow/tfjs-core';
@@ -61,6 +62,7 @@ describeWithFlags(`${REGRESSION} convert_predict`, ALL_ENVS, (env) => {
6162
continue;
6263
}
6364
it(`${model}.`, async () => {
65+
await tfc.setBackend(env.name);
6466
let inputsNames: string[];
6567
let inputsData: tfc.TypedArray[];
6668
let inputsShapes: number[][];

0 commit comments

Comments
 (0)