Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 27db87b

Browse files
manrajgroverNikhil Thorat
authored andcommitted
AdamaxOpt: Adds Adamax Optimizer with eager mode (#691)
* Adamax: Migrated to math * Adamax Optimizer: Adds eager gradient * Adamax Optimizer: Test cases updated * Merge branch 'master' into add-adamax-opt * Adamax Optimizer: Renames variables * Adamax Opt: Now uses variables * Merge branch 'master' into add-adamax-opt * Adamax Opt: Simplifies code * Adamax Optimizer: Now allows decay * Merge branch 'master' into add-adamax-opt * Adamax Opt: Graphs support decay * Merge branch 'add-adamax-opt' of https://github.com/ManrajGrover/deeplearnjs into add-adamax-opt * Merge branch 'master' into add-adamax-opt * Merge branch 'master' into add-adamax-opt * Merge branch 'master' into add-adamax-opt * Adamax Opt: Fixes test cases * Merge branch 'master' into add-adamax-opt
1 parent 9b3fa6c commit 27db87b

File tree

6 files changed

+431
-90
lines changed

6 files changed

+431
-90
lines changed

src/graph/optimizers/adamax_optimizer_test.ts

Lines changed: 0 additions & 87 deletions
This file was deleted.

src/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ export {Graph, SymbolicTensor} from './graph/graph';
3838
export {GraphRunner, GraphRunnerEventObserver, MetricReduction} from './graph/graph_runner';
3939
// tslint:disable-next-line:max-line-length
4040
export {ConstantInitializer, Initializer, OnesInitializer, RandomNormalInitializer, RandomTruncatedNormalInitializer, RandomUniformInitializer, TensorInitializer, VarianceScalingInitializer, ZerosInitializer} from './graph/initializers';
41-
export {AdamaxOptimizer} from './graph/optimizers/adamax_optimizer';
4241
export {CostReduction, FeedEntry, Session} from './graph/session';
4342
export {MathBackendCPU, NDArrayMathCPU} from './kernels/backend_cpu';
4443
// tslint:disable-next-line:max-line-length
@@ -51,6 +50,7 @@ export {LSTMCell} from './ops/lstm';
5150
export {AdadeltaOptimizer} from './optimizers/adadelta_optimizer';
5251
export {AdagradOptimizer} from './optimizers/adagrad_optimizer';
5352
export {AdamOptimizer} from './optimizers/adam_optimizer';
53+
export {AdamaxOptimizer} from './optimizers/adamax_optimizer';
5454
export {MomentumOptimizer} from './optimizers/momentum_optimizer';
5555
export {Optimizer} from './optimizers/optimizer';
5656
export {RMSPropOptimizer} from './optimizers/rmsprop_optimizer';

src/optimizers/adamax_optimizer.ts

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
/**
2+
* @license
3+
* Copyright 2018 Google Inc. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {ENV} from '../environment';
19+
import {keep, tidy} from '../globals';
20+
import {Node} from '../graph/graph';
21+
import {SessionRuntime} from '../graph/session';
22+
// tslint:disable-next-line:max-line-length
23+
import {SummedTensorArrayMap, TensorArrayMap} from '../graph/tensor_array_map';
24+
import {NDArrayMath} from '../math';
25+
import {scalar, zerosLike} from '../ops/ops';
26+
import {Scalar, Tensor, Variable} from '../tensor';
27+
import {variable} from '../tensor';
28+
import {NamedVariableMap} from '../types';
29+
30+
import {Optimizer} from './optimizer';
31+
32+
export class AdamaxOptimizer extends Optimizer {
33+
private c: Scalar;
34+
private eps: Scalar;
35+
private accBeta1: Variable;
36+
private beta1: Scalar;
37+
private beta2: Scalar;
38+
private decay: Scalar;
39+
private oneMinusBeta1: Scalar;
40+
private one: Scalar;
41+
private iteration: Variable;
42+
43+
private accumulatedFirstMoment: NamedVariableMap = {};
44+
private accumulatedWeightedInfNorm: NamedVariableMap = {};
45+
46+
constructor(
47+
protected learningRate: number, beta1: number, beta2: number,
48+
epsilon = 1e-8, decay = 0.0,
49+
/** @deprecated */ specifiedVariableList?: Node[]) {
50+
super(learningRate, specifiedVariableList);
51+
this.c = keep(scalar(-learningRate));
52+
this.eps = keep(scalar(epsilon));
53+
// b1, b2 keep initial value of beta* hyperparameters.
54+
this.beta1 = keep(scalar(beta1));
55+
this.beta2 = keep(scalar(beta2));
56+
57+
this.decay = keep(scalar(decay));
58+
59+
tidy(() => {
60+
this.iteration = variable(scalar(0));
61+
this.accBeta1 = variable(scalar(beta1));
62+
});
63+
64+
this.oneMinusBeta1 = keep(scalar(1 - beta1));
65+
this.one = keep(scalar(1));
66+
}
67+
68+
applyGradients(variableGradients: NamedVariableMap) {
69+
tidy(() => {
70+
const oneMinusAccBeta1 = this.one.sub(this.accBeta1);
71+
const lr = this.c.div(this.one.add(this.decay.mul(this.iteration)));
72+
73+
for (const variableName in variableGradients) {
74+
const value = ENV.engine.registeredVariables[variableName];
75+
if (this.accumulatedFirstMoment[variableName] == null) {
76+
const trainable = false;
77+
this.accumulatedFirstMoment[variableName] =
78+
variable(zerosLike(value), trainable);
79+
}
80+
if (this.accumulatedWeightedInfNorm[variableName] == null) {
81+
const trainable = false;
82+
this.accumulatedWeightedInfNorm[variableName] =
83+
variable(zerosLike(value), trainable);
84+
}
85+
86+
const gradient = variableGradients[variableName];
87+
const firstMoment = this.accumulatedFirstMoment[variableName];
88+
const weightedInfNorm = this.accumulatedWeightedInfNorm[variableName];
89+
90+
const newFirstMoment =
91+
this.beta1.mul(firstMoment).add(this.oneMinusBeta1.mul(gradient));
92+
93+
const ut0 = this.beta2.mul(weightedInfNorm);
94+
const ut1 = gradient.abs();
95+
96+
const newWeightedInfNorm = ut0.maximum(ut1);
97+
98+
this.accumulatedFirstMoment[variableName].assign(newFirstMoment);
99+
this.accumulatedWeightedInfNorm[variableName].assign(
100+
newWeightedInfNorm);
101+
102+
const newValue =
103+
lr.div(oneMinusAccBeta1)
104+
.mul(newFirstMoment.div(this.eps.add(newWeightedInfNorm)))
105+
.add(value);
106+
107+
value.assign(newValue);
108+
}
109+
110+
this.iteration.assign(this.iteration.add(this.one));
111+
this.accBeta1.assign(this.accBeta1.mul(this.beta1));
112+
});
113+
}
114+
115+
beforeBatch(
116+
math: NDArrayMath, batchSize: number, runtime: SessionRuntime,
117+
activationArrayMap: TensorArrayMap,
118+
gradientArrayMap: SummedTensorArrayMap) {
119+
super.beforeBatch(
120+
math, batchSize, runtime, activationArrayMap, gradientArrayMap);
121+
122+
if (this.firstMomentGraph.size() === 0) {
123+
this.variableNodes.forEach(node => {
124+
this.firstMomentGraph.set(node.output, Tensor.zeros(node.output.shape));
125+
});
126+
}
127+
128+
if (this.weightedInfNormGraph.size() === 0) {
129+
this.variableNodes.forEach(node => {
130+
this.weightedInfNormGraph.set(
131+
node.output, Tensor.zeros(node.output.shape));
132+
});
133+
}
134+
}
135+
136+
afterBatch(
137+
math: NDArrayMath, batchSize: number, runtime: SessionRuntime,
138+
activationArrayMap: TensorArrayMap,
139+
gradientArrayMap: SummedTensorArrayMap) {
140+
tidy(() => {
141+
const lr = this.cGraph.div(this.one.add(this.decay.mul(this.iteration)));
142+
143+
this.variableNodes.forEach(node => {
144+
const oldVariable = activationArrayMap.get(node.output);
145+
146+
const gradient = this.variableGradients.get(node.output);
147+
const oldFirstMoment = this.firstMomentGraph.get(node.output);
148+
const oldWeightedInfNorm = this.weightedInfNormGraph.get(node.output);
149+
150+
const newFirstMoment = math.scaledArrayAdd(
151+
this.beta1, oldFirstMoment, this.oneMinusBeta1, gradient);
152+
153+
const ut0 = this.beta2.mul(oldWeightedInfNorm);
154+
const ut1 = gradient.abs();
155+
156+
const newWeightedInfNorm = ut0.maximum(ut1);
157+
158+
const variable = math.scaledArrayAdd(
159+
this.one, oldVariable, lr.div(this.one.sub(this.accBeta1)),
160+
newFirstMoment.div(this.eps.add(newWeightedInfNorm)));
161+
162+
activationArrayMap.set(node.output, keep(variable));
163+
node.data = variable;
164+
165+
this.firstMomentGraph.set(node.output, keep(newFirstMoment));
166+
this.weightedInfNormGraph.set(node.output, keep(newWeightedInfNorm));
167+
168+
oldVariable.dispose();
169+
gradient.dispose();
170+
oldFirstMoment.dispose();
171+
oldWeightedInfNorm.dispose();
172+
});
173+
174+
this.iteration.assign(this.iteration.add(this.one));
175+
this.accBeta1.assign(this.accBeta1.mul(this.beta1));
176+
});
177+
178+
this.variableGradients.dispose();
179+
this.variableGradients = new TensorArrayMap();
180+
}
181+
182+
dispose() {
183+
super.dispose();
184+
this.c.dispose();
185+
this.eps.dispose();
186+
this.accBeta1.dispose();
187+
this.beta1.dispose();
188+
this.beta2.dispose();
189+
this.oneMinusBeta1.dispose();
190+
191+
this.decay.dispose();
192+
this.iteration.dispose();
193+
194+
this.one.dispose();
195+
196+
if (this.firstMomentGraph != null) {
197+
this.firstMomentGraph.dispose();
198+
}
199+
200+
if (this.weightedInfNormGraph != null) {
201+
this.weightedInfNormGraph.dispose();
202+
}
203+
204+
if (this.accumulatedFirstMoment != null) {
205+
Object.keys(this.accumulatedFirstMoment)
206+
.forEach(name => this.accumulatedFirstMoment[name].dispose());
207+
}
208+
209+
if (this.accumulatedWeightedInfNorm != null) {
210+
Object.keys(this.accumulatedWeightedInfNorm)
211+
.forEach(name => this.accumulatedWeightedInfNorm[name].dispose());
212+
}
213+
}
214+
215+
// Average of 1st gradient
216+
private firstMomentGraph = new TensorArrayMap();
217+
// Average of exponentially weighed infinity norm
218+
private weightedInfNormGraph = new TensorArrayMap();
219+
}

0 commit comments

Comments
 (0)