Skip to content

Version 0.5.0

Compare
Choose a tag to compare
@jheek jheek released this 23 May 12:52
· 2195 commits to main since this release

New features:

  • Added flax.jax_utils.ad_shard_unpad() by @lucasb-eyer
  • Implemented default dtype FLIP.
    This means the default dtype is now inferred from inputs and params rather than being hard-coded to float32.
    This is especially useful for dealing with complex numbers because the standard Modules will no longer truncate
    complex numbers to their real component by default. Instead the complex dtype is preserved by default.

Bug fixes:

  • Fix support for JAX's experimental_name_stack.

Breaking changes:

  • In rare cases the dtype of a layer can change due to default dtype FLIP. See the "Backward compatibility" section of the proposal for more information.