Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 48367b2

Browse files
authored
Fix bug with dl.customGrad() and add much better error handling for gradient-related API (#731)
1 parent bc95397 commit 48367b2

File tree

7 files changed

+200
-96
lines changed

7 files changed

+200
-96
lines changed

src/engine.ts

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@ interface ScopeState {
3737
}
3838

3939
/**
40-
* @docalias (...inputs: Tensor[]) => {
40+
* @docalias (a: Tensor, b: Tensor,...) => {
4141
* value: Tensor,
42-
* gradFunc: (dy: Tensor) => Tensor[]
42+
* gradFunc: (dy: Tensor) => Tensor|Tensor[]
4343
* }
4444
*/
4545
export type CustomGradientFunc<T extends Tensor> = (...args: Tensor[]) => {
46-
value: T, gradFunc: (dy: T) => Tensor[];
46+
value: T, gradFunc: (dy: T) => Tensor | Tensor[];
4747
};
4848

4949
export interface TensorManager {
@@ -287,23 +287,27 @@ export class Engine implements TensorManager {
287287
}
288288

289289
/**
290-
* Returns gradients of `f` w.r.t. each of the `xs`. The gradients returned
291-
* are of the same length as `xs`, but some might be null if `f` was not
292-
* a function of that `x`. It also takes optional dy to multiply the gradient,
293-
* which defaults to `1`.
290+
* Returns gradients of `f` with respect to each of the `xs`. The gradients
291+
* returned are of the same length as `xs`, but some might be null if `f` was
292+
* not a function of that `x`. It also takes optional dy to multiply the
293+
* gradient, which defaults to `1`.
294294
*/
295-
gradients<T extends Tensor>(f: () => T, xs: Tensor[], dy?: T):
296-
{value: T, grads: Tensor[]} {
295+
gradients<T extends Tensor>(
296+
f: () => T, xs: Tensor[], dy?: T,
297+
allowNoGradients = false): {value: T, grads: Tensor[]} {
297298
return tidy('gradients', () => {
298299
const y = f();
300+
util.assert(
301+
y instanceof Tensor,
302+
'The result y returned by f() must be a tensor.');
299303
// Filter out the nodes that don't connect x => y.
300304
const filteredTape =
301305
tape_util.getFilteredNodesXToY(this.activeTape, xs, y);
302-
if (filteredTape.length === 0 && xs.length > 0) {
306+
if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) {
303307
throw new Error(
304-
`Cannot compute gradient: y is not a function of \`x\`s. ` +
305-
`Make sure the xs you are computing gradients with respect ` +
306-
`to are used inside the gradient function.`);
308+
'Cannot compute gradient of y=f(x) with respect to x. Make sure ' +
309+
'that the f you passed encloses all operations that lead from x ' +
310+
'to y.');
307311
}
308312

309313
const accumulatedGradientMap: {[tensorId: number]: Tensor} = {};
@@ -319,21 +323,50 @@ export class Engine implements TensorManager {
319323

320324
customGrad<T extends Tensor>(f: CustomGradientFunc<T>):
321325
(...args: Tensor[]) => T {
322-
this.customGradientDepth++;
323-
326+
util.assert(
327+
util.isFunction(f),
328+
'The f passed in customGrad(f) must be a function.');
324329
return (...inputs: Tensor[]): T => {
325-
let gradientsFunc: (dy: T) => Tensor[];
330+
util.assert(
331+
inputs.every(t => t instanceof Tensor),
332+
'The args passed in customGrad(f)(x1, x2,...) must all be tensors');
333+
this.customGradientDepth++;
334+
335+
let gradientsFunc: (dy: T) => Tensor | Tensor[];
326336
const gradientsMode = true;
327337
const result = tidy(f.name, () => {
328338
const {value, gradFunc} = f(...inputs);
339+
util.assert(
340+
value instanceof Tensor,
341+
'The function f passed in customGrad(f) must return an object ' +
342+
'where `obj.value` is a tensor');
343+
util.assert(
344+
util.isFunction(gradFunc),
345+
'The function f passed in customGrad(f) must return an object ' +
346+
'where `obj.gradFunc` is a function.');
329347
gradientsFunc = gradFunc;
330348
return value;
331349
}, gradientsMode);
332350

333351
this.customGradientDepth--;
334352

335353
if (this.shouldRecord()) {
336-
this.addTapeNode(inputs, result, gradientsFunc);
354+
const gradFunc = (dy: T): Tensor[] => {
355+
const res = gradientsFunc(dy);
356+
const grads: Tensor[] = Array.isArray(res) ? res : [res];
357+
util.assert(
358+
grads.length === inputs.length,
359+
'The function f passed in customGrad(f) must return an object ' +
360+
'where `obj.gradFunc` is a function that returns the same ' +
361+
'number of tensors as inputs passed to f(...).');
362+
util.assert(
363+
grads.every(t => t instanceof Tensor),
364+
'The function f passed in customGrad(f) must return an object ' +
365+
'where `obj.gradFunc` is a function that returns a list of ' +
366+
'only tensors.');
367+
return grads;
368+
};
369+
this.addTapeNode(inputs, result, gradFunc);
337370
}
338371
return result;
339372
};

src/engine_test.ts

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,12 +233,30 @@ describeWithFlags('gradients', ALL_ENVS, () => {
233233
expectArraysClose(result, [.2, .4]);
234234
});
235235

236+
it('calling grad(f) twice works', () => {
237+
const grad = dl.grad(x => x.square());
238+
239+
const result = grad(dl.tensor1d([.1, .2]));
240+
const result2 = grad(dl.tensor1d([.1, .4]));
241+
expectArraysClose(result, [.2, .4]);
242+
expectArraysClose(result2, [.2, .8]);
243+
});
244+
236245
it('grads(f)', () => {
237246
const grads = dl.grads(x => x.square());
238247
const result = grads([dl.tensor1d([.1, .2])]);
239248
expectArraysClose(result[0], [.2, .4]);
240249
});
241250

251+
it('calling grads(f) twice works', () => {
252+
const grads = dl.grads(x => x.square());
253+
254+
const result = grads([dl.tensor1d([.1, .2])]);
255+
const result2 = grads([dl.tensor1d([.1, .4])]);
256+
expectArraysClose(result[0], [.2, .4]);
257+
expectArraysClose(result2[0], [.2, .8]);
258+
});
259+
242260
it('works with reshape', () => {
243261
const a = dl.tensor2d([1, 2, 3, 4], [2, 2]);
244262
const exponent = dl.tensor1d([2, 2, 2, 2], 'int32');
@@ -390,7 +408,7 @@ describeWithFlags('customGradient', ALL_ENVS, () => {
390408

391409
const customPow = dl.customGrad(a => {
392410
const value = dl.pow(a, b);
393-
const gradFunc = (dy: Tensor) => [dy.mul(dl.scalar(0.1))];
411+
const gradFunc = (dy: Tensor) => dy.mul(dl.scalar(0.1));
394412
return {value, gradFunc};
395413
});
396414

@@ -409,7 +427,7 @@ describeWithFlags('customGradient', ALL_ENVS, () => {
409427

410428
const customPow = dl.customGrad(a => {
411429
const value = dl.pow(a, b);
412-
const gradFunc = (dy: Tensor) => [dy.mul(a)];
430+
const gradFunc = (dy: Tensor) => dy.mul(a);
413431
return {value, gradFunc};
414432
});
415433

@@ -419,6 +437,18 @@ describeWithFlags('customGradient', ALL_ENVS, () => {
419437
// First order: dy * a. Second order: dy.
420438
expectArraysClose(dda, dy);
421439
});
440+
441+
it('calling gradient of custom op twice works', () => {
442+
const customOp = dl.customGrad(x => {
443+
// Override gradient of our custom x ^ 2 op to be dy * abs(x);
444+
return {value: x.square(), gradFunc: dy => dy.mul(x.abs())};
445+
});
446+
const x = dl.tensor1d([-1, -2, 3]);
447+
const grad = dl.grad(x => customOp(x));
448+
449+
expectArraysClose(grad(x), [1, 2, 3]);
450+
expectArraysClose(grad(x), [1, 2, 3]);
451+
});
422452
});
423453

424454
describeWithFlags('memory', ALL_ENVS, () => {

0 commit comments

Comments
 (0)