From 774ab4e97d20e73b5dc128bf61aed36450c99421 Mon Sep 17 00:00:00 2001 From: IvyZX Date: Thu, 3 Oct 2024 15:53:58 -0700 Subject: [PATCH] Landing page, glossary and misc --- docs/conf.py | 2 +- docs/index.rst | 6 +- docs_nnx/api_reference/flax.nnx/variables.rst | 2 + docs_nnx/conf.py | 2 +- docs_nnx/glossary.rst | 108 ++++-------------- docs_nnx/guides/bridge_guide.ipynb | 6 +- docs_nnx/guides/bridge_guide.md | 4 +- docs_nnx/guides/linen_to_nnx.rst | 2 + docs_nnx/index.rst | 33 ++++-- 9 files changed, 59 insertions(+), 106 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 8459f62da..bc0d98416 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -116,7 +116,7 @@ href="https://flax-nnx.readthedocs.io/en/latest/index.html" style="text-decoration: none; color: white;" > - Flax Linen [Explore the new Flax NNX API ✨] + This site covers the old Flax Linen API. [Explore the new Flax NNX API ✨] """ diff --git a/docs/index.rst b/docs/index.rst index 202f81e7e..2f0cfee61 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -4,7 +4,7 @@ contain the root `toctree` directive. ****************************** -Flax +Flax Linen ****************************** @@ -15,14 +15,14 @@ Flax ---- -Flax delivers an **end-to-end and flexible user experience for researchers +Flax Linen delivers an **end-to-end and flexible user experience for researchers who use JAX with neural networks**. Flax exposes the full power of `JAX `__. It is made up of loosely coupled libraries, which are showcased with end-to-end integrated `guides `__ and `examples `__. -Flax is used by +Flax Linen is used by `hundreds of projects (and growing) `__, both in the open source community (like `Hugging Face `__) diff --git a/docs_nnx/api_reference/flax.nnx/variables.rst b/docs_nnx/api_reference/flax.nnx/variables.rst index 54e442463..02fd3048c 100644 --- a/docs_nnx/api_reference/flax.nnx/variables.rst +++ b/docs_nnx/api_reference/flax.nnx/variables.rst @@ -18,5 +18,7 @@ variables :members: .. autoclass:: VariableMetadata :members: +.. autoclass:: VariableState + :members: .. autofunction:: with_metadata \ No newline at end of file diff --git a/docs_nnx/conf.py b/docs_nnx/conf.py index 24f748030..641080c28 100644 --- a/docs_nnx/conf.py +++ b/docs_nnx/conf.py @@ -116,7 +116,7 @@ href="https://flax-linen.readthedocs.io/en/latest" style="text-decoration: none; color: white;" > - Flax NNX [Click here for the old Flax Linen API] + This site covers the new Flax NNX API. [Click here for the old Flax Linen API] """ diff --git a/docs_nnx/glossary.rst b/docs_nnx/glossary.rst index 39aef0005..1ed754a09 100644 --- a/docs_nnx/glossary.rst +++ b/docs_nnx/glossary.rst @@ -2,111 +2,49 @@ Glossary ********* -For additional terms, refer to the `Jax glossary `__. +For additional terms, refer to the `JAX glossary `__. .. glossary:: - Bound Module - When a :class:`Module ` - is created through regular Python object construction (e.g. `module = SomeModule(args...)`, it is in an *unbound* state. This means that only - dataclass attributes are set, and no variables are bound to the module. When the pure - functions :meth:`Module.init() ` - or :meth:`Module.apply() ` - are called, Flax clones the Module and binds the variables to it, and the module's method code is - executed in a locally bound state, allowing things like calling submodules directly without - providing variables. For more details, refer to the - `module lifecycle `__. - - Compact / Non-compact Module - Modules with a single method are able to declare submodules and variables inline by - using the :func:`@nn.compact ` decorator. - These are referred to as “compact-style modules”, - whereas modules defining a :meth:`setup() ` method - (usually but not always with multiple callable methods) - are referred to as “setup-style modules”. To learn more, refer to the - `setup vs compact guide `__. + Filter + A way to extract only certain :term:`Variables` out of a :term:`Module`. Usually done via calling :meth:`nnx.split ` upon the module. See the `Filter guide `__ to learn more. `Folding in `__ Generating a new PRNG key given an input PRNG key and integer. Typically used when you want to generate a new key but still be able to use the original rng key afterwards. You can also do this with `jax.random.split `__ but this will effectively create two RNG keys, which is slower. See how Flax generates new PRNG keys - automatically within ``Modules`` in our - `RNG guide `__. - - `FrozenDict `__ - An immutable dictionary which can be “`unfrozen `__” - to a regular, mutable dictionary. Internally, Flax uses FrozenDicts to ensure variable dicts - aren't accidentally mutated. Note: We are considering returning to regular dicts from our APIs, - and only using FrozenDicts internally. - (see `#1223 `__). + automatically in our + `RNG guide `__. - Functional core - The flax core library implements the simple container Scope API for threading - variables and PRNGs through a model, as well as the lifting machinery needed to - transform functions passing Scope objects. The python class-based module API - is built on top of this core library. - - Lazy initialization - Variables in Flax are initialized late, only when needed. That is, during normal - execution of a module, if a requested variable name isn’t found in the provided - variable collection data, we call the initializer function to create it. This - allows us to treat initialization and application under the same code-paths, - simplifying the use of JAX transforms with layers. + GraphDef + :class:`nnx.GraphDef`, a class that represents all the static, stateless, Pythonic part of an :class:`nnx.Module` definition. Lifted transformation - Refer to the `Flax docs `__. + A wrapped version of the `JAX transformations `__ that allows the transformed function to take Flax :term:`Modules` as input or output. For example, a lifted version of `jax.jit `__ will be :meth:`flax.nnx.jit `. See the `lifted transforms guide `__. + + Merge + See :term:`Split and merge`. Module - A dataclass allowing the definition and initialization of parameters in a + :class:`nnx.Module `, a dataclass allowing the definition and initialization of parameters in a referentially-transparent form. This is responsible for storing and updating variables - and parameters within itself. Modules can be readily transformed into functions, - allowing them to be trivially used with JAX transformations like `vmap` and `scan`. + and parameters within itself. Params / parameters - "params" is the canonical variable collection in the variable dictionary (dict). - The “params” collection generally contains the trainable weights. + :class:`nnx.Param `, a particular subclass of :class:`nnx.Variable ` that generally contains the trainable weights. - RNG sequences - Inside Flax :class:`Modules `, you can obtain a new - `PRNG `__ - key through :meth:`Module.make_rng() `. - These keys can be used to generate random numbers through - `JAX's functional random number generators `__. - Having different RNG sequences (e.g. for "params" and "dropout") allows fine-grained - control in a multi-host setup (e.g. initializing parameters identically on different - hosts, but have different dropout masks) and treating these sequences differently when - `lifting transformations `__. - See the `RNG guide `__ + RNG states + A Flax :class:`module ` can keep a reference of an :class:`RNG state object ` that can generate new JAX `PRNG `__ keys. They keys are used to generate random JAX arrays through `JAX's functional random number generators `__. + You can use an RNG state with different seeds to make more fine-grained control on your model (e.g., independent random numbers for parameters and dropout masks). + See the `RNG guide `__ for more details. - Scope - A container class for holding the variables and PRNG keys for each layer. - - Shape inference - Modules do not need to specify the shape of the input array in their definitions. - Flax upon initialization inspects the input array, and infers the correct shapes - for parameters in the model. - - TrainState - Refer to :class:`flax.training.train_state.TrainState`. + Split and merge + :meth:`nnx.split `, a way to represent an `nnx.Module` by two parts - a static :term:`GraphDef ` that captures its Pythonic, static information, and one or more :term:`Variable state(s)` that captures its JAX arrays in the form of pytrees. They can be merged back to the original module with :meth:`nnx.merge `. Variable - The `weights / parameters / data / arrays `__ - residing in the leaves of :term:`variable collections`. - Variables are defined inside modules using :meth:`Module.variable() `. - A variable of collection "params" is simply called a param and can be set using - :meth:`Module.param() `. - - Variable collections - Entries in the variable dict, containing weights / parameters / data / arrays that - are used by the model. “params” is the canonical collection in the variable dict. - They are typically differentiable, updated by an outer SGD-like loop / optimizer, - rather than modified directly by forward-pass code. + The `weights / parameters / data / arrays `__ residing in a Flax :term:`Module`. Variables are defined inside modules as :class:`nnx.Variable ` or its subclasses. - `Variable dictionary `__ - A dictionary containing :term:`variable collections`. - Each variable collection is a mapping from a string name - (e.g., ":term:`params`" or "batch_stats") to a (possibly nested) - dictionary with :term:`Variables` as leaves, matching the submodule tree structure. - Read more about pytrees and leaves in the `Jax docs `__. \ No newline at end of file + Variable state + :class:`nnx.VariableState `, a purely functional pytree of all the :term:`Variables` inside a :term:`Module`. Since it's pure, it can be an input or output of a JAX transformation function. Obtained by using :term:`splitting` the module. \ No newline at end of file diff --git a/docs_nnx/guides/bridge_guide.ipynb b/docs_nnx/guides/bridge_guide.ipynb index 25967dd8c..e41836a93 100644 --- a/docs_nnx/guides/bridge_guide.ipynb +++ b/docs_nnx/guides/bridge_guide.ipynb @@ -17,9 +17,9 @@ "\n", "**Note**:\n", "\n", - "This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax.readthedocs.io/en/latest/nnx/haiku_linen_vs_nnx.html) guide. \n", + "This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax-nnx.readthedocs.io/en/latest/guides/linen_to_nnx.html) guide. \n", "\n", - "And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html)." + "And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html)." ] }, { @@ -682,7 +682,7 @@ "\n", "Flax uses a metadata wrapper box over the raw JAX array to annotate how a variable should be sharded.\n", "\n", - "In Linen, this is an optional feature that triggered by using `nn.with_partitioning` on initializers (see more on [Linen partition metadata guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html)). In NNX, since all NNX variables are wrapped by `nnx.Variable` class anyway, that class will hold the sharding annotations too. \n", + "In Linen, this is an optional feature that triggered by using `nn.with_partitioning` on initializers (see more on [Linen partition metadata guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html)). In NNX, since all NNX variables are wrapped by `nnx.Variable` class anyway, that class will hold the sharding annotations too.\n", "\n", "The `bridge.ToNNX` and `bridge.ToLinen` API will automatically convert the sharding annotations, if you use the built-in annotation methods (aka. `nn.with_partitioning` for Linen and `nnx.with_partitioning` for NNX)." ] diff --git a/docs_nnx/guides/bridge_guide.md b/docs_nnx/guides/bridge_guide.md index cfc15b17f..3f243ae2a 100644 --- a/docs_nnx/guides/bridge_guide.md +++ b/docs_nnx/guides/bridge_guide.md @@ -11,9 +11,9 @@ We hope this allows you to move and try out NNX at your own pace, and leverage t **Note**: -This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax.readthedocs.io/en/latest/nnx/haiku_linen_vs_nnx.html) guide. +This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax-nnx.readthedocs.io/en/latest/guides/linen_to_nnx.html) guide. -And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html). +And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html). ```python diff --git a/docs_nnx/guides/linen_to_nnx.rst b/docs_nnx/guides/linen_to_nnx.rst index 4cda701bd..d0c20fd09 100644 --- a/docs_nnx/guides/linen_to_nnx.rst +++ b/docs_nnx/guides/linen_to_nnx.rst @@ -6,6 +6,8 @@ models, and side-by-side comparisions to help you migrate your code from the Lin Before this guide, it's highly recommended to read through `The Basics of Flax NNX `__ to learn about the core concepts and code examples of Flax NNX. +This guide mainly covers converting arbitratry Linen code to NNX. If you want to play it safe and convert your codebase iteratively, check out the guide that allows you to `use NNX and Linen code together `__ + .. testsetup:: Linen, NNX diff --git a/docs_nnx/index.rst b/docs_nnx/index.rst index 8ee8676d4..ce1b81b2a 100644 --- a/docs_nnx/index.rst +++ b/docs_nnx/index.rst @@ -1,5 +1,5 @@ -Flax NNX +Flax ======== .. div:: sd-text-left sd-font-italic @@ -8,16 +8,19 @@ Flax NNX ---- -**Flax NNX is a simplified API that makes it easier to create, inspect, +Flax delivers an **end-to-end and flexible user experience for researchers +who use JAX with neural networks**. Flax +exposes the full power of `JAX `__. + +At its core is **Flax NNX, a simplified API that makes it easier to create, inspect, debug, and analyze neural networks in JAX.** It has first class support for Python reference semantics, allowing users to express their models using regular -Python objects. Flax NNX is an evolution of the previous Flax Linen APIs, and it takes years of +Python objects. Flax NNX is an evolution of the previous Flax Linen APIs, and it took years of experience to bring a simpler and more user-friendly experience. .. note:: - Flax Linen is not going to be deprecated in the near future as most of our users still + Flax Linen API is not going to be deprecated in the near future as most of our users still rely on this API, however new users are encouraged to use Flax NNX. - For existing Linen users to move to NNX, check out the `evolution guide `_. Features @@ -76,7 +79,7 @@ Features .. div:: sd-font-normal Flax NNX makes it very easy to integrate objects with regular JAX code - via the `Functional API `__. + via the `Functional API `__. Basic usage ^^^^^^^^^^^^ @@ -158,23 +161,30 @@ Learn more .. grid-item:: :columns: 6 6 6 4 - .. card:: :material-regular:`sync_alt;2em` Flax vs JAX Transformations + .. card:: :material-regular:`library_books;2em` Guides :class-card: sd-text-black sd-bg-light - :link: transforms.html + :link: guides/index.html .. grid-item:: :columns: 6 6 6 4 - .. card:: :material-regular:`transform;2em` Haiku and Flax Linen vs Flax NNX + .. card:: :material-regular:`transform;2em` Flax Linen to Flax NNX :class-card: sd-text-black sd-bg-light - :link: haiku_linen_vs_nnx.html + :link: guides/linen_to_nnx.html .. grid-item:: :columns: 6 6 6 4 .. card:: :material-regular:`menu_book;2em` API reference :class-card: sd-text-black sd-bg-light - :link: ../api_reference/flax.nnx/index.html + :link: /api_reference/index.html + + .. grid-item:: + :columns: 6 6 6 4 + + .. card:: :material-regular:`import_contacts;2em` Glossary + :class-card: sd-text-black sd-bg-light + :link: glossary.html ---- @@ -187,6 +197,7 @@ Learn more mnist_tutorial guides/index examples/index + glossary The Flax philosophy How to contribute api_reference/index