How use non-learnable parameters / freeze parameters in Flax/JAX? #1150
Answered
by
marcvanzee
marcvanzee
asked this question in
Q&A
-
I.e. parameters that are part of the model, but not updated after initialization? (Original question by @dpkingma) |
Beta Was this translation helpful? Give feedback.
Answered by
marcvanzee
Mar 17, 2021
Replies: 1 comment 1 reply
-
The recommended solution is to use a Another approach is to use |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
marcvanzee
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The recommended solution is to use a
MultiOptimizer
, our documentation contains an example: https://flax.readthedocs.io/en/latest/flax.optim.html#multioptimizerAnother approach is to use
jax.lax.stop_gradient
: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.stop_gradient.html