diff --git a/docs_nnx/guides/filters_guide.ipynb b/docs_nnx/guides/filters_guide.ipynb index a4dfabea..fbcbc5fd 100644 --- a/docs_nnx/guides/filters_guide.ipynb +++ b/docs_nnx/guides/filters_guide.ipynb @@ -5,12 +5,17 @@ "id": "95b08e64", "metadata": {}, "source": [ - "# Using Filters\n", + "# Using Filters, grouping NNX variables \n", "\n", - "> **Attention**: This page relates to the new Flax NNX API.\n", + "Flax NNX uses [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) extensively as a way to create [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) groups in APIs, such as [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`nnx.state()`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.state), and many of the [Flax NNX transformations (transforms)](https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html).\n", "\n", - "Filters are used extensively in Flax NNX as a way to create `State` groups in APIs\n", - "such as `nnx.split`, `nnx.state`, and many of the Flax NNX transforms. For example:" + "In this guide you will learn how to:\n", + "\n", + "* Use [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) to group Flax NNX variables and states into subgroups;\n", + "* Understand relationships between types, such as [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) or [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat), and [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html);\n", + "* Express your `Filter`s flexibly with [`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) language.\n", + "\n", + "In the following example [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) and [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat) are used as `Filter`s to split the model into two groups: one with the parameters and the other with the batch statistics:" ] }, { @@ -59,11 +64,7 @@ "id": "8f77e99a", "metadata": {}, "source": [ - "Here `nnx.Param` and `nnx.BatchStat` are used as Filters to split the model into two groups: one with the parameters and the other with the batch statistics. However, this begs the following questions:\n", - "\n", - "* What is a Filter?\n", - "* Why are types, such as `Param` or `BatchStat`, Filters?\n", - "* How is `State` grouped / filtered?" + "Let's dive deeper into `Filter`s." ] }, { @@ -71,20 +72,25 @@ "id": "a0413d64", "metadata": {}, "source": [ - "## The Filter Protocol\n", + "## The `Filter` Protocol\n", "\n", - "In general Filter are predicate functions of the form:\n", + "In general, Flax `Filter`s are predicate functions of the form:\n", "\n", "```python\n", "\n", "(path: tuple[Key, ...], value: Any) -> bool\n", "\n", "```\n", - "where `Key` is a hashable and comparable type, `path` is a tuple of `Key`s representing the path to the value in a nested structure, and `value` is the value at the path. The function returns `True` if the value should be included in the group and `False` otherwise.\n", "\n", - "Types are obviously not functions of this form, so the reason why they are treated as Filters\n", - "is because, as we will see next, types and some other literals are converted to predicates. For example,\n", - "`Param` is roughly converted to a predicate like this:" + "where:\n", + "\n", + "- `Key` is a hashable and comparable type;\n", + "- `path` is a tuple of `Key`s representing the path to the value in a nested structure; and\n", + "- `value` is the value at the path.\n", + "\n", + "The function returns `True` if the value should be included in the group, and `False` otherwise.\n", + "\n", + "Types are not functions of this form. They are treated as `Filter`s because, as you will learn in the next section, types and some other literals are converted to _predicates_. For example, [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) is roughly converted to a predicate like this:" ] }, { @@ -117,9 +123,7 @@ "id": "a8a2641e", "metadata": {}, "source": [ - "Such function matches any value that is an instance of `Param` or any value that has a\n", - "`type` attribute that is a subclass of `Param`. Internally Flax NNX uses `OfType` which\n", - "defines a callable of this form for a given type:" + "Such function matches any value that is an instance of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) or any value that has a `type` attribute that is a subclass of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param). Internally Flax NNX uses `OfType` which defines a callable of this form for a given type:" ] }, { @@ -149,14 +153,11 @@ "id": "87c06e39", "metadata": {}, "source": [ - "## The Filter DSL\n", + "## The `Filter` DSL\n", "\n", - "To avoid users having to create these functions, Flax NNX exposes a small DSL, formalized\n", - "as the `nnx.filterlib.Filter` type, which lets users pass types, booleans, ellipsis,\n", - "tuples/lists, etc, and converts them to the appropriate predicate internally.\n", + "Flax NNX exposes a small domain specific language ([DSL](https://en.wikipedia.org/wiki/Domain-specific_language)), formalized as the [`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) type. This means users don't have to create functions like in the previous section.\n", "\n", - "Here is a list of all the callable Filters included in Flax NNX and their DSL literals\n", - "(when available):\n", + "Here is a list of all the callable `Filter`s included in Flax NNX, and their corresponding DSL literals (when available):\n", "\n", "\n", "| Literal | Callable | Description |\n", @@ -170,10 +171,14 @@ "| | `All(*filters)` | Matches values that match all of the inner `filters` |\n", "| | `Not(filter)` | Matches values that do not match the inner `filter` |\n", "\n", - "Let see the DSL in action with a `nnx.vmap` example. Lets say we want vectorized all parameters\n", - "and `dropout` Rng(Keys|Counts) on the 0th axis, and broadcasted the rest. To do so we can\n", - "use the following filters to define a `nnx.StateAxes` object that we can pass to `nnx.vmap`'s `in_axes`\n", - "to specify how `model`'s various substates should be vectorized:" + "\n", + "Let's check out the DSL in action by using [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap) as an example. Consider the following:\n", + "\n", + "1) You want to vectorize all parameters;\n", + "2) Apply `'dropout'` `Rng(Keys|Counts)` on the `0`th axis; and\n", + "3) Broadcast the rest.\n", + "\n", + "To do this, you can use the following `Filter`s to define a `nnx.StateAxes` object that you can pass to `nnx.vmap`'s `in_axes` to specify how the `model`'s various sub-states should be vectorized:" ] }, { @@ -195,10 +200,9 @@ "id": "bd60f0e1", "metadata": {}, "source": [ - "Here `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...`\n", - "expands to `Everything()`.\n", + "Here `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...` expands to `Everything()`.\n", "\n", - "If you wish to manually convert literal into a predicate to can use `nnx.filterlib.to_predicate`:" + "If you wish to manually convert literal into a predicate, you can use [`nnx.filterlib.to_predicate`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html#flax.nnx.filterlib.to_predicate):" ] }, { @@ -235,15 +239,15 @@ "id": "db9b4cf3", "metadata": {}, "source": [ - "## Grouping States\n", + "## Grouping `State`s\n", "\n", - "With the knowledge of Filters at hand, let's see how `nnx.split` is roughly implemented. Key ideas:\n", + "With the knowledge of `Filter`s from previous sections at hand, let's learn how to roughly implement [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split). Here are the key ideas:\n", "\n", - "* Use `nnx.graph.flatten` to get the `GraphDef` and `State` representation of the node.\n", - "* Convert all the filters to predicates.\n", + "* Use `nnx.graph.flatten` to get the [`GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) and [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) representation of the node.\n", + "* Convert all the `Filter`s to predicates.\n", "* Use `State.flat_state` to get the flat representation of the state.\n", "* Traverse all the `(path, value)` pairs in the flat state and group them according to the predicates.\n", - "* Use `State.from_flat_state` to convert the flat states to nested `State`s." + "* Use `State.from_flat_state` to convert the flat states to nested [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s." ] }, { @@ -293,7 +297,7 @@ " )\n", " return graphdef, *states\n", "\n", - "# lets test it...\n", + "# Let's test it.\n", "foo = Foo()\n", "\n", "graphdef, params, batch_stats = split(foo, nnx.Param, nnx.BatchStat)\n", @@ -307,12 +311,14 @@ "id": "7b3aeac8", "metadata": {}, "source": [ - "One very important thing to note is that **filtering is order-dependent**. The first filter that\n", - "matches a value will keep it, therefore you should place more specific filters before more general\n", - "filters. For example if we create a `SpecialParam` type that is a subclass of `Param`, and a `Bar`\n", - "object that contains both types of parameters, if we try to split the `Param`s before the\n", - "`SpecialParam`s then all the values will be placed in the `Param` group and the `SpecialParam` group\n", - "will be empty because all `SpecialParam`s are also `Param`s:" + "**Note:*** It's very important to know that **filtering is order-dependent**. The first `Filter` that matches a value will keep it, and therefore you should place more specific `Filter`s before more general `Filter`s.\n", + "\n", + "For example, as demonstrated below, if you:\n", + "\n", + "1) Create a `SpecialParam` type that is a subclass of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param), and a `Bar` object (subclassing [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)) that contains both types of parameters; and\n", + "2) Try to split the [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s before the `SpecialParam`s\n", + "\n", + "then all the values will be placed in the [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) group, and the `SpecialParam` group will be empty because all `SpecialParam`s are also [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s:" ] }, { @@ -360,7 +366,7 @@ "id": "a9f0b7b8", "metadata": {}, "source": [ - "Reversing the order will make sure that the `SpecialParam` are captured first" + "And reversing the order will ensure that the `SpecialParam` are captured first:" ] }, { diff --git a/docs_nnx/guides/filters_guide.md b/docs_nnx/guides/filters_guide.md index dcd414d7..88a25a6a 100644 --- a/docs_nnx/guides/filters_guide.md +++ b/docs_nnx/guides/filters_guide.md @@ -8,12 +8,17 @@ jupytext: jupytext_version: 1.13.8 --- -# Using Filters +# Using Filters, grouping NNX variables -> **Attention**: This page relates to the new Flax NNX API. +Flax NNX uses [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) extensively as a way to create [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) groups in APIs, such as [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`nnx.state()`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.state), and many of the [Flax NNX transformations (transforms)](https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html). -Filters are used extensively in Flax NNX as a way to create `State` groups in APIs -such as `nnx.split`, `nnx.state`, and many of the Flax NNX transforms. For example: +In this guide you will learn how to: + +* Use [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) to group Flax NNX variables and states into subgroups; +* Understand relationships between types, such as [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) or [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat), and [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html); +* Express your `Filter`s flexibly with [`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) language. + +In the following example [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) and [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat) are used as `Filter`s to split the model into two groups: one with the parameters and the other with the batch statistics: ```{code-cell} ipython3 from flax import nnx @@ -31,28 +36,29 @@ print(f'{params = }') print(f'{batch_stats = }') ``` -Here `nnx.Param` and `nnx.BatchStat` are used as Filters to split the model into two groups: one with the parameters and the other with the batch statistics. However, this begs the following questions: - -* What is a Filter? -* Why are types, such as `Param` or `BatchStat`, Filters? -* How is `State` grouped / filtered? +Let's dive deeper into `Filter`s. +++ -## The Filter Protocol +## The `Filter` Protocol -In general Filter are predicate functions of the form: +In general, Flax `Filter`s are predicate functions of the form: ```python (path: tuple[Key, ...], value: Any) -> bool ``` -where `Key` is a hashable and comparable type, `path` is a tuple of `Key`s representing the path to the value in a nested structure, and `value` is the value at the path. The function returns `True` if the value should be included in the group and `False` otherwise. -Types are obviously not functions of this form, so the reason why they are treated as Filters -is because, as we will see next, types and some other literals are converted to predicates. For example, -`Param` is roughly converted to a predicate like this: +where: + +- `Key` is a hashable and comparable type; +- `path` is a tuple of `Key`s representing the path to the value in a nested structure; and +- `value` is the value at the path. + +The function returns `True` if the value should be included in the group, and `False` otherwise. + +Types are not functions of this form. They are treated as `Filter`s because, as you will learn in the next section, types and some other literals are converted to _predicates_. For example, [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) is roughly converted to a predicate like this: ```{code-cell} ipython3 def is_param(path, value) -> bool: @@ -64,9 +70,7 @@ print(f'{is_param((), nnx.Param(0)) = }') print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }') ``` -Such function matches any value that is an instance of `Param` or any value that has a -`type` attribute that is a subclass of `Param`. Internally Flax NNX uses `OfType` which -defines a callable of this form for a given type: +Such function matches any value that is an instance of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) or any value that has a `type` attribute that is a subclass of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param). Internally Flax NNX uses `OfType` which defines a callable of this form for a given type: ```{code-cell} ipython3 is_param = nnx.OfType(nnx.Param) @@ -75,14 +79,11 @@ print(f'{is_param((), nnx.Param(0)) = }') print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }') ``` -## The Filter DSL +## The `Filter` DSL -To avoid users having to create these functions, Flax NNX exposes a small DSL, formalized -as the `nnx.filterlib.Filter` type, which lets users pass types, booleans, ellipsis, -tuples/lists, etc, and converts them to the appropriate predicate internally. +Flax NNX exposes a small domain specific language ([DSL](https://en.wikipedia.org/wiki/Domain-specific_language)), formalized as the [`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) type. This means users don't have to create functions like in the previous section. -Here is a list of all the callable Filters included in Flax NNX and their DSL literals -(when available): +Here is a list of all the callable `Filter`s included in Flax NNX, and their corresponding DSL literals (when available): | Literal | Callable | Description | @@ -96,10 +97,14 @@ Here is a list of all the callable Filters included in Flax NNX and their DSL li | | `All(*filters)` | Matches values that match all of the inner `filters` | | | `Not(filter)` | Matches values that do not match the inner `filter` | -Let see the DSL in action with a `nnx.vmap` example. Lets say we want vectorized all parameters -and `dropout` Rng(Keys|Counts) on the 0th axis, and broadcasted the rest. To do so we can -use the following filters to define a `nnx.StateAxes` object that we can pass to `nnx.vmap`'s `in_axes` -to specify how `model`'s various substates should be vectorized: + +Let's check out the DSL in action by using [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap) as an example. Consider the following: + +1) You want to vectorize all parameters; +2) Apply `'dropout'` `Rng(Keys|Counts)` on the `0`th axis; and +3) Broadcast the rest. + +To do this, you can use the following `Filter`s to define a `nnx.StateAxes` object that you can pass to `nnx.vmap`'s `in_axes` to specify how the `model`'s various sub-states should be vectorized: ```{code-cell} ipython3 state_axes = nnx.StateAxes({(nnx.Param, 'dropout'): 0, ...: None}) @@ -109,10 +114,9 @@ def forward(model, x): ... ``` -Here `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...` -expands to `Everything()`. +Here `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...` expands to `Everything()`. -If you wish to manually convert literal into a predicate to can use `nnx.filterlib.to_predicate`: +If you wish to manually convert literal into a predicate, you can use [`nnx.filterlib.to_predicate`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html#flax.nnx.filterlib.to_predicate): ```{code-cell} ipython3 is_param = nnx.filterlib.to_predicate(nnx.Param) @@ -126,15 +130,15 @@ print(f'{nothing = }') print(f'{params_or_dropout = }') ``` -## Grouping States +## Grouping `State`s -With the knowledge of Filters at hand, let's see how `nnx.split` is roughly implemented. Key ideas: +With the knowledge of `Filter`s from previous sections at hand, let's learn how to roughly implement [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split). Here are the key ideas: -* Use `nnx.graph.flatten` to get the `GraphDef` and `State` representation of the node. -* Convert all the filters to predicates. +* Use `nnx.graph.flatten` to get the [`GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) and [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) representation of the node. +* Convert all the `Filter`s to predicates. * Use `State.flat_state` to get the flat representation of the state. * Traverse all the `(path, value)` pairs in the flat state and group them according to the predicates. -* Use `State.from_flat_state` to convert the flat states to nested `State`s. +* Use `State.from_flat_state` to convert the flat states to nested [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s. ```{code-cell} ipython3 from typing import Any @@ -158,7 +162,7 @@ def split(node, *filters): ) return graphdef, *states -# lets test it... +# Let's test it. foo = Foo() graphdef, params, batch_stats = split(foo, nnx.Param, nnx.BatchStat) @@ -167,12 +171,14 @@ print(f'{params = }') print(f'{batch_stats = }') ``` -One very important thing to note is that **filtering is order-dependent**. The first filter that -matches a value will keep it, therefore you should place more specific filters before more general -filters. For example if we create a `SpecialParam` type that is a subclass of `Param`, and a `Bar` -object that contains both types of parameters, if we try to split the `Param`s before the -`SpecialParam`s then all the values will be placed in the `Param` group and the `SpecialParam` group -will be empty because all `SpecialParam`s are also `Param`s: +**Note:*** It's very important to know that **filtering is order-dependent**. The first `Filter` that matches a value will keep it, and therefore you should place more specific `Filter`s before more general `Filter`s. + +For example, as demonstrated below, if you: + +1) Create a `SpecialParam` type that is a subclass of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param), and a `Bar` object (subclassing [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)) that contains both types of parameters; and +2) Try to split the [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s before the `SpecialParam`s + +then all the values will be placed in the [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) group, and the `SpecialParam` group will be empty because all `SpecialParam`s are also [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s: ```{code-cell} ipython3 class SpecialParam(nnx.Param): @@ -190,7 +196,7 @@ print(f'{params = }') print(f'{special_params = }') ``` -Reversing the order will make sure that the `SpecialParam` are captured first +And reversing the order will ensure that the `SpecialParam` are captured first: ```{code-cell} ipython3 graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # correct! diff --git a/docs_nnx/nnx_basics.ipynb b/docs_nnx/nnx_basics.ipynb index 351ae8b6..f5b74326 100644 --- a/docs_nnx/nnx_basics.ipynb +++ b/docs_nnx/nnx_basics.ipynb @@ -8,7 +8,22 @@ "\n", "Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug, and analyze neural networks in [JAX](https://jax.readthedocs.io/). It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, which are modeled as PyGraphs (instead of pytrees), enabling reference sharing and mutability. Such API design should make PyTorch or Keras users feel at home.\n", "\n", - "To begin, install Flax with `pip` and import necessary dependencies:" + "In this guide you will learn about:\n", + "\n", + "- The Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) system: An example of creating and initializing a custom `Linear` layer.\n", + " - Stateful computation: An example of creating a Flax [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) and updating its value (such as state updates needed during the forward pass).\n", + " - Nested [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s: An MLP example with `Linear`, [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout), and [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layers.\n", + " - Model surgery: An example of replacing custom `Linear` layers inside a model with custom `LoraLinear` layers.\n", + "- Flax transformations: An example of using [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) for automatic state management.\n", + " - [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) over layers.\n", + "- The Flax NNX Functional API: An example of a custom `StatefulLinear` layer with [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s with fine-grained control over the state.\n", + " - [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and [`GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef).\n", + " - [`split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge), and `update`\n", + " - Fine-grained [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) control: An example of using [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) type `Filter`s ([`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/guides/filters_guide.html)) to split into multiple [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s.\n", + "\n", + "## Setup\n", + "\n", + "Install Flax with `pip` and impost necessary dependencies:" ] }, { diff --git a/docs_nnx/nnx_basics.md b/docs_nnx/nnx_basics.md index fbf9be0a..61b96e2d 100644 --- a/docs_nnx/nnx_basics.md +++ b/docs_nnx/nnx_basics.md @@ -12,7 +12,22 @@ jupytext: Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug, and analyze neural networks in [JAX](https://jax.readthedocs.io/). It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, which are modeled as PyGraphs (instead of pytrees), enabling reference sharing and mutability. Such API design should make PyTorch or Keras users feel at home. -To begin, install Flax with `pip` and import necessary dependencies: +In this guide you will learn about: + +- The Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) system: An example of creating and initializing a custom `Linear` layer. + - Stateful computation: An example of creating a Flax [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) and updating its value (such as state updates needed during the forward pass). + - Nested [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s: An MLP example with `Linear`, [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout), and [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layers. + - Model surgery: An example of replacing custom `Linear` layers inside a model with custom `LoraLinear` layers. +- Flax transformations: An example of using [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) for automatic state management. + - [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) over layers. +- The Flax NNX Functional API: An example of a custom `StatefulLinear` layer with [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s with fine-grained control over the state. + - [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and [`GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef). + - [`split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge), and `update` + - Fine-grained [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) control: An example of using [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) type `Filter`s ([`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/guides/filters_guide.html)) to split into multiple [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s. + +## Setup + +Install Flax with `pip` and impost necessary dependencies: ```{code-cell} ipython3 :tags: [skip-execution]