diff --git a/jax/experimental/optimizers.py b/jax/experimental/optimizers.py index 7970825fe750..08dc2c98c465 100644 --- a/jax/experimental/optimizers.py +++ b/jax/experimental/optimizers.py @@ -382,6 +382,41 @@ def get_params(state): return x return init, update, get_params + +@optimizer +def adamax(step_size, b1=0.9, b2=0.999, eps=1e-8): + """Construct optimizer triple for AdaMax (a variant of Adam based on infinity norm). + + Args: + step_size: positive scalar, or a callable representing a step size schedule + that maps the iteration index to positive scalar. + b1: optional, a positive scalar value for beta_1, the exponential decay rate + for the first moment estimates (default 0.9). + b2: optional, a positive scalar value for beta_2, the exponential decay rate + for the second moment estimates (default 0.999). + eps: optional, a positive scalar value for epsilon, a small constant for + numerical stability (default 1e-8). + + Returns: + An (init_fun, update_fun, get_params) triple. + """ + step_size = make_schedule(step_size) + def init(x0): + m0 = np.zeros_like(x0) + u0 = np.zeros_like(x0) + return x0, m0, u0 + def update(i, g, state): + x, m, u = state + m = (1 - b1) * g + b1 * m # First moment estimate. + u = np.maximum(b2 * u, np.abs(g)) # Update exponentially weighted infinity norm. + x = x - (step_size(i) / (1 - b1 ** (i + 1))) * m / (u + eps) + return x, m, u + def get_params(state): + x, m, u = state + return x + return init, update, get_params + + @optimizer def sm3(step_size, momentum=0.9): """Construct optimizer triple for SM3. diff --git a/tests/optimizers_test.py b/tests/optimizers_test.py index 6a994e90ca5c..0df6ade29d1d 100644 --- a/tests/optimizers_test.py +++ b/tests/optimizers_test.py @@ -149,6 +149,13 @@ def loss(xs): x0 = (np.ones(2), np.ones((2, 2))) self._CheckOptimizer(optimizers.sm3, loss, x0, num_iters, step_size) + def testAdaMaxVector(self): + def loss(x): return np.dot(x, x) + x0 = np.ones(2) + num_iters = 100 + step_size = 0.1 + self._CheckOptimizer(optimizers.adamax, loss, x0, num_iters, step_size) + def testSgdVectorExponentialDecaySchedule(self): def loss(x): return np.dot(x, x) x0 = np.ones(2)