Skip to content

Commit

Permalink
add adamax
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobjinkelly committed Apr 19, 2020
1 parent 3ca7f6e commit 8078973
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions jax/experimental/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,41 @@ def get_params(state):
return x
return init, update, get_params


@optimizer
def adamax(step_size, b1=0.9, b2=0.999, eps=1e-8):
"""Construct optimizer triple for AdaMax (a variant of Adam based on infinity norm).
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to positive scalar.
b1: optional, a positive scalar value for beta_1, the exponential decay rate
for the first moment estimates (default 0.9).
b2: optional, a positive scalar value for beta_2, the exponential decay rate
for the second moment estimates (default 0.999).
eps: optional, a positive scalar value for epsilon, a small constant for
numerical stability (default 1e-8).
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def init(x0):
m0 = np.zeros_like(x0)
u0 = np.zeros_like(x0)
return x0, m0, u0
def update(i, g, state):
x, m, u = state
m = (1 - b1) * g + b1 * m # First moment estimate.
u = np.maximum(b2 * u, np.abs(g)) # Update exponentially weighted infinity norm.
x = x - (step_size(i) / (1 - b1 ** (i + 1))) * m / (u + eps)
return x, m, u
def get_params(state):
x, m, u = state
return x
return init, update, get_params


@optimizer
def sm3(step_size, momentum=0.9):
"""Construct optimizer triple for SM3.
Expand Down

0 comments on commit 8078973

Please sign in to comment.