@@ -37,13 +37,13 @@ interface ScopeState {
37
37
}
38
38
39
39
/**
40
- * @docalias (...inputs : Tensor[] ) => {
40
+ * @docalias (a : Tensor, b: Tensor,... ) => {
41
41
* value: Tensor,
42
- * gradFunc: (dy: Tensor) => Tensor[]
42
+ * gradFunc: (dy: Tensor) => Tensor|Tensor []
43
43
* }
44
44
*/
45
45
export type CustomGradientFunc < T extends Tensor > = ( ...args : Tensor [ ] ) => {
46
- value : T , gradFunc : ( dy : T ) => Tensor [ ] ;
46
+ value : T , gradFunc : ( dy : T ) => Tensor | Tensor [ ] ;
47
47
} ;
48
48
49
49
export interface TensorManager {
@@ -287,23 +287,27 @@ export class Engine implements TensorManager {
287
287
}
288
288
289
289
/**
290
- * Returns gradients of `f` w.r.t. each of the `xs`. The gradients returned
291
- * are of the same length as `xs`, but some might be null if `f` was not
292
- * a function of that `x`. It also takes optional dy to multiply the gradient,
293
- * which defaults to `1`.
290
+ * Returns gradients of `f` with respect to each of the `xs`. The gradients
291
+ * returned are of the same length as `xs`, but some might be null if `f` was
292
+ * not a function of that `x`. It also takes optional dy to multiply the
293
+ * gradient, which defaults to `1`.
294
294
*/
295
- gradients < T extends Tensor > ( f : ( ) => T , xs : Tensor [ ] , dy ?: T ) :
296
- { value : T , grads : Tensor [ ] } {
295
+ gradients < T extends Tensor > (
296
+ f : ( ) => T , xs : Tensor [ ] , dy ?: T ,
297
+ allowNoGradients = false ) : { value : T , grads : Tensor [ ] } {
297
298
return tidy ( 'gradients' , ( ) => {
298
299
const y = f ( ) ;
300
+ util . assert (
301
+ y instanceof Tensor ,
302
+ 'The result y returned by f() must be a tensor.' ) ;
299
303
// Filter out the nodes that don't connect x => y.
300
304
const filteredTape =
301
305
tape_util . getFilteredNodesXToY ( this . activeTape , xs , y ) ;
302
- if ( filteredTape . length === 0 && xs . length > 0 ) {
306
+ if ( ! allowNoGradients && filteredTape . length === 0 && xs . length > 0 ) {
303
307
throw new Error (
304
- ` Cannot compute gradient: y is not a function of \`x\`s. ` +
305
- `Make sure the xs you are computing gradients with respect ` +
306
- ` to are used inside the gradient function.` ) ;
308
+ ' Cannot compute gradient of y=f(x) with respect to x. Make sure ' +
309
+ 'that the f you passed encloses all operations that lead from x ' +
310
+ ' to y.' ) ;
307
311
}
308
312
309
313
const accumulatedGradientMap : { [ tensorId : number ] : Tensor } = { } ;
@@ -319,21 +323,50 @@ export class Engine implements TensorManager {
319
323
320
324
customGrad < T extends Tensor > ( f : CustomGradientFunc < T > ) :
321
325
( ...args : Tensor [ ] ) => T {
322
- this . customGradientDepth ++ ;
323
-
326
+ util . assert (
327
+ util . isFunction ( f ) ,
328
+ 'The f passed in customGrad(f) must be a function.' ) ;
324
329
return ( ...inputs : Tensor [ ] ) : T => {
325
- let gradientsFunc : ( dy : T ) => Tensor [ ] ;
330
+ util . assert (
331
+ inputs . every ( t => t instanceof Tensor ) ,
332
+ 'The args passed in customGrad(f)(x1, x2,...) must all be tensors' ) ;
333
+ this . customGradientDepth ++ ;
334
+
335
+ let gradientsFunc : ( dy : T ) => Tensor | Tensor [ ] ;
326
336
const gradientsMode = true ;
327
337
const result = tidy ( f . name , ( ) => {
328
338
const { value, gradFunc} = f ( ...inputs ) ;
339
+ util . assert (
340
+ value instanceof Tensor ,
341
+ 'The function f passed in customGrad(f) must return an object ' +
342
+ 'where `obj.value` is a tensor' ) ;
343
+ util . assert (
344
+ util . isFunction ( gradFunc ) ,
345
+ 'The function f passed in customGrad(f) must return an object ' +
346
+ 'where `obj.gradFunc` is a function.' ) ;
329
347
gradientsFunc = gradFunc ;
330
348
return value ;
331
349
} , gradientsMode ) ;
332
350
333
351
this . customGradientDepth -- ;
334
352
335
353
if ( this . shouldRecord ( ) ) {
336
- this . addTapeNode ( inputs , result , gradientsFunc ) ;
354
+ const gradFunc = ( dy : T ) : Tensor [ ] => {
355
+ const res = gradientsFunc ( dy ) ;
356
+ const grads : Tensor [ ] = Array . isArray ( res ) ? res : [ res ] ;
357
+ util . assert (
358
+ grads . length === inputs . length ,
359
+ 'The function f passed in customGrad(f) must return an object ' +
360
+ 'where `obj.gradFunc` is a function that returns the same ' +
361
+ 'number of tensors as inputs passed to f(...).' ) ;
362
+ util . assert (
363
+ grads . every ( t => t instanceof Tensor ) ,
364
+ 'The function f passed in customGrad(f) must return an object ' +
365
+ 'where `obj.gradFunc` is a function that returns a list of ' +
366
+ 'only tensors.' ) ;
367
+ return grads ;
368
+ } ;
369
+ this . addTapeNode ( inputs , result , gradFunc ) ;
337
370
}
338
371
return result ;
339
372
} ;
0 commit comments