Version 0.5.0
New features:
- Added
flax.jax_utils.ad_shard_unpad()
by @lucasb-eyer - Implemented default dtype FLIP.
This means the default dtype is now inferred from inputs and params rather than being hard-coded to float32.
This is especially useful for dealing with complex numbers because the standard Modules will no longer truncate
complex numbers to their real component by default. Instead the complex dtype is preserved by default.
Bug fixes:
- Fix support for JAX's experimental_name_stack.
Breaking changes:
- In rare cases the dtype of a layer can change due to default dtype FLIP. See the "Backward compatibility" section of the proposal for more information.