Skip to content

How use non-learnable parameters / freeze parameters in Flax/JAX? #1150

Answered by marcvanzee
marcvanzee asked this question in Q&A
Discussion options

You must be logged in to vote

The recommended solution is to use a MultiOptimizer, our documentation contains an example: https://flax.readthedocs.io/en/latest/flax.optim.html#multioptimizer

Another approach is to use jax.lax.stop_gradient: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.stop_gradient.html

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@andsteing
Comment options

Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants