Skip to content

Commit

Permalink
Fix resnet50v2_nhwc model (webmachinelearning#162)
Browse files Browse the repository at this point in the history
* Fix resnet50v2_nhwc model

1. Change 1x1 filter to 2x2 for `maxPool2d` [1].
2. Merge `pad` into `conv2d`.
3. Replace `reduceMean` for axis 1, 2 with `averagePool2d`.

[1]: https://github.com/google/XNNPACK/issues/472s

* Fix lint error
  • Loading branch information
huningxin authored Mar 2, 2023
1 parent b836c99 commit 23e6f66
Showing 1 changed file with 6 additions and 14 deletions.
20 changes: 6 additions & 14 deletions image_classification/resnet50v2_nhwc.js
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,9 @@ export class ResNet50V2Nhwc {
}
if (!downsample && shortcut) {
residual = this.builder_.maxPool2d(
input, {windowDimensions: [1, 1], strides, layout, autoPad});
const padding = this.builder_.constant(
{type: 'int32', dimensions: [4, 2]},
new Int32Array([0, 0, 1, 1, 1, 1, 0, 0]));
const pad = this.builder_.pad(conv1, padding);
conv2 = await this.buildConv_(pad, nameIndices.concat(['2']), {strides});
input, {windowDimensions: [2, 2], strides, layout, autoPad});
conv2 = await this.buildConv_(
conv1, nameIndices.concat(['2']), {strides, padding: [1, 1, 1, 1]});
} else {
conv2 = await this.buildConv_(
conv1, nameIndices.concat(['2']), {autoPad});
Expand All @@ -101,14 +98,10 @@ export class ResNet50V2Nhwc {
async load(contextOptions) {
this.context_ = await navigator.ml.createContext(contextOptions);
this.builder_ = new MLGraphBuilder(this.context_);
const padding = this.builder_.constant(
{type: 'int32', dimensions: [4, 2]},
new Int32Array([0, 0, 3, 3, 3, 3, 0, 0]));

const input = this.builder_.input('input',
{type: 'float32', dimensions: this.inputOptions.inputDimensions});
const pad = this.builder_.pad(input, padding);
const conv1 = await this.buildConv_(pad, ['', '', '1'], {strides}, false);
const conv1 = await this.buildConv_(
input, ['', '', '1'], {strides, padding: [3, 3, 3, 3]}, false);
const pool = this.builder_.maxPool2d(
conv1, {windowDimensions: [3, 3], strides, layout, autoPad});
// Block 1
Expand Down Expand Up @@ -155,8 +148,7 @@ export class ResNet50V2Nhwc {

const fusedBn =
await this.buildFusedBatchNorm_(bottleneck13, ['postnorm']);
const mean = this.builder_.reduceMean(
fusedBn, {keepDimensions: true, axes: [1, 2]});
const mean = this.builder_.averagePool2d(fusedBn, {layout});
const conv2 = await this.buildConv_(
mean, ['', '', 'logits'], {autoPad}, false);
const reshape = this.builder_.reshape(conv2, [1, null]);
Expand Down

0 comments on commit 23e6f66

Please sign in to comment.