Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit f133e46

Browse files
authored
Align dl.conv2d with tf.conv2d (remove bias as param) (#723)
1 parent 8c5c730 commit f133e46

File tree

19 files changed

+94
-244
lines changed

19 files changed

+94
-244
lines changed

demos/benchmarks/conv_benchmarks.ts

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,14 @@ export class ConvGPUBenchmark implements BenchmarkTest {
4545

4646
let x: dl.Tensor3D = dl.randomUniform(inShape, -1, 1);
4747
let W: dl.Tensor4D;
48-
let b: dl.Tensor1D;
4948

5049
let benchmark: () => dl.Tensor;
5150
if (opType === 'regular') {
5251
const regParams = params as RegularConvParams;
5352
const wShape = dl.conv_util.computeWeightsShape4D(
5453
inDepth, regParams.outDepth, filterSize, filterSize);
5554
W = dl.randomUniform(wShape, -1, 1);
56-
b = dl.randomUniform([regParams.outDepth], -1, 1);
57-
benchmark = () => x.conv2d(W, b, stride, pad);
55+
benchmark = () => x.conv2d(W, stride, pad);
5856
} else if (opType === 'transposed') {
5957
const regParams = params as RegularConvParams;
6058
const wShape = dl.conv_util.computeWeightsShape4D(
@@ -80,9 +78,6 @@ export class ConvGPUBenchmark implements BenchmarkTest {
8078
x.dispose();
8179
W.dispose();
8280
math.dispose();
83-
if (b != null) {
84-
b.dispose();
85-
}
8681

8782
return time;
8883
}

demos/fast-style-transfer/net.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ export class TransformNet implements dl.Model {
8888
input: dl.Tensor3D, strides: number, relu: boolean,
8989
varId: number): dl.Tensor3D {
9090
const y = input.conv2d(
91-
this.variables[this.varName(varId)] as dl.Tensor4D, null,
92-
[strides, strides], 'same');
91+
this.variables[this.varName(varId)] as dl.Tensor4D, [strides, strides],
92+
'same');
9393

9494
const y2 = this.instanceNorm(y, varId + 1);
9595

demos/performance_rnn/performance_rnn.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,8 @@ async function generateStep(loopId: number) {
445445
const logits = outputH.matMul(fcW).add(fcB);
446446

447447
const softmax = logits.as1D().softmax();
448-
const sampledOutput = dl.multinomial(softmax, 1).asScalar();
448+
// TODO(smilkov): Use dl.multinomial once exposed to the user.
449+
const sampledOutput = dl.ENV.math.multinomial(softmax, 1).asScalar();
449450

450451
outputs.push(sampledOutput);
451452
dl.keep(sampledOutput);

src/graph/ops/convolution.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
*/
1717

1818
import {keep, tidy} from '../../globals';
19-
import * as conv_util from '../../ops/conv_util';
2019
import {NDArrayMath} from '../../math';
20+
import * as conv_util from '../../ops/conv_util';
2121
import {Tensor1D, Tensor3D, Tensor4D} from '../../tensor';
2222
import * as util from '../../util';
2323
import {SymbolicTensor} from '../graph';
@@ -84,7 +84,7 @@ export class Convolution2D extends Operation {
8484
tidy(() => {
8585
const dw =
8686
math.conv2dDerFilter(x, dy, filter.shape, this.stride, this.zeroPad);
87-
const db = math.conv2dDerBias(dy);
87+
const db = math.sum(dy, [0, 1] /* axis */);
8888
const dx =
8989
math.conv2dDerInput(x.shape, dy, filter, this.stride, this.zeroPad);
9090
gradientArrays.add(this.wTensor, dw);

src/kernels/backend.ts

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,10 @@ export interface KernelBackend extends TensorStorage, BackendTimer {
150150

151151
step<T extends Tensor>(x: T, alpha: number): T;
152152

153-
conv2d(
154-
x: Tensor4D, filter: Tensor4D, bias: Tensor1D|null,
155-
convInfo: Conv2DInfo): Tensor4D;
153+
conv2d(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo): Tensor4D;
156154
conv2dDerInput(dy: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):
157155
Tensor4D;
158156
conv2dDerFilter(x: Tensor4D, dY: Tensor4D, convInfo: Conv2DInfo): Tensor4D;
159-
conv2dDerBias(dY: Tensor4D): Tensor1D;
160157

161158
depthwiseConv2D(input: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):
162159
Tensor4D;

src/kernels/backend_cpu.ts

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -903,9 +903,7 @@ export class MathBackendCPU implements KernelBackend {
903903
return Tensor.make(x.shape, {values: resultValues}) as T;
904904
}
905905

906-
conv2d(
907-
x: Tensor4D, filter: Tensor4D, bias: Tensor1D|null,
908-
convInfo: Conv2DInfo): Tensor4D {
906+
conv2d(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo): Tensor4D {
909907
const filterHeight = convInfo.filterHeight;
910908
const filterWidth = convInfo.filterWidth;
911909
const padLeft = convInfo.padInfo.left;
@@ -934,8 +932,7 @@ export class MathBackendCPU implements KernelBackend {
934932
}
935933
}
936934
}
937-
const biasVal = (bias != null) ? bias.get(d2) : 0;
938-
y.set(dotProd + biasVal, b, yR, yC, d2);
935+
y.set(dotProd, b, yR, yC, d2);
939936
}
940937
}
941938
}
@@ -1031,23 +1028,6 @@ export class MathBackendCPU implements KernelBackend {
10311028
return dW.toTensor();
10321029
}
10331030

1034-
conv2dDerBias(dy: Tensor4D): Tensor1D {
1035-
const [batchSize, numRows, numCols, outDepth] = dy.shape;
1036-
const values = new Float32Array(outDepth);
1037-
for (let d2 = 0; d2 < outDepth; ++d2) {
1038-
let sum = 0;
1039-
for (let b = 0; b < batchSize; ++b) {
1040-
for (let r = 0; r < numRows; ++r) {
1041-
for (let c = 0; c < numCols; ++c) {
1042-
sum += dy.get(b, r, c, d2);
1043-
}
1044-
}
1045-
}
1046-
values[d2] = sum;
1047-
}
1048-
return ops.tensor1d(values);
1049-
}
1050-
10511031
depthwiseConv2D(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):
10521032
Tensor4D {
10531033
const filterHeight = convInfo.filterHeight;

src/kernels/backend_webgl.ts

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import {BinaryOpProgram} from './webgl/binaryop_gpu';
3737
import {ClipProgram} from './webgl/clip_gpu';
3838
import {ConcatProgram} from './webgl/concat_gpu';
3939
// tslint:disable-next-line:max-line-length
40-
import {Conv2DDerBiasProgram, Conv2DDerFilterProgram, Conv2DDerInputProgram} from './webgl/conv_backprop_gpu';
40+
import {Conv2DDerFilterProgram, Conv2DDerInputProgram} from './webgl/conv_backprop_gpu';
4141
import {Conv2DProgram} from './webgl/conv_gpu';
4242
import {DepthwiseConv2DProgram} from './webgl/conv_gpu_depthwise';
4343
import {FromPixelsProgram} from './webgl/from_pixels_gpu';
@@ -776,12 +776,9 @@ export class MathBackendWebGL implements KernelBackend {
776776
return this.compileAndRun(program, [x]) as T;
777777
}
778778

779-
conv2d(
780-
x: Tensor4D, filter: Tensor4D, bias: Tensor1D|null,
781-
convInfo: Conv2DInfo): Tensor4D {
782-
const program = new Conv2DProgram(convInfo, bias != null);
783-
const inputs = bias != null ? [x, filter, bias] : [x, filter];
784-
return this.compileAndRun(program, inputs);
779+
conv2d(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo): Tensor4D {
780+
const program = new Conv2DProgram(convInfo);
781+
return this.compileAndRun(program, [x, filter]);
785782
}
786783

787784
conv2dDerInput(dy: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):
@@ -795,11 +792,6 @@ export class MathBackendWebGL implements KernelBackend {
795792
return this.compileAndRun(program, [x, dy]);
796793
}
797794

798-
conv2dDerBias(dy: Tensor4D): Tensor1D {
799-
const program = new Conv2DDerBiasProgram(dy.shape);
800-
return this.compileAndRun(program, [dy]);
801-
}
802-
803795
depthwiseConv2D(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo):
804796
Tensor4D {
805797
const program = new DepthwiseConv2DProgram(convInfo);

src/kernels/kernel_registry.ts

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import {CastNode} from './types/cast';
2828
// tslint:disable-next-line:max-line-length
2929
import {ConcatNode} from './types/concat';
3030
// tslint:disable-next-line:max-line-length
31-
import {Conv2DDerBiasNode, Conv2DDerFilterNode, Conv2DDerInputNode, Conv2DNode, DepthwiseConv2DNode} from './types/conv';
31+
import {Conv2DDerFilterNode, Conv2DDerInputNode, Conv2DNode, DepthwiseConv2DNode} from './types/conv';
3232
import {GatherNode} from './types/gather';
3333
import {EqualNode, LogicalNode, WhereNode} from './types/logical';
3434
import {LRN4DNode} from './types/lrn';
@@ -286,8 +286,8 @@ executeKernel<R extends Rank, K extends keyof KernelConfigRegistry<R>, O extends
286286
} else if (kernelName === 'Conv2D') {
287287
const config = inputAndArgs as Conv2DNode['inputAndArgs'];
288288
return backend.conv2d(
289-
config.inputs.x, config.inputs.filter, config.inputs.bias,
290-
config.args.convInfo) as O;
289+
config.inputs.x, config.inputs.filter, config.args.convInfo) as
290+
O;
291291
} else if (kernelName === 'Conv2DDerInput') {
292292
const config = inputAndArgs as Conv2DDerInputNode['inputAndArgs'];
293293
return backend.conv2dDerInput(
@@ -297,9 +297,6 @@ executeKernel<R extends Rank, K extends keyof KernelConfigRegistry<R>, O extends
297297
const config = inputAndArgs as Conv2DDerFilterNode['inputAndArgs'];
298298
return backend.conv2dDerFilter(
299299
config.inputs.x, config.inputs.dy, config.args.convInfo) as O;
300-
} else if (kernelName === 'Conv2DDerBias') {
301-
const config = inputAndArgs as Conv2DDerBiasNode['inputAndArgs'];
302-
return backend.conv2dDerBias(config.inputs.dy) as O;
303300
} else if (kernelName === 'DepthwiseConv2D') {
304301
const config = inputAndArgs as DepthwiseConv2DNode['inputAndArgs'];
305302
return backend.depthwiseConv2D(
@@ -423,7 +420,6 @@ export interface KernelConfigRegistry<R extends Rank> {
423420
Conv2D: Conv2DNode;
424421
Conv2DDerInput: Conv2DDerInputNode;
425422
Conv2DDerFilter: Conv2DDerFilterNode;
426-
Conv2DDerBias: Conv2DDerBiasNode;
427423
DepthwiseConv2D: Conv2DNode;
428424
MaxPool: PoolNode;
429425
MaxPoolBackprop: PoolBackpropNode;

src/kernels/types/conv.ts

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,16 @@
1717

1818
import {Conv2DInfo} from '../../ops/conv_util';
1919
import {KernelNode} from '../../tape_types';
20-
import {Tensor1D, Tensor4D} from '../../tensor';
20+
import {Tensor4D} from '../../tensor';
2121

2222
export interface Conv2DNode extends KernelNode {
2323
inputAndArgs: {
24-
inputs: {x: Tensor4D; filter: Tensor4D; bias?: Tensor1D;};
25-
args: {convInfo: Conv2DInfo;};
24+
inputs: {x: Tensor4D; filter: Tensor4D;}; args: {convInfo: Conv2DInfo;};
2625
};
2726
output: Tensor4D;
2827
gradient: (dy: Tensor4D, y: Tensor4D) => {
2928
x: () => Tensor4D;
3029
filter: () => Tensor4D;
31-
bias?: () => Tensor1D;
3230
};
3331
}
3432

@@ -53,14 +51,6 @@ export interface Conv2DDerFilterNode extends KernelNode {
5351
};
5452
}
5553

56-
export interface Conv2DDerBiasNode extends KernelNode {
57-
inputAndArgs: {inputs: {dy: Tensor4D;};};
58-
output: Tensor1D;
59-
gradient: (dy: Tensor1D, y: Tensor1D) => {
60-
dy: () => Tensor4D;
61-
};
62-
}
63-
6454
export interface DepthwiseConv2DNode extends KernelNode {
6555
inputAndArgs: {
6656
inputs: {x: Tensor4D; filter: Tensor4D;}; args: {convInfo: Conv2DInfo;};

src/kernels/webgl/conv_backprop_gpu.ts

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -134,29 +134,3 @@ export class Conv2DDerInputProgram implements GPGPUProgram {
134134
`;
135135
}
136136
}
137-
138-
export class Conv2DDerBiasProgram implements GPGPUProgram {
139-
variableNames = ['dy'];
140-
outputShape: number[];
141-
userCode: string;
142-
143-
constructor(yShape: [number, number, number, number]) {
144-
const [batchSize, yNumRows, yNumCols, outputDepth] = yShape;
145-
this.outputShape = [outputDepth];
146-
this.userCode = `
147-
void main() {
148-
int d2 = getOutputCoords();
149-
150-
float derBias = 0.0;
151-
for (int b = 0; b < ${batchSize}; b++) {
152-
for (int yR = 0; yR < ${yNumRows}; yR++) {
153-
for (int yC = 0; yC < ${yNumCols}; yC++) {
154-
derBias += getDy(b, yR, yC, d2);
155-
}
156-
}
157-
}
158-
setOutput(derBias);
159-
}
160-
`;
161-
}
162-
}

0 commit comments

Comments
 (0)