-
Notifications
You must be signed in to change notification settings - Fork 78
Expand file tree
/
Copy pathjax_momentum.py
More file actions
102 lines (88 loc) · 3.27 KB
/
jax_momentum.py
File metadata and controls
102 lines (88 loc) · 3.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""Submission file for a SGD with HeavyBall momentum optimizer in Jax."""
from typing import Callable
import jax
import jax.numpy as jnp
import optax
from algoperf import spec
from algorithms.target_setting_algorithms.data_selection import ( # noqa: F401
data_selection,
)
from algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401
get_batch_size,
)
from algorithms.target_setting_algorithms.jax_submission_base import ( # noqa: F401
update_params,
)
def init_optimizer_state(
workload: spec.Workload,
model_params: spec.ParameterContainer,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
rng: spec.RandomState,
) -> spec.OptimizerState:
"""Creates a Nesterov optimizer and a learning rate schedule."""
del model_params
del model_state
del rng
# Create learning rate schedule.
target_setting_step_hint = int(0.75 * workload.step_hint)
lr_schedule_fn = create_lr_schedule_fn(
target_setting_step_hint, hyperparameters
)
# Create optimizer.
params_zeros_like = jax.tree.map(
lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
)
opt_init_fn, opt_update_fn = sgd(
learning_rate=lr_schedule_fn,
weight_decay=hyperparameters.weight_decay,
momentum=hyperparameters.beta1,
nesterov=False,
)
optimizer_state = opt_init_fn(params_zeros_like)
return optimizer_state, opt_update_fn
def create_lr_schedule_fn(
step_hint: int, hyperparameters: spec.Hyperparameters
) -> Callable[[int], float]:
warmup_fn = optax.linear_schedule(
init_value=0.0,
end_value=hyperparameters.learning_rate,
transition_steps=hyperparameters.warmup_steps,
)
decay_steps = step_hint - hyperparameters.warmup_steps
polynomial_schedule_fn = optax.polynomial_schedule(
init_value=hyperparameters.learning_rate,
end_value=hyperparameters.learning_rate * hyperparameters.end_factor,
power=1,
transition_steps=int(decay_steps * hyperparameters.decay_steps_factor),
)
lr_schedule_fn = optax.join_schedules(
schedules=[warmup_fn, polynomial_schedule_fn],
boundaries=[hyperparameters.warmup_steps],
)
return lr_schedule_fn
# Forked from github.com/google/init2winit/blob/master/init2winit/ (cont. below)
# optimizer_lib/optimizers.py.
def sgd(learning_rate, weight_decay, momentum=None, nesterov=False):
r"""A customizable gradient descent optimizer.
NOTE: We apply weight decay **before** computing the momentum update.
This is equivalent to applying WD after for heavy-ball momentum,
but slightly different when using Nesterov acceleration. This is the same as
how the Flax optimizers handle weight decay
https://flax.readthedocs.io/en/latest/_modules/flax/optim/momentum.html.
Args:
learning_rate: The learning rate. Expected as the positive learning rate,
for example `\alpha` in `w -= \alpha * u` (as opposed to `\alpha`).
weight_decay: The weight decay hyperparameter.
momentum: The momentum hyperparameter.
nesterov: Whether or not to use Nesterov momentum.
Returns:
An optax gradient transformation that applies weight decay and then one of a
{SGD, Momentum, Nesterov} update.
"""
return optax.chain(
optax.add_decayed_weights(weight_decay),
optax.sgd(
learning_rate=learning_rate, momentum=momentum, nesterov=nesterov
),
)