Skip to content

Commit

Permalink
Upgrade NNX Filters guides
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Dec 12, 2024
1 parent 0bd3e83 commit 9dd1e4e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 14 deletions.
13 changes: 6 additions & 7 deletions docs_nnx/guides/filters_guide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@
"id": "95b08e64",
"metadata": {},
"source": [
"# Using `Filter`s\n",
"# Using Filters, grouping NNX variables \n",
"\n",
"Flax NNX uses [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) extensively as a way to create [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) groups in APIs, such as [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`nnx.state()`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.state), and many of the [Flax NNX transformations (transforms)](https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html).\n",
"\n",
"In this guide you will learn:\n",
"In this guide you will learn how to:\n",
"\n",
"* What is a `Filter`?\n",
"* Why are types, such as [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) or [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat), treated as `Filter`s?\n",
"* What is the `Filter` domain specific language (DSL)?\n",
"* How is [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) grouped / filtered?\n",
"* Use [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) to group Flax NNX variables and states into subgroups;\n",
"* Understand relationships between types, such as [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) or [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat), and [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html);\n",
"* Express your `Filter`s flexibly with [`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) language.\n",
"\n",
"In the following example [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) and [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat) are used as `Filter`s to split the model into two groups: one with the parameters and the other with the batch statistics:"
]
Expand Down Expand Up @@ -156,7 +155,7 @@
"source": [
"## The `Filter` DSL\n",
"\n",
"To help users avoid having to create functions mentioned in the previous section, Flax NNX exposes a small domain specific language ([DSL](https://en.wikipedia.org/wiki/Domain-specific_language)), formalized as the [`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) type. The `Filter` DSL allows users to pass types, booleans, ellipsis, tuples/lists, etc, and converts them to the appropriate predicate internally.\n",
"Flax NNX exposes a small domain specific language ([DSL](https://en.wikipedia.org/wiki/Domain-specific_language)), formalized as the [`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) type. This means users don't have to create functions like in the previous section.\n",
"\n",
"Here is a list of all the callable `Filter`s included in Flax NNX, and their corresponding DSL literals (when available):\n",
"\n",
Expand Down
13 changes: 6 additions & 7 deletions docs_nnx/guides/filters_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,15 @@ jupytext:
jupytext_version: 1.13.8
---

# Using `Filter`s
# Using Filters, grouping NNX variables

Flax NNX uses [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) extensively as a way to create [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) groups in APIs, such as [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`nnx.state()`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.state), and many of the [Flax NNX transformations (transforms)](https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html).

In this guide you will learn:
In this guide you will learn how to:

* What is a `Filter`?
* Why are types, such as [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) or [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat), treated as `Filter`s?
* What is the `Filter` domain specific language (DSL)?
* How is [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) grouped / filtered?
* Use [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) to group Flax NNX variables and states into subgroups;
* Understand relationships between types, such as [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) or [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat), and [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html);
* Express your `Filter`s flexibly with [`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) language.

In the following example [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) and [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat) are used as `Filter`s to split the model into two groups: one with the parameters and the other with the batch statistics:

Expand Down Expand Up @@ -82,7 +81,7 @@ print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')

## The `Filter` DSL

To help users avoid having to create functions mentioned in the previous section, Flax NNX exposes a small domain specific language ([DSL](https://en.wikipedia.org/wiki/Domain-specific_language)), formalized as the [`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) type. The `Filter` DSL allows users to pass types, booleans, ellipsis, tuples/lists, etc, and converts them to the appropriate predicate internally.
Flax NNX exposes a small domain specific language ([DSL](https://en.wikipedia.org/wiki/Domain-specific_language)), formalized as the [`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) type. This means users don't have to create functions like in the previous section.

Here is a list of all the callable `Filter`s included in Flax NNX, and their corresponding DSL literals (when available):

Expand Down

0 comments on commit 9dd1e4e

Please sign in to comment.