Skip to content

How is state handled differently than parameters? #2268

Answered by marcvanzee
nalzok asked this question in Q&A
Discussion options

You must be logged in to vote

I have the impression that states like batch norm statistics are handled in the same fashion as model parameters

No, this is not how it works. A variable dictionary in Flax contains "variable collections" as keys, which are mapped to the variables used in this collection.

  • One variable collection is params, where we store the trainable parameters.
  • Another is batch_stats, where we store batch statistics if using batch normalization. (If we do not use batch normalization, this collection is not present.)

The mutable argument to apply allows the user to control which variable collections are allowed to be modified during the forward pass. For instance, when you apply a module, params is n…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by nalzok
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants