Version 0.3.4
Possibly breaking changes:
- When calling
init
the 'intermediates' collection is no longer mutable.
Therefore, intermediates will no longer be returned from initialization by default. - Don't update batch statistics during initialization.
- When not using any non-determinism (e.g., dropout), it is not longer necessary to specify the
deterministic
argument inMultiHeadDotProductAttention
.
Other changes:
- Rewrote various examples to use Optax instead of Flax optimizers (e.g., Imagenet, SST2).
- Added an NLP text classification example (on the SST-2 dataset) to
examples/sst2
.
that uses a bidirectional LSTM (BiLSTM) to encode the input text. - Added
flax.training.train_state
to simplify using Optax optimizers. mutable
argument is now available onModule.init
andModule.init_with_outputs
- Bug fix: Correctly handle non-default parameters of Linen Modules with nested inheritance.
- Expose
dot_product_attention_weights
, allowing access to attention weights. BatchNorm
instances will behave correctly during init when called multiple times.- Added a more extensive "how to contribute" guide in
contributing.md
. - Add proper cache behavior for
lift.jit
,
fixing cache misses. - Fix bug in Embed layer: make sure it behaves correctly when embedding is np.array.
- Fix
linen.Module
for deep inheritance chains. - Fix bug in DenseGeneral: correctly expand bias to account for batch & noncontracting dimensions.
- Allow Flax lifted transforms to work on partially applied Modules.
- Make
MultiOptimizer
useapply_gradient
instead ofapply_param_gradient
.