Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nitting and adding links #4248

Merged
merged 1 commit into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions docs_nnx/guides/flax_gspmd.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@
"\n",
"Flax relies on [JAX](https://jax.readthedocs.io) for numeric computations and scaling the computations up across multiple devices, such as GPU and TPUs. At the core of scaling up is the [JAX just-in-time compiler `jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html). Throughout this guide, you will be using Flax’s own [`nnx.jit`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) which wraps around [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and works more conveniently with Flax NNX `Module`s.\n",
"\n",
"JAX compilation follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm. This means you write Python code as if it runs only on one device, and `jax.jit` will automatically compile and run it on multiple devices.\n",
"JAX compilation follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm. This means you write Python code as if it runs only on one device, and [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) will automatically compile and run it on multiple devices.\n",
"\n",
"To ensure the compilation performance, you often need to instruct JAX how your model's variables need to be sharded across devices. This is where Flax NNX's Sharding Metadata API - [`flax.nnx.spmd`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) - comes in. It helps you annotate your model variables with this information.\n",
"\n",
"> **Note to Flax Linen users**: The [`flax.nnx.spmd`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) API is similar to what is described in [the Linen Flax on `(p)jit` guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html) on the model definition level. However, the top-level code in Flax NNX is simpler due to the benefits brought by Flax NNX, and some text explanations will be more updated and clearer.\n",
"\n",
"If you are new parallelization in JAX, you can learn more about its APIs for scaling up in the following tutorials:\n",
"\n",
"- [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html): A 101 level tutorial covering the basics of automatic parallelization with `jax.jit`, semi-automatic parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html), and manual sharding with [`shard_map`](https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html#jax.experimental.shard_map.shard_map).\n",
"- [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html): A 101 level tutorial covering the basics of automatic parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit), semi-automatic parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html), and manual sharding with [`shard_map`](https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html#jax.experimental.shard_map.shard_map).\n",
"- [JAX in multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html).\n",
"- [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html): A more detailed tutorial about parallelization with `jax.jit` and `jax.lax.with_sharding_constraint`. Study it after the [101]([Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html).\n",
"- [Manual parallelism with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html): Another more in-depth doc that follows the [101]([Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html)."
"- [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html): A more detailed tutorial about parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) and [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html). Study it after the [101](https://jax.readthedocs.io/en/latest/sharded-computation.html).\n",
"- [Manual parallelism with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html): Another more in-depth doc that follows the [101](https://jax.readthedocs.io/en/latest/sharded-computation.html)."
]
},
{
Expand Down Expand Up @@ -525,7 +525,7 @@
"\n",
"Note that with the correct sharding for all inputs, the output will be sharded in the most natural way even without JIT compilation. In the example below even without [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) on the output `y`, it was still sharded as `('data', None)`.\n",
"\n",
"> If you are interested in why: The second matmul of `DotReluDot.__call__` has two inputs of sharding `('data', 'model')` and `('model', None)`, in which both inputs' contraction axis are `model`. So a reduce-scatter matmul happened and will naturally shard the output as `('data', None)`. Check out the [JAX shard map collective guide](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#example-2-psum-scatter-the-result) and its examples if you want to learn mathematically how it happened at low level."
"> If you are interested in why: The second matmul of `DotReluDot.__call__` has two inputs of sharding `('data', 'model')` and `('model', None)`, in which both inputs' contraction axis are `model`. So a reduce-scatter matmul happened and will naturally shard the output as `('data', None)`. Check out the [JAX shard map collective guide](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#example-2-psum-scatter-the-result) and its examples if you want to learn mathematically how it happens at a low level."
]
},
{
Expand Down Expand Up @@ -850,7 +850,7 @@
"\n",
" * For a simpler model, this can save you a few extra lines of code of converting the logical naming back to the device naming.\n",
"\n",
" * Shardings of intermediate *activation* values can only be done via `jax.lax.with_sharding_constraint` and device mesh axis. So if you want super fine-grained control over your model's sharding, directly using device mesh axis names everywhere might be less confusing.\n",
" * Shardings of intermediate *activation* values can only be done via [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) and device mesh axis. So if you want super fine-grained control over your model's sharding, directly using device mesh axis names everywhere might be less confusing.\n",
"\n",
"* **Logical naming**: Helpful if you want to experiment around and find the most optimal partition layout for your *model weights*."
]
Expand Down
12 changes: 6 additions & 6 deletions docs_nnx/guides/flax_gspmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@ This guide demonstrates how to scale up [Flax NNX `Module`s](https://flax-nnx.re

Flax relies on [JAX](https://jax.readthedocs.io) for numeric computations and scaling the computations up across multiple devices, such as GPU and TPUs. At the core of scaling up is the [JAX just-in-time compiler `jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html). Throughout this guide, you will be using Flax’s own [`nnx.jit`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) which wraps around [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and works more conveniently with Flax NNX `Module`s.

JAX compilation follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm. This means you write Python code as if it runs only on one device, and `jax.jit` will automatically compile and run it on multiple devices.
JAX compilation follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm. This means you write Python code as if it runs only on one device, and [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) will automatically compile and run it on multiple devices.

To ensure the compilation performance, you often need to instruct JAX how your model's variables need to be sharded across devices. This is where Flax NNX's Sharding Metadata API - [`flax.nnx.spmd`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) - comes in. It helps you annotate your model variables with this information.

> **Note to Flax Linen users**: The [`flax.nnx.spmd`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) API is similar to what is described in [the Linen Flax on `(p)jit` guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html) on the model definition level. However, the top-level code in Flax NNX is simpler due to the benefits brought by Flax NNX, and some text explanations will be more updated and clearer.

If you are new parallelization in JAX, you can learn more about its APIs for scaling up in the following tutorials:

- [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html): A 101 level tutorial covering the basics of automatic parallelization with `jax.jit`, semi-automatic parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html), and manual sharding with [`shard_map`](https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html#jax.experimental.shard_map.shard_map).
- [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html): A 101 level tutorial covering the basics of automatic parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit), semi-automatic parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html), and manual sharding with [`shard_map`](https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html#jax.experimental.shard_map.shard_map).
- [JAX in multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html).
- [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html): A more detailed tutorial about parallelization with `jax.jit` and `jax.lax.with_sharding_constraint`. Study it after the [101]([Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html).
- [Manual parallelism with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html): Another more in-depth doc that follows the [101]([Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html).
- [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html): A more detailed tutorial about parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) and [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html). Study it after the [101](https://jax.readthedocs.io/en/latest/sharded-computation.html).
- [Manual parallelism with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html): Another more in-depth doc that follows the [101](https://jax.readthedocs.io/en/latest/sharded-computation.html).

+++

Expand Down Expand Up @@ -243,7 +243,7 @@ Now, from initialization or from checkpoint, you have a sharded model. To carry

Note that with the correct sharding for all inputs, the output will be sharded in the most natural way even without JIT compilation. In the example below even without [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) on the output `y`, it was still sharded as `('data', None)`.

> If you are interested in why: The second matmul of `DotReluDot.__call__` has two inputs of sharding `('data', 'model')` and `('model', None)`, in which both inputs' contraction axis are `model`. So a reduce-scatter matmul happened and will naturally shard the output as `('data', None)`. Check out the [JAX shard map collective guide](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#example-2-psum-scatter-the-result) and its examples if you want to learn mathematically how it happened at low level.
> If you are interested in why: The second matmul of `DotReluDot.__call__` has two inputs of sharding `('data', 'model')` and `('model', None)`, in which both inputs' contraction axis are `model`. So a reduce-scatter matmul happened and will naturally shard the output as `('data', None)`. Check out the [JAX shard map collective guide](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#example-2-psum-scatter-the-result) and its examples if you want to learn mathematically how it happens at a low level.

```{code-cell} ipython3
# In data parallelism, the first dimension (batch) will be sharded on `data` axis.
Expand Down Expand Up @@ -384,6 +384,6 @@ Choosing when to use a device or logical axis depends on how much you want to co

* For a simpler model, this can save you a few extra lines of code of converting the logical naming back to the device naming.

* Shardings of intermediate *activation* values can only be done via `jax.lax.with_sharding_constraint` and device mesh axis. So if you want super fine-grained control over your model's sharding, directly using device mesh axis names everywhere might be less confusing.
* Shardings of intermediate *activation* values can only be done via [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) and device mesh axis. So if you want super fine-grained control over your model's sharding, directly using device mesh axis names everywhere might be less confusing.

* **Logical naming**: Helpful if you want to experiment around and find the most optimal partition layout for your *model weights*.
Loading
Loading