Preferred pattern for freezing layers/params #1176
-
Hello everyone, How would I go about freezing some of my layers/params (for transfer learning)? Lets say in the following example from the docs, how would I freeze only the kernel weight? class SimpleDense(nn.Module):
features: int
kernel_init: Callable = nn.initializers.lecun_normal()
bias_init: Callable = nn.initializers.zeros
@nn.compact
def __call__(self, inputs):
kernel = self.param('kernel',
self.kernel_init,
(inputs.shape[-1], self.features))
y = lax.dot_general(inputs, kernel,
(((inputs.ndim - 1,), (0,)), ((), ())),)
bias = self.param('bias', self.bias_init, (self.features,))
y = y + bias
return y And then in this simple network, how would I freeze only the first layer? class SimpleMLP(nn.Module):
features1: int
features2: int
@nn.compact
def __call__(self, x):
x = nn.Dense(features1, name='dense1')(x)
x = nn.Dense(features2, name='dense2')(x)
return x Thanks! |
Beta Was this translation helpful? Give feedback.
Answered by
marcvanzee
Mar 23, 2021
Replies: 1 comment 1 reply
-
You can do that using a https://flax.readthedocs.io/en/latest/flax.optim.html#flax.optim.MultiOptimizer |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
schrute99
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You can do that using a
MultiOptimizer
. Our documentation contains some examples, please let me know if it is still unclear!https://flax.readthedocs.io/en/latest/flax.optim.html#flax.optim.MultiOptimizer