From 9dd1e4eb53e3a10d8d15c4700a60e16b374f483b Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Thu, 12 Dec 2024 22:41:56 +0000 Subject: [PATCH] Upgrade NNX Filters guides --- docs_nnx/guides/filters_guide.ipynb | 13 ++++++------- docs_nnx/guides/filters_guide.md | 13 ++++++------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/docs_nnx/guides/filters_guide.ipynb b/docs_nnx/guides/filters_guide.ipynb index 66f6fbcf..7fbec345 100644 --- a/docs_nnx/guides/filters_guide.ipynb +++ b/docs_nnx/guides/filters_guide.ipynb @@ -5,16 +5,15 @@ "id": "95b08e64", "metadata": {}, "source": [ - "# Using `Filter`s\n", + "# Using Filters, grouping NNX variables \n", "\n", "Flax NNX uses [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) extensively as a way to create [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) groups in APIs, such as [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`nnx.state()`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.state), and many of the [Flax NNX transformations (transforms)](https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html).\n", "\n", - "In this guide you will learn:\n", + "In this guide you will learn how to:\n", "\n", - "* What is a `Filter`?\n", - "* Why are types, such as [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) or [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat), treated as `Filter`s?\n", - "* What is the `Filter` domain specific language (DSL)?\n", - "* How is [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) grouped / filtered?\n", + "* Use [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) to group Flax NNX variables and states into subgroups;\n", + "* Understand relationships between types, such as [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) or [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat), and [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html);\n", + "* Express your `Filter`s flexibly with [`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) language.\n", "\n", "In the following example [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) and [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat) are used as `Filter`s to split the model into two groups: one with the parameters and the other with the batch statistics:" ] @@ -156,7 +155,7 @@ "source": [ "## The `Filter` DSL\n", "\n", - "To help users avoid having to create functions mentioned in the previous section, Flax NNX exposes a small domain specific language ([DSL](https://en.wikipedia.org/wiki/Domain-specific_language)), formalized as the [`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) type. The `Filter` DSL allows users to pass types, booleans, ellipsis, tuples/lists, etc, and converts them to the appropriate predicate internally.\n", + "Flax NNX exposes a small domain specific language ([DSL](https://en.wikipedia.org/wiki/Domain-specific_language)), formalized as the [`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) type. This means users don't have to create functions like in the previous section.\n", "\n", "Here is a list of all the callable `Filter`s included in Flax NNX, and their corresponding DSL literals (when available):\n", "\n", diff --git a/docs_nnx/guides/filters_guide.md b/docs_nnx/guides/filters_guide.md index 68e23bb5..27806566 100644 --- a/docs_nnx/guides/filters_guide.md +++ b/docs_nnx/guides/filters_guide.md @@ -8,16 +8,15 @@ jupytext: jupytext_version: 1.13.8 --- -# Using `Filter`s +# Using Filters, grouping NNX variables Flax NNX uses [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) extensively as a way to create [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) groups in APIs, such as [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`nnx.state()`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.state), and many of the [Flax NNX transformations (transforms)](https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html). -In this guide you will learn: +In this guide you will learn how to: -* What is a `Filter`? -* Why are types, such as [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) or [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat), treated as `Filter`s? -* What is the `Filter` domain specific language (DSL)? -* How is [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) grouped / filtered? +* Use [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) to group Flax NNX variables and states into subgroups; +* Understand relationships between types, such as [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) or [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat), and [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html); +* Express your `Filter`s flexibly with [`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) language. In the following example [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) and [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat) are used as `Filter`s to split the model into two groups: one with the parameters and the other with the batch statistics: @@ -82,7 +81,7 @@ print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }') ## The `Filter` DSL -To help users avoid having to create functions mentioned in the previous section, Flax NNX exposes a small domain specific language ([DSL](https://en.wikipedia.org/wiki/Domain-specific_language)), formalized as the [`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) type. The `Filter` DSL allows users to pass types, booleans, ellipsis, tuples/lists, etc, and converts them to the appropriate predicate internally. +Flax NNX exposes a small domain specific language ([DSL](https://en.wikipedia.org/wiki/Domain-specific_language)), formalized as the [`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) type. This means users don't have to create functions like in the previous section. Here is a list of all the callable `Filter`s included in Flax NNX, and their corresponding DSL literals (when available):