diff --git a/face_recognition/facenet_nchw.js b/face_recognition/facenet_nchw.js index 1a7d76fd..d33ef09d 100644 --- a/face_recognition/facenet_nchw.js +++ b/face_recognition/facenet_nchw.js @@ -233,7 +233,10 @@ export class FaceNetNchw { block8_5, 6, ['977', '1104', '978', '1080', '1086'], false); const averagePool = this.builder_.averagePool2d(block8_6); - const squeeze = this.builder_.squeeze(averagePool, {axes: [2, 3]}); + // Use reshape to implement squeeze(averagePool, {axes: [2, 3]}); + const squeezed_shape = averagePool.shape(); + squeezed_shape.splice(2, 2); + const squeeze = this.builder_.reshape(averagePool, squeezed_shape); const gemm = await this.buildGemm_(squeeze); // L2Normalization will be handled in post-processing return gemm;