Skip to content

Commit

Permalink
Landing page, glossary and misc
Browse files Browse the repository at this point in the history
  • Loading branch information
IvyZX committed Oct 3, 2024
1 parent 9c162ab commit 774ab4e
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 106 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
href="https://flax-nnx.readthedocs.io/en/latest/index.html"
style="text-decoration: none; color: white;"
>
Flax Linen <span style="color: lightgray;">[Explore the new <b>Flax NNX</b> API ✨]</span>
This site covers the old Flax Linen API. <span style="color: lightgray;">[Explore the new <b>Flax NNX</b> API ✨]</span>
</a>
"""

Expand Down
6 changes: 3 additions & 3 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
contain the root `toctree` directive.
******************************
Flax
Flax Linen
******************************


Expand All @@ -15,14 +15,14 @@ Flax

----

Flax delivers an **end-to-end and flexible user experience for researchers
Flax Linen delivers an **end-to-end and flexible user experience for researchers
who use JAX with neural networks**. Flax
exposes the full power of `JAX <https://jax.readthedocs.io>`__. It is made up of
loosely coupled libraries, which are showcased with end-to-end integrated
`guides <https://flax.readthedocs.io/en/latest/guides/index.html>`__
and `examples <https://flax.readthedocs.io/en/latest/examples.html>`__.

Flax is used by
Flax Linen is used by
`hundreds of projects (and growing) <https://github.com/google/flax/network/dependents?package_id=UGFja2FnZS01MjEyMjA2MA%3D%3D>`__,
both in the open source community
(like `Hugging Face <https://huggingface.co/flax-community>`__)
Expand Down
2 changes: 2 additions & 0 deletions docs_nnx/api_reference/flax.nnx/variables.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,7 @@ variables
:members:
.. autoclass:: VariableMetadata
:members:
.. autoclass:: VariableState
:members:

.. autofunction:: with_metadata
2 changes: 1 addition & 1 deletion docs_nnx/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
href="https://flax-linen.readthedocs.io/en/latest"
style="text-decoration: none; color: white;"
>
Flax NNX <span style="color: lightgray;">[Click here for the old <b>Flax Linen</b> API]</span>
This site covers the new Flax NNX API. <span style="color: lightgray;">[Click here for the old <b>Flax Linen</b> API]</span>
</a>
"""

Expand Down
108 changes: 23 additions & 85 deletions docs_nnx/glossary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,111 +2,49 @@
Glossary
*********

For additional terms, refer to the `Jax glossary <https://jax.readthedocs.io/en/latest/glossary.html>`__.
For additional terms, refer to the `JAX glossary <https://jax.readthedocs.io/en/latest/glossary.html>`__.

.. glossary::

Bound Module
When a :class:`Module <flax.linen.Module>`
is created through regular Python object construction (e.g. `module = SomeModule(args...)`, it is in an *unbound* state. This means that only
dataclass attributes are set, and no variables are bound to the module. When the pure
functions :meth:`Module.init() <flax.linen.Module.init>`
or :meth:`Module.apply() <flax.linen.Module.apply>`
are called, Flax clones the Module and binds the variables to it, and the module's method code is
executed in a locally bound state, allowing things like calling submodules directly without
providing variables. For more details, refer to the
`module lifecycle <https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html>`__.

Compact / Non-compact Module
Modules with a single method are able to declare submodules and variables inline by
using the :func:`@nn.compact <flax.linen.compact>` decorator.
These are referred to as “compact-style modules”,
whereas modules defining a :meth:`setup() <flax.linen.Module.setup>` method
(usually but not always with multiple callable methods)
are referred to as “setup-style modules”. To learn more, refer to the
`setup vs compact guide <https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/setup_or_nncompact.html>`__.
Filter
A way to extract only certain :term:`Variables<Variable>` out of a :term:`Module<Module>`. Usually done via calling :meth:`nnx.split <flax.nnx.split>` upon the module. See the `Filter guide <https://flax-nnx.readthedocs.io/en/latest/guides/filters_guide.html>`__ to learn more.

`Folding in <https://jax.readthedocs.io/en/latest/_autosummary/jax.random.fold_in.html>`__
Generating a new PRNG key given an input PRNG key and integer. Typically used when you want to
generate a new key but still be able to use the original rng key afterwards. You can also do this with
`jax.random.split <https://jax.readthedocs.io/en/latest/_autosummary/jax.random.split.html>`__
but this will effectively create two RNG keys, which is slower. See how Flax generates new PRNG keys
automatically within ``Modules`` in our
`RNG guide <https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html#how-self-make-rng-works-under-the-hood>`__.

`FrozenDict <https://flax.readthedocs.io/en/latest/api_reference/flax.core.frozen_dict.html#flax.core.frozen_dict.FrozenDict>`__
An immutable dictionary which can be “`unfrozen <https://flax.readthedocs.io/en/latest/api_reference/flax.core.frozen_dict.html#flax.core.frozen_dict.unfreeze>`__”
to a regular, mutable dictionary. Internally, Flax uses FrozenDicts to ensure variable dicts
aren't accidentally mutated. Note: We are considering returning to regular dicts from our APIs,
and only using FrozenDicts internally.
(see `#1223 <https://github.com/google/flax/issues/1223>`__).
automatically in our
`RNG guide <https://flax-nnx.readthedocs.io/en/latest/guides/randomness.html>`__.

Functional core
The flax core library implements the simple container Scope API for threading
variables and PRNGs through a model, as well as the lifting machinery needed to
transform functions passing Scope objects. The python class-based module API
is built on top of this core library.

Lazy initialization
Variables in Flax are initialized late, only when needed. That is, during normal
execution of a module, if a requested variable name isn’t found in the provided
variable collection data, we call the initializer function to create it. This
allows us to treat initialization and application under the same code-paths,
simplifying the use of JAX transforms with layers.
GraphDef
:class:`nnx.GraphDef<flax.nnx.GraphDef>`, a class that represents all the static, stateless, Pythonic part of an :class:`nnx.Module<flax.nnx.Module>` definition.

Lifted transformation
Refer to the `Flax docs <https://flax.readthedocs.io/en/latest/developer_notes/lift.html>`__.
A wrapped version of the `JAX transformations <https://flax-nnx.readthedocs.io/en/latest/guides/transforms.html>`__ that allows the transformed function to take Flax :term:`Modules<Module>` as input or output. For example, a lifted version of `jax.jit <https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit>`__ will be :meth:`flax.nnx.jit <flax.nnx.jit>`. See the `lifted transforms guide <https://flax-nnx.readthedocs.io/en/latest/guides/transforms.html>`__.

Merge
See :term:`Split and merge<Split and merge>`.

Module
A dataclass allowing the definition and initialization of parameters in a
:class:`nnx.Module <flax.nnx.Module>`, a dataclass allowing the definition and initialization of parameters in a
referentially-transparent form. This is responsible for storing and updating variables
and parameters within itself. Modules can be readily transformed into functions,
allowing them to be trivially used with JAX transformations like `vmap` and `scan`.
and parameters within itself.

Params / parameters
"params" is the canonical variable collection in the variable dictionary (dict).
The “params” collection generally contains the trainable weights.
:class:`nnx.Param <flax.nnx.Param>`, a particular subclass of :class:`nnx.Variable <flax.nnx.Variable>` that generally contains the trainable weights.

RNG sequences
Inside Flax :class:`Modules <flax.linen.Module>`, you can obtain a new
`PRNG <https://en.wikipedia.org/wiki/Pseudorandom_number_generator>`__
key through :meth:`Module.make_rng() <flax.linen.Module.make_rng>`.
These keys can be used to generate random numbers through
`JAX's functional random number generators <https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html>`__.
Having different RNG sequences (e.g. for "params" and "dropout") allows fine-grained
control in a multi-host setup (e.g. initializing parameters identically on different
hosts, but have different dropout masks) and treating these sequences differently when
`lifting transformations <https://flax.readthedocs.io/en/latest/developer_notes/lift.html>`__.
See the `RNG guide <https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html>`__
RNG states
A Flax :class:`module <flax.nnx.Module>` can keep a reference of an :class:`RNG state object <flax.nnx.Rngs>` that can generate new JAX `PRNG <https://en.wikipedia.org/wiki/Pseudorandom_number_generator>`__ keys. They keys are used to generate random JAX arrays through `JAX's functional random number generators <https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html>`__.
You can use an RNG state with different seeds to make more fine-grained control on your model (e.g., independent random numbers for parameters and dropout masks).
See the `RNG guide <https://flax-nnx.readthedocs.io/en/latest/guides/randomness.html>`__
for more details.

Scope
A container class for holding the variables and PRNG keys for each layer.

Shape inference
Modules do not need to specify the shape of the input array in their definitions.
Flax upon initialization inspects the input array, and infers the correct shapes
for parameters in the model.

TrainState
Refer to :class:`flax.training.train_state.TrainState`.
Split and merge
:meth:`nnx.split <flax.nnx.split>`, a way to represent an `nnx.Module` by two parts - a static :term:`GraphDef <GraphDef>` that captures its Pythonic, static information, and one or more :term:`Variable state(s)<Variable state>` that captures its JAX arrays in the form of pytrees. They can be merged back to the original module with :meth:`nnx.merge <flax.nnx.merge>`.

Variable
The `weights / parameters / data / arrays <https://flax.readthedocs.io/en/latest/api_reference/flax.linen/variable.html#flax.linen.Variable>`__
residing in the leaves of :term:`variable collections<Variable collections>`.
Variables are defined inside modules using :meth:`Module.variable() <flax.linen.Module.variable>`.
A variable of collection "params" is simply called a param and can be set using
:meth:`Module.param() <flax.linen.Module.param>`.

Variable collections
Entries in the variable dict, containing weights / parameters / data / arrays that
are used by the model. “params” is the canonical collection in the variable dict.
They are typically differentiable, updated by an outer SGD-like loop / optimizer,
rather than modified directly by forward-pass code.
The `weights / parameters / data / arrays <https://flax.readthedocs.io/en/latest/api_reference/flax.linen/variable.html#flax.linen.Variable>`__ residing in a Flax :term:`Module<Module>`. Variables are defined inside modules as :class:`nnx.Variable <flax.nnx.Variable>` or its subclasses.

`Variable dictionary <https://flax.readthedocs.io/en/latest/api_reference/flax.linen/variable.html>`__
A dictionary containing :term:`variable collections<Variable collections>`.
Each variable collection is a mapping from a string name
(e.g., ":term:`params<Params / parameters>`" or "batch_stats") to a (possibly nested)
dictionary with :term:`Variables<Variable>` as leaves, matching the submodule tree structure.
Read more about pytrees and leaves in the `Jax docs <https://jax.readthedocs.io/en/latest/pytrees.html>`__.
Variable state
:class:`nnx.VariableState <flax.nnx.VariableState>`, a purely functional pytree of all the :term:`Variables<Variable>` inside a :term:`Module<Module>`. Since it's pure, it can be an input or output of a JAX transformation function. Obtained by using :term:`splitting<Split and merge>` the module.
6 changes: 3 additions & 3 deletions docs_nnx/guides/bridge_guide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
"\n",
"**Note**:\n",
"\n",
"This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax.readthedocs.io/en/latest/nnx/haiku_linen_vs_nnx.html) guide. \n",
"This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax-nnx.readthedocs.io/en/latest/guides/linen_to_nnx.html) guide. \n",
"\n",
"And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html)."
"And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html)."
]
},
{
Expand Down Expand Up @@ -682,7 +682,7 @@
"\n",
"Flax uses a metadata wrapper box over the raw JAX array to annotate how a variable should be sharded.\n",
"\n",
"In Linen, this is an optional feature that triggered by using `nn.with_partitioning` on initializers (see more on [Linen partition metadata guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html)). In NNX, since all NNX variables are wrapped by `nnx.Variable` class anyway, that class will hold the sharding annotations too. \n",
"In Linen, this is an optional feature that triggered by using `nn.with_partitioning` on initializers (see more on [Linen partition metadata guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html)). In NNX, since all NNX variables are wrapped by `nnx.Variable` class anyway, that class will hold the sharding annotations too.\n",
"\n",
"The `bridge.ToNNX` and `bridge.ToLinen` API will automatically convert the sharding annotations, if you use the built-in annotation methods (aka. `nn.with_partitioning` for Linen and `nnx.with_partitioning` for NNX)."
]
Expand Down
4 changes: 2 additions & 2 deletions docs_nnx/guides/bridge_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ We hope this allows you to move and try out NNX at your own pace, and leverage t

**Note**:

This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax.readthedocs.io/en/latest/nnx/haiku_linen_vs_nnx.html) guide.
This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax-nnx.readthedocs.io/en/latest/guides/linen_to_nnx.html) guide.

And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html).
And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html).


```python
Expand Down
2 changes: 2 additions & 0 deletions docs_nnx/guides/linen_to_nnx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ models, and side-by-side comparisions to help you migrate your code from the Lin

Before this guide, it's highly recommended to read through `The Basics of Flax NNX <https://flax-nnx.readthedocs.io/en/latest/nnx_basics.html>`__ to learn about the core concepts and code examples of Flax NNX.

This guide mainly covers converting arbitratry Linen code to NNX. If you want to play it safe and convert your codebase iteratively, check out the guide that allows you to `use NNX and Linen code together <https://flax-nnx.readthedocs.io/en/latest/guides/bridge_guide.html>`__


.. testsetup:: Linen, NNX

Expand Down
33 changes: 22 additions & 11 deletions docs_nnx/index.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

Flax NNX
Flax
========
.. div:: sd-text-left sd-font-italic

Expand All @@ -8,16 +8,19 @@ Flax NNX

----

**Flax NNX is a simplified API that makes it easier to create, inspect,
Flax delivers an **end-to-end and flexible user experience for researchers
who use JAX with neural networks**. Flax
exposes the full power of `JAX <https://jax.readthedocs.io>`__.

At its core is **Flax NNX, a simplified API that makes it easier to create, inspect,
debug, and analyze neural networks in JAX.** It has first class support
for Python reference semantics, allowing users to express their models using regular
Python objects. Flax NNX is an evolution of the previous Flax Linen APIs, and it takes years of
Python objects. Flax NNX is an evolution of the previous Flax Linen APIs, and it took years of
experience to bring a simpler and more user-friendly experience.

.. note::
Flax Linen is not going to be deprecated in the near future as most of our users still
Flax Linen API is not going to be deprecated in the near future as most of our users still
rely on this API, however new users are encouraged to use Flax NNX.

For existing Linen users to move to NNX, check out the `evolution guide <guides/linen_to_nnx.html>`_.

Features
Expand Down Expand Up @@ -76,7 +79,7 @@ Features
.. div:: sd-font-normal

Flax NNX makes it very easy to integrate objects with regular JAX code
via the `Functional API <nnx_basics.html#the-functional-api>`__.
via the `Functional API <nnx_basics.html#the-flax-functional-api>`__.

Basic usage
^^^^^^^^^^^^
Expand Down Expand Up @@ -158,23 +161,30 @@ Learn more
.. grid-item::
:columns: 6 6 6 4

.. card:: :material-regular:`sync_alt;2em` Flax vs JAX Transformations
.. card:: :material-regular:`library_books;2em` Guides
:class-card: sd-text-black sd-bg-light
:link: transforms.html
:link: guides/index.html

.. grid-item::
:columns: 6 6 6 4

.. card:: :material-regular:`transform;2em` Haiku and Flax Linen vs Flax NNX
.. card:: :material-regular:`transform;2em` Flax Linen to Flax NNX
:class-card: sd-text-black sd-bg-light
:link: haiku_linen_vs_nnx.html
:link: guides/linen_to_nnx.html

.. grid-item::
:columns: 6 6 6 4

.. card:: :material-regular:`menu_book;2em` API reference
:class-card: sd-text-black sd-bg-light
:link: ../api_reference/flax.nnx/index.html
:link: /api_reference/index.html

.. grid-item::
:columns: 6 6 6 4

.. card:: :material-regular:`import_contacts;2em` Glossary
:class-card: sd-text-black sd-bg-light
:link: glossary.html


----
Expand All @@ -187,6 +197,7 @@ Learn more
mnist_tutorial
guides/index
examples/index
glossary
The Flax philosophy <philosophyhttps://flax.readthedocs.io/en/latest/philosophy.html>
How to contribute <https://flax.readthedocs.io/en/latest/contributing.html>
api_reference/index

0 comments on commit 774ab4e

Please sign in to comment.