21
21
*/
22
22
23
23
describe ( 'benchmark_util' , ( ) => {
24
- beforeAll ( ( ) => tf . setBackend ( 'cpu' ) ) ;
24
+ beforeEach ( ( ) => tf . setBackend ( 'cpu' ) ) ;
25
25
26
26
describe ( 'generateInput' , ( ) => {
27
27
it ( 'LayersModel' , ( ) => {
@@ -56,8 +56,81 @@ describe('benchmark_util', () => {
56
56
} ) ;
57
57
} ) ;
58
58
59
+ describe ( 'getPredictFnForModel' , ( ) => {
60
+ it ( 'graph model with async ops uses executeAsync to run' , ( ) => {
61
+ const model = new tf . GraphModel ( ) ;
62
+ const input = tf . tensor ( [ 1 ] ) ;
63
+ const oldTensorNum = tf . memory ( ) . numTensors ;
64
+ spyOn ( model , 'execute' ) . and . callFake ( ( ) => {
65
+ const leakedTensor = tf . tensor ( [ 1 ] ) ;
66
+ throw new Error (
67
+ 'This model has dynamic ops, ' +
68
+ 'please use model.executeAsync() instead' ) ;
69
+ return leakedTensor ;
70
+ } ) ;
71
+ spyOn ( model , 'executeAsync' ) ;
72
+
73
+ const wrappedPredict = getPredictFnForModel ( model , input ) ;
74
+ expect ( tf . memory ( ) . numTensors ) . toBe ( oldTensorNum ) ;
75
+ expect ( model . execute . calls . count ( ) ) . toBe ( 1 ) ;
76
+ expect ( model . execute . calls . first ( ) . args ) . toEqual ( [ input ] ) ;
77
+
78
+ wrappedPredict ( ) ;
79
+ expect ( model . execute . calls . count ( ) ) . toBe ( 1 ) ;
80
+ expect ( model . executeAsync . calls . count ( ) ) . toBe ( 1 ) ;
81
+ expect ( model . executeAsync . calls . first ( ) . args ) . toEqual ( [ input ] ) ;
82
+
83
+ tf . dispose ( input ) ;
84
+ } ) ;
85
+
86
+ it ( 'graph model without async ops uses execute to run' , ( ) => {
87
+ const model = new tf . GraphModel ( ) ;
88
+ const input = tf . tensor ( [ 1 ] ) ;
89
+ const oldTensorNum = tf . memory ( ) . numTensors ;
90
+ spyOn ( model , 'execute' ) . and . callFake ( ( ) => {
91
+ const leakedTensor = tf . tensor ( [ 1 ] ) ;
92
+ } ) ;
93
+ spyOn ( model , 'executeAsync' ) ;
94
+
95
+ const wrappedPredict = getPredictFnForModel ( model , input ) ;
96
+ expect ( tf . memory ( ) . numTensors ) . toBe ( oldTensorNum ) ;
97
+ expect ( model . execute . calls . count ( ) ) . toBe ( 1 ) ;
98
+ expect ( model . execute . calls . first ( ) . args ) . toEqual ( [ input ] ) ;
99
+
100
+ wrappedPredict ( ) ;
101
+ expect ( model . execute . calls . count ( ) ) . toBe ( 2 ) ;
102
+ expect ( model . execute . calls . argsFor ( 1 ) ) . toEqual ( [ input ] ) ;
103
+ expect ( model . executeAsync . calls . count ( ) ) . toBe ( 0 ) ;
104
+
105
+ tf . dispose ( input ) ;
106
+ } ) ;
107
+
108
+ it ( 'layers model uses predict to run' , ( ) => {
109
+ const model = tf . sequential (
110
+ { layers : [ tf . layers . dense ( { units : 1 , inputShape : [ 1 ] } ) ] } ) ;
111
+ const input = tf . ones ( [ 1 , 1 ] ) ;
112
+ spyOn ( model , 'predict' ) ;
113
+
114
+ const wrappedPredict = getPredictFnForModel ( model , input ) ;
115
+ wrappedPredict ( ) ;
116
+
117
+ expect ( model . predict . calls . count ( ) ) . toBe ( 1 ) ;
118
+ expect ( model . predict . calls . first ( ) . args ) . toEqual ( [ input ] ) ;
119
+
120
+ tf . dispose ( input ) ;
121
+ model . dispose ( ) ;
122
+ } ) ;
123
+
124
+ it ( 'throws when passed in a model that is not layers or graph model' ,
125
+ ( ) => {
126
+ const model = { } ;
127
+ const input = [ ] ;
128
+ expect ( ( ) => getPredictFnForModel ( model , input ) ) . toThrowError ( Error ) ;
129
+ } ) ;
130
+ } ) ;
131
+
59
132
describe ( 'setEnvFlags' , ( ) => {
60
- describe ( 'change nothing' , ( ) => {
133
+ describe ( 'changes nothing when setting empty config or rejecting ' , ( ) => {
61
134
let originalFlags = { } ;
62
135
63
136
beforeEach ( ( ) => {
@@ -70,7 +143,7 @@ describe('benchmark_util', () => {
70
143
expect ( tf . env ( ) . flags ) . toEqual ( originalFlags ) ;
71
144
} ) ;
72
145
73
- it ( 'untunable flag ' , async ( ) => {
146
+ it ( 'rejects when setting untunable flags ' , async ( ) => {
74
147
const flagConfig = {
75
148
IS_BROWSER : false ,
76
149
} ;
@@ -80,23 +153,23 @@ describe('benchmark_util', () => {
80
153
expect ( tf . env ( ) . flags ) . toEqual ( originalFlags ) ;
81
154
} ) ;
82
155
83
- it ( 'set a number type flag by a boolean value' , async ( ) => {
156
+ it ( 'rejects when setting a number flag by a boolean value' , async ( ) => {
84
157
const flagConfig = {
85
158
WEBGL_VERSION : false ,
86
159
} ;
87
160
expectAsync ( setEnvFlags ( flagConfig ) ) . toBeRejectedWithError ( Error ) ;
88
161
expect ( tf . env ( ) . flags ) . toEqual ( originalFlags ) ;
89
162
} ) ;
90
163
91
- it ( 'set boolean flag by a number' , async ( ) => {
164
+ it ( 'rejects when setting boolean flag by a number' , async ( ) => {
92
165
const flagConfig = {
93
166
WEBGL_PACK : 1 ,
94
167
} ;
95
168
expectAsync ( setEnvFlags ( flagConfig ) ) . toBeRejectedWithError ( Error ) ;
96
169
expect ( tf . env ( ) . flags ) . toEqual ( originalFlags ) ;
97
170
} ) ;
98
171
99
- it ( 'set flag value out of the range' , async ( ) => {
172
+ it ( 'rejects when setting flag value out of the range' , async ( ) => {
100
173
const outOfRangeValue =
101
174
Math . max ( ...TUNABLE_FLAG_VALUE_RANGE_MAP . WEBGL_VERSION ) + 1 ;
102
175
const flagConfig = {
@@ -107,7 +180,7 @@ describe('benchmark_util', () => {
107
180
} ) ;
108
181
} ) ;
109
182
110
- describe ( 'reset flags' , ( ) => {
183
+ describe ( 'reset simple flags' , ( ) => {
111
184
beforeEach ( ( ) => tf . env ( ) . reset ( ) ) ;
112
185
afterEach ( ( ) => tf . env ( ) . reset ( ) ) ;
113
186
@@ -201,13 +274,13 @@ describe('benchmark_util', () => {
201
274
beforeEach ( ( ) => tf . setBackend ( 'cpu' ) ) ;
202
275
afterAll ( ( ) => tf . engine ( ) . reset ( ) ) ;
203
276
204
- it ( 'reset a backend that is not registed' , async ( ) => {
277
+ it ( 'rejects when resetting a backend that is not registed' , async ( ) => {
205
278
expectAsync ( resetBackend ( 'invalidBackendName' ) )
206
279
. toBeRejectedWithError (
207
280
Error , 'invalidBackendName backend is not registed.' ) ;
208
281
} ) ;
209
282
210
- it ( 'reset a backend that is not generated ' , async ( ) => {
283
+ it ( 'do nothing when resetting a backend that is not created ' , async ( ) => {
211
284
const testCpuBackend = 'testCpuBackend' ;
212
285
tf . registerBackend ( testCpuBackend , tf . findBackendFactory ( 'cpu' ) ) ;
213
286
expect ( tf . engine ( ) . registry [ testCpuBackend ] ) . toBeUndefined ( ) ;
@@ -223,7 +296,7 @@ describe('benchmark_util', () => {
223
296
tf . removeBackend ( testCpuBackend ) ;
224
297
} ) ;
225
298
226
- it ( 'reset a backend that has been generated ' , async ( ) => {
299
+ it ( 'reset the backend when resetting an existed backend ' , async ( ) => {
227
300
await tf . ready ( ) ;
228
301
const currentBackend = tf . getBackend ( ) ;
229
302
expect ( tf . engine ( ) . registry [ currentBackend ] ) . toBeDefined ( ) ;
@@ -238,23 +311,25 @@ describe('benchmark_util', () => {
238
311
expect ( tf . registerBackend . calls . count ( ) ) . toBe ( 1 ) ;
239
312
} ) ;
240
313
241
- it ( 'reset the active backend' , async ( ) => {
242
- const currentBackend = tf . getBackend ( ) ;
243
- spyOn ( tf , 'setBackend' ) ;
244
- await resetBackend ( currentBackend ) ;
245
- expect ( tf . setBackend . calls . count ( ) ) . toBe ( 1 ) ;
246
- } ) ;
247
-
248
- it ( 'reset an inactive backend' , async ( ) => {
249
- const testCpuBackend = 'testCpuBackend' ;
250
- tf . registerBackend ( testCpuBackend , tf . findBackendFactory ( 'cpu' ) ) ;
251
- expect ( tf . getBackend ( ) ) . not . toBe ( testCpuBackend ) ;
252
- spyOn ( tf , 'setBackend' ) ;
253
-
254
- await resetBackend ( testCpuBackend ) ;
255
-
256
- expect ( tf . setBackend . calls . count ( ) ) . toBe ( 0 ) ;
257
- tf . removeBackend ( testCpuBackend ) ;
258
- } ) ;
314
+ it ( 'tf.setBackend is called when resetting the active backend' ,
315
+ async ( ) => {
316
+ const currentBackend = tf . getBackend ( ) ;
317
+ spyOn ( tf , 'setBackend' ) ;
318
+ await resetBackend ( currentBackend ) ;
319
+ expect ( tf . setBackend . calls . count ( ) ) . toBe ( 1 ) ;
320
+ } ) ;
321
+
322
+ it ( 'tf.setBackend is not called when resetting an inactive backend' ,
323
+ async ( ) => {
324
+ const testCpuBackend = 'testCpuBackend' ;
325
+ tf . registerBackend ( testCpuBackend , tf . findBackendFactory ( 'cpu' ) ) ;
326
+ expect ( tf . getBackend ( ) ) . not . toBe ( testCpuBackend ) ;
327
+ spyOn ( tf , 'setBackend' ) ;
328
+
329
+ await resetBackend ( testCpuBackend ) ;
330
+
331
+ expect ( tf . setBackend . calls . count ( ) ) . toBe ( 0 ) ;
332
+ tf . removeBackend ( testCpuBackend ) ;
333
+ } ) ;
259
334
} ) ;
260
335
} ) ;
0 commit comments