diff --git a/image_classification/resnet50v2_nhwc.js b/image_classification/resnet50v2_nhwc.js index a9cf6212..945813d1 100644 --- a/image_classification/resnet50v2_nhwc.js +++ b/image_classification/resnet50v2_nhwc.js @@ -83,12 +83,9 @@ export class ResNet50V2Nhwc { } if (!downsample && shortcut) { residual = this.builder_.maxPool2d( - input, {windowDimensions: [1, 1], strides, layout, autoPad}); - const padding = this.builder_.constant( - {type: 'int32', dimensions: [4, 2]}, - new Int32Array([0, 0, 1, 1, 1, 1, 0, 0])); - const pad = this.builder_.pad(conv1, padding); - conv2 = await this.buildConv_(pad, nameIndices.concat(['2']), {strides}); + input, {windowDimensions: [2, 2], strides, layout, autoPad}); + conv2 = await this.buildConv_( + conv1, nameIndices.concat(['2']), {strides, padding: [1, 1, 1, 1]}); } else { conv2 = await this.buildConv_( conv1, nameIndices.concat(['2']), {autoPad}); @@ -101,14 +98,10 @@ export class ResNet50V2Nhwc { async load(contextOptions) { this.context_ = await navigator.ml.createContext(contextOptions); this.builder_ = new MLGraphBuilder(this.context_); - const padding = this.builder_.constant( - {type: 'int32', dimensions: [4, 2]}, - new Int32Array([0, 0, 3, 3, 3, 3, 0, 0])); - const input = this.builder_.input('input', {type: 'float32', dimensions: this.inputOptions.inputDimensions}); - const pad = this.builder_.pad(input, padding); - const conv1 = await this.buildConv_(pad, ['', '', '1'], {strides}, false); + const conv1 = await this.buildConv_( + input, ['', '', '1'], {strides, padding: [3, 3, 3, 3]}, false); const pool = this.builder_.maxPool2d( conv1, {windowDimensions: [3, 3], strides, layout, autoPad}); // Block 1 @@ -155,8 +148,7 @@ export class ResNet50V2Nhwc { const fusedBn = await this.buildFusedBatchNorm_(bottleneck13, ['postnorm']); - const mean = this.builder_.reduceMean( - fusedBn, {keepDimensions: true, axes: [1, 2]}); + const mean = this.builder_.averagePool2d(fusedBn, {layout}); const conv2 = await this.buildConv_( mean, ['', '', 'logits'], {autoPad}, false); const reshape = this.builder_.reshape(conv2, [1, null]);