-
Notifications
You must be signed in to change notification settings - Fork 78
Expand file tree
/
Copy pathjax_submission_base.py
More file actions
181 lines (162 loc) · 4.91 KB
/
jax_submission_base.py
File metadata and controls
181 lines (162 loc) · 4.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""Update submission function in Jax."""
from typing import Any, Dict, List, Optional, Tuple
import jax
import jax.numpy as jnp
import optax
from algoperf import jax_sharding_utils, spec
_GRAD_CLIP_EPS = 1e-6
def train_step(
workload,
opt_update_fn,
model_state,
optimizer_state,
current_param_container,
batch,
rng,
grad_clip,
label_smoothing,
):
def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
params,
batch,
model_state,
spec.ForwardPassMode.TRAIN,
rng,
update_batch_norm=True,
)
loss_dict = workload.loss_fn(
label_batch=batch['targets'],
logits_batch=logits,
mask_batch=batch.get('weights'),
label_smoothing=label_smoothing,
)
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
current_param_container
)
# Compute mean loss and grad
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
)
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)
updates, new_optimizer_state = opt_update_fn(
grad, optimizer_state, current_param_container
)
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState,
train_state: Optional[Dict[str, Any]] = None,
) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del train_state
del eval_results
optimizer_state, opt_update_fn = optimizer_state
if hasattr(hyperparameters, 'label_smoothing'):
label_smoothing = hyperparameters.label_smoothing
else:
label_smoothing = 0.0
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
# Create shardings for each argument
replicated = jax_sharding_utils.get_replicate_sharding() # No partitioning
sharded = (
jax_sharding_utils.get_batch_dim_sharding()
) # Partition along batch dimension
# Create the sharding rules for each argument
arg_shardings = (
# workload is static
# opt_update_fn is static
replicated, # model_state
replicated, # optimizer_state
replicated, # current_param_container
sharded, # batch
replicated, # rng
replicated, # grad_clip
replicated, # label_smoothing
)
out_shardings = (
replicated, # new_optimizer_state
replicated, # updated_params
replicated, # new_model_state
replicated, # loss
replicated, # grad_norm
)
# Jit with shardings
jitted_train_step = jax.jit(
train_step,
static_argnums=(0, 1),
donate_argnums=(2, 3, 4),
in_shardings=arg_shardings,
out_shardings=out_shardings,
)
new_optimizer_state, new_params, new_model_state, loss, grad_norm = (
jitted_train_step(
workload,
opt_update_fn,
model_state,
optimizer_state,
current_param_container,
batch,
rng,
grad_clip,
label_smoothing,
)
)
# Log loss, grad_norm.
if (
global_step <= 100 or global_step % 500 == 0
) and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
{
'loss': loss.item(),
'grad_norm': grad_norm.item(),
},
global_step,
)
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
def prepare_for_eval(
workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState,
) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
del current_params_types
del loss_type
del eval_results
del global_step
del rng
return (optimizer_state, current_param_container, model_state)