Skip to content

Commit a0d6250

Browse files
authored
Add async execution detection for benchmarking models (#3610)
FEATURE
1 parent 2e29f98 commit a0d6250

File tree

3 files changed

+128
-47
lines changed

3 files changed

+128
-47
lines changed

e2e/benchmarks/benchmark_util.js

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,16 +92,26 @@ function generateInput(model) {
9292
* wrapping the predict function.
9393
* @param input The input tensor container for model inference.
9494
*/
95-
function wrapPredictFnForModel(model, input) {
95+
function getPredictFnForModel(model, input) {
9696
let predict;
9797
if (model instanceof tf.GraphModel) {
98-
predict = () => model.executeAsync(input);
98+
// Because there's no straightforward way to analyze whether a graph has
99+
// dynamic op, so we try to use `execute` and, if it fails, we will fall
100+
// back to `executeAsync`.
101+
try {
102+
tf.tidy(() => {
103+
model.execute(input);
104+
});
105+
predict = () => model.execute(input);
106+
} catch (e) {
107+
predict = async () => await model.executeAsync(input);
108+
}
99109
} else if (model instanceof tf.LayersModel) {
100110
predict = () => model.predict(input);
101111
} else {
102112
throw new Error(
103-
'Please pass in an instance of tf.GraphModel ' +
104-
'or tf.LayersModel as the first parameter.');
113+
'Predict function was not found. Please provide a tf.GraphModel or ' +
114+
'tf.LayersModel');
105115
}
106116
return predict;
107117
}
@@ -132,7 +142,7 @@ function wrapPredictFnForModel(model, input) {
132142
* @param numRuns The number of rounds for timing the inference process.
133143
*/
134144
async function profileInferenceTimeForModel(model, input, numRuns = 1) {
135-
const predict = wrapPredictFnForModel(model, input);
145+
const predict = getPredictFnForModel(model, input);
136146
return profileInferenceTime(predict, numRuns);
137147
}
138148

@@ -246,7 +256,7 @@ async function downloadValuesFromTensorContainer(tensorContainer) {
246256
* @param input The input tensor container for model inference.
247257
*/
248258
async function profileInferenceMemoryForModel(model, input) {
249-
const predict = wrapPredictFnForModel(model, input);
259+
const predict = getPredictFnForModel(model, input);
250260
return profileInferenceMemory(predict);
251261
}
252262

@@ -291,6 +301,8 @@ async function profileInferenceMemory(predict) {
291301
}
292302

293303
/**
304+
* This function is temporarily used and will be deleted after a new release of
305+
* tf-core. This function modifies
294306
* This function is temporarily used and will be deleted after a new release
295307
* of tf-core. This function modifies
296308
* [`tf.profile`](https://github.com/tensorflow/tfjs/blob/95b5f878218ee45c0f8464386ee01d1f96e78297/tfjs-core/src/engine.ts#L848)

e2e/benchmarks/benchmark_util_test.js

Lines changed: 103 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
*/
2222

2323
describe('benchmark_util', () => {
24-
beforeAll(() => tf.setBackend('cpu'));
24+
beforeEach(() => tf.setBackend('cpu'));
2525

2626
describe('generateInput', () => {
2727
it('LayersModel', () => {
@@ -56,8 +56,81 @@ describe('benchmark_util', () => {
5656
});
5757
});
5858

59+
describe('getPredictFnForModel', () => {
60+
it('graph model with async ops uses executeAsync to run', () => {
61+
const model = new tf.GraphModel();
62+
const input = tf.tensor([1]);
63+
const oldTensorNum = tf.memory().numTensors;
64+
spyOn(model, 'execute').and.callFake(() => {
65+
const leakedTensor = tf.tensor([1]);
66+
throw new Error(
67+
'This model has dynamic ops, ' +
68+
'please use model.executeAsync() instead');
69+
return leakedTensor;
70+
});
71+
spyOn(model, 'executeAsync');
72+
73+
const wrappedPredict = getPredictFnForModel(model, input);
74+
expect(tf.memory().numTensors).toBe(oldTensorNum);
75+
expect(model.execute.calls.count()).toBe(1);
76+
expect(model.execute.calls.first().args).toEqual([input]);
77+
78+
wrappedPredict();
79+
expect(model.execute.calls.count()).toBe(1);
80+
expect(model.executeAsync.calls.count()).toBe(1);
81+
expect(model.executeAsync.calls.first().args).toEqual([input]);
82+
83+
tf.dispose(input);
84+
});
85+
86+
it('graph model without async ops uses execute to run', () => {
87+
const model = new tf.GraphModel();
88+
const input = tf.tensor([1]);
89+
const oldTensorNum = tf.memory().numTensors;
90+
spyOn(model, 'execute').and.callFake(() => {
91+
const leakedTensor = tf.tensor([1]);
92+
});
93+
spyOn(model, 'executeAsync');
94+
95+
const wrappedPredict = getPredictFnForModel(model, input);
96+
expect(tf.memory().numTensors).toBe(oldTensorNum);
97+
expect(model.execute.calls.count()).toBe(1);
98+
expect(model.execute.calls.first().args).toEqual([input]);
99+
100+
wrappedPredict();
101+
expect(model.execute.calls.count()).toBe(2);
102+
expect(model.execute.calls.argsFor(1)).toEqual([input]);
103+
expect(model.executeAsync.calls.count()).toBe(0);
104+
105+
tf.dispose(input);
106+
});
107+
108+
it('layers model uses predict to run', () => {
109+
const model = tf.sequential(
110+
{layers: [tf.layers.dense({units: 1, inputShape: [1]})]});
111+
const input = tf.ones([1, 1]);
112+
spyOn(model, 'predict');
113+
114+
const wrappedPredict = getPredictFnForModel(model, input);
115+
wrappedPredict();
116+
117+
expect(model.predict.calls.count()).toBe(1);
118+
expect(model.predict.calls.first().args).toEqual([input]);
119+
120+
tf.dispose(input);
121+
model.dispose();
122+
});
123+
124+
it('throws when passed in a model that is not layers or graph model',
125+
() => {
126+
const model = {};
127+
const input = [];
128+
expect(() => getPredictFnForModel(model, input)).toThrowError(Error);
129+
});
130+
});
131+
59132
describe('setEnvFlags', () => {
60-
describe('change nothing', () => {
133+
describe('changes nothing when setting empty config or rejecting', () => {
61134
let originalFlags = {};
62135

63136
beforeEach(() => {
@@ -70,7 +143,7 @@ describe('benchmark_util', () => {
70143
expect(tf.env().flags).toEqual(originalFlags);
71144
});
72145

73-
it('untunable flag', async () => {
146+
it('rejects when setting untunable flags', async () => {
74147
const flagConfig = {
75148
IS_BROWSER: false,
76149
};
@@ -80,23 +153,23 @@ describe('benchmark_util', () => {
80153
expect(tf.env().flags).toEqual(originalFlags);
81154
});
82155

83-
it('set a number type flag by a boolean value', async () => {
156+
it('rejects when setting a number flag by a boolean value', async () => {
84157
const flagConfig = {
85158
WEBGL_VERSION: false,
86159
};
87160
expectAsync(setEnvFlags(flagConfig)).toBeRejectedWithError(Error);
88161
expect(tf.env().flags).toEqual(originalFlags);
89162
});
90163

91-
it('set boolean flag by a number', async () => {
164+
it('rejects when setting boolean flag by a number', async () => {
92165
const flagConfig = {
93166
WEBGL_PACK: 1,
94167
};
95168
expectAsync(setEnvFlags(flagConfig)).toBeRejectedWithError(Error);
96169
expect(tf.env().flags).toEqual(originalFlags);
97170
});
98171

99-
it('set flag value out of the range', async () => {
172+
it('rejects when setting flag value out of the range', async () => {
100173
const outOfRangeValue =
101174
Math.max(...TUNABLE_FLAG_VALUE_RANGE_MAP.WEBGL_VERSION) + 1;
102175
const flagConfig = {
@@ -107,7 +180,7 @@ describe('benchmark_util', () => {
107180
});
108181
});
109182

110-
describe('reset flags', () => {
183+
describe('reset simple flags', () => {
111184
beforeEach(() => tf.env().reset());
112185
afterEach(() => tf.env().reset());
113186

@@ -201,13 +274,13 @@ describe('benchmark_util', () => {
201274
beforeEach(() => tf.setBackend('cpu'));
202275
afterAll(() => tf.engine().reset());
203276

204-
it('reset a backend that is not registed', async () => {
277+
it('rejects when resetting a backend that is not registed', async () => {
205278
expectAsync(resetBackend('invalidBackendName'))
206279
.toBeRejectedWithError(
207280
Error, 'invalidBackendName backend is not registed.');
208281
});
209282

210-
it('reset a backend that is not generated', async () => {
283+
it('do nothing when resetting a backend that is not created', async () => {
211284
const testCpuBackend = 'testCpuBackend';
212285
tf.registerBackend(testCpuBackend, tf.findBackendFactory('cpu'));
213286
expect(tf.engine().registry[testCpuBackend]).toBeUndefined();
@@ -223,7 +296,7 @@ describe('benchmark_util', () => {
223296
tf.removeBackend(testCpuBackend);
224297
});
225298

226-
it('reset a backend that has been generated', async () => {
299+
it('reset the backend when resetting an existed backend', async () => {
227300
await tf.ready();
228301
const currentBackend = tf.getBackend();
229302
expect(tf.engine().registry[currentBackend]).toBeDefined();
@@ -238,23 +311,25 @@ describe('benchmark_util', () => {
238311
expect(tf.registerBackend.calls.count()).toBe(1);
239312
});
240313

241-
it('reset the active backend', async () => {
242-
const currentBackend = tf.getBackend();
243-
spyOn(tf, 'setBackend');
244-
await resetBackend(currentBackend);
245-
expect(tf.setBackend.calls.count()).toBe(1);
246-
});
247-
248-
it('reset an inactive backend', async () => {
249-
const testCpuBackend = 'testCpuBackend';
250-
tf.registerBackend(testCpuBackend, tf.findBackendFactory('cpu'));
251-
expect(tf.getBackend()).not.toBe(testCpuBackend);
252-
spyOn(tf, 'setBackend');
253-
254-
await resetBackend(testCpuBackend);
255-
256-
expect(tf.setBackend.calls.count()).toBe(0);
257-
tf.removeBackend(testCpuBackend);
258-
});
314+
it('tf.setBackend is called when resetting the active backend',
315+
async () => {
316+
const currentBackend = tf.getBackend();
317+
spyOn(tf, 'setBackend');
318+
await resetBackend(currentBackend);
319+
expect(tf.setBackend.calls.count()).toBe(1);
320+
});
321+
322+
it('tf.setBackend is not called when resetting an inactive backend',
323+
async () => {
324+
const testCpuBackend = 'testCpuBackend';
325+
tf.registerBackend(testCpuBackend, tf.findBackendFactory('cpu'));
326+
expect(tf.getBackend()).not.toBe(testCpuBackend);
327+
spyOn(tf, 'setBackend');
328+
329+
await resetBackend(testCpuBackend);
330+
331+
expect(tf.setBackend.calls.count()).toBe(0);
332+
tf.removeBackend(testCpuBackend);
333+
});
259334
});
260335
});

e2e/benchmarks/modelConfig.js

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ const benchmarks = {
175175
return async model => {
176176
const res = await model.embed(sentences30);
177177
return res;
178-
}
178+
};
179179
}
180180
},
181181
'USE - batchsize 1': {
@@ -191,7 +191,7 @@ const benchmarks = {
191191
nextIdx += 1;
192192
const res = await model.embed(next);
193193
return res;
194-
}
194+
};
195195
}
196196
},
197197
'posenet': {
@@ -204,7 +204,7 @@ const benchmarks = {
204204
predictFunc: () => {
205205
return async model => {
206206
return model.estimateSinglePose(model.image);
207-
}
207+
};
208208
}
209209
},
210210
'bodypix': {
@@ -217,7 +217,7 @@ const benchmarks = {
217217
predictFunc: () => {
218218
return async model => {
219219
return model.segmentPerson(model.image);
220-
}
220+
};
221221
}
222222
},
223223
'custom': {
@@ -230,15 +230,9 @@ const benchmarks = {
230230
let inferenceInput;
231231
try {
232232
inferenceInput = generateInput(model);
233-
let resultTensor;
234-
if (model instanceof tf.GraphModel && model.executeAsync != null) {
235-
resultTensor = await model.executeAsync(inferenceInput);
236-
} else if (model.predict != null) {
237-
resultTensor = model.predict(inferenceInput);
238-
} else {
239-
throw new Error('Predict function was not found.');
240-
}
241-
return resultTensor;
233+
const predict = getPredictFnForModel(model, inferenceInput);
234+
const inferenceOutput = await predict();
235+
return inferenceOutput;
242236
} finally {
243237
// dispose input tensors
244238
tf.dispose(inferenceInput);

0 commit comments

Comments
 (0)