Skip to content

Version 0.3.5

Compare
Choose a tag to compare
@jheek jheek released this 21 Sep 07:47
· 2726 commits to main since this release

Breaking changes:

  • You can no longer pass an int as the kernel_size for a flax.linen.Conv. Instead a type error is raised stating that a tuple/list should be provided. Stride and dilation arguments do support broadcasting a single int value now because this is not ambiguous when the kernel rank is known.
  • flax.linen.enable_named_call and flax.linen.disable_named_call now work anywhere instead of only affecting Modules constructed after the enable/disable call. Additionally, there is now flax.linen.override_named_call that provided a context manager to locally disable/enable named_call.
  • NamedTuples are no longer converted to tuples on assignment to a linen.Module.
    New features:
  • Flax internal stack frames are now removed from exception state traces.
  • Added flax.linen.nowrap to decorate method that should not be transformed because they are stateful.
  • Flax no longer uses implicit rank broadcasting. Thus, you can now use Flax with --jax_numpy_rank_promotion=raise.

Bugfixes:

  • linen Modules and dataclasses made with flax.struct.dataclass or flax.struct.PyTreeNode are now correctly recognized as dataclasses by static analysis tools like PyLance. Autocomplete of constructors has been verified to work with VSCode.
  • Fixed a bug in FrozenDict which didn't allow copying dicts with reserved names.
  • Fix the serialization of named tuples. Tuple fields are no longer stored in the state dict and the named tuple class is no longer recreated (bug).
  • Mixed precision training with float16 now works correctly with the attention layers.
  • auto-generated linen Module hash, eq, repr no longer fail by default on non-init attributes.