Skip to content

How to implement layer-wise learning rate decay? #2056

Answered by andsteing
andsteing asked this question in Q&A
Discussion options

You must be logged in to vote

(originally answered by @levskaya)

Adapted from #1453

This can easily be done with optax.multi_transform.

For Flax it can be very handy to use flax.traverse_util.ModelParamTraversal to create the second parameter:

def flattened_traversal(fn):
  """Returns function that is called with `(path, param)` instead of pytree."""
  def mask(tree):
    flat = flax.traverse_util.flatten_dict(tree)
    return flax.traverse_util.unflatten_dict(
        {k: fn(k, v) for k, v in flat.items()})
  return mask

# Specify layer-wise learning rate.
lrs = {'Dense_0': 0.1, 'Dense_1': 0.2, 'head': 0.3}
label_fn = flattened_traversal(lambda path, _: path[0])

tx = optax.multi_transform(
    {name: optax.sgd(lr) for

Replies: 3 comments

Comment options

andsteing
Apr 19, 2022
Maintainer Author

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants