diff --git a/CHANGELOG.md b/CHANGELOG.md index 2bbcd73620..6a298007a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -82,7 +82,7 @@ vNext 0.8.0 ----- -- Added [NNX](https://github.com/google/flax/tree/main/flax/experimental/nnx#nnx), a neural network library for JAX that provides a simple yet powerful module system that adheres to standard Python semantics. Its aim is to combine the robustness of Linen with a simplified, Pythonic API akin to that of PyTorch. +- Added [NNX](https://github.com/google/flax/tree/main/flax/nnx#nnx), a neural network library for JAX that provides a simple yet powerful module system that adheres to standard Python semantics. Its aim is to combine the robustness of Linen with a simplified, Pythonic API akin to that of PyTorch. - Added `nn.compact_name_scope` decorator that enables methods to act as compact name scopes as with regular Haiku methods. This makes porting Haiku code easier. - Add copy() method to Module. This is a user-friendly version of the internal clone() method with better defaults for common use cases. diff --git a/README.md b/README.md index f5dd4b09b7..6cc80978dc 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,8 @@ | [**What does Flax look like?**](#what-does-flax-look-like) | [**Documentation**](https://flax.readthedocs.io/) +**📣 NEW**: Check out the [**NNX**](https://flax.readthedocs.io/en/latest/nnx/index.html) API! + This README is a very short intro. **To learn everything you need to know about Flax, refer to our [full documentation](https://flax.readthedocs.io/).** Flax was originally started by engineers and researchers within the Brain Team in Google Research (in close collaboration with the JAX team), and is now developed jointly with the open source community. diff --git a/docs/api_reference/flax.experimental.nnx/nn/stochastic.rst b/docs/api_reference/flax.experimental.nnx/nn/stochastic.rst deleted file mode 100644 index 975b8bdb90..0000000000 --- a/docs/api_reference/flax.experimental.nnx/nn/stochastic.rst +++ /dev/null @@ -1,8 +0,0 @@ -Stochastic ------------------------- - -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx - -.. autoclass:: Dropout - :members: \ No newline at end of file diff --git a/docs/api_reference/flax.experimental.nnx/training/optimizer.rst b/docs/api_reference/flax.experimental.nnx/training/optimizer.rst deleted file mode 100644 index a17b74e990..0000000000 --- a/docs/api_reference/flax.experimental.nnx/training/optimizer.rst +++ /dev/null @@ -1,8 +0,0 @@ -Optimizer ------------------------- - -.. automodule:: flax.experimental.nnx.optimizer -.. currentmodule:: flax.experimental.nnx.optimizer - -.. autoclass:: Optimizer - :members: diff --git a/docs/api_reference/flax.experimental.nnx/visualization.rst b/docs/api_reference/flax.experimental.nnx/visualization.rst deleted file mode 100644 index 0bbbb9872e..0000000000 --- a/docs/api_reference/flax.experimental.nnx/visualization.rst +++ /dev/null @@ -1,7 +0,0 @@ -visualization ------------------------- - -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx - -.. autofunction:: display \ No newline at end of file diff --git a/docs/api_reference/flax.experimental.nnx/graph.rst b/docs/api_reference/flax.nnx/graph.rst similarity index 81% rename from docs/api_reference/flax.experimental.nnx/graph.rst rename to docs/api_reference/flax.nnx/graph.rst index a2d21b60e4..35d3939db5 100644 --- a/docs/api_reference/flax.experimental.nnx/graph.rst +++ b/docs/api_reference/flax.nnx/graph.rst @@ -1,8 +1,8 @@ graph ------------------------ -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx .. autofunction:: split diff --git a/docs/api_reference/flax.experimental.nnx/helpers.rst b/docs/api_reference/flax.nnx/helpers.rst similarity index 69% rename from docs/api_reference/flax.experimental.nnx/helpers.rst rename to docs/api_reference/flax.nnx/helpers.rst index c0413acf55..f2b67522d7 100644 --- a/docs/api_reference/flax.experimental.nnx/helpers.rst +++ b/docs/api_reference/flax.nnx/helpers.rst @@ -1,8 +1,8 @@ helpers ------------------------ -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx .. autoclass:: Dict :members: diff --git a/docs/api_reference/flax.experimental.nnx/index.rst b/docs/api_reference/flax.nnx/index.rst similarity index 73% rename from docs/api_reference/flax.experimental.nnx/index.rst rename to docs/api_reference/flax.nnx/index.rst index fb90e3d4e3..37a22d3118 100644 --- a/docs/api_reference/flax.experimental.nnx/index.rst +++ b/docs/api_reference/flax.nnx/index.rst @@ -1,7 +1,7 @@ -flax.experimental.nnx +flax.nnx ------------------------ -Experimental API. See the `NNX page `__ for more details. +Experimental API. See the `NNX page `__ for more details. .. toctree:: :maxdepth: 3 diff --git a/docs/api_reference/flax.experimental.nnx/module.rst b/docs/api_reference/flax.nnx/module.rst similarity index 61% rename from docs/api_reference/flax.experimental.nnx/module.rst rename to docs/api_reference/flax.nnx/module.rst index ffdff78a88..9e58068a8f 100644 --- a/docs/api_reference/flax.experimental.nnx/module.rst +++ b/docs/api_reference/flax.nnx/module.rst @@ -1,8 +1,8 @@ module ------------------------ -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx .. autoclass:: Module :members: diff --git a/docs/api_reference/flax.experimental.nnx/nn/activations.rst b/docs/api_reference/flax.nnx/nn/activations.rst similarity index 89% rename from docs/api_reference/flax.experimental.nnx/nn/activations.rst rename to docs/api_reference/flax.nnx/nn/activations.rst index 0464975fed..db20ceb4d9 100644 --- a/docs/api_reference/flax.experimental.nnx/nn/activations.rst +++ b/docs/api_reference/flax.nnx/nn/activations.rst @@ -1,8 +1,8 @@ Activation functions ------------------------ -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx .. autofunction:: celu .. autofunction:: elu diff --git a/docs/api_reference/flax.experimental.nnx/nn/attention.rst b/docs/api_reference/flax.nnx/nn/attention.rst similarity index 74% rename from docs/api_reference/flax.experimental.nnx/nn/attention.rst rename to docs/api_reference/flax.nnx/nn/attention.rst index a2137ac885..3a10c7728b 100644 --- a/docs/api_reference/flax.experimental.nnx/nn/attention.rst +++ b/docs/api_reference/flax.nnx/nn/attention.rst @@ -1,8 +1,8 @@ Attention ------------------------ -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx .. autoclass:: MultiHeadAttention :members: diff --git a/docs/api_reference/flax.experimental.nnx/nn/index.rst b/docs/api_reference/flax.nnx/nn/index.rst similarity index 77% rename from docs/api_reference/flax.experimental.nnx/nn/index.rst rename to docs/api_reference/flax.nnx/nn/index.rst index a179948da6..abe4da330e 100644 --- a/docs/api_reference/flax.experimental.nnx/nn/index.rst +++ b/docs/api_reference/flax.nnx/nn/index.rst @@ -1,7 +1,7 @@ nn ---------------------------- -Experimental API. See the `NNX page `__ for more details. +Experimental API. See the `NNX page `__ for more details. .. toctree:: :maxdepth: 3 diff --git a/docs/api_reference/flax.experimental.nnx/nn/initializers.rst b/docs/api_reference/flax.nnx/nn/initializers.rst similarity index 86% rename from docs/api_reference/flax.experimental.nnx/nn/initializers.rst rename to docs/api_reference/flax.nnx/nn/initializers.rst index 0468f18703..a5734d8a45 100644 --- a/docs/api_reference/flax.experimental.nnx/nn/initializers.rst +++ b/docs/api_reference/flax.nnx/nn/initializers.rst @@ -1,8 +1,8 @@ Initializers ------------------------ -.. automodule:: flax.experimental.nnx.initializers -.. currentmodule:: flax.experimental.nnx.initializers +.. automodule:: flax.nnx.initializers +.. currentmodule:: flax.nnx.initializers .. autofunction:: constant .. autofunction:: delta_orthogonal diff --git a/docs/api_reference/flax.experimental.nnx/nn/linear.rst b/docs/api_reference/flax.nnx/nn/linear.rst similarity index 75% rename from docs/api_reference/flax.experimental.nnx/nn/linear.rst rename to docs/api_reference/flax.nnx/nn/linear.rst index 3206c4e841..0576820690 100644 --- a/docs/api_reference/flax.experimental.nnx/nn/linear.rst +++ b/docs/api_reference/flax.nnx/nn/linear.rst @@ -3,8 +3,8 @@ Linear NNX linear layer classes. -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx .. autoclass:: Conv :members: diff --git a/docs/api_reference/flax.experimental.nnx/nn/normalization.rst b/docs/api_reference/flax.nnx/nn/normalization.rst similarity index 65% rename from docs/api_reference/flax.experimental.nnx/nn/normalization.rst rename to docs/api_reference/flax.nnx/nn/normalization.rst index 402fa83769..c35bc5e0fb 100644 --- a/docs/api_reference/flax.experimental.nnx/nn/normalization.rst +++ b/docs/api_reference/flax.nnx/nn/normalization.rst @@ -1,8 +1,8 @@ Normalization ------------------------ -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx .. autoclass:: BatchNorm :members: diff --git a/docs/api_reference/flax.nnx/nn/stochastic.rst b/docs/api_reference/flax.nnx/nn/stochastic.rst new file mode 100644 index 0000000000..70f7c497a6 --- /dev/null +++ b/docs/api_reference/flax.nnx/nn/stochastic.rst @@ -0,0 +1,8 @@ +Stochastic +------------------------ + +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx + +.. autoclass:: Dropout + :members: \ No newline at end of file diff --git a/docs/api_reference/flax.experimental.nnx/rnglib.rst b/docs/api_reference/flax.nnx/rnglib.rst similarity index 57% rename from docs/api_reference/flax.experimental.nnx/rnglib.rst rename to docs/api_reference/flax.nnx/rnglib.rst index 9defbc76f9..2db1d6d63c 100644 --- a/docs/api_reference/flax.experimental.nnx/rnglib.rst +++ b/docs/api_reference/flax.nnx/rnglib.rst @@ -1,8 +1,8 @@ rnglib ------------------------ -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx .. autoclass:: Rngs :members: diff --git a/docs/api_reference/flax.experimental.nnx/spmd.rst b/docs/api_reference/flax.nnx/spmd.rst similarity index 69% rename from docs/api_reference/flax.experimental.nnx/spmd.rst rename to docs/api_reference/flax.nnx/spmd.rst index ed7af7f696..3429d898cc 100644 --- a/docs/api_reference/flax.experimental.nnx/spmd.rst +++ b/docs/api_reference/flax.nnx/spmd.rst @@ -1,8 +1,8 @@ spmd ------------------------ -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx .. autofunction:: get_partition_spec .. autofunction:: get_named_sharding diff --git a/docs/api_reference/flax.experimental.nnx/training/index.rst b/docs/api_reference/flax.nnx/training/index.rst similarity index 71% rename from docs/api_reference/flax.experimental.nnx/training/index.rst rename to docs/api_reference/flax.nnx/training/index.rst index c9bb4aa39f..32404f1de7 100644 --- a/docs/api_reference/flax.experimental.nnx/training/index.rst +++ b/docs/api_reference/flax.nnx/training/index.rst @@ -1,7 +1,7 @@ training ---------------------------- -Experimental API. See the `NNX page `__ for more details. +Experimental API. See the `NNX page `__ for more details. .. toctree:: :maxdepth: 3 diff --git a/docs/api_reference/flax.experimental.nnx/training/metrics.rst b/docs/api_reference/flax.nnx/training/metrics.rst similarity index 65% rename from docs/api_reference/flax.experimental.nnx/training/metrics.rst rename to docs/api_reference/flax.nnx/training/metrics.rst index f0e5ea201b..e60c9d1c1d 100644 --- a/docs/api_reference/flax.experimental.nnx/training/metrics.rst +++ b/docs/api_reference/flax.nnx/training/metrics.rst @@ -1,8 +1,8 @@ Metrics ------------------------ -.. automodule:: flax.experimental.nnx.metrics -.. currentmodule:: flax.experimental.nnx.metrics +.. automodule:: flax.nnx.metrics +.. currentmodule:: flax.nnx.metrics .. autoclass:: Metric :members: diff --git a/docs/api_reference/flax.nnx/training/optimizer.rst b/docs/api_reference/flax.nnx/training/optimizer.rst new file mode 100644 index 0000000000..15966a1a2e --- /dev/null +++ b/docs/api_reference/flax.nnx/training/optimizer.rst @@ -0,0 +1,8 @@ +Optimizer +------------------------ + +.. automodule:: flax.nnx.optimizer +.. currentmodule:: flax.nnx.optimizer + +.. autoclass:: Optimizer + :members: diff --git a/docs/api_reference/flax.experimental.nnx/transforms.rst b/docs/api_reference/flax.nnx/transforms.rst similarity index 82% rename from docs/api_reference/flax.experimental.nnx/transforms.rst rename to docs/api_reference/flax.nnx/transforms.rst index bdf105feed..6750a109df 100644 --- a/docs/api_reference/flax.experimental.nnx/transforms.rst +++ b/docs/api_reference/flax.nnx/transforms.rst @@ -1,8 +1,8 @@ transforms ------------------------ -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx .. autoclass:: JIT :members: diff --git a/docs/api_reference/flax.experimental.nnx/variables.rst b/docs/api_reference/flax.nnx/variables.rst similarity index 80% rename from docs/api_reference/flax.experimental.nnx/variables.rst rename to docs/api_reference/flax.nnx/variables.rst index b9f3d1dc54..54e4424633 100644 --- a/docs/api_reference/flax.experimental.nnx/variables.rst +++ b/docs/api_reference/flax.nnx/variables.rst @@ -1,8 +1,8 @@ variables ------------------------ -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx .. autoclass:: BatchStat :members: diff --git a/docs/api_reference/flax.nnx/visualization.rst b/docs/api_reference/flax.nnx/visualization.rst new file mode 100644 index 0000000000..a189aae524 --- /dev/null +++ b/docs/api_reference/flax.nnx/visualization.rst @@ -0,0 +1,7 @@ +visualization +------------------------ + +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx + +.. autofunction:: display \ No newline at end of file diff --git a/docs/api_reference/index.rst b/docs/api_reference/index.rst index 8448f316ac..2c0d360254 100644 --- a/docs/api_reference/index.rst +++ b/docs/api_reference/index.rst @@ -8,7 +8,7 @@ API Reference flax.core.frozen_dict flax.cursor flax.errors - flax.experimental.nnx/index + flax.nnx/index flax.jax_utils flax.linen/index flax.serialization diff --git a/docs/conf.py b/docs/conf.py index 2ee6faca2f..93d3d7009e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -110,6 +110,16 @@ html_extra_path = ['robots.txt'] +# href with no underline and white bold text color +announcement = """ + + 📣 Check out the new NNX API! + +""" + html_theme_options = { 'repository_url': 'https://github.com/google/flax', 'use_repository_button': True, # add a 'link to repository' button @@ -122,6 +132,7 @@ }, 'prev_next_buttons_location': None, 'show_navbar_depth': 1, + 'announcement': announcement, } # -- Options for myst ---------------------------------------------- @@ -135,7 +146,7 @@ nb_execution_excludepatterns = [ 'quick_start.ipynb', # <-- times out 'transfer_learning.ipynb', # <-- transformers requires flax<=0.7.0 - 'flax/experimental/nnx', # exclude nnx + 'flax/nnx', # exclude nnx ] # raise exceptions on execution so CI can catch errors nb_execution_allow_errors = False @@ -151,7 +162,7 @@ doctest_global_setup = """ import jax import jax.numpy as jnp -from flax.experimental import nnx +from flax import nnx import logging as slog from absl import logging as alog diff --git a/docs/experimental/index.rst b/docs/experimental/index.rst deleted file mode 100644 index 368491ce3d..0000000000 --- a/docs/experimental/index.rst +++ /dev/null @@ -1,7 +0,0 @@ -Experimental -============= - -.. toctree:: - :maxdepth: 2 - - nnx/index \ No newline at end of file diff --git a/docs/guides/flax_fundamentals/flax_basics.ipynb b/docs/guides/flax_fundamentals/flax_basics.ipynb index e20069aebc..e8e43f21c1 100644 --- a/docs/guides/flax_fundamentals/flax_basics.ipynb +++ b/docs/guides/flax_fundamentals/flax_basics.ipynb @@ -951,7 +951,7 @@ "source": [ "### Exporting to Tensorflow's SavedModel with jax2tf\n", "\n", - "JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/experimental/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax." + "JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax." ] } ], diff --git a/docs/guides/flax_fundamentals/flax_basics.md b/docs/guides/flax_fundamentals/flax_basics.md index 0ce0f6f77f..52755e9b5c 100644 --- a/docs/guides/flax_fundamentals/flax_basics.md +++ b/docs/guides/flax_fundamentals/flax_basics.md @@ -469,4 +469,4 @@ Flax provides a handy wrapper - `TrainState` - that simplifies the above code. C ### Exporting to Tensorflow's SavedModel with jax2tf -JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/experimental/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax. +JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax. diff --git a/docs/index.rst b/docs/index.rst index be6781e82f..75f5d985fe 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -28,7 +28,7 @@ both in the open source community (like `Hugging Face `__) and at Google (like -`PaLM `__, +`Gemini `__, `Imagen `__, `Scenic `__, and `Big Vision `__). @@ -309,6 +309,8 @@ Notable examples in Flax include: +.. role:: bold + :class: bold .. toctree:: :hidden: @@ -325,4 +327,4 @@ Notable examples in Flax include: contributing experimental api_reference/index - experimental/index + NNX diff --git a/docs/experimental/nnx/index.rst b/docs/nnx/index.rst similarity index 69% rename from docs/experimental/nnx/index.rst rename to docs/nnx/index.rst index 9a7defeeb4..5865e6c17d 100644 --- a/docs/experimental/nnx/index.rst +++ b/docs/nnx/index.rst @@ -3,13 +3,11 @@ NNX ======== -NNX is a JAX-based neural network library designed for simplicity and power. Its modular -approach follows standard Python conventions, making it both intuitive and compatible with -the broader JAX ecosystem. - -.. note:: - NNX is currently in an experimental state and is subject to change. Linen is still the - recommended option for large-scale projects. Feedback and contributions are welcome! +NNX is a **N**\ eural **N**\ etwork library for JA\ **X** that focuses on providing the best +development experience, so building and experimenting with neural networks is easy and +intuitive. It achieves this by embracing Python’s object-oriented model and making it +compatible with JAX transforms, resulting in code that is easy to inspect, debug, and +analyze. Features ^^^^^^^^^ @@ -26,47 +24,47 @@ Features .. div:: sd-font-normal - Modules are standard Python classes, promoting ease of use and a more familiar - development experience. + NNX supports the use or regular Python object, providing an intuitive + and predictable development experience. .. grid-item:: :columns: 12 12 12 6 - .. card:: Compatible + .. card:: Simple :class-card: sd-border-0 :shadow: none :class-title: sd-fs-5 .. div:: sd-font-normal - Effortlessly convert between Modules and pytrees using the Functional API for maximum - flexibility. + NNX relies on Python's object model, this results in simplicity for + the user which increases development speed. .. grid-item:: :columns: 12 12 12 6 - .. card:: Control + .. card:: Streamlined :class-card: sd-border-0 :shadow: none :class-title: sd-fs-5 .. div:: sd-font-normal - Manage a Module's state with precision using typed Variable collections, enabling fine-grained - control on JAX transformations. + NNX integrates of user feedback and hands-on experience with Linen + into a new simplified API. .. grid-item:: :columns: 12 12 12 6 - .. card:: User-friendly + .. card:: Compatible :class-card: sd-border-0 :shadow: none :class-title: sd-fs-5 .. div:: sd-font-normal - NNX prioritizes simplicity for common use cases, building upon lessons learned from Linen - to provide a streamlined experience. + NNX makes it very easy integrate objects with regular JAX code + via the `Functional API `__. Basic usage ^^^^^^^^^^^^ @@ -78,7 +76,7 @@ Basic usage .. testcode:: - from flax.experimental import nnx + from flax import nnx import optax @@ -110,7 +108,14 @@ Basic usage Installation ^^^^^^^^^^^^ -NNX is under active development, we recommend using the latest version from Flax's GitHub repository: + +Install NNX via pip: + +.. code-block:: bash + + pip install flax + +Or install the latest version from the repository: .. code-block:: bash @@ -150,7 +155,7 @@ Learn more .. card:: :material-regular:`menu_book;2em` API reference :class-card: sd-text-black sd-bg-light - :link: ../../api_reference/index.html + :link: ../api_reference/flax.nnx/index.html ---- diff --git a/docs/experimental/nnx/mnist_tutorial.ipynb b/docs/nnx/mnist_tutorial.ipynb similarity index 99% rename from docs/experimental/nnx/mnist_tutorial.ipynb rename to docs/nnx/mnist_tutorial.ipynb index c143ca57b2..6f990696e5 100644 --- a/docs/experimental/nnx/mnist_tutorial.ipynb +++ b/docs/nnx/mnist_tutorial.ipynb @@ -132,7 +132,7 @@ } ], "source": [ - "from flax.experimental import nnx # NNX API\n", + "from flax import nnx # NNX API\n", "from functools import partial\n", "\n", "class CNN(nnx.Module):\n", @@ -297,7 +297,7 @@ "id": "17", "metadata": {}, "source": [ - "The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.experimental.nnx/transforms.html#flax.experimental.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with \n", + "The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with \n", "[XLA](https://www.tensorflow.org/xla), optimizing performance on \n", "hardware accelerators. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit),\n", "except it can transforms functions that contain NNX objects as inputs and outputs.\n", diff --git a/docs/experimental/nnx/mnist_tutorial.md b/docs/nnx/mnist_tutorial.md similarity index 98% rename from docs/experimental/nnx/mnist_tutorial.md rename to docs/nnx/mnist_tutorial.md index e6510d2397..3c4ba09555 100644 --- a/docs/experimental/nnx/mnist_tutorial.md +++ b/docs/nnx/mnist_tutorial.md @@ -77,7 +77,7 @@ test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) Create a convolutional neural network with NNX by subclassing `nnx.Module`. ```{code-cell} ipython3 -from flax.experimental import nnx # NNX API +from flax import nnx # NNX API from functools import partial class CNN(nnx.Module): @@ -163,7 +163,7 @@ def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, b optimizer.update(grads) ``` -The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.experimental.nnx/transforms.html#flax.experimental.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with +The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with [XLA](https://www.tensorflow.org/xla), optimizing performance on hardware accelerators. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit), except it can transforms functions that contain NNX objects as inputs and outputs. diff --git a/docs/experimental/nnx/nnx_basics.ipynb b/docs/nnx/nnx_basics.ipynb similarity index 99% rename from docs/experimental/nnx/nnx_basics.ipynb rename to docs/nnx/nnx_basics.ipynb index 3a90e0d836..d033184421 100644 --- a/docs/experimental/nnx/nnx_basics.ipynb +++ b/docs/nnx/nnx_basics.ipynb @@ -23,7 +23,7 @@ "metadata": {}, "outputs": [], "source": [ - "from flax.experimental import nnx\n", + "from flax import nnx\n", "import jax\n", "import jax.numpy as jnp" ] diff --git a/docs/experimental/nnx/nnx_basics.md b/docs/nnx/nnx_basics.md similarity index 99% rename from docs/experimental/nnx/nnx_basics.md rename to docs/nnx/nnx_basics.md index c27ae068ac..b1c8841b9a 100644 --- a/docs/experimental/nnx/nnx_basics.md +++ b/docs/nnx/nnx_basics.md @@ -21,7 +21,7 @@ Despite its simplified implementation, NNX supports the same powerful design pat that have allowed Linen to scale effectively to large codebases. ```{code-cell} ipython3 -from flax.experimental import nnx +from flax import nnx import jax import jax.numpy as jnp ``` diff --git a/docs/experimental/nnx/transforms.rst b/docs/nnx/transforms.rst similarity index 96% rename from docs/experimental/nnx/transforms.rst rename to docs/nnx/transforms.rst index 9f35afcc26..76e807f241 100644 --- a/docs/experimental/nnx/transforms.rst +++ b/docs/nnx/transforms.rst @@ -9,7 +9,7 @@ First, let's set up imports and generate some dummy data: .. testcode:: NNX, JAX - from flax.experimental import nnx + from flax import nnx import jax x = jax.random.normal(jax.random.key(0), (1, 2)) @@ -24,7 +24,7 @@ even those whose state will be mutated, whereas they aren't recognized in JAX tr Therefore NNX transformations can transform functions that are not pure and make mutations and side-effects. -NNX's `Functional API `_ +NNX's `Functional API `_ provides a way to convert graph structures to pytrees and back, by doing this at every function boundary you can effectively use graph structures with any JAX transform and propagate state updates in a way consistent with functional purity. NNX custom transforms such as ``nnx.jit`` and ``nnx.grad`` diff --git a/flax/experimental/nnx.py b/flax/experimental/nnx.py new file mode 100644 index 0000000000..4899914293 --- /dev/null +++ b/flax/experimental/nnx.py @@ -0,0 +1,22 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl import logging + +from flax.nnx import * + + +logging.warning( + "Using 'flax.experimental.nnx' is deprecated. Please use 'flax.nnx' instead." +) \ No newline at end of file diff --git a/flax/experimental/nnx/.gitignore b/flax/nnx/.gitignore similarity index 100% rename from flax/experimental/nnx/.gitignore rename to flax/nnx/.gitignore diff --git a/flax/experimental/nnx/README.md b/flax/nnx/README.md similarity index 73% rename from flax/experimental/nnx/README.md rename to flax/nnx/README.md index cc00e1358e..854e0971d0 100644 --- a/flax/experimental/nnx/README.md +++ b/flax/nnx/README.md @@ -2,7 +2,7 @@ # NNX -_**N**eural **N**etworks for JA**X**_ - | [docs](https://flax.readthedocs.io/en/latest/experimental/nnx/index.html) | +_**N**eural **N**etworks for JA**X**_ - | [docs](https://flax.readthedocs.io/en/latest/nnx/index.html) | NNX is a JAX-based neural network library that focuses on providing the best development experience to make building and experimenting with neural networks as easy and intuitive as possible. @@ -28,7 +28,7 @@ a Module system that uses standard Python classes, and a set of transforms that JAX to handle objects. ```python -from flax.experimental import nnx +from flax import nnx import optax class Model(nnx.Module): @@ -58,7 +58,7 @@ def train_step(model, optimizer, x, y): return loss ``` -To learn more about the `Module` abstraction, check out our [NNX Basics](https://flax.readthedocs.io/en/latest/experimental/nnx/nnx_basics.html#) guide. +To learn more about the `Module` abstraction, check out our [NNX Basics](https://flax.readthedocs.io/en/latest/nnx/nnx_basics.html#) guide. ## Installation @@ -69,10 +69,10 @@ pip install git+https://github.com/google/flax.git ### Examples -* [LM1B](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/lm1b): A language model trained on the 1 Billion Word Benchmark dataset. +* [LM1B](https://github.com/google/flax/tree/main/flax/nnx/examples/lm1b): A language model trained on the 1 Billion Word Benchmark dataset. #### Toy Examples -* [Basic Example](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/toy_examples/02_lifted_transforms.py): Shows how to train a simple model using NNX. -* [Using the Functional API](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/toy_examples/01_functional_api.py): Shows how to train a simple model using the functional API. -* [Training a VAE](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/toy_examples/05_vae.py): Shows how to train a VAE on the binarized MNIST dataset. -* [Scan over layers](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/toy_examples/06_scan_over_layers.py): An contrived example that implements scan over layers with dropout and a share BatcNorm layer to showcase how lifted transforms can be implemented. It uses the functional API along with `jax.vmap` and `jax.lax.scan`. +* [Basic Example](https://github.com/google/flax/tree/main/flax/nnx/examples/toy_examples/02_lifted_transforms.py): Shows how to train a simple model using NNX. +* [Using the Functional API](https://github.com/google/flax/tree/main/flax/nnx/examples/toy_examples/01_functional_api.py): Shows how to train a simple model using the functional API. +* [Training a VAE](https://github.com/google/flax/tree/main/flax/nnx/examples/toy_examples/05_vae.py): Shows how to train a VAE on the binarized MNIST dataset. +* [Scan over layers](https://github.com/google/flax/tree/main/flax/nnx/examples/toy_examples/06_scan_over_layers.py): An contrived example that implements scan over layers with dropout and a share BatcNorm layer to showcase how lifted transforms can be implemented. It uses the functional API along with `jax.vmap` and `jax.lax.scan`. diff --git a/flax/experimental/nnx/__init__.py b/flax/nnx/__init__.py similarity index 100% rename from flax/experimental/nnx/__init__.py rename to flax/nnx/__init__.py diff --git a/flax/experimental/nnx/docs/blog.md b/flax/nnx/docs/blog.md similarity index 100% rename from flax/experimental/nnx/docs/blog.md rename to flax/nnx/docs/blog.md diff --git a/flax/experimental/nnx/docs/demo.ipynb b/flax/nnx/docs/demo.ipynb similarity index 99% rename from flax/experimental/nnx/docs/demo.ipynb rename to flax/nnx/docs/demo.ipynb index ae71ad479a..a2521ef10f 100644 --- a/flax/experimental/nnx/docs/demo.ipynb +++ b/flax/nnx/docs/demo.ipynb @@ -17,7 +17,7 @@ "source": [ "import jax\n", "from jax import numpy as jnp\n", - "from flax.experimental import nnx" + "from flax import nnx" ] }, { diff --git a/flax/experimental/nnx/docs/demo.md b/flax/nnx/docs/demo.md similarity index 99% rename from flax/experimental/nnx/docs/demo.md rename to flax/nnx/docs/demo.md index 5d02e5da7c..f507f9c482 100644 --- a/flax/experimental/nnx/docs/demo.md +++ b/flax/nnx/docs/demo.md @@ -13,7 +13,7 @@ jupytext: ```{code-cell} ipython3 import jax from jax import numpy as jnp -from flax.experimental import nnx +from flax import nnx ``` ### [1] NNX is Pythonic diff --git a/flax/experimental/nnx/docs/images/stateful-transforms.png b/flax/nnx/docs/images/stateful-transforms.png similarity index 100% rename from flax/experimental/nnx/docs/images/stateful-transforms.png rename to flax/nnx/docs/images/stateful-transforms.png diff --git a/flax/experimental/nnx/docs/quick_start.ipynb b/flax/nnx/docs/quick_start.ipynb similarity index 99% rename from flax/experimental/nnx/docs/quick_start.ipynb rename to flax/nnx/docs/quick_start.ipynb index fc617db8a8..df64361b43 100644 --- a/flax/experimental/nnx/docs/quick_start.ipynb +++ b/flax/nnx/docs/quick_start.ipynb @@ -146,7 +146,7 @@ "source": [ "import jax\n", "import jax.numpy as jnp\n", - "from flax.experimental import nnx\n", + "from flax import nnx\n", "\n", "\n", "class CNN(nnx.Module):\n", diff --git a/flax/experimental/nnx/docs/tiny_nnx.ipynb b/flax/nnx/docs/tiny_nnx.ipynb similarity index 100% rename from flax/experimental/nnx/docs/tiny_nnx.ipynb rename to flax/nnx/docs/tiny_nnx.ipynb diff --git a/flax/experimental/nnx/docs/why.ipynb b/flax/nnx/docs/why.ipynb similarity index 99% rename from flax/experimental/nnx/docs/why.ipynb rename to flax/nnx/docs/why.ipynb index 04cad17dab..46caf8c4e8 100644 --- a/flax/experimental/nnx/docs/why.ipynb +++ b/flax/nnx/docs/why.ipynb @@ -7,7 +7,7 @@ "# Why NNX?\n", "\n", "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/flax/experimental/nnx/docs/why.ipynb)\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/flax/nnx/docs/why.ipynb)\n", "\n", "Four years ago we developed the Flax \"Linen\" API to support modeling research on JAX, with a focus on scaling scaling and performance. We've learned a lot from our users over these years.\n", "\n", @@ -25,8 +25,8 @@ "\n", "We'd love to hear from any of our users about their thoughts on these ideas.\n", "\n", - "[[nnx on github](https://github.com/google/flax/tree/main/flax/experimental/nnx)]\n", - "[[this doc on github](https://github.com/google/flax/blob/main/flax/experimental/nnx/docs/why.ipynb)]" + "[[nnx on github](https://github.com/google/flax/tree/main/flax/nnx)]\n", + "[[this doc on github](https://github.com/google/flax/blob/main/flax/nnx/docs/why.ipynb)]" ] }, { @@ -39,7 +39,7 @@ "from functools import partial\n", "import jax\n", "from jax import random, numpy as jnp\n", - "from flax.experimental import nnx" + "from flax import nnx" ] }, { diff --git a/flax/experimental/nnx/docs/why.md b/flax/nnx/docs/why.md similarity index 98% rename from flax/experimental/nnx/docs/why.md rename to flax/nnx/docs/why.md index 3dce4ad63d..07142c0f49 100644 --- a/flax/experimental/nnx/docs/why.md +++ b/flax/nnx/docs/why.md @@ -11,7 +11,7 @@ jupytext: # Why NNX? -[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/flax/experimental/nnx/docs/why.ipynb) +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/flax/nnx/docs/why.ipynb) Four years ago we developed the Flax "Linen" API to support modeling research on JAX, with a focus on scaling scaling and performance. We've learned a lot from our users over these years. @@ -29,15 +29,15 @@ NNX is an attempt to keep the features that made Linen useful while introducing We'd love to hear from any of our users about their thoughts on these ideas. -[[nnx on github](https://github.com/google/flax/tree/main/flax/experimental/nnx)] -[[this doc on github](https://github.com/google/flax/blob/main/flax/experimental/nnx/docs/why.ipynb)] +[[nnx on github](https://github.com/google/flax/tree/main/flax/nnx)] +[[this doc on github](https://github.com/google/flax/blob/main/flax/nnx/docs/why.ipynb)] ```{code-cell} ! pip install -U git+https://github.com/google/flax.git from functools import partial import jax from jax import random, numpy as jnp -from flax.experimental import nnx +from flax import nnx ``` ### NNX is Pythonic diff --git a/flax/experimental/nnx/examples/lm1b/README.md b/flax/nnx/examples/lm1b/README.md similarity index 100% rename from flax/experimental/nnx/examples/lm1b/README.md rename to flax/nnx/examples/lm1b/README.md diff --git a/flax/experimental/nnx/examples/lm1b/configs/default.py b/flax/nnx/examples/lm1b/configs/default.py similarity index 100% rename from flax/experimental/nnx/examples/lm1b/configs/default.py rename to flax/nnx/examples/lm1b/configs/default.py diff --git a/flax/experimental/nnx/examples/lm1b/input_pipeline.py b/flax/nnx/examples/lm1b/input_pipeline.py similarity index 100% rename from flax/experimental/nnx/examples/lm1b/input_pipeline.py rename to flax/nnx/examples/lm1b/input_pipeline.py diff --git a/flax/experimental/nnx/examples/lm1b/input_pipeline_test.py b/flax/nnx/examples/lm1b/input_pipeline_test.py similarity index 98% rename from flax/experimental/nnx/examples/lm1b/input_pipeline_test.py rename to flax/nnx/examples/lm1b/input_pipeline_test.py index 4ead911fe3..e6287fac07 100644 --- a/flax/experimental/nnx/examples/lm1b/input_pipeline_test.py +++ b/flax/nnx/examples/lm1b/input_pipeline_test.py @@ -46,7 +46,7 @@ def _get_datasets(self): vocab_path = os.path.join(tempfile.mkdtemp(), 'sentencepiece_model') # Go two directories up to the root of the flax directory. - flax_root_dir = pathlib.Path(__file__).parents[5] + flax_root_dir = pathlib.Path(__file__).parents[4] data_dir = str(flax_root_dir) + '/.tfds/metadata' # pylint: disable=unused-variable with tfds.testing.mock_data(num_examples=128, data_dir=data_dir): diff --git a/flax/experimental/nnx/examples/lm1b/main.py b/flax/nnx/examples/lm1b/main.py similarity index 100% rename from flax/experimental/nnx/examples/lm1b/main.py rename to flax/nnx/examples/lm1b/main.py diff --git a/flax/experimental/nnx/examples/lm1b/models.py b/flax/nnx/examples/lm1b/models.py similarity index 99% rename from flax/experimental/nnx/examples/lm1b/models.py rename to flax/nnx/examples/lm1b/models.py index 1731ec7f31..bb80e1eee8 100644 --- a/flax/experimental/nnx/examples/lm1b/models.py +++ b/flax/nnx/examples/lm1b/models.py @@ -32,7 +32,7 @@ import numpy as np from jax import lax -from flax.experimental import nnx +from flax import nnx from configs import default Shape = tuple[int, ...] diff --git a/flax/experimental/nnx/examples/lm1b/models_test.py b/flax/nnx/examples/lm1b/models_test.py similarity index 99% rename from flax/experimental/nnx/examples/lm1b/models_test.py rename to flax/nnx/examples/lm1b/models_test.py index 76296ae503..cc377eb333 100644 --- a/flax/experimental/nnx/examples/lm1b/models_test.py +++ b/flax/nnx/examples/lm1b/models_test.py @@ -27,7 +27,7 @@ from jax import random from flax import traverse_util -from flax.experimental import nnx +from flax import nnx from configs import default from models import TransformerConfig, TransformerLM from utils import HasCache @@ -35,7 +35,7 @@ jax.config.update('jax_disable_most_optimizations', True) # add project_root to import lm1b Linen model -project_root = str(Path(__file__).absolute().parents[5]) +project_root = str(Path(__file__).absolute().parents[4]) sys.path.append(project_root) from examples.lm1b.models import TransformerLM as TransformerLinen # type: ignore[import-error] diff --git a/flax/experimental/nnx/examples/lm1b/requirements.txt b/flax/nnx/examples/lm1b/requirements.txt similarity index 100% rename from flax/experimental/nnx/examples/lm1b/requirements.txt rename to flax/nnx/examples/lm1b/requirements.txt diff --git a/flax/experimental/nnx/examples/lm1b/temperature_sampler.py b/flax/nnx/examples/lm1b/temperature_sampler.py similarity index 100% rename from flax/experimental/nnx/examples/lm1b/temperature_sampler.py rename to flax/nnx/examples/lm1b/temperature_sampler.py diff --git a/flax/experimental/nnx/examples/lm1b/temperature_sampler_test.py b/flax/nnx/examples/lm1b/temperature_sampler_test.py similarity index 100% rename from flax/experimental/nnx/examples/lm1b/temperature_sampler_test.py rename to flax/nnx/examples/lm1b/temperature_sampler_test.py diff --git a/flax/experimental/nnx/examples/lm1b/tokenizer.py b/flax/nnx/examples/lm1b/tokenizer.py similarity index 100% rename from flax/experimental/nnx/examples/lm1b/tokenizer.py rename to flax/nnx/examples/lm1b/tokenizer.py diff --git a/flax/experimental/nnx/examples/lm1b/train.py b/flax/nnx/examples/lm1b/train.py similarity index 99% rename from flax/experimental/nnx/examples/lm1b/train.py rename to flax/nnx/examples/lm1b/train.py index ed3f5986d0..a137b9da12 100644 --- a/flax/experimental/nnx/examples/lm1b/train.py +++ b/flax/nnx/examples/lm1b/train.py @@ -42,7 +42,7 @@ from utils import HasCache, TrainState from flax import linen as nn -from flax.experimental import nnx +from flax import nnx from flax.training import checkpoints, common_utils @@ -605,9 +605,7 @@ def constructor(config: models.TransformerConfig, key: jax.Array): lambda x: x / denominator, metrics_sums ) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr - summary['perplexity'] = jnp.clip( - jnp.exp(summary['loss']), max=1.0e4 - ) + summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), max=1.0e4) summary = {'train_' + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] diff --git a/flax/experimental/nnx/examples/lm1b/train_test.py b/flax/nnx/examples/lm1b/train_test.py similarity index 97% rename from flax/experimental/nnx/examples/lm1b/train_test.py rename to flax/nnx/examples/lm1b/train_test.py index 9040c4f268..1f135048dc 100644 --- a/flax/experimental/nnx/examples/lm1b/train_test.py +++ b/flax/nnx/examples/lm1b/train_test.py @@ -59,7 +59,7 @@ def test_train_and_evaluate(self): workdir = tempfile.mkdtemp() # Go two directories up to the root of the flax directory. - flax_root_dir = pathlib.Path(__file__).parents[5] + flax_root_dir = pathlib.Path(__file__).parents[4] data_dir = str(flax_root_dir) + '/.tfds/metadata' # pylint: disable=unused-variable print('data_dir: ', data_dir) diff --git a/flax/experimental/nnx/examples/lm1b/utils.py b/flax/nnx/examples/lm1b/utils.py similarity index 99% rename from flax/experimental/nnx/examples/lm1b/utils.py rename to flax/nnx/examples/lm1b/utils.py index 1bf2d7d8c6..d2afc3c3bc 100644 --- a/flax/experimental/nnx/examples/lm1b/utils.py +++ b/flax/nnx/examples/lm1b/utils.py @@ -25,7 +25,7 @@ from configs import default from models import TransformerConfig, TransformerLM -from flax.experimental import nnx +from flax import nnx from flax.training import train_state Dtype = Any @@ -38,8 +38,7 @@ class TrainState(train_state.TrainState): @runtime_checkable class HasCache(Protocol): - def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32): - ... + def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32): ... # Mesh utils. diff --git a/flax/experimental/nnx/examples/toy_examples/01_functional_api.py b/flax/nnx/examples/toy_examples/01_functional_api.py similarity index 98% rename from flax/experimental/nnx/examples/toy_examples/01_functional_api.py rename to flax/nnx/examples/toy_examples/01_functional_api.py index bd6451555e..8f90a24ef6 100644 --- a/flax/experimental/nnx/examples/toy_examples/01_functional_api.py +++ b/flax/nnx/examples/toy_examples/01_functional_api.py @@ -18,7 +18,7 @@ import matplotlib.pyplot as plt import numpy as np -from flax.experimental import nnx +from flax import nnx X = np.linspace(0, 1, 100)[:, None] Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) diff --git a/flax/experimental/nnx/examples/toy_examples/02_lifted_transforms.py b/flax/nnx/examples/toy_examples/02_lifted_transforms.py similarity index 98% rename from flax/experimental/nnx/examples/toy_examples/02_lifted_transforms.py rename to flax/nnx/examples/toy_examples/02_lifted_transforms.py index a29efe153c..bb2238f7a2 100644 --- a/flax/experimental/nnx/examples/toy_examples/02_lifted_transforms.py +++ b/flax/nnx/examples/toy_examples/02_lifted_transforms.py @@ -19,7 +19,7 @@ import numpy as np import optax -from flax.experimental import nnx +from flax import nnx X = np.linspace(0, 1, 100)[:, None] Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) @@ -62,6 +62,7 @@ def __call__(self, x): tx = optax.sgd(1e-3) optimizer = nnx.Optimizer(model, tx) + @nnx.jit def train_step(model: MLP, optimizer: nnx.Optimizer, batch): x, y = batch diff --git a/flax/experimental/nnx/examples/toy_examples/05_vae.py b/flax/nnx/examples/toy_examples/05_vae.py similarity index 99% rename from flax/experimental/nnx/examples/toy_examples/05_vae.py rename to flax/nnx/examples/toy_examples/05_vae.py index 895dcd894b..7819c8dbec 100644 --- a/flax/experimental/nnx/examples/toy_examples/05_vae.py +++ b/flax/nnx/examples/toy_examples/05_vae.py @@ -22,7 +22,7 @@ import optax from datasets import load_dataset -from flax.experimental import nnx +from flax import nnx np.random.seed(42) latent_size = 32 diff --git a/flax/experimental/nnx/examples/toy_examples/06_scan_over_layers.py b/flax/nnx/examples/toy_examples/06_scan_over_layers.py similarity index 98% rename from flax/experimental/nnx/examples/toy_examples/06_scan_over_layers.py rename to flax/nnx/examples/toy_examples/06_scan_over_layers.py index 9a2b01727c..ad2b2edcea 100644 --- a/flax/experimental/nnx/examples/toy_examples/06_scan_over_layers.py +++ b/flax/nnx/examples/toy_examples/06_scan_over_layers.py @@ -17,7 +17,7 @@ import jax import jax.numpy as jnp -from flax.experimental import nnx +from flax import nnx class Block(nnx.Module): diff --git a/flax/experimental/nnx/examples/toy_examples/08_save_load_checkpoints.py b/flax/nnx/examples/toy_examples/08_save_load_checkpoints.py similarity index 98% rename from flax/experimental/nnx/examples/toy_examples/08_save_load_checkpoints.py rename to flax/nnx/examples/toy_examples/08_save_load_checkpoints.py index 281a290f1f..ea69079640 100644 --- a/flax/experimental/nnx/examples/toy_examples/08_save_load_checkpoints.py +++ b/flax/nnx/examples/toy_examples/08_save_load_checkpoints.py @@ -18,7 +18,7 @@ import jax.numpy as jnp import orbax.checkpoint as orbax -from flax.experimental import nnx +from flax import nnx class MLP(nnx.Module): diff --git a/flax/experimental/nnx/examples/toy_examples/09_parameter_surgery.py b/flax/nnx/examples/toy_examples/09_parameter_surgery.py similarity index 98% rename from flax/experimental/nnx/examples/toy_examples/09_parameter_surgery.py rename to flax/nnx/examples/toy_examples/09_parameter_surgery.py index c7f5dd07f7..11a785aaa6 100644 --- a/flax/experimental/nnx/examples/toy_examples/09_parameter_surgery.py +++ b/flax/nnx/examples/toy_examples/09_parameter_surgery.py @@ -15,7 +15,7 @@ import jax -from flax.experimental import nnx +from flax import nnx # lets pretend this function loads a pretrained model from a checkpoint diff --git a/flax/experimental/nnx/examples/toy_examples/requirements.txt b/flax/nnx/examples/toy_examples/requirements.txt similarity index 100% rename from flax/experimental/nnx/examples/toy_examples/requirements.txt rename to flax/nnx/examples/toy_examples/requirements.txt diff --git a/flax/experimental/nnx/nnx/__init__.py b/flax/nnx/nnx/__init__.py similarity index 100% rename from flax/experimental/nnx/nnx/__init__.py rename to flax/nnx/nnx/__init__.py diff --git a/flax/experimental/nnx/nnx/compat/__init__.py b/flax/nnx/nnx/compat/__init__.py similarity index 100% rename from flax/experimental/nnx/nnx/compat/__init__.py rename to flax/nnx/nnx/compat/__init__.py diff --git a/flax/experimental/nnx/nnx/compat/module.py b/flax/nnx/nnx/compat/module.py similarity index 94% rename from flax/experimental/nnx/nnx/compat/module.py rename to flax/nnx/nnx/compat/module.py index c152811a18..0af4d38f58 100644 --- a/flax/experimental/nnx/nnx/compat/module.py +++ b/flax/nnx/nnx/compat/module.py @@ -21,13 +21,13 @@ import typing as tp import typing_extensions as tpe -from flax.experimental.nnx.nnx import graph, rnglib -import flax.experimental.nnx.nnx.module as nnx_module -from flax.experimental.nnx.nnx.proxy_caller import ( +from flax.nnx.nnx import graph, rnglib +import flax.nnx.nnx.module as nnx_module +from flax.nnx.nnx.proxy_caller import ( CallableProxy, DelayedAccessor, ) -from flax.experimental.nnx.nnx.object import Object +from flax.nnx.nnx.object import Object M = tp.TypeVar('M', bound='Module') F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) @@ -141,8 +141,8 @@ def init(self: M) -> M: Example:: - >>> from flax.experimental import nnx - >>> from flax.experimental.nnx import compat as nnc + >>> from flax import nnx + >>> from flax.nnx import compat as nnc >>> import jax >>> import jax.numpy as jnp ... diff --git a/flax/experimental/nnx/nnx/compat/wrappers.py b/flax/nnx/nnx/compat/wrappers.py similarity index 90% rename from flax/experimental/nnx/nnx/compat/wrappers.py rename to flax/nnx/nnx/compat/wrappers.py index 50a954e65b..27c889c411 100644 --- a/flax/experimental/nnx/nnx/compat/wrappers.py +++ b/flax/nnx/nnx/compat/wrappers.py @@ -16,12 +16,12 @@ import typing as tp from typing import Any -from flax.experimental import nnx +from flax import nnx from flax import linen -from flax.experimental.nnx.nnx import variables as variableslib -from flax.experimental.nnx.nnx.module import GraphDef, Module -from flax.experimental.nnx.nnx.rnglib import Rngs -from flax.experimental.nnx.nnx.state import State +from flax.nnx.nnx import variables as variableslib +from flax.nnx.nnx.module import GraphDef, Module +from flax.nnx.nnx.rnglib import Rngs +from flax.nnx.nnx.state import State M = tp.TypeVar('M', bound=Module) @@ -107,5 +107,4 @@ def __call__( return out -class NNXWrapper(linen.Module): - ... +class NNXWrapper(linen.Module): ... diff --git a/flax/experimental/nnx/nnx/errors.py b/flax/nnx/nnx/errors.py similarity index 100% rename from flax/experimental/nnx/nnx/errors.py rename to flax/nnx/nnx/errors.py diff --git a/flax/experimental/nnx/nnx/filterlib.py b/flax/nnx/nnx/filterlib.py similarity index 100% rename from flax/experimental/nnx/nnx/filterlib.py rename to flax/nnx/nnx/filterlib.py diff --git a/flax/experimental/nnx/nnx/graph.py b/flax/nnx/nnx/graph.py similarity index 99% rename from flax/experimental/nnx/nnx/graph.py rename to flax/nnx/nnx/graph.py index 957418513a..b1efb090a0 100644 --- a/flax/experimental/nnx/nnx/graph.py +++ b/flax/nnx/nnx/graph.py @@ -26,22 +26,22 @@ import numpy as np import typing_extensions as tpe -from flax.experimental.nnx.nnx import ( +from flax.nnx.nnx import ( filterlib, reprlib, ) -from flax.experimental.nnx.nnx.proxy_caller import ( +from flax.nnx.nnx.proxy_caller import ( ApplyCaller, CallableProxy, DelayedAccessor, ) -from flax.experimental.nnx.nnx.state import ( +from flax.nnx.nnx.state import ( FlatState, State, StateLeaf, is_state_leaf, ) -from flax.experimental.nnx.nnx.variables import Variable, VariableState +from flax.nnx.nnx.variables import Variable, VariableState from flax.typing import Key, PathParts A = tp.TypeVar('A') @@ -69,6 +69,7 @@ NodeLeaf = tp.Union[Variable[tp.Any], np.ndarray, jax.Array] + @dataclasses.dataclass class GraphContext(threading.local): update_context_stacks: defaultdict[str, list[UpdateContext]] = ( @@ -831,6 +832,12 @@ def _graph_update_static( node_impl.set_key(node, name, value_updates) + +# -------------------------------------------------------- +# UpdateContext +# -------------------------------------------------------- + + # -------------------------------------------------------- # UpdateContext # -------------------------------------------------------- @@ -987,6 +994,7 @@ def merge( jax.tree_util.register_static(UpdateContext) + @dataclasses.dataclass class UpdateContextManager: tag: str @@ -1054,7 +1062,7 @@ def update_context(tag: str): Here is a simple example showing the use of ``update_context``:: - >>> from flax.experimental import nnx + >>> from flax import nnx ... >>> m1 = nnx.Dict({}) >>> with nnx.update_context('example') as ctx: @@ -1078,7 +1086,7 @@ def update_context(tag: str): current active context. current_update_context can be used as a way of accessing the current active context without having to pass it as a capture:: - >>> from flax.experimental import nnx + >>> from flax import nnx ... >>> m1 = nnx.Dict({}) >>> @jax.jit @@ -1378,7 +1386,7 @@ def iter_graph(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]: root. Repeated nodes are visited only once. Leaves include static values. Example:: - >>> from flax.experimental import nnx + >>> from flax import nnx >>> import jax.numpy as jnp ... >>> class Linear(nnx.Module): diff --git a/flax/experimental/nnx/nnx/helpers.py b/flax/nnx/nnx/helpers.py similarity index 95% rename from flax/experimental/nnx/nnx/helpers.py rename to flax/nnx/nnx/helpers.py index 90901c1f88..5667e38df1 100644 --- a/flax/experimental/nnx/nnx/helpers.py +++ b/flax/nnx/nnx/helpers.py @@ -34,11 +34,11 @@ import jax.numpy as jnp import optax -from flax.experimental.nnx.nnx.graph import Key -from flax.experimental.nnx.nnx.module import GraphDef, Module -from flax.experimental.nnx.nnx.proxy_caller import ApplyCaller -from flax.experimental.nnx.nnx.rnglib import Rngs -from flax.experimental.nnx.nnx.state import State +from flax.nnx.nnx.graph import Key +from flax.nnx.nnx.module import GraphDef, Module +from flax.nnx.nnx.proxy_caller import ApplyCaller +from flax.nnx.nnx.rnglib import Rngs +from flax.nnx.nnx.state import State from flax.training.train_state import struct A = tp.TypeVar('A') diff --git a/flax/experimental/nnx/nnx/ids.py b/flax/nnx/nnx/ids.py similarity index 100% rename from flax/experimental/nnx/nnx/ids.py rename to flax/nnx/nnx/ids.py diff --git a/flax/experimental/nnx/nnx/module.py b/flax/nnx/nnx/module.py similarity index 95% rename from flax/experimental/nnx/nnx/module.py rename to flax/nnx/nnx/module.py index 1cb578d83a..6f99558e7e 100644 --- a/flax/experimental/nnx/nnx/module.py +++ b/flax/nnx/nnx/module.py @@ -19,14 +19,14 @@ import jax.tree_util as jtu -from flax.experimental.nnx.nnx import ( +from flax.nnx.nnx import ( filterlib, graph, ) -from flax.experimental.nnx.nnx import variables as variableslib -from flax.experimental.nnx.nnx.graph import GraphDef -from flax.experimental.nnx.nnx.object import Object, ObjectMeta -from flax.experimental.nnx.nnx.state import State, StateLeaf +from flax.nnx.nnx import variables as variableslib +from flax.nnx.nnx.graph import GraphDef +from flax.nnx.nnx.object import Object, ObjectMeta +from flax.nnx.nnx.state import State, StateLeaf from flax.typing import Path, PathParts A = tp.TypeVar('A') @@ -83,7 +83,7 @@ def iter_modules(self) -> tp.Iterator[tuple[PathParts, Module]]: Example:: - >>> from flax.experimental import nnx + >>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): @@ -116,7 +116,7 @@ def set_attributes( Example:: - >>> from flax.experimental import nnx + >>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): @@ -174,7 +174,7 @@ def train(self, **attributes): Example:: - >>> from flax.experimental import nnx + >>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): @@ -210,7 +210,7 @@ def eval(self, **attributes): Example:: - >>> from flax.experimental import nnx + >>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): diff --git a/flax/experimental/nnx/nnx/nn/__init__.py b/flax/nnx/nnx/nn/__init__.py similarity index 100% rename from flax/experimental/nnx/nnx/nn/__init__.py rename to flax/nnx/nnx/nn/__init__.py diff --git a/flax/experimental/nnx/nnx/nn/activations.py b/flax/nnx/nnx/nn/activations.py similarity index 100% rename from flax/experimental/nnx/nnx/nn/activations.py rename to flax/nnx/nnx/nn/activations.py diff --git a/flax/experimental/nnx/nnx/nn/attention.py b/flax/nnx/nnx/nn/attention.py similarity index 98% rename from flax/experimental/nnx/nnx/nn/attention.py rename to flax/nnx/nnx/nn/attention.py index 8e66567c6e..d66400bc30 100644 --- a/flax/experimental/nnx/nnx/nn/attention.py +++ b/flax/nnx/nnx/nn/attention.py @@ -23,16 +23,16 @@ import jax.numpy as jnp from jax import lax, random -from flax.experimental import nnx -from flax.experimental.nnx.nnx import rnglib -from flax.experimental.nnx.nnx.module import Module, first_from -from flax.experimental.nnx.nnx.nn import initializers -from flax.experimental.nnx.nnx.nn.dtypes import promote_dtype -from flax.experimental.nnx.nnx.nn.linear import ( +from flax import nnx +from flax.nnx.nnx import rnglib +from flax.nnx.nnx.module import Module, first_from +from flax.nnx.nnx.nn import initializers +from flax.nnx.nnx.nn.dtypes import promote_dtype +from flax.nnx.nnx.nn.linear import ( LinearGeneral, default_kernel_init, ) -from flax.experimental.nnx.nnx.nn.normalization import LayerNorm +from flax.nnx.nnx.nn.normalization import LayerNorm from flax.typing import ( Dtype, Shape, @@ -40,6 +40,7 @@ PrecisionLike, DotGeneralT, ) + Array = jax.Array @@ -590,7 +591,7 @@ def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32): Example usage:: - >>> from flax.experimental import nnx + >>> from flax import nnx >>> import jax.numpy as jnp ... >>> rngs = nnx.Rngs(42) diff --git a/flax/experimental/nnx/nnx/nn/dtypes.py b/flax/nnx/nnx/nn/dtypes.py similarity index 100% rename from flax/experimental/nnx/nnx/nn/dtypes.py rename to flax/nnx/nnx/nn/dtypes.py diff --git a/flax/experimental/nnx/nnx/nn/initializers.py b/flax/nnx/nnx/nn/initializers.py similarity index 95% rename from flax/experimental/nnx/nnx/nn/initializers.py rename to flax/nnx/nnx/nn/initializers.py index ce73718244..2a44d7147a 100644 --- a/flax/experimental/nnx/nnx/nn/initializers.py +++ b/flax/nnx/nnx/nn/initializers.py @@ -42,7 +42,7 @@ def zeros_init() -> Initializer: """Builds an initializer that returns a constant array full of zeros. >>> import jax, jax.numpy as jnp - >>> from flax.experimental.nnx import initializers + >>> from flax.nnx import initializers >>> zeros_initializer = initializers.zeros_init() >>> zeros_initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[0., 0., 0.], @@ -55,7 +55,7 @@ def ones_init() -> Initializer: """Builds an initializer that returns a constant array full of ones. >>> import jax, jax.numpy as jnp - >>> from flax.experimental.nnx import initializers + >>> from flax.nnx import initializers >>> ones_initializer = initializers.ones_init() >>> ones_initializer(jax.random.key(42), (3, 2), jnp.float32) Array([[1., 1.], diff --git a/flax/experimental/nnx/nnx/nn/linear.py b/flax/nnx/nnx/nn/linear.py similarity index 97% rename from flax/experimental/nnx/nnx/nn/linear.py rename to flax/nnx/nnx/nn/linear.py index 0c8f0cd911..696aeac54d 100644 --- a/flax/experimental/nnx/nnx/nn/linear.py +++ b/flax/nnx/nnx/nn/linear.py @@ -36,10 +36,10 @@ import opt_einsum from flax.core.frozen_dict import FrozenDict -from flax.experimental import nnx -from flax.experimental.nnx.nnx import rnglib, variables -from flax.experimental.nnx.nnx.module import Module, first_from -from flax.experimental.nnx.nnx.nn import dtypes, initializers +from flax import nnx +from flax.nnx.nnx import rnglib, variables +from flax.nnx.nnx.module import Module, first_from +from flax.nnx.nnx.nn import dtypes, initializers from flax.typing import ( Dtype, Shape, @@ -110,7 +110,7 @@ class LinearGeneral(Module): Example usage:: - >>> from flax.experimental import nnx + >>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> # equivalent to `nnx.Linear(2, 4)` @@ -270,7 +270,7 @@ def __call__(self, inputs: Array) -> Array: contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims)) inputs, kernel, bias = dtypes.promote_dtype( - (inputs, kernel, bias), dtype=self.dtype + (inputs, kernel, bias), dtype=self.dtype ) if self.dot_general_cls is not None: @@ -355,7 +355,7 @@ def __call__(self, inputs: Array) -> Array: bias = self.bias.value inputs, kernel, bias = dtypes.promote_dtype( - (inputs, kernel, bias), dtype=self.dtype + (inputs, kernel, bias), dtype=self.dtype ) y = self.dot_general( inputs, @@ -373,7 +373,7 @@ class Einsum(Module): Example usage:: - >>> from flax.experimental import nnx + >>> from flax import nnx >>> import jax.numpy as jnp ... >>> layer = nnx.Einsum('nta,hab->nthb', (8, 2, 4), (8, 4), rngs=nnx.Rngs(0)) @@ -463,12 +463,12 @@ def __call__( self._einsum_str_check(einsum_str) inputs, kernel, bias = dtypes.promote_dtype( - ( - inputs, - self.kernel.value, - self.bias.value if self.bias is not None else self.bias, - ), - dtype=self.dtype, + ( + inputs, + self.kernel.value, + self.bias.value if self.bias is not None else self.bias, + ), + dtype=self.dtype, ) y = jnp.einsum(einsum_str, inputs, kernel, precision=self.precision) @@ -706,7 +706,7 @@ def maybe_broadcast( bias = self.bias.value inputs, kernel, bias = dtypes.promote_dtype( - (inputs, kernel, bias), dtype=self.dtype + (inputs, kernel, bias), dtype=self.dtype ) y = self.conv_general_dilated( @@ -730,6 +730,7 @@ def maybe_broadcast( y = jnp.reshape(y, output_shape) return y + class ConvTranspose(Module): # features: int # kernel_size: Union[int, Sequence[int]] @@ -869,7 +870,7 @@ def maybe_broadcast( bias = self.bias.value if self.bias is not None else None inputs, kernel, bias = dtypes.promote_dtype( - (inputs, kernel, bias), dtype=self.dtype + (inputs, kernel, bias), dtype=self.dtype ) y = lax.conv_transpose( @@ -997,7 +998,7 @@ def __call__(self, inputs: Array) -> Array: # Use take because fancy indexing numpy arrays with JAX indices does not # work correctly. (embedding,) = dtypes.promote_dtype( - (self.embedding.value,), dtype=self.dtype, inexact=False + (self.embedding.value,), dtype=self.dtype, inexact=False ) if self.num_embeddings == 1: return jnp.where( @@ -1022,6 +1023,6 @@ def attend(self, query: Array) -> Array: in NLP models. """ query, embedding = dtypes.promote_dtype( - (query, self.embedding.value), dtype=self.dtype + (query, self.embedding.value), dtype=self.dtype ) return jnp.dot(query, embedding.T) diff --git a/flax/experimental/nnx/nnx/nn/lora.py b/flax/nnx/nnx/nn/lora.py similarity index 89% rename from flax/experimental/nnx/nnx/nn/lora.py rename to flax/nnx/nnx/nn/lora.py index 2ac217efdd..96d495db53 100644 --- a/flax/experimental/nnx/nnx/nn/lora.py +++ b/flax/nnx/nnx/nn/lora.py @@ -32,14 +32,11 @@ import jax import jax.numpy as jnp -from flax.experimental import nnx -from flax.experimental.nnx.nnx import rnglib, variables -from flax.experimental.nnx.nnx.module import Module -from flax.experimental.nnx.nnx.nn import initializers -from flax.typing import ( - Dtype, - Initializer, -) +from flax.nnx.nnx import rnglib, variables +from flax.nnx.nnx.module import Module +from flax.nnx.nnx.nn import initializers +from flax.nnx.nnx.nn.linear import Linear +from flax.typing import Dtype, Initializer Array = jax.Array Axis = int @@ -49,7 +46,8 @@ default_kernel_init = initializers.lecun_normal() -class LoRAParam(variables.Variable[A]): pass +class LoRAParam(variables.Variable[A]): + pass class LoRA(Module): @@ -88,13 +86,14 @@ class LoRA(Module): kernel_init: initializer function for the weight matrices. lora_param_type: the type of the LoRA params. """ + def __init__( self, in_features: int, lora_rank: int, out_features: int, *, - base_module: tp.Optional[nnx.Module] = None, + base_module: tp.Optional[Module] = None, dtype: tp.Optional[Dtype] = None, param_dtype: Dtype = jnp.float32, kernel_init: Initializer = default_kernel_init, @@ -124,7 +123,7 @@ def __call__(self, x: jax.Array): return out -class LoRALinear(nnx.Linear): +class LoRALinear(Linear): """An `nnx.Linear` layer in which the output will be LoRAified. The model state structure will be compatible with that of Linear. @@ -159,6 +158,7 @@ class LoRALinear(nnx.Linear): kernel_init: initializer function for the weight matrices. lora_param_type: the type of the LoRA params. """ + def __init__( self, in_features: int, @@ -173,18 +173,18 @@ def __init__( **kwargs, ): super().__init__(in_features, out_features, rngs=rngs, **kwargs) - self.lora = LoRA(in_features, lora_rank, out_features, - dtype=lora_dtype, param_dtype=lora_param_dtype, - kernel_init=lora_kernel_init, lora_param_type=lora_param_type, - rngs=rngs) + self.lora = LoRA( + in_features, + lora_rank, + out_features, + dtype=lora_dtype, + param_dtype=lora_param_dtype, + kernel_init=lora_kernel_init, + lora_param_type=lora_param_type, + rngs=rngs, + ) def __call__(self, x: jax.Array): y = super().__call__(x) y += self.lora(x) return y - - - - - - diff --git a/flax/experimental/nnx/nnx/nn/normalization.py b/flax/nnx/nnx/nn/normalization.py similarity index 98% rename from flax/experimental/nnx/nnx/nn/normalization.py rename to flax/nnx/nnx/nn/normalization.py index f27d6b2798..c65754fda7 100644 --- a/flax/experimental/nnx/nnx/nn/normalization.py +++ b/flax/nnx/nnx/nn/normalization.py @@ -18,10 +18,10 @@ import jax.numpy as jnp from jax import lax -from flax.experimental import nnx -from flax.experimental.nnx.nnx import rnglib -from flax.experimental.nnx.nnx.module import Module, first_from -from flax.experimental.nnx.nnx.nn import dtypes, initializers +from flax import nnx +from flax.nnx.nnx import rnglib +from flax.nnx.nnx.module import Module, first_from +from flax.nnx.nnx.nn import dtypes, initializers from flax.typing import ( Array, Dtype, diff --git a/flax/experimental/nnx/nnx/nn/stochastic.py b/flax/nnx/nnx/nn/stochastic.py similarity index 96% rename from flax/experimental/nnx/nnx/nn/stochastic.py rename to flax/nnx/nnx/nn/stochastic.py index efd8f94f31..a2ee77bc27 100644 --- a/flax/experimental/nnx/nnx/nn/stochastic.py +++ b/flax/nnx/nnx/nn/stochastic.py @@ -34,8 +34,8 @@ import jax.numpy as jnp from jax import lax, random -from flax.experimental.nnx.nnx import rnglib -from flax.experimental.nnx.nnx.module import Module, first_from +from flax.nnx.nnx import rnglib +from flax.nnx.nnx.module import Module, first_from @dataclasses.dataclass diff --git a/flax/experimental/nnx/nnx/object.py b/flax/nnx/nnx/object.py similarity index 97% rename from flax/experimental/nnx/nnx/object.py rename to flax/nnx/nnx/object.py index cd0284cb36..9b2ae9a431 100644 --- a/flax/experimental/nnx/nnx/object.py +++ b/flax/nnx/nnx/object.py @@ -24,13 +24,13 @@ import jax import numpy as np -from flax.experimental.nnx.nnx import ( +from flax.nnx.nnx import ( errors, reprlib, tracers, ) -from flax.experimental.nnx.nnx import graph -from flax.experimental.nnx.nnx.variables import Variable, VariableState +from flax.nnx.nnx import graph +from flax.nnx.nnx.variables import Variable, VariableState from flax.typing import Key G = tp.TypeVar('G', bound='Object') diff --git a/flax/experimental/nnx/nnx/proxy_caller.py b/flax/nnx/nnx/proxy_caller.py similarity index 100% rename from flax/experimental/nnx/nnx/proxy_caller.py rename to flax/nnx/nnx/proxy_caller.py diff --git a/flax/experimental/nnx/nnx/reprlib.py b/flax/nnx/nnx/reprlib.py similarity index 100% rename from flax/experimental/nnx/nnx/reprlib.py rename to flax/nnx/nnx/reprlib.py diff --git a/flax/experimental/nnx/nnx/rnglib.py b/flax/nnx/nnx/rnglib.py similarity index 94% rename from flax/experimental/nnx/nnx/rnglib.py rename to flax/nnx/nnx/rnglib.py index 8554f85f4d..2f93457f11 100644 --- a/flax/experimental/nnx/nnx/rnglib.py +++ b/flax/nnx/nnx/rnglib.py @@ -33,12 +33,12 @@ import jax import jax.numpy as jnp -from flax.experimental.nnx.nnx import graph -from flax.experimental.nnx.nnx.state import State -from flax.experimental.nnx.nnx.variables import Variable -from flax.experimental.nnx.nnx import filterlib -from flax.experimental.nnx.nnx.filterlib import All -from flax.experimental.nnx.nnx.object import Object +from flax.nnx.nnx import graph +from flax.nnx.nnx.state import State +from flax.nnx.nnx.variables import Variable +from flax.nnx.nnx import filterlib +from flax.nnx.nnx.filterlib import All +from flax.nnx.nnx.object import Object Counts = list[int] AxesValue = tp.Union[int, None] @@ -63,6 +63,7 @@ class RngCount(RngState): class RngKey(RngState): tag: str + class RngKeyBackup(RngState): pass @@ -155,6 +156,7 @@ def __len__(self) -> int: def __contains__(self, name: tp.Any) -> bool: return name in vars(self) + class ForkStates(tp.NamedTuple): split_keys: State split_counts: State diff --git a/flax/experimental/nnx/nnx/spmd.py b/flax/nnx/nnx/spmd.py similarity index 98% rename from flax/experimental/nnx/nnx/spmd.py rename to flax/nnx/nnx/spmd.py index 20c0630173..fd7067c0ae 100644 --- a/flax/experimental/nnx/nnx/spmd.py +++ b/flax/nnx/nnx/spmd.py @@ -19,8 +19,8 @@ from jax.interpreters import pxla from jax.sharding import Mesh, PartitionSpec -from flax.experimental.nnx.nnx import variables -from flax.experimental.nnx.nnx.state import State +from flax.nnx.nnx import variables +from flax.nnx.nnx.state import State from flax.typing import ( Array, ArrayPytree, # pylint: disable=invalid-name diff --git a/flax/experimental/nnx/nnx/state.py b/flax/nnx/nnx/state.py similarity index 96% rename from flax/experimental/nnx/nnx/state.py rename to flax/nnx/nnx/state.py index dff6fec5d4..c9edd5a756 100644 --- a/flax/experimental/nnx/nnx/state.py +++ b/flax/nnx/nnx/state.py @@ -35,8 +35,8 @@ import numpy as np from flax import traverse_util -from flax.experimental.nnx.nnx import filterlib, reprlib -from flax.experimental.nnx.nnx.variables import VariableState +from flax.nnx.nnx import filterlib, reprlib +from flax.nnx.nnx.variables import VariableState from flax.typing import Key, PathParts A = tp.TypeVar('A') @@ -142,8 +142,7 @@ def from_flat_path( return cls(nested_state) @tp.overload - def split(self, first: filterlib.Filter, /) -> 'State': - ... + def split(self, first: filterlib.Filter, /) -> 'State': ... @tp.overload def split( @@ -152,8 +151,7 @@ def split( second: filterlib.Filter, /, *filters: filterlib.Filter, - ) -> tuple['State', ...]: - ... + ) -> tuple['State', ...]: ... def split( self, first: filterlib.Filter, /, *filters: filterlib.Filter @@ -179,8 +177,7 @@ def filter( self, first: filterlib.Filter, /, - ) -> 'State': - ... + ) -> 'State': ... @tp.overload def filter( @@ -189,8 +186,7 @@ def filter( second: filterlib.Filter, /, *filters: filterlib.Filter, - ) -> tuple['State', ...]: - ... + ) -> tuple['State', ...]: ... def filter( self, @@ -239,6 +235,7 @@ def __sub__(self, other: 'State') -> 'State': return State.from_flat_path(diff) + def _state_flatten_with_keys(x: State): items = sorted(x._mapping.items()) children = tuple((jtu.DictKey(key), value) for key, value in items) diff --git a/flax/experimental/nnx/nnx/tracers.py b/flax/nnx/nnx/tracers.py similarity index 97% rename from flax/experimental/nnx/nnx/tracers.py rename to flax/nnx/nnx/tracers.py index 1e8688f4eb..c73e627e5e 100644 --- a/flax/experimental/nnx/nnx/tracers.py +++ b/flax/nnx/nnx/tracers.py @@ -20,7 +20,7 @@ import jax.core from jax.core import MainTrace -from flax.experimental.nnx.nnx import reprlib +from flax.nnx.nnx import reprlib @tp.runtime_checkable diff --git a/flax/experimental/nnx/nnx/training/__init__.py b/flax/nnx/nnx/training/__init__.py similarity index 100% rename from flax/experimental/nnx/nnx/training/__init__.py rename to flax/nnx/nnx/training/__init__.py diff --git a/flax/experimental/nnx/nnx/training/metrics.py b/flax/nnx/nnx/training/metrics.py similarity index 88% rename from flax/experimental/nnx/nnx/training/metrics.py rename to flax/nnx/nnx/training/metrics.py index 87ec7831b1..41b130fbf3 100644 --- a/flax/experimental/nnx/nnx/training/metrics.py +++ b/flax/nnx/nnx/training/metrics.py @@ -28,13 +28,14 @@ from __future__ import annotations import jax, jax.numpy as jnp -from flax.experimental.nnx.nnx.object import Object -from flax.experimental.nnx.nnx.variables import Variable -from flax.experimental.nnx.nnx import filterlib, graph +from flax.nnx.nnx.object import Object +from flax.nnx.nnx.variables import Variable +from flax.nnx.nnx import filterlib, graph import typing as tp -#TODO: add tests and docstrings +# TODO: add tests and docstrings + class MetricState(Variable): """Wrapper class for Metric Variables.""" @@ -45,13 +46,16 @@ class MetricState(Variable): class Metric(Object): def __init__(self): raise NotImplementedError('Must override `__init__()` method.') + def reset(self): raise NotImplementedError('Must override `reset()` method.') def update(self, **kwargs) -> None: raise NotImplementedError('Must override `update()` method.') + def compute(self): raise NotImplementedError('Must override `compute()` method.') + def split(self, *filters: filterlib.Filter): return graph.split(self, *filters) @@ -61,6 +65,7 @@ def __init__(self, argname: str = 'values'): self.argname = argname self.total = MetricState(jnp.array(0, dtype=jnp.float32)) self.count = MetricState(jnp.array(0, dtype=jnp.int32)) + def reset(self): self.total.value = jnp.array(0, dtype=jnp.float32) self.count.value = jnp.array(0, dtype=jnp.int32) @@ -69,19 +74,24 @@ def update(self, **kwargs): if self.argname not in kwargs: raise TypeError(f"Expected keyword argument '{self.argname}'") values: tp.Union[int, float, jax.Array] = kwargs[self.argname] - self.total.value += values if isinstance(values, (int, float)) else values.sum() + self.total.value += ( + values if isinstance(values, (int, float)) else values.sum() + ) self.count.value += 1 if isinstance(values, (int, float)) else values.size + def compute(self): return self.total.value / self.count.value + class Accuracy(Average): def update(self, *, logits: jax.Array, labels: jax.Array, **_): # type: ignore[override] if logits.ndim != labels.ndim + 1 or labels.dtype != jnp.int32: raise ValueError( - f"Expected labels.dtype==jnp.int32 and logits.ndim={logits.ndim}==" - f"labels.ndim+1={labels.ndim + 1}" + f'Expected labels.dtype==jnp.int32 and logits.ndim={logits.ndim}==' + f'labels.ndim+1={labels.ndim + 1}' ) - super().update(values=(logits.argmax(axis=-1)==labels)) + super().update(values=(logits.argmax(axis=-1) == labels)) + class MultiMetric(Metric): """MultiMetric class to store multiple metrics and update them in a single call. @@ -89,7 +99,7 @@ class MultiMetric(Metric): Example usage:: >>> import jax, jax.numpy as jnp - >>> from flax.experimental import nnx + >>> from flax import nnx ... >>> logits = jax.random.normal(jax.random.key(0), (5, 2)) >>> labels = jnp.array([1, 1, 0, 1, 0]) @@ -114,19 +124,26 @@ class MultiMetric(Metric): >>> metrics.compute() {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)} """ + def __init__(self, **metrics): # TODO: raise error if a kwarg is passed that is in ('reset', 'update', 'compute'), since these names are reserved for methods self._metric_names = [] for metric_name, metric in metrics.items(): self._metric_names.append(metric_name) vars(self)[metric_name] = metric + def reset(self): for metric_name in self._metric_names: getattr(self, metric_name).reset() + def update(self, **updates): # TODO: should we give the option of updating only some of the metrics and not all? e.g. if for some kwargs==None, don't do update # TODO: should we raise an error if a kwarg is passed into **updates that has no match with any underlying metric? e.g. user typo for metric_name in self._metric_names: getattr(self, metric_name).update(**updates) + def compute(self): - return {f'{metric_name}': getattr(self, metric_name).compute() for metric_name in self._metric_names} \ No newline at end of file + return { + f'{metric_name}': getattr(self, metric_name).compute() + for metric_name in self._metric_names + } \ No newline at end of file diff --git a/flax/experimental/nnx/nnx/training/optimizer.py b/flax/nnx/nnx/training/optimizer.py similarity index 92% rename from flax/experimental/nnx/nnx/training/optimizer.py rename to flax/nnx/nnx/training/optimizer.py index c20fc35d4c..00215c0062 100644 --- a/flax/experimental/nnx/nnx/training/optimizer.py +++ b/flax/nnx/nnx/training/optimizer.py @@ -30,24 +30,27 @@ import jax.numpy as jnp import optax -from flax.experimental import nnx -from flax.experimental.nnx.nnx import filterlib, graph -from flax.experimental.nnx.nnx.object import Object -from flax.experimental.nnx.nnx.variables import Variable +from flax import nnx +from flax.nnx.nnx import filterlib, graph +from flax.nnx.nnx.object import Object +from flax.nnx.nnx.variables import Variable + +# TODO: add tests and docstrings -#TODO: add tests and docstrings class OptState(Variable): """Wrapper class for Optimizer Variables.""" + pass + class Optimizer(Object): """Simple train state for the common case with a single Optax optimizer. Example usage:: >>> import jax, jax.numpy as jnp - >>> from flax.experimental import nnx + >>> from flax import nnx >>> import optax ... >>> class Model(nnx.Module): @@ -134,13 +137,10 @@ def update(self, grads): """ params = nnx.state(self.model, nnx.Param) - updates, new_opt_state = self.tx.update( - grads, self.opt_state, params - ) + updates, new_opt_state = self.tx.update(grads, self.opt_state, params) new_params = optax.apply_updates(params, updates) assert isinstance(new_params, nnx.State) self.step.value += 1 nnx.update(self.model, new_params) self.opt_state = new_opt_state - diff --git a/flax/experimental/nnx/nnx/transforms.py b/flax/nnx/nnx/transforms.py similarity index 96% rename from flax/experimental/nnx/nnx/transforms.py rename to flax/nnx/nnx/transforms.py index 653d3b1161..8fc4a8979c 100644 --- a/flax/experimental/nnx/nnx/transforms.py +++ b/flax/nnx/nnx/transforms.py @@ -35,19 +35,19 @@ from flax import struct from flax.core.frozen_dict import FrozenDict -from flax.experimental.nnx.nnx import ( +from flax.nnx.nnx import ( filterlib, graph, rnglib, spmd, variables, ) -from flax.experimental.nnx.nnx.module import GraphDef, Module -from flax.experimental.nnx.nnx.proxy_caller import ( +from flax.nnx.nnx.module import GraphDef, Module +from flax.nnx.nnx.proxy_caller import ( CallableProxy, DelayedAccessor, ) -from flax.experimental.nnx.nnx.state import State +from flax.nnx.nnx.state import State from flax.typing import Leaf import jax from jax._src.tree_util import broadcast_prefix @@ -115,6 +115,7 @@ def check_and_call(accessor: DelayedAccessor, *args, **kwargs): UNSPECIFIED = object() + def _default_constrain_state(state: State) -> State: state_spec = spmd.get_partition_spec(state) state = jax.lax.with_sharding_constraint(state, state_spec) @@ -619,7 +620,7 @@ def grad( Example:: - >>> from flax.experimental import nnx + >>> from flax import nnx ... >>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 2)) @@ -779,6 +780,7 @@ def _call(self, accessor: DelayedAccessor, *args) -> tp.Any: # scan # ------------------------------- + @dataclasses.dataclass(frozen=True) class FlatDef(tp.Generic[A]): type: type[A] @@ -1333,22 +1335,21 @@ def __post_init__(self): class Remat(tp.Generic[M], LiftedModule[M]): - @staticmethod def constructor( - module_constructor: tp.Callable[..., MA], - prevent_cse: bool = True, - static_argnums: int | tuple[int, ...] = (), - policy: tp.Callable[..., bool] | None = None, + module_constructor: tp.Callable[..., MA], + prevent_cse: bool = True, + static_argnums: int | tuple[int, ...] = (), + policy: tp.Callable[..., bool] | None = None, ) -> tp.Callable[..., 'Remat[MA]']: def create_remat(*args, **kwargs): return Remat( - module_constructor=module_constructor, - module_init_args=args, - module_init_kwargs=kwargs, - prevent_cse=prevent_cse, - static_argnums=static_argnums, - policy=policy, + module_constructor=module_constructor, + module_init_args=args, + module_init_kwargs=kwargs, + prevent_cse=prevent_cse, + static_argnums=static_argnums, + policy=policy, ) return create_remat @@ -1461,38 +1462,37 @@ class VmapOptions: class Vmap(tp.Generic[M], LiftedModule[M]): - @staticmethod def constructor( - module_constructor: tp.Callable[..., MA], - *, - in_axes: int | None | tp.Sequence[tp.Any] = 0, - out_axes: tp.Any = 0, - axis_name: AxisName | None = None, - axis_size: int | None = None, - spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, - # nnx specific - in_axes_kwargs: tp.Any = 0, - state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), - split_rngs: filterlib.Filter = ..., - transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), + module_constructor: tp.Callable[..., MA], + *, + in_axes: int | None | tp.Sequence[tp.Any] = 0, + out_axes: tp.Any = 0, + axis_name: AxisName | None = None, + axis_size: int | None = None, + spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, + # nnx specific + in_axes_kwargs: tp.Any = 0, + state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), ) -> tp.Callable[..., 'Vmap[MA]']: def _create_vmap(*args, **kwargs): return Vmap( - module_constructor=module_constructor, - in_axes=in_axes, - out_axes=out_axes, - axis_size=axis_size, - axis_name=axis_name, - spmd_axis_name=spmd_axis_name, - # nnx specific - in_axes_kwargs=in_axes_kwargs, - state_axes=state_axes, - split_rngs=split_rngs, - transform_metadata=transform_metadata, - # submodule args - module_init_args=args, - module_init_kwargs=kwargs, + module_constructor=module_constructor, + in_axes=in_axes, + out_axes=out_axes, + axis_size=axis_size, + axis_name=axis_name, + spmd_axis_name=spmd_axis_name, + # nnx specific + in_axes_kwargs=in_axes_kwargs, + state_axes=state_axes, + split_rngs=split_rngs, + transform_metadata=transform_metadata, + # submodule args + module_init_args=args, + module_init_kwargs=kwargs, ) return _create_vmap @@ -1577,7 +1577,7 @@ def vmap_apply( # split module state filters = (*options.state_axes.keys(), ...) graphdef, rng_state, *vectorized_states, broadcast_state = ctx.split( # type: ignore[misc] - input_graph_nodes, rnglib.RngState, *filters + input_graph_nodes, rnglib.RngState, *filters ) # infer length @@ -1682,14 +1682,14 @@ def vmap_fn( # split module state ( - graphdef_out, - rng_state_out, - *vectorized_states_out, - broadcast_state_out, + graphdef_out, + rng_state_out, + *vectorized_states_out, + broadcast_state_out, ) = ctx.split( # type: ignore[misc] - (input_graph_nodes, output_graph_nodes), - rnglib.RngState, - *filters, + (input_graph_nodes, output_graph_nodes), + rnglib.RngState, + *filters, ) not_keys_out, split_keys_out, broadcast_keys_out = rng_state_out.split( @@ -1804,6 +1804,7 @@ def _eval_shape_fn(state: State, *args, **kwargs): out = graph.insert_graph_nodes(out, output_nodes) return out + # ------------------------------- # cond # ------------------------------- @@ -1814,6 +1815,7 @@ class CondStaticInputs(tp.Generic[A]): true_fun: tp.Callable[..., A] false_fun: tp.Callable[..., A] + jax.tree_util.register_static(CondStaticInputs) @@ -1867,4 +1869,4 @@ def cond( **kwargs, ) _operands_out, out = ctx.merge(graphdef_out, state_out) - return out \ No newline at end of file + return out diff --git a/flax/experimental/nnx/nnx/variables.py b/flax/nnx/nnx/variables.py similarity index 98% rename from flax/experimental/nnx/nnx/variables.py rename to flax/nnx/nnx/variables.py index bef44f5ddc..dafe286c70 100644 --- a/flax/experimental/nnx/nnx/variables.py +++ b/flax/nnx/nnx/variables.py @@ -34,8 +34,8 @@ import typing as tp from typing import Any -from flax.experimental import nnx -from flax.experimental.nnx.nnx import reprlib, tracers +from flax import nnx +from flax.nnx.nnx import reprlib, tracers import jax.tree_util as jtu A = tp.TypeVar('A') @@ -76,6 +76,7 @@ def __hash__(self): class _Missing: pass + MISSING = _Missing() @@ -224,8 +225,7 @@ def __init__( if tp.TYPE_CHECKING: - def __getattr__(self, name: str) -> tp.Any: - ... + def __getattr__(self, name: str) -> tp.Any: ... else: def __setattr__(self, name: str, value: Any) -> None: @@ -304,12 +304,10 @@ def __eq__(self, other: object) -> bool: return type(self) is type(other) and vars(other) == vars(self) @tp.overload - def replace(self, value: B, **kwargs) -> 'Variable[B]': - ... + def replace(self, value: B, **kwargs) -> 'Variable[B]': ... @tp.overload - def replace(self, **kwargs) -> 'Variable[A]': - ... + def replace(self, **kwargs) -> 'Variable[A]': ... def replace(self, value: tp.Any = MISSING, **kwargs) -> 'Variable[tp.Any]': if value is not MISSING: @@ -571,12 +569,11 @@ class Intermediate(Variable[A]): class VariableState(tp.Generic[A], reprlib.Representable): - def __init__( - self, - type: type[Variable[tp.Any]], - value: A, - **metadata, + self, + type: type[Variable[tp.Any]], + value: A, + **metadata, ): self.type = type self.value = value @@ -584,8 +581,7 @@ def __init__( if tp.TYPE_CHECKING: - def __getattr__(self, name: str) -> tp.Any: - ... + def __getattr__(self, name: str) -> tp.Any: ... def __nnx_repr__(self): yield reprlib.Object(type=type(self)) diff --git a/flax/experimental/nnx/nnx/visualization.py b/flax/nnx/nnx/visualization.py similarity index 99% rename from flax/experimental/nnx/nnx/visualization.py rename to flax/nnx/nnx/visualization.py index 0f657363c7..03317b3e26 100644 --- a/flax/experimental/nnx/nnx/visualization.py +++ b/flax/nnx/nnx/visualization.py @@ -18,7 +18,7 @@ import jax -from flax.experimental import nnx +from flax import nnx penzai_installed = importlib.util.find_spec('penzai') is not None try: diff --git a/flax/experimental/nnx/scripts/requirements.txt b/flax/nnx/scripts/requirements.txt similarity index 100% rename from flax/experimental/nnx/scripts/requirements.txt rename to flax/nnx/scripts/requirements.txt diff --git a/flax/experimental/nnx/scripts/run-all-examples.bash b/flax/nnx/scripts/run-all-examples.bash similarity index 91% rename from flax/experimental/nnx/scripts/run-all-examples.bash rename to flax/nnx/scripts/run-all-examples.bash index 523fa3cf49..570e9c98e9 100644 --- a/flax/experimental/nnx/scripts/run-all-examples.bash +++ b/flax/nnx/scripts/run-all-examples.bash @@ -2,7 +2,7 @@ set -e cd ../../.. source .venv/bin/activate -cd flax/experimental/nnx +cd flax/nnx for f in $(find examples/toy_examples -name "*.py" -maxdepth 1); do echo -e "\n---------------------------------" diff --git a/flax/experimental/nnx/tests/__init__.py b/flax/nnx/tests/__init__.py similarity index 100% rename from flax/experimental/nnx/tests/__init__.py rename to flax/nnx/tests/__init__.py diff --git a/flax/experimental/nnx/tests/compat/test_module.py b/flax/nnx/tests/compat/test_module.py similarity index 97% rename from flax/experimental/nnx/tests/compat/test_module.py rename to flax/nnx/tests/compat/test_module.py index 70bd403c51..df76033510 100644 --- a/flax/experimental/nnx/tests/compat/test_module.py +++ b/flax/nnx/tests/compat/test_module.py @@ -17,8 +17,8 @@ import jax import jax.numpy as jnp -from flax.experimental import nnx -from flax.experimental.nnx import compat +from flax import nnx +from flax.nnx import compat class TestCompatModule: diff --git a/flax/experimental/nnx/tests/compat/test_wrappers.py b/flax/nnx/tests/compat/test_wrappers.py similarity index 93% rename from flax/experimental/nnx/tests/compat/test_wrappers.py rename to flax/nnx/tests/compat/test_wrappers.py index 1b5cd2bf7b..64f8c7743f 100644 --- a/flax/experimental/nnx/tests/compat/test_wrappers.py +++ b/flax/nnx/tests/compat/test_wrappers.py @@ -15,8 +15,8 @@ import jax from flax import linen -from flax.experimental import nnx -from flax.experimental.nnx import compat +from flax import nnx +from flax.nnx import compat class TestCompatibility: diff --git a/flax/experimental/nnx/tests/nn/test_attention.py b/flax/nnx/tests/nn/test_attention.py similarity index 99% rename from flax/experimental/nnx/tests/nn/test_attention.py rename to flax/nnx/tests/nn/test_attention.py index 489c786a50..9c45264d9c 100644 --- a/flax/experimental/nnx/tests/nn/test_attention.py +++ b/flax/nnx/tests/nn/test_attention.py @@ -16,7 +16,7 @@ from jax.lax import Precision from flax import linen -from flax.experimental import nnx +from flax import nnx from flax.typing import Dtype, PrecisionLike from numpy.testing import assert_array_equal diff --git a/flax/experimental/nnx/tests/nn/test_conv.py b/flax/nnx/tests/nn/test_conv.py similarity index 99% rename from flax/experimental/nnx/tests/nn/test_conv.py rename to flax/nnx/tests/nn/test_conv.py index f6a7739058..41a3a8044e 100644 --- a/flax/experimental/nnx/tests/nn/test_conv.py +++ b/flax/nnx/tests/nn/test_conv.py @@ -22,7 +22,7 @@ from numpy.testing import assert_array_equal from flax import linen -from flax.experimental import nnx +from flax import nnx from flax.typing import PaddingLike, Dtype, PrecisionLike diff --git a/flax/experimental/nnx/tests/nn/test_embed.py b/flax/nnx/tests/nn/test_embed.py similarity index 98% rename from flax/experimental/nnx/tests/nn/test_embed.py rename to flax/nnx/tests/nn/test_embed.py index bed5ab1a8d..faababe008 100644 --- a/flax/experimental/nnx/tests/nn/test_embed.py +++ b/flax/nnx/tests/nn/test_embed.py @@ -20,7 +20,7 @@ from numpy.testing import assert_array_equal from flax import linen -from flax.experimental import nnx +from flax import nnx from flax.typing import Dtype diff --git a/flax/experimental/nnx/tests/nn/test_linear.py b/flax/nnx/tests/nn/test_linear.py similarity index 99% rename from flax/experimental/nnx/tests/nn/test_linear.py rename to flax/nnx/tests/nn/test_linear.py index 944f03b97a..aa55eb6427 100644 --- a/flax/experimental/nnx/tests/nn/test_linear.py +++ b/flax/nnx/tests/nn/test_linear.py @@ -21,7 +21,7 @@ from numpy.testing import assert_array_equal from flax import linen -from flax.experimental import nnx +from flax import nnx from flax.typing import Dtype, PrecisionLike, Shape diff --git a/flax/experimental/nnx/tests/nn/test_lora.py b/flax/nnx/tests/nn/test_lora.py similarity index 94% rename from flax/experimental/nnx/tests/nn/test_lora.py rename to flax/nnx/tests/nn/test_lora.py index d619f1e812..b58db02456 100644 --- a/flax/experimental/nnx/tests/nn/test_lora.py +++ b/flax/nnx/tests/nn/test_lora.py @@ -17,7 +17,7 @@ from absl.testing import absltest import numpy as np -from flax.experimental import nnx +from flax import nnx class TestLora(absltest.TestCase): @@ -31,7 +31,6 @@ def test_basic(self): assert module.lora_b.value.shape == (2, 4) np.testing.assert_allclose(y, x @ module.lora_a.value @ module.lora_b.value) - def test_lora_base_module(self): rngs = nnx.Rngs(0) linear = nnx.Linear(3, 4, use_bias=False, rngs=rngs) @@ -45,14 +44,16 @@ def test_lora_base_module(self): assert module.base_module.bias.value == None assert module.lora_a.value.shape == (3, 2) assert module.lora_b.value.shape == (2, 4) - np.testing.assert_allclose(y, x @ linear.kernel.value + x @ module.lora_a.value @ module.lora_b.value) - + np.testing.assert_allclose( + y, x @ linear.kernel.value + x @ module.lora_a.value @ module.lora_b.value + ) def test_layer_swap_lora(self): class MLP(nnx.Module): def __init__(self, dim, rngs: nnx.Rngs): self.linear1 = nnx.Linear(dim, dim, rngs=rngs) self.linear2 = nnx.Linear(dim, dim, rngs=rngs) + def __call__(self, x): x = self.linear1(x) return self.linear2(x) @@ -72,12 +73,12 @@ def __call__(self, x): a, b = model.linear2.lora_a.value, model.linear2.lora_b.value np.testing.assert_allclose(y + model.linear1(x) @ a @ b, lora_y) - def test_layer_swap_loralinear(self): class MLP(nnx.Module): def __init__(self, dim, rngs: nnx.Rngs): self.linear1 = nnx.Linear(dim, dim, rngs=rngs) self.linear2 = nnx.Linear(dim, dim, rngs=rngs) + def __call__(self, x): x = self.linear1(x) return self.linear2(x) @@ -88,7 +89,9 @@ def __call__(self, x): y = model(x) # Replace one of the linear layers as LoRA linear layer. - _, state = nnx.split(model.linear2) # To keep the kernel and bias of linear2 + _, state = nnx.split( + model.linear2 + ) # To keep the kernel and bias of linear2 model.linear2 = nnx.LoRALinear(3, 3, lora_rank=4, rngs=rngs) nnx.update(model.linear2, state) lora_y = model(x) @@ -99,7 +102,6 @@ def __call__(self, x): a, b = model.linear2.lora.lora_a.value, model.linear2.lora.lora_b.value np.testing.assert_allclose(y + model.linear1(x) @ a @ b, lora_y) - def test_lora_param_type(self): rngs = nnx.Rngs(0) model = nnx.LoRA(3, 4, 2, lora_param_type=nnx.LoRAParam, rngs=rngs) @@ -117,4 +119,3 @@ def test_lora_param_type(self): if __name__ == '__main__': absltest.main() - diff --git a/flax/experimental/nnx/tests/nn/test_normalization.py b/flax/nnx/tests/nn/test_normalization.py similarity index 99% rename from flax/experimental/nnx/tests/nn/test_normalization.py rename to flax/nnx/tests/nn/test_normalization.py index 854c367ae3..3e30febcf6 100644 --- a/flax/experimental/nnx/tests/nn/test_normalization.py +++ b/flax/nnx/tests/nn/test_normalization.py @@ -20,7 +20,7 @@ from numpy.testing import assert_array_equal from flax import linen -from flax.experimental import nnx +from flax import nnx from flax.typing import Dtype diff --git a/flax/experimental/nnx/tests/nn/test_stochastic.py b/flax/nnx/tests/nn/test_stochastic.py similarity index 98% rename from flax/experimental/nnx/tests/nn/test_stochastic.py rename to flax/nnx/tests/nn/test_stochastic.py index f302a34f10..1ba6944ae5 100644 --- a/flax/experimental/nnx/tests/nn/test_stochastic.py +++ b/flax/nnx/tests/nn/test_stochastic.py @@ -16,7 +16,7 @@ import jax.numpy as jnp import numpy as np -from flax.experimental import nnx +from flax import nnx import pytest diff --git a/flax/experimental/nnx/tests/test_containers.py b/flax/nnx/tests/test_containers.py similarity index 97% rename from flax/experimental/nnx/tests/test_containers.py rename to flax/nnx/tests/test_containers.py index 582d661ab8..4757d494ee 100644 --- a/flax/experimental/nnx/tests/test_containers.py +++ b/flax/nnx/tests/test_containers.py @@ -13,7 +13,7 @@ # limitations under the License. -from flax.experimental import nnx +from flax import nnx class TestContainers: diff --git a/flax/experimental/nnx/tests/test_graph_utils.py b/flax/nnx/tests/test_graph_utils.py similarity index 99% rename from flax/experimental/nnx/tests/test_graph_utils.py rename to flax/nnx/tests/test_graph_utils.py index 64a07a1938..52ebcba756 100644 --- a/flax/experimental/nnx/tests/test_graph_utils.py +++ b/flax/nnx/tests/test_graph_utils.py @@ -17,7 +17,7 @@ import jax import pytest -from flax.experimental import nnx +from flax import nnx from flax import struct diff --git a/flax/experimental/nnx/tests/test_helpers.py b/flax/nnx/tests/test_helpers.py similarity index 90% rename from flax/experimental/nnx/tests/test_helpers.py rename to flax/nnx/tests/test_helpers.py index 4e84f3b30f..8a7cec4dbc 100644 --- a/flax/experimental/nnx/tests/test_helpers.py +++ b/flax/nnx/tests/test_helpers.py @@ -19,7 +19,8 @@ from numpy.testing import assert_array_equal from flax import linen -from flax.experimental import nnx +from flax import nnx + class TrainState(nnx.TrainState): batch_stats: nnx.State @@ -76,13 +77,17 @@ def test_nnx_linen_sequential_equivalence(self): rngs = nnx.Rngs(0) x = jax.random.uniform(key1, (3, 1, 5)) - model_nnx = nnx.Sequential(nnx.Linear(5, 4, rngs=rngs), nnx.Linear(4, 2, rngs=rngs)) + model_nnx = nnx.Sequential( + nnx.Linear(5, 4, rngs=rngs), nnx.Linear(4, 2, rngs=rngs) + ) model = linen.Sequential([linen.Dense(4), linen.Dense(2)]) variables = model.init(key2, x) for layer_index in range(2): for param in ('kernel', 'bias'): - variables['params'][f'layers_{layer_index}'][param] = getattr(model_nnx.layers[layer_index], param).value + variables['params'][f'layers_{layer_index}'][param] = getattr( + model_nnx.layers[layer_index], param + ).value out_nnx = model_nnx(x) out = model.apply(variables, x) assert_array_equal(out, out_nnx) @@ -90,7 +95,9 @@ def test_nnx_linen_sequential_equivalence(self): variables = model.init(key2, x) for layer_index in range(2): for param in ('kernel', 'bias'): - getattr(model_nnx.layers[layer_index], param).value = variables['params'][f'layers_{layer_index}'][param] + getattr(model_nnx.layers[layer_index], param).value = variables[ + 'params' + ][f'layers_{layer_index}'][param] out_nnx = model_nnx(x) out = model.apply(variables, x) - assert_array_equal(out, out_nnx) \ No newline at end of file + assert_array_equal(out, out_nnx) diff --git a/flax/experimental/nnx/tests/test_ids.py b/flax/nnx/tests/test_ids.py similarity index 95% rename from flax/experimental/nnx/tests/test_ids.py rename to flax/nnx/tests/test_ids.py index 9460e6724b..d72490c836 100644 --- a/flax/experimental/nnx/tests/test_ids.py +++ b/flax/nnx/tests/test_ids.py @@ -14,7 +14,7 @@ import copy -from flax.experimental.nnx.nnx import ids +from flax.nnx.nnx import ids class TestIds: diff --git a/flax/experimental/nnx/tests/test_integration.py b/flax/nnx/tests/test_integration.py similarity index 99% rename from flax/experimental/nnx/tests/test_integration.py rename to flax/nnx/tests/test_integration.py index 49b58af2a7..c473562b85 100644 --- a/flax/experimental/nnx/tests/test_integration.py +++ b/flax/nnx/tests/test_integration.py @@ -18,7 +18,7 @@ import jax.numpy as jnp import numpy as np -from flax.experimental import nnx +from flax import nnx A = tp.TypeVar('A') diff --git a/flax/experimental/nnx/tests/test_metrics.py b/flax/nnx/tests/test_metrics.py similarity index 96% rename from flax/experimental/nnx/tests/test_metrics.py rename to flax/nnx/tests/test_metrics.py index 2e0188ee7e..9a84cceb9a 100644 --- a/flax/experimental/nnx/tests/test_metrics.py +++ b/flax/nnx/tests/test_metrics.py @@ -15,7 +15,7 @@ import jax import jax.numpy as jnp -from flax.experimental import nnx +from flax import nnx from absl.testing import parameterized @@ -58,7 +58,7 @@ def test_multimetric(self): metrics.update(logits=logits2, labels=labels2, values=batch_loss2) values = metrics.compute() self.assertEqual(values['accuracy'], 0.7) - self.assertEqual(values['loss'], 2.) + self.assertEqual(values['loss'], 2.0) metrics.reset() values = metrics.compute() diff --git a/flax/experimental/nnx/tests/test_module.py b/flax/nnx/tests/test_module.py similarity index 99% rename from flax/experimental/nnx/tests/test_module.py rename to flax/nnx/tests/test_module.py index 1d4724c6f7..1590d49342 100644 --- a/flax/experimental/nnx/tests/test_module.py +++ b/flax/nnx/tests/test_module.py @@ -21,15 +21,14 @@ import numpy as np import pytest -from flax.experimental import nnx +from flax import nnx A = TypeVar('A') class TestModule: def test_has_module_state(self): - class Foo(nnx.Module): - ... + class Foo(nnx.Module): ... foo = Foo() @@ -475,6 +474,7 @@ def __init__(self, din, dout, *, rngs: nnx.Rngs): raise_if_not_found=False, ) + class TestModulePytree: def test_tree_map(self): class Foo(nnx.Module, experimental_pytree=True): diff --git a/flax/experimental/nnx/tests/test_optimizer.py b/flax/nnx/tests/test_optimizer.py similarity index 94% rename from flax/experimental/nnx/tests/test_optimizer.py rename to flax/nnx/tests/test_optimizer.py index d1de7cc55b..a7e0310f18 100644 --- a/flax/experimental/nnx/tests/test_optimizer.py +++ b/flax/nnx/tests/test_optimizer.py @@ -17,7 +17,7 @@ import numpy as np import optax -from flax.experimental import nnx +from flax import nnx from absl.testing import parameterized @@ -26,6 +26,7 @@ class Model(nnx.Module): def __init__(self, in_features, out_features, rngs): self.linear1 = nnx.Linear(in_features, 3, rngs=rngs) self.linear2 = nnx.Linear(3, out_features, rngs=rngs) + def __call__(self, x): return self.linear2(self.linear1(x)) @@ -54,7 +55,9 @@ def test_jit(self, module_cls, jit_decorator, optimizer): x = jax.random.normal(jax.random.key(0), (1, 2)) y = jnp.ones((1, 4)) model = module_cls(2, 4, rngs=nnx.Rngs(0)) - tx = optimizer(1e-3) # TODO: this doesn't work with adam optimizer for some reason + tx = optimizer( + 1e-3 + ) # TODO: this doesn't work with adam optimizer for some reason state = nnx.Optimizer(model, tx) if jit_decorator == jax.jit: @@ -76,7 +79,7 @@ def jax_jit_train_step(graphdef, state, x, y): new_loss = loss_fn(*nnx.split(state.model), x, y) else: - loss_fn = lambda model, x, y: ((model(x)-y)**2).mean() + loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean() initial_loss = loss_fn(state.model, x, y) def nnx_jit_train_step(optimizer: nnx.Optimizer, x, y): @@ -109,7 +112,7 @@ def update(self, *, grads, **updates): # type: ignore[signature-mismatch] metrics = nnx.metrics.Average() state = TrainState(model, tx, metrics) - loss_fn = lambda model: ((model(x)-y)**2).mean() + loss_fn = lambda model: ((model(x) - y) ** 2).mean() grads = nnx.grad(loss_fn, wrt=nnx.Param)(state.model) state.update(grads=grads, values=loss_fn(state.model)) initial_loss = state.metrics.compute() diff --git a/flax/experimental/nnx/tests/test_partitioning.py b/flax/nnx/tests/test_partitioning.py similarity index 99% rename from flax/experimental/nnx/tests/test_partitioning.py rename to flax/nnx/tests/test_partitioning.py index b2d5fdfdc5..e390887aed 100644 --- a/flax/experimental/nnx/tests/test_partitioning.py +++ b/flax/nnx/tests/test_partitioning.py @@ -16,7 +16,7 @@ import jax import pytest -from flax.experimental import nnx +from flax import nnx class TestPartitioning: diff --git a/flax/experimental/nnx/tests/test_rngs.py b/flax/nnx/tests/test_rngs.py similarity index 99% rename from flax/experimental/nnx/tests/test_rngs.py rename to flax/nnx/tests/test_rngs.py index 23c505c30b..918a1be1ef 100644 --- a/flax/experimental/nnx/tests/test_rngs.py +++ b/flax/nnx/tests/test_rngs.py @@ -19,7 +19,7 @@ import jax.numpy as jnp import pytest -from flax.experimental import nnx +from flax import nnx class TestRngs: @@ -51,7 +51,6 @@ def test_rng_stream(self): assert rngs.params.key.value is key0 assert not jnp.allclose(key1, key2) - def test_rng_trace_level_constraints(self): rngs = nnx.Rngs(0) diff --git a/flax/experimental/nnx/tests/test_spmd.py b/flax/nnx/tests/test_spmd.py similarity index 98% rename from flax/experimental/nnx/tests/test_spmd.py rename to flax/nnx/tests/test_spmd.py index 83e2f1b10a..0353bfc535 100644 --- a/flax/experimental/nnx/tests/test_spmd.py +++ b/flax/nnx/tests/test_spmd.py @@ -19,7 +19,7 @@ from jax.experimental import mesh_utils from jax.sharding import Mesh, PartitionSpec -from flax.experimental import nnx +from flax import nnx class TestSPMD: diff --git a/flax/experimental/nnx/tests/test_state.py b/flax/nnx/tests/test_state.py similarity index 98% rename from flax/experimental/nnx/tests/test_state.py rename to flax/nnx/tests/test_state.py index be8a8a1782..e1884134ff 100644 --- a/flax/experimental/nnx/tests/test_state.py +++ b/flax/nnx/tests/test_state.py @@ -14,7 +14,7 @@ from absl.testing import absltest -from flax.experimental import nnx +from flax import nnx class StateTest(absltest.TestCase): diff --git a/flax/experimental/nnx/tests/test_transforms.py b/flax/nnx/tests/test_transforms.py similarity index 99% rename from flax/experimental/nnx/tests/test_transforms.py rename to flax/nnx/tests/test_transforms.py index 18721d941d..1d9c0b707b 100644 --- a/flax/experimental/nnx/tests/test_transforms.py +++ b/flax/nnx/tests/test_transforms.py @@ -22,7 +22,7 @@ import pytest from jax.experimental import mesh_utils -from flax.experimental import nnx +from flax import nnx class TestJIT: @@ -345,7 +345,6 @@ def constrain_object(m): m.kernel.value.sharding - class TestGrad: def test_grad(self): p1 = nnx.Param(10.0) @@ -1221,6 +1220,7 @@ def __call__(self, x: jax.Array) -> jax.Array: assert module.vmap_module.graphdef == 'hello' + class TestCond: def test_basic(self): class TimeStep(tp.NamedTuple): diff --git a/flax/experimental/nnx/tests/test_variable.py b/flax/nnx/tests/test_variable.py similarity index 98% rename from flax/experimental/nnx/tests/test_variable.py rename to flax/nnx/tests/test_variable.py index af297eeaea..de5c5c52c0 100644 --- a/flax/experimental/nnx/tests/test_variable.py +++ b/flax/nnx/tests/test_variable.py @@ -17,7 +17,7 @@ import jax import jax.numpy as jnp -from flax.experimental import nnx +from flax import nnx A = tp.TypeVar('A') diff --git a/pyproject.toml b/pyproject.toml index 9b3992306c..f10838f0c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,7 +115,7 @@ ignore_missing_imports = true disable_error_code = "annotation-unchecked" # exclude nnx examples [[tool.mypy.overrides]] -module = "flax.experimental.nnx.examples.*" +module = "flax.nnx.examples.*" ignore_errors = true [tool.pytest.ini_options] diff --git a/tests/run_all_tests.sh b/tests/run_all_tests.sh index 8106f2c5c9..a130c26caf 100755 --- a/tests/run_all_tests.sh +++ b/tests/run_all_tests.sh @@ -85,7 +85,7 @@ if $RUN_DOCTEST; then pytest -n auto flax \ --doctest-modules \ --suppress-no-test-exit-code \ - --ignore=flax/experimental/nnx/examples + --ignore=flax/nnx/examples fi # check that flax is running on editable mode @@ -112,7 +112,7 @@ if $RUN_PYTEST; then echo "pytest -n auto tests $PYTEST_OPTS $PYTEST_IGNORE" pytest -n auto tests $PYTEST_OPTS $PYTEST_IGNORE # Run nnx tests - pytest -n auto flax/experimental/nnx/tests $PYTEST_OPTS $PYTEST_IGNORE + pytest -n auto flax/nnx/tests $PYTEST_OPTS $PYTEST_IGNORE pytest -n auto docs/_ext/codediff_test.py $PYTEST_OPTS $PYTEST_IGNORE # Per-example tests. @@ -128,7 +128,7 @@ if $RUN_PYTEST; then pytest $egd done - for egd in $(find flax/experimental/nnx/examples -maxdepth 1 -mindepth 1 -type d); do + for egd in $(find flax/nnx/examples -maxdepth 1 -mindepth 1 -type d); do # skip if folder starts with "_" or is "toy_examples" if [[ $egd == *"_"* ]] || [[ $egd == *"toy_examples"* ]]; then continue @@ -140,7 +140,7 @@ fi if $RUN_PYTYPE; then echo "=== RUNNING PYTYPE ===" # Validate types in NNX examples. - for egd in $(find flax/experimental/nnx/examples -maxdepth 1 -mindepth 1 -type d); do + for egd in $(find flax/nnx/examples -maxdepth 1 -mindepth 1 -type d); do # skip if folder starts with "_" or is "toy_examples" if [[ $egd == *"_"* ]] || [[ $egd == *"toy_examples"* ]]; then continue @@ -148,11 +148,11 @@ if $RUN_PYTYPE; then # use cd to make sure pytype cache lives in example dir and doesn't name clash # use *.py to avoid importing configs as a top-level import which leads to import errors # because config files use relative imports (e.g. from config import ...). - (cd $egd ; pytype "*.py" --jobs auto --config ../../../../../pyproject.toml) + (cd $egd ; pytype "*.py" --jobs auto --config ../../../../pyproject.toml) done # Validate types in library code. pytype --jobs auto --config pyproject.toml flax/ \ - --exclude flax/experimental/nnx/examples + --exclude flax/nnx/examples # Validate types in examples. for egd in $(find examples -maxdepth 1 -mindepth 1 -type d); do