-
(originally asked by @debidatta) How can I implement an Optax optimizer that uses different learning rates for different layers? |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments
-
(originally answered by @levskaya) Adapted from #1453 This can easily be done with For Flax it can be very handy to use def flattened_traversal(fn):
"""Returns function that is called with `(path, param)` instead of pytree."""
def mask(tree):
flat = flax.traverse_util.flatten_dict(tree)
return flax.traverse_util.unflatten_dict(
{k: fn(k, v) for k, v in flat.items()})
return mask
# Specify layer-wise learning rate.
lrs = {'Dense_0': 0.1, 'Dense_1': 0.2, 'head': 0.3}
label_fn = flattened_traversal(lambda path, _: path[0])
tx = optax.multi_transform(
{name: optax.sgd(lr) for name, lr in lrs.items()}, label_fn) Full exampleimport flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
class Model(nn.Module):
num_layers: int
depth: int
@nn.compact
def __call__(self, x):
for i in range(self.num_layers - 1):
x = nn.relu(nn.Dense(self.depth)(x))
return nn.Dense(self.depth, name='head')(x)
model = Model(num_layers=3, depth=10)
x = jnp.zeros([1, 10])
params = model.init(jax.random.PRNGKey(0), x)['params']
jax.tree_map(jnp.shape, params)
def flattened_traversal(fn):
"""Returns function that is called with `(path, param)` instead of pytree."""
def mask(tree):
flat = flax.traverse_util.flatten_dict(tree)
return flax.traverse_util.unflatten_dict(
{k: fn(k, v) for k, v in flat.items()})
return mask
# Specify layer-wise learning rate.
lrs = {'Dense_0': 0.1, 'Dense_1': 0.2, 'head': 0.3}
label_fn = flattened_traversal(lambda path, _: path[0])
tx = optax.multi_transform(
{name: optax.sgd(lr) for name, lr in lrs.items()}, label_fn)
fake_grads = jax.tree_map(jnp.ones_like, params.unfreeze())
opt_state = tx.init(params.unfreeze())
updates, opt_state = tx.update(fake_grads, opt_state)
jax.tree_map(lambda x: jnp.sum(jnp.abs(x)), updates)
|
Beta Was this translation helpful? Give feedback.
-
In addition to @andsteing 's answer and after a few discussions with him, we found a glitch in the previous answer while using The previous solution leads to a
This error is raised because
|
Beta Was this translation helpful? Give feedback.
-
A slightly modified version of @andsteing's |
Beta Was this translation helpful? Give feedback.
(originally answered by @levskaya)
Adapted from #1453
This can easily be done with
optax.multi_transform
.For Flax it can be very handy to use
flax.traverse_util.ModelParamTraversal
to create the second parameter: