How to fix some layers for transfer learning? #1706
-
Hi, Suppose I need do transfer learning based on pretrained ResNet50, and I want to fix the first layer, and allow other layers to update. How to do this in flax? I noticed #1176, they suggest to use https://flax.readthedocs.io/en/latest/flax.optim.html#flax.optim.MultiOptimizer, but there are two concerns, one is that it seems flax.optim tends to be replace by optax and optax do not have similar API. Second, In my project I already using optax, so I do not prefer change the code just for fix a layer, it seems cumbersome, I think there should be easy way to do it. The preferred way I can imagine is something like below:
|
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 9 replies
-
Hi @wztdream, note that optax is now the recommended optimizer API (see here). You can freeze a subset of your params by using optax.multi_transform. One way of achieving this is to create a mask of your params, that assigns one label to trainable parameters and another label to frozen parameters. Consider this simple parameter tree:
We want to freeze the params of the layers whose names start with
And then: def zero_grads():
# from https://github.com/deepmind/optax/issues/159#issuecomment-896459491
def init_fn(_):
return ()
def update_fn(updates, state, params=None):
return jax.tree_map(jnp.zeros_like, updates), ()
return optax.GradientTransformation(init_fn, update_fn)
tx = optax.multi_transform({'adam': optax.adam(0.1), 'zero': zero_grads()},
create_mask(params, lambda s: s.startswith('frozen'))) This simply means that all parameters with the label The only thing you have to do for this to work is give the modules you want to freeze a name that starts with I recently did something similar, so I created a Colab with a full example: https://colab.research.google.com/drive/1g_pt2Rc3bv6H6qchvGHD-BpgF-Pt4vrC#scrollTo=WWHlukuvIpXb Hope that helps! |
Beta Was this translation helpful? Give feedback.
-
Hey @matthias-wright, Thanks a lot for the google colab - it's great! Do you know by any chance if this can save a significant amount of memory? In your example if the first layer is frozen it never needs to compute the gradients. E.g. ideally during the forward pass the activations should never be saved so that we can save some memomry. Is this the case here? Image you want to fine-tune a large BERT model and freeze the first 12 layers, but only train the final layer. In this case, it would be very important that no activations for the first 12 layers are computed to save memory. Is this the case here? |
Beta Was this translation helpful? Give feedback.
Hi @wztdream,
note that optax is now the recommended optimizer API (see here).
You can freeze a subset of your params by using optax.multi_transform.
One way of achieving this is to create a mask of your params, that assigns one label to trainable parameters and another label to frozen parameters.
Consider this simple parameter tree:
We want to freeze the params of the layers whose names start with
frozen
, so we assign them the labelzero
(the name is arbitrary). The other parameters will be assigned the nameadam
(also arbitrary):