Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NPU device type and three fp16 models for image classification #226

Merged
merged 6 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions common/component/component.js
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,18 @@ $(document).ready(async () => {
"title",
"WebNN is supported, disable WebNN Polyfill."
);
// Disable WebNN NPU backend if failed to find a capable NPU adapter.
try {
await navigator.ml.createContext({deviceType: 'npu'});
} catch (error) {
$('#webnn_npu').parent().addClass('disabled');
$('#webnn_npu').parent().addClass('btn-outline-secondary');
$('#webnn_npu').parent().removeClass('btn-outline-info');
$('#webnn_npu').parent().attr(
"title",
"Unable to find a capable NPU adapter."
);
}
}
}
$("#webnnstatus").html("supported").addClass("webnn-status-true");
Expand Down
17 changes: 17 additions & 0 deletions common/ui.js
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,23 @@ export function handleClick(cssSelectors, disabled = true) {
}
}

/**
* Handle button UI, disable or enable the button.
* @param {String} selector, css selector.
* @param {Boolean} disabled, disable or enable the button.
*/
export function handleBtnUI(selector, disabled = true) {
if (disabled) {
$(selector).addClass('disabled');
$(selector).addClass('btn-outline-secondary');
$(selector).removeClass('btn-outline-info');
} else {
$(selector).removeClass('disabled');
$(selector).removeClass('btn-outline-secondary');
$(selector).addClass('btn-outline-info');
}
}

/**
* Show flexible alert messages
* @param {String} msg, alert message.
Expand Down
1 change: 1 addition & 0 deletions image_classification/.eslintrc.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module.exports = {
globals: {
'MLGraphBuilder': 'readonly',
'tf': 'readonly',
},
};
178 changes: 178 additions & 0 deletions image_classification/efficientnet_fp16_nchw.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
'use strict';

import {buildConstantByNpy, weightsOrigin} from '../common/utils.js';

// EfficientNet fp16 model with 'nchw' input layout
export class EfficientNetFP16Nchw {
constructor() {
this.context_ = null;
this.builder_ = null;
this.graph_ = null;
this.targetDataType_ = 'float16';
mingmingtasd marked this conversation as resolved.
Show resolved Hide resolved
this.weightsUrl_ = weightsOrigin() +
'/test-data/models/efficientnet_fp16_nchw_optimized/weights/';
this.inputOptions = {
mean: [0.485, 0.456, 0.406],
std: [0.229, 0.224, 0.225],
norm: true,
inputLayout: 'nchw',
labelUrl: './labels/labels1000.txt',
inputDimensions: [1, 3, 224, 224],
dataType: 'float32',
mingmingtasd marked this conversation as resolved.
Show resolved Hide resolved
};
this.outputDimensions = [1, 1000];
}

async buildConv_(input, name, blockName, clip = false, options = {}) {
let prefix = '';
if (blockName !== '') {
prefix = this.weightsUrl_ + 'block' + blockName + '_conv' +
name;
} else {
prefix = this.weightsUrl_ + 'conv' + name;
}
const weight = buildConstantByNpy(this.builder_, prefix + '_w.npy',
this.targetDataType_ = 'float16');
mingmingtasd marked this conversation as resolved.
Show resolved Hide resolved
options.bias = await buildConstantByNpy(this.builder_, prefix + '_b.npy',
this.targetDataType_ = 'float16');
if (clip) {
return this.builder_.clamp(
this.builder_.conv2d(await input, await weight, options),
{minValue: 0, maxValue: 6});
}
return this.builder_.conv2d(await input, await weight, options);
}

async buildGemm_(input, name) {
const prefix = this.weightsUrl_ + 'dense' + name;
const weightName = prefix + '_w.npy';
const weight = buildConstantByNpy(this.builder_, weightName,
this.targetDataType_ = 'float16');
const biasName = prefix + '_b.npy';
const bias = buildConstantByNpy(this.builder_, biasName,
this.targetDataType_ = 'float16');
const options =
{c: this.builder_.reshape(await bias, [1, 1000])};
return this.builder_.gemm(await input, await weight, options);
}

async buildBottleneck_(input, blockName, group, pad = 1) {
const conv1 = this.buildConv_(input, '0', blockName, true);
const conv2 = this.buildConv_(conv1, '1', blockName, true,
{groups: group, padding: [pad, pad, pad, pad]});
const conv3 = this.buildConv_(conv2, '2', blockName);
return this.builder_.add(await conv3, await input);
}

async buildBottlenecks_(input, blockNames, group, pad = 1) {
let result = input;
for (let i = 0; i < blockNames.length; i++) {
const bottleneck = await this.buildBottleneck_(result, blockNames[i],
group, pad);
result = bottleneck;
}
return result;
}

async load(contextOptions) {
this.context_ = await navigator.ml.createContext(contextOptions);
this.builder_ = new MLGraphBuilder(this.context_);
let data = this.builder_.input('input', {
dataType: this.inputOptions.dataType,
dimensions: this.inputOptions.inputDimensions,
});
data = this.builder_.cast(data, 'float16');
// Block 0
const conv1 = this.buildConv_(
data, '0', '0', true, {padding: [0, 1, 0, 1], strides: [2, 2]});
const conv2 = this.buildConv_(conv1, '1', '0', true,
{groups: 32, padding: [1, 1, 1, 1]});
const conv3 = this.buildConv_(conv2, '2', '0');

// Block 1
const conv4 = this.buildConv_(conv3, '0', '1', true);
const conv5 = this.buildConv_(conv4, '1', '1', true,
{groups: 144, padding: [0, 1, 0, 1], strides: [2, 2]});
const conv6 = this.buildConv_(conv5, '2', '1');

// Block 2~4
const bottleneck4 = this.buildBottlenecks_(conv6,
['2', '3', '4'], 192);

// Block 5
const conv7 = this.buildConv_(bottleneck4, '0', '5', true);
const conv8 = this.buildConv_(conv7, '1', '5', true,
{groups: 192, padding: [1, 2, 1, 2], strides: [2, 2]});
const conv9 = this.buildConv_(conv8, '2', '5');

// Block 6~8
const bottleneck8 = this.buildBottlenecks_(conv9,
['6', '7', '8'], 336, 2);

// Block 9
const conv10 = this.buildConv_(bottleneck8, '0', '9', true);
const conv11 = this.buildConv_(conv10, '1', '9', true,
{groups: 336, padding: [0, 1, 0, 1], strides: [2, 2]});
const conv12 = this.buildConv_(conv11, '2', '9');

// Block 10~14
const bottleneck14 = this.buildBottlenecks_(conv12,
['10', '11', '12', '13', '14'], 672);

// Block 15
const conv13 = this.buildConv_(bottleneck14, '0', '15', true);
const conv14 = this.buildConv_(conv13, '1', '15', true,
{groups: 672, padding: [2, 2, 2, 2]});
const conv15 = this.buildConv_(conv14, '2', '15');

// Block 16~20
const bottleneck20 = await this.buildBottlenecks_(conv15,
['16', '17', '18', '19', '20'], 960, 2);

// Block 21
const conv16 = this.buildConv_(bottleneck20, '0', '21', true);
const conv17 = this.buildConv_(conv16, '1', '21', true,
{groups: 960, padding: [1, 2, 1, 2], strides: [2, 2]});
const conv18 = this.buildConv_(conv17, '2', '21');

// Block 22~28
const bottleneck28 = this.buildBottlenecks_(conv18,
['22', '23', '24', '25', '26', '27', '28'], 1632, 2);

// Block 29
const conv19 = this.buildConv_(bottleneck28, '0', '29', true);
const conv20 = this.buildConv_(conv19, '1', '29', true,
{groups: 1632, padding: [1, 1, 1, 1]});
const conv21 = this.buildConv_(conv20, '2', '29');

const conv22 = this.buildConv_(conv21, '0', '', true);
const pool1 = this.builder_.averagePool2d(await conv22);
const reshape = this.builder_.reshape(pool1, [1, 1280]);
const gemm = this.buildGemm_(reshape, '0');
if (contextOptions.deviceType === 'npu') {
return this.builder_.cast(await gemm, 'float32');
mingmingtasd marked this conversation as resolved.
Show resolved Hide resolved
} else {
const softmax = this.builder_.softmax(await gemm);
mingmingtasd marked this conversation as resolved.
Show resolved Hide resolved
return this.builder_.cast(softmax, 'float32');
}
}

async build(outputOperand) {
this.graph_ = await this.builder_.build({'output': outputOperand});
}

// Release the constant tensors of a model
dispose() {
// dispose() is only available in webnn-polyfill
if (this.graph_ !== null && 'dispose' in this.graph_) {
this.graph_.dispose();
}
}

async compute(inputBuffer, outputBuffer) {
const inputs = {'input': inputBuffer};
const outputs = {'output': outputBuffer};
const results = await this.context_.compute(this.graph_, inputs, outputs);
return results;
}
}
46 changes: 37 additions & 9 deletions image_classification/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_gpu" autocomplete="off">WebNN (GPU)
</label>
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_npu" autocomplete="off">WebNN (NPU)
mingmingtasd marked this conversation as resolved.
Show resolved Hide resolved
</label>
</div>
</div>
</div>
Expand All @@ -61,21 +64,43 @@
</div>
</div>
</div> -->
<div class="row mb-2 align-items-center">
<div class="col-1 col-md-1">
<span>Data Type</span>
</div>
<div class="col-md-auto">
<div class="btn-group-toggle" data-toggle="buttons" id="dataTypeBtns">
<label class="btn btn-outline-info" id="float32Label" active>
<input type="radio" name="layout" id="float32" autocomplete="off" checked>Float32
</label>
<label class="btn btn-outline-info" id="float16Label">
<input type="radio" name="layout" id="float16" autocomplete="off">Float16
</label>
</div>
</div>
</div>
<div class="row align-items-center">
<div class="col col-md-1">
<span>Model</span>
</div>
<div class="col-md-auto">
<div class="btn-group-toggle" data-toggle="buttons" id="modelBtns">
<label class="btn btn-outline-info">
<input type="radio" name="model" id="mobilenet" autocomplete="off">MobileNet V2
</label>
<label class="btn btn-outline-info">
<input type="radio" name="model" id="squeezenet" autocomplete="off">SqueezeNet
</label>
<label class="btn btn-outline-info">
<input type="radio" name="model" id="resnet50" autocomplete="off">ResNet V2 50
</label>
<label class="btn btn-outline-info">
<input type="radio" name="model" id="mobilenet" autocomplete="off">MobileNet V2
</label>
<label class="btn btn-outline-info">
<input type="radio" name="model" id="squeezenet" autocomplete="off">SqueezeNet
</label>
<label class="btn btn-outline-info">
<input type="radio" name="model" id="resnet50v2" autocomplete="off">ResNet 50 V2
</label>
<label class="btn btn-outline-info">
<input type="radio" name="model" id="resnet50v1" autocomplete="off">ResNet 50 V1
</label>
<label class="btn btn-outline-info">
<input type="radio" name="model" id="efficientnet" autocomplete="off">EfficientNet
</label>

</div>
</div>
</div>
Expand Down Expand Up @@ -213,6 +238,9 @@ <h2 class="text-uppercase text-info">No model selected</h2>
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/umd/popper.min.js"
integrity="sha384-9/reFTGAW83EW2RDu2S0VKaIzap3H66lZH81PoYlFhbGU+6BZp6G7niu735Sk7lN"
crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js"
integrity="sha256-28ZvjeNGrGNEIj9/2D8YAPE6Vm5JSvvDs+LI4ED31x8="
crossorigin="anonymous"></script>
<script src="https://stackpath.bootstrapcdn.com/bootstrap/4.5.2/js/bootstrap.min.js"
integrity="sha384-B4gt1jrGC7Jh4AgTPSdUtOBvfO8shuf57BaghqFfPlYxofvL8/KUEfYiJOMMV+rV"
crossorigin="anonymous"></script>
Expand Down
Loading