Skip to content

Commit

Permalink
Pre-allocate the output buffer and reuse it for every compute
Browse files Browse the repository at this point in the history
  • Loading branch information
Honry committed Jul 9, 2021
1 parent cf1e8a7 commit 6b39ec2
Show file tree
Hide file tree
Showing 18 changed files with 81 additions and 94 deletions.
8 changes: 4 additions & 4 deletions code/samples/dynamic_shape.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ const b = builder.input('b', descB);
const c = builder.matmul(a, b);
const graph = builder.build({'c': c});

function allocateAndCompute(shapeA, shapeB, shapeC) {
function compute(shapeA, shapeB, shapeC) {
const bufferA = new Float32Array(sizeOfShape(shapeA)).fill(0.5);
const bufferB = new Float32Array(sizeOfShape(shapeB)).fill(0.5);
const bufferC = new Float32Array(sizeOfShape(shapeC));
Expand All @@ -24,6 +24,6 @@ function allocateAndCompute(shapeA, shapeB, shapeC) {
console.log(`values: ${bufferC}`);
}

allocateAndCompute([3, 4], [4, 3], [3, 3]);
allocateAndCompute([4, 4], [4, 4], [4, 4]);
allocateAndCompute([5, 4], [4, 5], [5, 5]);
compute([3, 4], [4, 3], [3, 3]);
compute([4, 4], [4, 4], [4, 4]);
compute([5, 4], [4, 5], [5, 5]);
2 changes: 1 addition & 1 deletion code/samples/matmul.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ const c = builder.matmul(a, b);

const graph = builder.build({c});
const bufferA = new Float32Array(sizeOfShape(descA.dimensions)).fill(0.5);
const bufferC = new Float32Array(9);
const bufferC = new Float32Array(sizeOfShape([3, 3]));
const inputs = {'a': bufferA};
const outputs = {'c': bufferC};
graph.compute(inputs, outputs);
Expand Down
10 changes: 6 additions & 4 deletions image_classification/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {SqueezeNetNhwc} from './squeezenet_nhwc.js';
import {ResNet50V2Nchw} from './resnet50v2_nchw.js';
import {ResNet101V2Nhwc} from './resnet101v2_nhwc.js';
import {showProgressComponent, readyShowResultComponents} from '../common/ui.js';
import {getInputTensor, getMedianValue} from '../common/utils.js';
import {getInputTensor, getMedianValue, sizeOfShape} from '../common/utils.js';

const maxWidth = 380;
const maxHeight = 380;
Expand All @@ -27,6 +27,7 @@ let loadTime = 0;
let buildTime = 0;
let computeTime = 0;
let inputOptions;
let outputBuffer;

async function fetchLabels(url) {
const response = await fetch(url);
Expand Down Expand Up @@ -101,7 +102,7 @@ async function renderCamStream() {
const inputBuffer = getInputTensor(camElement, inputOptions);
console.log('- Computing... ');
const start = performance.now();
const outputBuffer = netInstance.compute(inputBuffer);
netInstance.compute(inputBuffer, outputBuffer);
computeTime = (performance.now() - start).toFixed(2);
console.log(` done in ${computeTime} ms.`);
camElement.width = camElement.videoWidth;
Expand Down Expand Up @@ -223,6 +224,8 @@ export async function main() {
netInstance = constructNetObject(instanceType);
inputOptions = netInstance.inputOptions;
labels = await fetchLabels(inputOptions.labelUrl);
outputBuffer =
new Float32Array(sizeOfShape(netInstance.outputDimensions));
isFirstTimeLoad = false;
console.log(`- Model name: ${modelName}, Model layout: ${layout} -`);
// UI shows model loading progress
Expand All @@ -247,10 +250,9 @@ export async function main() {
console.log('- Computing... ');
const computeTimeArray = [];
let medianComputeTime;
let outputBuffer;
for (let i = 0; i < numRuns; i++) {
start = performance.now();
outputBuffer = netInstance.compute(inputBuffer);
netInstance.compute(inputBuffer, outputBuffer);
computeTime = (performance.now() - start).toFixed(2);
console.log(` compute time ${i+1}: ${computeTime} ms`);
computeTimeArray.push(Number(computeTime));
Expand Down
5 changes: 2 additions & 3 deletions image_classification/mobilenet_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export class MobileNetV2Nchw {
labelUrl: './labels/labels1000.txt',
inputDimensions: [1, 3, 224, 224],
};
this.outputDimensions = [1, 1000];
}

async buildConv_(input, name, relu6 = true, options = undefined) {
Expand Down Expand Up @@ -132,11 +133,9 @@ export class MobileNetV2Nchw {
}
}

compute(inputBuffer) {
compute(inputBuffer, outputBuffer) {
const inputs = {'input': inputBuffer};
const outputBuffer = new Float32Array(1000);
const outputs = {'output': outputBuffer};
this.graph_.compute(inputs, outputs);
return outputBuffer;
}
}
5 changes: 2 additions & 3 deletions image_classification/mobilenet_nhwc.js
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export class MobileNetV2Nhwc {
labelUrl: './labels/labels1001.txt',
inputDimensions: [1, 224, 224, 3],
};
this.outputDimensions = [1, 1001];
}

async buildConv_(input, weightsSubName, biasSubName, relu6, options) {
Expand Down Expand Up @@ -131,11 +132,9 @@ export class MobileNetV2Nhwc {
}
}

compute(inputBuffer) {
compute(inputBuffer, outputBuffer) {
const inputs = {'input': inputBuffer};
const outputBuffer = new Float32Array(1001);
const outputs = {'output': outputBuffer};
this.graph_.compute(inputs, outputs);
return outputBuffer;
}
}
5 changes: 2 additions & 3 deletions image_classification/resnet101v2_nhwc.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export class ResNet101V2Nhwc {
labelUrl: './labels/labels1001.txt',
inputDimensions: [1, 299, 299, 3],
};
this.outputDimensions = [1, 1001];
}

async buildConv_(input, nameIndices, options = undefined, relu = true) {
Expand Down Expand Up @@ -180,11 +181,9 @@ export class ResNet101V2Nhwc {
}
}

compute(inputBuffer) {
compute(inputBuffer, outputBuffer) {
const inputs = {'input': inputBuffer};
const outputBuffer = new Float32Array(1001);
const outputs = {'output': outputBuffer};
this.graph_.compute(inputs, outputs);
return outputBuffer;
}
}
5 changes: 2 additions & 3 deletions image_classification/resnet50v2_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export class ResNet50V2Nchw {
labelUrl: './labels/labels1000.txt',
inputDimensions: [1, 3, 224, 224],
};
this.outputDimensions = [1, 1000];
}

async buildConv_(input, name, stageName, options = undefined) {
Expand Down Expand Up @@ -161,11 +162,9 @@ export class ResNet50V2Nchw {
}
}

compute(inputBuffer) {
compute(inputBuffer, outputBuffer) {
const inputs = {'input': inputBuffer};
const outputBuffer = new Float32Array(1000);
const outputs = {'output': outputBuffer};
this.graph_.compute(inputs, outputs);
return outputBuffer;
}
}
5 changes: 2 additions & 3 deletions image_classification/squeezenet_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export class SqueezeNetNchw {
labelUrl: './labels/labels1000.txt',
inputDimensions: [1, 3, 224, 224],
};
this.outputDimensions = [1, 1000];
}

async buildConv_(input, name, options = undefined) {
Expand Down Expand Up @@ -77,11 +78,9 @@ export class SqueezeNetNchw {
}
}

compute(inputBuffer) {
compute(inputBuffer, outputBuffer) {
const inputs = {'input': inputBuffer};
const outputBuffer = new Float32Array(1000);
const outputs = {'output': outputBuffer};
this.graph_.compute(inputs, outputs);
return outputBuffer;
}
}
5 changes: 2 additions & 3 deletions image_classification/squeezenet_nhwc.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export class SqueezeNetNhwc {
labelUrl: './labels/labels1001.txt',
inputDimensions: [1, 224, 224, 3],
};
this.outputDimensions = [1, 1001];
}

async buildConv_(input, name, options = undefined) {
Expand Down Expand Up @@ -85,11 +86,9 @@ export class SqueezeNetNhwc {
}
}

compute(inputBuffer) {
compute(inputBuffer, outputBuffer) {
const inputs = {'input': inputBuffer};
const outputBuffer = new Float32Array(1001);
const outputs = {'output': outputBuffer};
this.graph_.compute(inputs, outputs);
return outputBuffer;
}
}
4 changes: 1 addition & 3 deletions lenet/lenet.js
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,9 @@ export class LeNet {
this.graph_ = this.builder_.build({'output': outputOperand});
}

predict(inputBuffer) {
predict(inputBuffer, outputBuffer) {
const inputs = {'input': inputBuffer};
const outputBuffer = new Float32Array(10);
const outputs = {'output': outputBuffer};
this.graph_.compute(inputs, outputs);
return outputBuffer;
}
}
9 changes: 5 additions & 4 deletions lenet/main.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
'use strict';

import {sizeOfShape} from '../common/utils.js';
import {LeNet} from './lenet.js';
import {Pen} from './pen.js';

Expand Down Expand Up @@ -88,17 +89,17 @@ export async function main() {
}

let start;
let result;
let inferenceTime;
const inferenceTimeArray = [];
const input = getInputFromCanvas();
const outputBuffer = new Float32Array(sizeOfShape([1, 10]));

for (let i = 0; i < n; i++) {
start = performance.now();
result = lenet.predict(input);
lenet.predict(input, outputBuffer);
inferenceTime = performance.now() - start;
console.log(`execution elapsed time: ${inferenceTime.toFixed(2)} ms`);
console.log(`execution result: ${result}`);
console.log(`execution result: ${outputBuffer}`);
inferenceTimeArray.push(inferenceTime);
}

Expand All @@ -114,7 +115,7 @@ export async function main() {
'</span> ms';
}

const classes = topK(Array.from(result));
const classes = topK(Array.from(outputBuffer));
classes.forEach((c, i) => {
console.log(`\tlabel: ${c.label}, probability: ${c.prob}%`);
const labelElement = document.getElementById(`label${i}`);
Expand Down
33 changes: 21 additions & 12 deletions object_detection/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {TinyYoloV2Nhwc} from './tiny_yolov2_nhwc.js';
import {SsdMobilenetV1Nchw} from './ssd_mobilenetv1_nchw.js';
import {SsdMobilenetV1Nhwc} from './ssd_mobilenetv1_nhwc.js';
import {showProgressComponent, readyShowResultComponents} from '../common/ui.js';
import {getInputTensor, getMedianValue} from '../common/utils.js';
import {getInputTensor, getMedianValue, sizeOfShape} from '../common/utils.js';
import * as Yolo2Decoder from './libs/yolo2Decoder.js';
import * as SsdDecoder from './libs/ssdDecoder.js';

Expand All @@ -25,6 +25,7 @@ let loadTime = 0;
let buildTime = 0;
let computeTime = 0;
let inputOptions;
let outputs;

async function fetchLabels(url) {
const response = await fetch(url);
Expand Down Expand Up @@ -99,7 +100,7 @@ async function renderCamStream() {
const inputBuffer = getInputTensor(camElement, inputOptions);
console.log('- Computing... ');
const start = performance.now();
const outputs = netInstance.compute(inputBuffer);
netInstance.compute(inputBuffer, outputs);
computeTime = (performance.now() - start).toFixed(2);
console.log(` done in ${computeTime} ms.`);
camElement.width = camElement.videoWidth;
Expand All @@ -117,31 +118,30 @@ async function drawOutput(inputElement, outputs, labels) {

// Draw output for SSD Mobilenet V1 model
if (modelName === 'ssdmobilenetv1') {
const boxesTensor = outputs.boxes;
const scoresTensor = outputs.scores;
const anchors = SsdDecoder.generateAnchors({});
SsdDecoder.decodeOutputBoxTensor({}, boxesTensor, anchors);
SsdDecoder.decodeOutputBoxTensor({}, outputs.boxes, anchors);
let [totalDetections, boxesList, scoresList, classesList] =
SsdDecoder.nonMaxSuppression({}, boxesTensor, scoresTensor);
SsdDecoder.nonMaxSuppression({}, outputs.boxes, outputs.scores);
boxesList = SsdDecoder.cropSsdBox(
inputElement, totalDetections, boxesList, inputOptions.margin);
SsdDecoder.drawBoxes(
outputElement, totalDetections, inputElement,
boxesList, scoresList, classesList, labels);
} else {
// Draw output for Tiny Yolo V2 model
let outputTensor = outputs.output;
// Transpose 'nchw' output to 'nhwc' for postprocessing
let outputBuffer = outputs.output;
if (layout === 'nchw') {
const tf = navigator.ml.createContext().tf;
const a = tf.tensor(outputTensor, [1, 125, 13, 13], 'float32');
const a =
tf.tensor(outputBuffer, netInstance.outputDimensions, 'float32');
const b = tf.transpose(a, [0, 2, 3, 1]);
const buffer = await b.buffer();
tf.dispose();
outputTensor = buffer.values;
outputBuffer = buffer.values;
}
const decodeOut = Yolo2Decoder.decodeYOLOv2({numClasses: 20},
outputTensor, inputOptions.anchors);
outputBuffer, inputOptions.anchors);
const boxes = Yolo2Decoder.getBoxes(decodeOut, inputOptions.margin);
Yolo2Decoder.drawBoxes(inputElement, outputElement, boxes, labels);
}
Expand Down Expand Up @@ -204,6 +204,16 @@ export async function main() {
netInstance = constructNetObject(instanceType);
inputOptions = netInstance.inputOptions;
labels = await fetchLabels(inputOptions.labelUrl);
if (modelName === 'tinyyolov2') {
outputs = {
'output': new Float32Array(sizeOfShape(netInstance.outputDimensions)),
};
} else {
outputs = {
'boxes': new Float32Array(sizeOfShape([1, 1917, 1, 4])),
'scores': new Float32Array(sizeOfShape([1, 1917, 91])),
};
}
isFirstTimeLoad = false;
console.log(`- Model name: ${modelName}, Model layout: ${layout} -`);
// UI shows model loading progress
Expand All @@ -228,10 +238,9 @@ export async function main() {
console.log('- Computing... ');
const computeTimeArray = [];
let medianComputeTime;
let outputs;
for (let i = 0; i < numRuns; i++) {
start = performance.now();
outputs = netInstance.compute(inputBuffer);
netInstance.compute(inputBuffer, outputs);
computeTime = (performance.now() - start).toFixed(2);
console.log(` compute time ${i+1}: ${computeTime} ms`);
computeTimeArray.push(Number(computeTime));
Expand Down
9 changes: 1 addition & 8 deletions object_detection/ssd_mobilenetv1_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -259,15 +259,8 @@ ${nameArray[1]}_BatchNorm_batchnorm`;
}
}

compute(inputBuffer) {
compute(inputBuffer, outputs) {
const inputs = {'input': inputBuffer};
const boxesBuffer = new Float32Array(sizeOfShape([1, 1917, 1, 4]));
const scoresBuffer = new Float32Array(sizeOfShape([1, 1917, 91]));
const outputs = {
'boxes': boxesBuffer,
'scores': scoresBuffer,
};
this.graph_.compute(inputs, outputs);
return outputs;
}
}
9 changes: 1 addition & 8 deletions object_detection/ssd_mobilenetv1_nhwc.js
Original file line number Diff line number Diff line change
Expand Up @@ -238,15 +238,8 @@ ${nameArray[1]}_BatchNorm_batchnorm`;
}
}

compute(inputBuffer) {
compute(inputBuffer, outputs) {
const inputs = {'input': inputBuffer};
const boxesBuffer = new Float32Array(sizeOfShape([1, 1917, 1, 4]));
const scoresBuffer = new Float32Array(sizeOfShape([1, 1917, 91]));
const outputs = {
'boxes': boxesBuffer,
'scores': scoresBuffer,
};
this.graph_.compute(inputs, outputs);
return outputs;
}
}
6 changes: 2 additions & 4 deletions object_detection/tiny_yolov2_nchw.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export class TinyYoloV2Nchw {
anchors: [1.08, 1.19, 3.42, 4.41, 6.63, 11.38, 9.42, 5.11, 16.62, 10.52],
inputDimensions: [1, 3, 416, 416],
};
this.outputDimensions = [1, 125, 13, 13];
}

async buildConv_(input, name, useBias = false) {
Expand Down Expand Up @@ -103,11 +104,8 @@ export class TinyYoloV2Nchw {
}
}

compute(inputBuffer) {
compute(inputBuffer, outputs) {
const inputs = {'input': inputBuffer};
const outputBuffer = new Float32Array(sizeOfShape([1, 125, 13, 13]));
const outputs = {'output': outputBuffer};
this.graph_.compute(inputs, outputs);
return outputs;
}
}
Loading

0 comments on commit 6b39ec2

Please sign in to comment.