17
17
18
18
import '@tensorflow/tfjs-backend-cpu' ;
19
19
import '@tensorflow/tfjs-backend-webgl' ;
20
+ import '@tensorflow/tfjs-backend-webgpu' ;
20
21
21
22
import * as tfc from '@tensorflow/tfjs-core' ;
22
23
// tslint:disable-next-line: no-imports-from-dist
@@ -27,27 +28,28 @@ import {SMOKE} from './constants';
27
28
/**
28
29
* This file tests backend switching scenario.
29
30
*/
30
-
31
+ // TODO: Support backend switching between wasm and cpu.
32
+ // https://github.com/tensorflow/tfjs/issues/7623
31
33
describeWithFlags (
32
34
`${ SMOKE } backend switching` , {
33
- predicate : testEnv => testEnv . backendName === 'webgl' &&
34
- tfc . findBackend ( 'webgl' ) !== null && tfc . findBackend ( 'cpu' ) !== null
35
+ predicate : testEnv =>
36
+ testEnv . backendName !== 'cpu' && testEnv . backendName !== 'wasm'
35
37
} ,
36
38
37
- ( ) => {
38
- it ( `from webgl to cpu.` , async ( ) => {
39
- await tfc . setBackend ( 'webgl' ) ;
39
+ ( env ) => {
40
+ it ( `from ${ env . name } to cpu.` , async ( ) => {
41
+ await tfc . setBackend ( env . name ) ;
40
42
41
- const webglBefore = tfc . engine ( ) . backend . numDataIds ( ) ;
43
+ const backendBefore = tfc . engine ( ) . backend . numDataIds ( ) ;
42
44
43
45
const input = tfc . tensor2d ( [ 1 , 1 , 1 , 1 ] , [ 2 , 2 ] , 'float32' ) ;
44
- // input is stored in webgl backend.
46
+ // input is stored in backend.
45
47
46
48
const inputReshaped = tfc . reshape ( input , [ 2 , 2 ] ) ;
47
49
48
- const webglAfter = tfc . engine ( ) . backend . numDataIds ( ) ;
50
+ const backendAfter = tfc . engine ( ) . backend . numDataIds ( ) ;
49
51
50
- expect ( webglAfter ) . toEqual ( webglBefore + 1 ) ;
52
+ expect ( backendAfter ) . toEqual ( backendBefore + 1 ) ;
51
53
52
54
await tfc . setBackend ( 'cpu' ) ;
53
55
@@ -56,8 +58,9 @@ describeWithFlags(
56
58
const inputReshaped2 = tfc . reshape ( inputReshaped , [ 2 , 2 ] ) ;
57
59
// input moved to cpu.
58
60
59
- // Because input is moved to cpu, data should be deleted from webgl
60
- expect ( tfc . findBackend ( 'webgl' ) . numDataIds ( ) ) . toEqual ( webglAfter - 1 ) ;
61
+ // Because input is moved to cpu, data should be deleted from backend.
62
+ expect ( tfc . findBackend ( env . name ) . numDataIds ( ) )
63
+ . toEqual ( backendAfter - 1 ) ;
61
64
62
65
const cpuAfter = tfc . engine ( ) . backend . numDataIds ( ) ;
63
66
@@ -77,7 +80,7 @@ describeWithFlags(
77
80
expect ( after ) . toBe ( cpuBefore ) ;
78
81
} ) ;
79
82
80
- it ( `from cpu to webgl .` , async ( ) => {
83
+ it ( `from cpu to ${ env . name } .` , async ( ) => {
81
84
await tfc . setBackend ( 'cpu' ) ;
82
85
83
86
const cpuBefore = tfc . engine ( ) . backend . numDataIds ( ) ;
@@ -91,46 +94,47 @@ describeWithFlags(
91
94
92
95
expect ( cpuAfter ) . toEqual ( cpuBefore + 1 ) ;
93
96
94
- await tfc . setBackend ( 'webgl' ) ;
97
+ await tfc . setBackend ( env . name ) ;
95
98
96
- const webglBefore = tfc . engine ( ) . backend . numDataIds ( ) ;
99
+ const backendBefore = tfc . engine ( ) . backend . numDataIds ( ) ;
97
100
98
101
const inputReshaped2 = tfc . reshape ( inputReshaped , [ 2 , 2 ] ) ;
99
- // input moved to webgl.
102
+ // input moved to webgl or webgpu .
100
103
101
- // Because input is moved to webgl, data should be deleted from cpu
104
+ // Because input is moved to backend, data should be deleted
105
+ // from cpu.
102
106
expect ( tfc . findBackend ( 'cpu' ) . numDataIds ( ) ) . toEqual ( cpuAfter - 1 ) ;
103
107
104
- const webglAfter = tfc . engine ( ) . backend . numDataIds ( ) ;
108
+ const backendAfter = tfc . engine ( ) . backend . numDataIds ( ) ;
105
109
106
- expect ( webglAfter ) . toEqual ( webglBefore + 1 ) ;
110
+ expect ( backendAfter ) . toEqual ( backendBefore + 1 ) ;
107
111
108
112
input . dispose ( ) ;
109
113
110
- expect ( tfc . engine ( ) . backend . numDataIds ( ) ) . toEqual ( webglAfter ) ;
114
+ expect ( tfc . engine ( ) . backend . numDataIds ( ) ) . toEqual ( backendAfter ) ;
111
115
112
116
inputReshaped . dispose ( ) ;
113
117
114
- expect ( tfc . engine ( ) . backend . numDataIds ( ) ) . toEqual ( webglAfter ) ;
118
+ expect ( tfc . engine ( ) . backend . numDataIds ( ) ) . toEqual ( backendAfter ) ;
115
119
116
120
inputReshaped2 . dispose ( ) ;
117
121
118
122
const after = tfc . engine ( ) . backend . numDataIds ( ) ;
119
123
120
- expect ( after ) . toBe ( webglBefore ) ;
124
+ expect ( after ) . toBe ( backendBefore ) ;
121
125
} ) ;
122
126
123
127
it ( 'can execute op with data from mixed backends' , async ( ) => {
124
128
const numTensors = tfc . memory ( ) . numTensors ;
125
- const webglNumDataIds = tfc . findBackend ( 'webgl' ) . numDataIds ( ) ;
129
+ const backendNumDataIds = tfc . findBackend ( env . name ) . numDataIds ( ) ;
126
130
const cpuNumDataIds = tfc . findBackend ( 'cpu' ) . numDataIds ( ) ;
127
131
128
132
await tfc . setBackend ( 'cpu' ) ;
129
133
// This scalar lives in cpu.
130
134
const a = tfc . scalar ( 5 ) ;
131
135
132
- await tfc . setBackend ( 'webgl' ) ;
133
- // This scalar lives in webgl.
136
+ await tfc . setBackend ( env . name ) ;
137
+ // This scalar lives in webgl or webgpu .
134
138
const b = tfc . scalar ( 3 ) ;
135
139
136
140
// Verify that ops can execute with mixed backend data.
@@ -141,32 +145,34 @@ describeWithFlags(
141
145
tfc . test_util . expectArraysClose ( await result . data ( ) , [ 8 ] ) ;
142
146
expect ( tfc . findBackend ( 'cpu' ) . numDataIds ( ) ) . toBe ( cpuNumDataIds + 3 ) ;
143
147
144
- await tfc . setBackend ( 'webgl' ) ;
148
+ await tfc . setBackend ( env . name ) ;
145
149
tfc . test_util . expectArraysClose ( await tfc . add ( a , b ) . data ( ) , [ 8 ] ) ;
146
- expect ( tfc . findBackend ( 'webgl' ) . numDataIds ( ) ) . toBe ( webglNumDataIds + 3 ) ;
150
+ expect ( tfc . findBackend ( env . name ) . numDataIds ( ) )
151
+ . toBe ( backendNumDataIds + 3 ) ;
147
152
148
153
tfc . engine ( ) . endScope ( ) ;
149
154
150
155
expect ( tfc . memory ( ) . numTensors ) . toBe ( numTensors + 2 ) ;
151
- expect ( tfc . findBackend ( 'webgl' ) . numDataIds ( ) ) . toBe ( webglNumDataIds + 2 ) ;
156
+ expect ( tfc . findBackend ( env . name ) . numDataIds ( ) )
157
+ . toBe ( backendNumDataIds + 2 ) ;
152
158
expect ( tfc . findBackend ( 'cpu' ) . numDataIds ( ) ) . toBe ( cpuNumDataIds ) ;
153
159
154
160
tfc . dispose ( [ a , b ] ) ;
155
161
156
162
expect ( tfc . memory ( ) . numTensors ) . toBe ( numTensors ) ;
157
- expect ( tfc . findBackend ( 'webgl' ) . numDataIds ( ) ) . toBe ( webglNumDataIds ) ;
163
+ expect ( tfc . findBackend ( env . name ) . numDataIds ( ) ) . toBe ( backendNumDataIds ) ;
158
164
expect ( tfc . findBackend ( 'cpu' ) . numDataIds ( ) ) . toBe ( cpuNumDataIds ) ;
159
165
} ) ;
160
166
161
167
// tslint:disable-next-line: ban
162
- xit ( ' can move complex tensor from cpu to webgl.' , async ( ) => {
168
+ xit ( ` can move complex tensor from cpu to ${ env . name } .` , async ( ) => {
163
169
await tfc . setBackend ( 'cpu' ) ;
164
170
165
171
const real1 = tfc . tensor1d ( [ 1 ] ) ;
166
172
const imag1 = tfc . tensor1d ( [ 2 ] ) ;
167
173
const complex1 = tfc . complex ( real1 , imag1 ) ;
168
174
169
- await tfc . setBackend ( 'webgl' ) ;
175
+ await tfc . setBackend ( env . name ) ;
170
176
171
177
const real2 = tfc . tensor1d ( [ 3 ] ) ;
172
178
const imag2 = tfc . tensor1d ( [ 4 ] ) ;
@@ -178,8 +184,8 @@ describeWithFlags(
178
184
} ) ;
179
185
180
186
// tslint:disable-next-line: ban
181
- xit ( ' can move complex tensor from webgl to cpu.' , async ( ) => {
182
- await tfc . setBackend ( 'webgl' ) ;
187
+ xit ( ` can move complex tensor from ${ env . name } to cpu.` , async ( ) => {
188
+ await tfc . setBackend ( env . name ) ;
183
189
184
190
const real1 = tfc . tensor1d ( [ 1 ] ) ;
185
191
const imag1 = tfc . tensor1d ( [ 2 ] ) ;
0 commit comments