How is state handled differently than parameters? #2268
-
After reading Stateful Computations in JAX, I have the impression that states like batch norm statistics are handled in the same fashion as model parameters: both are essentially extra input parameters to the function. If that is the case, why do we need a In other words, what does the |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
No, this is not how it works. A variable dictionary in Flax contains "variable collections" as keys, which are mapped to the variables used in this collection.
The
JAX arrays are immutable, so when you modify them you make a copy and use that. This is what we do in Flax as well: when you have mutable variables collections, vars_in = {'params': params, 'batch_stats': old_batch_stats}
y, mutated_vars = BN.apply(vars_in, x, mutable=['batch_stats'])
new_batch_stats = mutated_vars['batch_stats'] |
Beta Was this translation helpful? Give feedback.
No, this is not how it works. A variable dictionary in Flax contains "variable collections" as keys, which are mapped to the variables used in this collection.
params
, where we store the trainable parameters.batch_stats
, where we store batch statistics if using batch normalization. (If we do not use batch normalization, this collection is not present.)The
mutable
argument toapply
allows the user to control which variable collections are allowed to be modified during the forward pass. For instance, when you apply a module,params
is n…