Skip to content

Commit d0bfa88

Browse files
authored
[converter] Fix fusedMatMul bug. (#6455)
* [converter] Fix fusedMatMul bug. * Add default value to core op.
1 parent 5794028 commit d0bfa88

File tree

5 files changed

+45
-5
lines changed

5 files changed

+45
-5
lines changed

tfjs-converter/python/tensorflowjs/op_list/convolution.json

+3-2
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,8 @@
394394
{
395395
"tfName": "leakyrelu_alpha",
396396
"name": "leakyreluAlpha",
397-
"type": "number"
397+
"type": "number",
398+
"defaultValue": 0.2
398399
}
399400
]
400401
},
@@ -685,4 +686,4 @@
685686
}
686687
]
687688
}
688-
]
689+
]

tfjs-converter/python/tensorflowjs/op_list/matrices.json

+7-1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@
5050
"type": "bool",
5151
"defaultValue": false
5252
},
53+
{
54+
"tfName": "leakyrelu_alpha",
55+
"name": "leakyreluAlpha",
56+
"type": "number",
57+
"defaultValue": 0.2
58+
},
5359
{
5460
"tfName": "T",
5561
"name": "dtype",
@@ -220,4 +226,4 @@
220226
}
221227
]
222228
}
223-
]
229+
]

tfjs-converter/src/operations/executors/matrices_executor_test.ts

+14-1
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
import {Tensor} from '@tensorflow/tfjs-core';
1919
// tslint:disable-next-line: no-imports-from-dist
2020
import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';
21+
import * as matrices from '../op_list/matrices';
2122

2223
import {ExecutionContext} from '../../executor/execution_context';
2324
import {Node} from '../types';
2425

2526
import {executeOp} from './matrices_executor';
26-
import {createBoolAttr, createNumberAttr, createNumericArrayAttr, createStrArrayAttr, createStrAttr, createTensorAttr, createTensorsAttr} from './test_helper';
27+
import {createBoolAttr, createNumberAttr, createNumericArrayAttr, createStrArrayAttr, createStrAttr, createTensorAttr, createTensorsAttr, validateParam} from './test_helper';
2728

2829
describe('matrices', () => {
2930
let node: Node;
@@ -130,6 +131,18 @@ describe('matrices', () => {
130131
leakyreluAlpha: 0.3
131132
});
132133
});
134+
it('should match json def.', () => {
135+
node.op = '_FusedMatMul';
136+
137+
node.attrParams['fusedOps'] =
138+
createStrArrayAttr(['biasadd', 'leakyrelu']);
139+
node.attrParams['numArgs'] = createNumberAttr(1);
140+
node.attrParams.transposeA = createBoolAttr(true);
141+
node.attrParams.transposeB = createBoolAttr(false);
142+
node.attrParams.leakyreluAlpha = createNumberAttr(0.3);
143+
144+
expect(validateParam(node, matrices.json)).toBeTruthy();
145+
});
133146
});
134147
describe('BatchMatMul', () => {
135148
it('should call tfOps.matMul', () => {

tfjs-core/src/ops/fused/fused_mat_mul_test.ts

+20
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,26 @@ describeWithFlags('fused matmul', ALL_ENVS, () => {
111111
expectArraysClose(await c.data(), [0, 8, -0.9000000357627869, 20]);
112112
});
113113

114+
it('fused A x B with leakyrelu not provided.', async () => {
115+
const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
116+
const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);
117+
const transposeA = false;
118+
const transposeB = false;
119+
120+
const c = tf.fused.matMul({
121+
a,
122+
b,
123+
transposeA,
124+
transposeB,
125+
bias: null,
126+
activation: 'leakyrelu'
127+
});
128+
129+
expect(c.shape).toEqual([2, 2]);
130+
// leakyRelu should use default alpha=0.2.
131+
expectArraysClose(await c.data(), [0, 8, -0.6000000238418579, 20]);
132+
});
133+
114134
it('fused A x B with sigmoid', async () => {
115135
const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
116136
const b = tf.tensor2d([0, 1, -3, 2, 2, 1], [3, 2]);

tfjs-core/src/ops/fused/mat_mul.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ function fusedMatMul_({
6363
bias,
6464
activation = 'linear',
6565
preluActivationWeights,
66-
leakyreluAlpha,
66+
leakyreluAlpha = 0.2,
6767
}: {
6868
a: Tensor|TensorLike,
6969
b: Tensor|TensorLike,

0 commit comments

Comments
 (0)