Skip to content

Commit e035ede

Browse files
Remove duplicate Prod and SparseToDense ops from converter (#7649)
Fixes #7648
1 parent d66c17b commit e035ede

9 files changed

+22
-130
lines changed

tfjs-converter/python/tensorflowjs/op_list/basic_math.json

-30
Original file line numberDiff line numberDiff line change
@@ -796,36 +796,6 @@
796796
}
797797
]
798798
},
799-
{
800-
"tfOpName": "Prod",
801-
"category": "basic_math",
802-
"inputs": [
803-
{
804-
"start": 0,
805-
"name": "x",
806-
"type": "tensor"
807-
},
808-
{
809-
"start": 1,
810-
"name": "axes",
811-
"type": "number[]"
812-
}
813-
],
814-
"attrs": [
815-
{
816-
"tfName": "keep_dims",
817-
"name": "keepDims",
818-
"type": "bool",
819-
"notSupported": true
820-
},
821-
{
822-
"tfName": "T",
823-
"name": "dtype",
824-
"type": "dtype",
825-
"notSupported": true
826-
}
827-
]
828-
},
829799
{
830800
"tfOpName": "LeakyRelu",
831801
"category": "basic_math",

tfjs-converter/python/tensorflowjs/op_list/normalization.json

-35
Original file line numberDiff line numberDiff line change
@@ -216,40 +216,5 @@
216216
"type": "tensor"
217217
}
218218
]
219-
},
220-
{
221-
"tfOpName": "SparseToDense",
222-
"category": "normalization",
223-
"inputs": [
224-
{
225-
"start": 0,
226-
"name": "sparseIndices",
227-
"type": "tensor"
228-
},
229-
{
230-
"start": 1,
231-
"name": "outputShape",
232-
"type": "number[]"
233-
},
234-
{
235-
"start": 2,
236-
"name": "sparseValues",
237-
"type": "tensor"
238-
},
239-
{
240-
"start": 3,
241-
"name": "defaultValue",
242-
"type": "tensor"
243-
}
244-
],
245-
"attrs": [
246-
{
247-
"tfName": "validate_indices",
248-
"name": "validateIndices",
249-
"type": "bool",
250-
"defaultValue": true,
251-
"notSupported": true
252-
}
253-
]
254219
}
255220
]

tfjs-converter/python/tensorflowjs/op_list/reduction.json

+6
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,12 @@
238238
"tfName": "keep_dims",
239239
"name": "keepDims",
240240
"type": "bool"
241+
},
242+
{
243+
"tfName": "T",
244+
"name": "dtype",
245+
"type": "dtype",
246+
"notSupported": true
241247
}
242248
]
243249
},

tfjs-converter/src/operations/executors/basic_math_executor.ts

-4
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,6 @@ export const executeOp: InternalOpExecutor =
159159
getParamValue('x', node, tensorMap, context) as Tensor)];
160160
case 'Rsqrt':
161161
return [ops.rsqrt(getTensor(node.inputNames[0], tensorMap, context))];
162-
case 'Prod':
163-
return [ops.prod(
164-
getParamValue('x', node, tensorMap, context) as Tensor,
165-
getParamValue('axes', node, tensorMap, context) as number[])];
166162
case 'LeakyRelu':
167163
return [ops.leakyRelu(
168164
getParamValue('x', node, tensorMap, context) as Tensor,

tfjs-converter/src/operations/executors/basic_math_executor_test.ts

+1-19
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import {Node} from '../types';
2424

2525
import {executeOp} from './basic_math_executor';
2626
import {RecursiveSpy, spyOnAllFunctions} from './spy_ops';
27-
import {createNumberAttr, createNumberAttrFromIndex, createNumericArrayAttrFromIndex, createTensorAttr, uncapitalize, validateParam} from './test_helper';
27+
import {createNumberAttr, createNumberAttrFromIndex, createTensorAttr, uncapitalize, validateParam} from './test_helper';
2828

2929
describe('basic math', () => {
3030
let node: Node;
@@ -108,24 +108,6 @@ describe('basic math', () => {
108108
expect(validateParam(node, basic_math.json)).toBeTruthy();
109109
});
110110
});
111-
describe('Prod', () => {
112-
it('should call tfOps.prod', () => {
113-
node.op = 'Prod';
114-
node.inputParams['axes'] = createNumericArrayAttrFromIndex(1);
115-
node.inputNames = ['input1', 'input2'];
116-
const input2 = [tfOps.tensor1d([2])];
117-
spyOps.prod.and.returnValue({});
118-
executeOp(node, {input1, input2}, context, spyOpsAsTfOps);
119-
120-
expect(spyOps.prod).toHaveBeenCalledWith(input1[0], [2]);
121-
});
122-
it('should match op def', () => {
123-
node.op = 'Prod';
124-
node.inputParams['axes'] = createNumericArrayAttrFromIndex(1);
125-
126-
expect(validateParam(node, basic_math.json)).toBeTruthy();
127-
});
128-
});
129111
describe('Rsqrt', () => {
130112
it('should call tfOps.rsqrt', () => {
131113
const input1 = [tfOps.scalar(1)];

tfjs-converter/src/operations/executors/normalization_executor.ts

+1-11
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* =============================================================================
1616
*/
1717

18-
import {Scalar, Tensor, Tensor3D, Tensor4D} from '@tensorflow/tfjs-core';
18+
import {Tensor, Tensor3D, Tensor4D} from '@tensorflow/tfjs-core';
1919
// tslint:disable-next-line: no-imports-from-dist
2020
import * as tfOps from '@tensorflow/tfjs-core/dist/ops/ops_for_converter';
2121

@@ -70,16 +70,6 @@ export const executeOp: InternalOpExecutor =
7070
return [ops.logSoftmax(
7171
getParamValue('x', node, tensorMap, context) as Tensor)];
7272
}
73-
case 'SparseToDense': {
74-
return [ops.sparseToDense(
75-
getParamValue('sparseIndices', node, tensorMap, context) as
76-
Tensor,
77-
getParamValue('outputShape', node, tensorMap, context) as Tensor,
78-
getParamValue('sparseValues', node, tensorMap, context) as
79-
number[],
80-
getParamValue('defaultValue', node, tensorMap, context) as
81-
Scalar)];
82-
}
8373
default:
8474
throw TypeError(`Node type ${node.op} is not implemented`);
8575
}

tfjs-converter/src/operations/executors/normalization_executor_test.ts

-29
Original file line numberDiff line numberDiff line change
@@ -185,35 +185,6 @@ describe('normalization', () => {
185185
it('should match json def', () => {
186186
node.op = 'LogSoftmax';
187187

188-
expect(validateParam(node, normalization.json)).toBeTruthy();
189-
});
190-
});
191-
describe('SparseToDense', () => {
192-
it('should call tfOps.sparseToDense', () => {
193-
node.op = 'SparseToDense';
194-
node.inputParams.sparseIndices = createTensorAttr(0);
195-
node.inputParams.outputShape = createNumericArrayAttrFromIndex(1);
196-
node.inputParams.sparseValues = createTensorAttr(2);
197-
node.inputParams.defaultValue = createTensorAttr(3);
198-
node.inputNames = ['input1', 'input2', 'input3', 'input4'];
199-
const input2 = [tfOps.tensor1d([1], 'int32')];
200-
const input3 = [tfOps.scalar(2)];
201-
const input4 = [tfOps.scalar(3)];
202-
spyOps.sparseToDense.and.returnValue({});
203-
executeOp(node, {input1, input2, input3, input4}, context,
204-
spyOpsAsTfOps);
205-
206-
expect(spyOps.sparseToDense)
207-
.toHaveBeenCalledWith(input1[0], [1], input3[0], input4[0]);
208-
});
209-
it('should match json def', () => {
210-
node.op = 'SparseToDense';
211-
delete node.inputParams.x;
212-
node.inputParams.sparseIndices = createTensorAttr(0);
213-
node.inputParams.outputShape = createNumericArrayAttrFromIndex(1);
214-
node.inputParams.sparseValues = createTensorAttr(2);
215-
node.inputParams.defaultValue = createTensorAttr(3);
216-
217188
expect(validateParam(node, normalization.json)).toBeTruthy();
218189
});
219190
});

tfjs-converter/src/operations/executors/reduction_executor_test.ts

+10-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import {Node} from '../types';
2323

2424
import {executeOp} from './reduction_executor';
2525
import {RecursiveSpy, spyOnAllFunctions} from './spy_ops';
26-
import {createBoolAttr, createNumberAttr, createNumberAttrFromIndex, createTensorAttr, uncapitalize, validateParam} from './test_helper';
26+
import {createBoolAttr, createNumberAttr, createNumberAttrFromIndex, createNumericArrayAttrFromIndex, createTensorAttr, uncapitalize, validateParam} from './test_helper';
2727

2828
describe('reduction', () => {
2929
let node: Node;
@@ -156,5 +156,14 @@ describe('reduction', () => {
156156
.toBeTruthy();
157157
});
158158
});
159+
describe('Prod', () => {
160+
it('should match op def', () => {
161+
node.op = 'Prod';
162+
node.inputParams['axis'] = createNumericArrayAttrFromIndex(1);
163+
node.attrParams['keepDims'] = createBoolAttr(true);
164+
165+
expect(validateParam(node, reduction.json)).toBeTruthy();
166+
});
167+
});
159168
});
160169
});

tfjs-converter/src/operations/operation_mapper_test.ts

+4-1
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,10 @@ describe('completeness check', () => {
274274
};
275275
convertedGraph = mapper.transformGraph(graph);
276276
expect(Object.keys(convertedGraph.nodes)).toEqual([tfOp.tfOpName]);
277-
expect(convertedGraph.nodes[tfOp.tfOpName].op).toEqual(tfOp.tfOpName);
277+
const node = convertedGraph.nodes[tfOp.tfOpName];
278+
expect(node.op).toEqual(tfOp.tfOpName);
279+
expect(node.category).withContext(`Op: ${node.op}, category`)
280+
.toEqual(tfOp.category);
278281
});
279282
});
280283
});

0 commit comments

Comments
 (0)