Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 611838b

Browse files
author
Nikhil Thorat
authored
Add variable() to chain API. (#730)
1 parent 8ea6dd8 commit 611838b

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

src/tensor.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,11 @@ export class Tensor<R extends Rank = Rank> {
879879
return ops.localResponseNormalization(
880880
this, radius, bias, alpha, beta, normRegion);
881881
}
882+
883+
variable(trainable = true, name?: string, dtype?: DataType): Variable<R> {
884+
this.throwIfDisposed();
885+
return Variable.variable(this, trainable, name, dtype);
886+
}
882887
}
883888

884889
/** @doclink Tensor */

src/variable_test.ts

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@ describeWithFlags('variable', ALL_ENVS, () => {
3030
expectArraysClose(v, [4, 5, 6]);
3131
});
3232

33+
it('simple chain assign', () => {
34+
const v = dl.tensor1d([1, 2, 3]).variable();
35+
expectArraysClose(v, [1, 2, 3]);
36+
37+
v.assign(dl.tensor1d([4, 5, 6]));
38+
expectArraysClose(v, [4, 5, 6]);
39+
});
40+
3341
it('default names are unique', () => {
3442
const v = variable(dl.tensor1d([1, 2, 3]));
3543
expect(v.name).not.toBeNull();
@@ -50,13 +58,19 @@ describeWithFlags('variable', ALL_ENVS, () => {
5058
.toThrowError();
5159
});
5260

53-
it('math ops can take variables', () => {
61+
it('ops can take variables', () => {
5462
const value = dl.tensor1d([1, 2, 3]);
5563
const v = variable(value);
5664
const res = dl.sum(v);
5765
expectArraysClose(res, [6]);
5866
});
5967

68+
it('chained variables works', () => {
69+
const v = dl.tensor1d([1, 2, 3]).variable();
70+
const res = dl.sum(v);
71+
expectArraysClose(res, [6]);
72+
});
73+
6074
it('variables are not affected by tidy', () => {
6175
let v: Variable<Rank.R1>;
6276
expect(dl.memory().numTensors).toBe(0);

0 commit comments

Comments
 (0)