Skip to content

Commit

Permalink
Merge pull request jax-ml#2763 from jacobjinkelly/adamax
Browse files Browse the repository at this point in the history
Add AdaMax optimizer
  • Loading branch information
mattjj authored Apr 21, 2020
2 parents 4433ddf + 61fc2bf commit f527ed4
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
35 changes: 35 additions & 0 deletions jax/experimental/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,41 @@ def get_params(state):
return x
return init, update, get_params


@optimizer
def adamax(step_size, b1=0.9, b2=0.999, eps=1e-8):
"""Construct optimizer triple for AdaMax (a variant of Adam based on infinity norm).
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to positive scalar.
b1: optional, a positive scalar value for beta_1, the exponential decay rate
for the first moment estimates (default 0.9).
b2: optional, a positive scalar value for beta_2, the exponential decay rate
for the second moment estimates (default 0.999).
eps: optional, a positive scalar value for epsilon, a small constant for
numerical stability (default 1e-8).
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def init(x0):
m0 = np.zeros_like(x0)
u0 = np.zeros_like(x0)
return x0, m0, u0
def update(i, g, state):
x, m, u = state
m = (1 - b1) * g + b1 * m # First moment estimate.
u = np.maximum(b2 * u, np.abs(g)) # Update exponentially weighted infinity norm.
x = x - (step_size(i) / (1 - b1 ** (i + 1))) * m / (u + eps)
return x, m, u
def get_params(state):
x, m, u = state
return x
return init, update, get_params


@optimizer
def sm3(step_size, momentum=0.9):
"""Construct optimizer triple for SM3.
Expand Down
7 changes: 7 additions & 0 deletions tests/optimizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,13 @@ def loss(xs):
x0 = (np.ones(2), np.ones((2, 2)))
self._CheckOptimizer(optimizers.sm3, loss, x0, num_iters, step_size)

def testAdaMaxVector(self):
def loss(x): return np.dot(x, x)
x0 = np.ones(2)
num_iters = 100
step_size = 0.1
self._CheckOptimizer(optimizers.adamax, loss, x0, num_iters, step_size)

def testSgdVectorExponentialDecaySchedule(self):
def loss(x): return np.dot(x, x)
x0 = np.ones(2)
Expand Down

0 comments on commit f527ed4

Please sign in to comment.