diff --git a/docs_nnx/guides/flax_gspmd.ipynb b/docs_nnx/guides/flax_gspmd.ipynb index 5787ab7501..44dcfd5135 100644 --- a/docs_nnx/guides/flax_gspmd.ipynb +++ b/docs_nnx/guides/flax_gspmd.ipynb @@ -6,7 +6,7 @@ "source": [ "# Scale up on multiple devices\n", "\n", - "This guide demonstrates how to scale up [Flax NNX `Module`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) on multiple devices and hosts, such as GPUs, Google TPUs, and CPUs, using [JAX just-in-time compilation machinery (`jax.jit`)](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html)." + "This guide demonstrates how to scale up [Flax NNX `Module`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) on [multiple devices and hosts](Multi-host and multi-process environments) - such as GPUs, Google TPUs, and CPUs - using the [JAX just-in-time compilation machinery (`jax.jit`)](https://jax.readthedocs.io/en/latest/jit-compilation.html) and [`flax.nnx.spmd`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html)." ] }, { @@ -16,9 +16,11 @@ "source": [ "## Overview\n", "\n", - "Flax relies on [JAX](https://jax.readthedocs.io) for numeric computations and scaling the computations up across multiple devices, such as GPU and TPUs. At the core of scaling up is the [JAX just-in-time compiler `jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html). Throughout this guide, you will be using Flax’s own [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) which wraps around [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and works more conveniently with Flax NNX `Module`s.\n", + "Flax relies on [JAX](https://jax.readthedocs.io) for numeric computations and scaling the computations up across multiple devices, such as GPU and Google TPUs. At the core of scaling up is the [JAX just-in-time (`jax.jit`) compiler `jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html). Throughout this guide, you will be using Flax’s own [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) transform, which wraps around [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and works more conveniently with Flax NNX `Module`s.\n", "\n", - "JAX compilation follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm. This means you write Python code as if it runs only on one device, and [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) will automatically compile and run it on multiple devices.\n", + "> **Note:** To learn more about Flax’s transformations, such as `nnx.jit` and `nnx.vmap`, go to [Why Flax NNX? - Transforms](https://flax.readthedocs.io/en/latest/why.html#transforms), [Transformations](https://flax.readthedocs.io/en/latest/guides/transforms.html), and [Flax NNX vs JAX Transformations](https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html).\n", + "\n", + "JAX compilation follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm. This means you write Python code as if it runs only on one device, and [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) will [automatically compile](https://jax.readthedocs.io/en/latest/jit-compilation.html#jit-compilation) and [run it](https://jax.readthedocs.io/en/latest/sharded-computation.html) on [multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html).\n", "\n", "To ensure the compilation performance, you often need to instruct JAX how your model's variables need to be sharded across devices. This is where Flax NNX's Sharding Metadata API - [`flax.nnx.spmd`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) - comes in. It helps you annotate your model variables with this information.\n", "\n", @@ -131,9 +133,9 @@ "source": [ "## Define a model with specified sharding\n", "\n", - "Next, create an example layer called `DotReluDot` that subclasses Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module). This layer carries out two dot product multiplication upon the input `x`, and uses the `jax.nn.relu` (ReLU) activation function in-between.\n", - "\n", - "To annotate a model variable with their ideal sharding, you can use [`flax.nnx.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) to wrap over its initializer function. Essentially, this calls [`flax.nnx.with_metadata`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.with_metadata) which adds a `.sharding` attribute field to the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable).\n", + "Next, create an example layer called `DotReluDot` that subclasses Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module).\n", + "- This layer carries out two dot product multiplications upon the input `x`, and uses the `jax.nn.relu` (ReLU) activation function in-between.\n", + "- To annotate a model variable with their ideal sharding, you can use [`flax.nnx.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) to wrap over its initializer function. Essentially, this calls [`flax.nnx.with_metadata`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.with_metadata) which adds a `.sharding` attribute field to the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable).\n", "\n", "> **Note:** This annotation will be [preserved and adjusted accordingly across lifted transformations in Flax NNX](https://flax.readthedocs.io/en/latest/guides/transforms.html#axes-metadata). This means if you use sharding annotations along with any transform that modifies axes (like [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html), [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html)), you need to provide sharding of that additional axis via the `transform_metadata` arg. Check out the [Flax NNX transformations (transforms) guide](https://flax.readthedocs.io/en/latest/guides/transforms.html) to learn more." ] @@ -197,7 +199,7 @@ "source": [ "## Initialize a sharded model\n", "\n", - "Now, you have annotations attached to the [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), but the actual weights haven't been sharded yet. If you just go ahead and create this model, all JAX arrays are still stuck in device `0`. In practice, you'd want to avoid this, because a large model will \"OOM\" (will cause the device to run out of memory) in this situation, while all the other devices are not utilized." + "Now, you have annotations attached to the Flax [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), but the actual weights have not been sharded yet. If you just go ahead and create this model, all [`jax.Array`s](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) are still stuck in device `0`. In practice, you'd want to avoid this, because a large model will \"OOM\" (will cause the device to run out of memory) in this situation, while all the other devices are not utilized." ] }, { @@ -219,7 +221,7 @@ "source": [ "unsharded_model = DotReluDot(1024, rngs=nnx.Rngs(0))\n", "\n", - "# You have annotations sticked there, yay!\n", + "# You have annotations stuck there, yay!\n", "print(unsharded_model.dot1.kernel.sharding) # (None, 'model')\n", "print(unsharded_model.w2.sharding) # ('model', None)\n", "\n", @@ -232,7 +234,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Here, you should leverage JAX's compilation mechanism, via [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit), to create the sharded model. The key is to initialize a model and assign shardings upon the model state within a `jit`ted function:\n", + "Here, you should leverage JAX's compilation mechanism via Flax’s [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) to create the sharded model. The key is to initialize a model and assign shardings upon the model state within a `jit`ted function:\n", "\n", "1. Use [`nnx.get_partition_spec`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) to strip out the `.sharding` annotations attached upon model variables.\n", "\n", @@ -240,7 +242,7 @@ "\n", "1. Throw away the unsharded state and return the model based upon the sharded state.\n", "\n", - "1. Compile the whole function with `nnx.jit`, which allows the output to be a stateful NNX module.\n", + "1. Compile the whole function with `nnx.jit`, which allows the output to be a stateful Flax NNX `Module`.\n", "\n", "1. Run it under a device mesh context so that JAX knows which devices to shard it to.\n", "\n", @@ -291,7 +293,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "You can view the sharding of any 1-D or 2-D array with `jax.debug.visualize_array_sharding`:" + "You can view the sharding of any 1-D or 2-D array with [`jax.debug.visualize_array_sharding`](https://jax.readthedocs.io/en/latest/_autosummary/jax.debug.visualize_array_sharding.html):" ] }, { @@ -393,7 +395,7 @@ "\n", "> **Note:** Both [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html) and [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) in the JAX documentation cover automatic parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html), and semi-automatic parallelization with `jax.jit` and [`jax.lax.with_sharding_constraint](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) in greater detail.\n", "\n", - "You may have noticed you also used [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) once in the model definition too, to constraint the sharding of an intermediate value. This is just to show that you can always use it orthogonally with the Flax NNX API, if you want to explicitly shard values that are not model variables.\n", + "You may have noticed you also used [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) once in the model definition to constraint the sharding of an intermediate value. This is just to show that you can always use it orthogonally with the Flax NNX API if you want to explicitly shard values that are not model variables.\n", "\n", "This brings a question: Why use the Flax NNX Annotation API then? Why not just add JAX sharding constraints inside the model definition? The most important reason is that you still need the explicit annotations to load a sharded model from an on-disk checkpoint. This is described in the next section." ] @@ -402,13 +404,13 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Load sharded model from a checkpoint\n", + "## Load a sharded model from a checkpoint\n", "\n", - "Now you can initialize a sharded model without OOM, but what about loading it from a checkpoint on disk? JAX checkpointing libraries, such as [Orbax](https://orbax.readthedocs.io/en/latest/), usually support loading it sharded if a sharding pytree is given.\n", + "Now you learned how to initialize a sharded model without OOM, but what about loading it from a checkpoint on disk? JAX checkpointing libraries, such as [Orbax](https://orbax.readthedocs.io/en/latest/), usually support loading a model sharded if a sharding pytree is provided.\n", "\n", - "You can generate such as sharding pytree with [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding). To avoid any real memory allocation, use the [`nnx.eval_shape`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.eval_shape) transform to generate a model of abstract JAX arrays, and only use its `.sharding` annotations to obtain the sharding tree.\n", + "You can generate such a sharding pytree with Flax’s [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding). To avoid any real memory allocation, use the [`nnx.eval_shape`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.eval_shape) transform to generate a model of abstract JAX arrays, and only use its `.sharding` annotations to obtain the sharding tree.\n", "\n", - "Below is an example demonstration using Orbax's `StandardCheckpointer` API. Check out [Orbax website](https://orbax.readthedocs.io/en/latest/) to learn their latest updates and recommended APIs." + "Below is an example that demonstrates using Orbax's `StandardCheckpointer` API. (Go to the [Orbax documentation site](https://orbax.readthedocs.io/en/latest/) to learn about their latest and most recommended APIs.)" ] }, { @@ -509,11 +511,13 @@ "source": [ "## Compile the training loop\n", "\n", - "Now, from initialization or from checkpoint, you have a sharded model. To carry out the compiled, scaled up training, you need to shard the inputs as well. In this data parallelism example, the training data has its batch dimension sharded across the `data` device axis, so you should put your data in sharding `('data', None)`. You can use [`jax.device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html#jax.device_put) for this.\n", + "Now, after either initialization or loading the checkpoint, you have a sharded model. To carry out the compiled scaled up training, you need to shard the inputs as well.\n", "\n", - "Note that with the correct sharding for all inputs, the output will be sharded in the most natural way even without JIT compilation. In the example below even without [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) on the output `y`, it was still sharded as `('data', None)`.\n", + "- In the data parallelism example, the training data has its batch dimension sharded across the `data` device axis, so you should put your data in sharding `('data', None)`. You can use [`jax.device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html#jax.device_put) for this.\n", + "- Note that with the correct sharding for all inputs, the output will be sharded in the most natural way even without `jit` compilation. \n", + "- In the example below, even without [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) on the output `y`, it was still sharded as `('data', None)`.\n", "\n", - "> If you are interested in why: The second matmul of `DotReluDot.__call__` has two inputs of sharding `('data', 'model')` and `('model', None)`, in which both inputs' contraction axis are `model`. So a reduce-scatter matmul happened and will naturally shard the output as `('data', None)`. Check out the [JAX shard map collective guide](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#example-2-psum-scatter-the-result) and its examples if you want to learn mathematically how it happens at a low level." + "> If you are interested in why: The second matmul of `DotReluDot.__call__` has two inputs of sharding `('data', 'model')` and `('model', None)`, in which both inputs' contraction axis are `model`. So a reduce-scatter matmul happened and will naturally shard the output as `('data', None)`. Check out the [JAX shard map collective guide](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#example-2-psum-scatter-the-result) and its examples if you want to learn mathematically how it happens at a low-level." ] }, { @@ -563,23 +567,23 @@ } ], "source": [ - "# In data parallelism, the first dimension (batch) will be sharded on `data` axis.\n", + "# In data parallelism, the first dimension (batch) will be sharded on the `data` axis.\n", "data_sharding = NamedSharding(mesh, PartitionSpec('data', None))\n", "input = jax.device_put(jnp.ones((8, 1024)), data_sharding)\n", "\n", "with mesh:\n", " output = sharded_model(input)\n", "print(output.shape)\n", - "jax.debug.visualize_array_sharding(output) # Also sharded as ('data', None)" + "jax.debug.visualize_array_sharding(output) # Also sharded as `('data', None)`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Now the rest of the training loop is pretty conventional - almost the same as the example in [NNX Basics](https://flax.readthedocs.io/en/latest/nnx_basics.html#transforms), except that the inputs and labels are also explicitly sharded.\n", - "\n", - "[`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) will adjust and automatically choose the best layout based on how its inputs are already sharded, so try out different shardings for your own model and inputs." + "Now the rest of the training loop is pretty conventional - it is almost the same as the example in [Flax NNX Basics](https://flax.readthedocs.io/en/latest/nnx_basics.html#transforms):\n", + "- Except that the inputs and labels are also explicitly sharded.\n", + "- [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) will adjust and automatically choose the best layout based on how its inputs are already sharded, so try out different shardings for your own model and inputs." ] }, { @@ -628,7 +632,7 @@ "source": [ "## Profiling\n", "\n", - "If you are running on a TPU pod or a pod slice, you can create a custom `block_all()` utility function, as defined below, to measure the performance:" + "If you are using a Google TPU pod or a pod slice, you can create a custom `block_all()` utility function, as defined below, to measure the performance:" ] }, { @@ -664,7 +668,7 @@ "\n", "JAX's [automatic](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) [SPMD]((https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD)) encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you have the option to annotate with more descriptive axis names (not just device mesh axis names like `'data'` and `'model'`), as long as you provide a mapping from your alias to the device mesh axes.\n", "\n", - "You can provide the mapping along with the annotation as another metadata of the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), or overwrite it at top-level. Check out the `LogicalDotReluDot` example below." + "You can provide the mapping along with the annotation as another metadata of the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), or overwrite it at top-level. Check out the `LogicalDotReluDot()` example below." ] }, { @@ -709,7 +713,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "If you didn't provide all `sharding_rule` annotations in model definition, you can write a few lines to add it to [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) of the model, before the call of [`nnx.get_partition_spec`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) or [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding)." + "If you didn't provide all `sharding_rule` annotations in the model definition, you can write a few lines to add it to Flax’s [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) of the model, before the call of [`nnx.get_partition_spec`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) or [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding)." ] }, { @@ -830,9 +834,9 @@ "\n", " * For a simpler model, this can save you a few extra lines of code of converting the logical naming back to the device naming.\n", "\n", - " * Shardings of intermediate *activation* values can only be done via [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) and device mesh axis. So if you want super fine-grained control over your model's sharding, directly using device mesh axis names everywhere might be less confusing.\n", + " * Shardings of intermediate *activation* values can only be done via [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) and device mesh axis. Therefore, if you want super fine-grained control over your model's sharding, directly using device mesh axis names everywhere might be less confusing.\n", "\n", - "* **Logical naming**: Helpful if you want to experiment around and find the most optimal partition layout for your *model weights*." + "* **Logical naming**: This is helpful if you want to experiment around and find the most optimal partition layout for your *model weights*." ] } ], diff --git a/docs_nnx/guides/flax_gspmd.md b/docs_nnx/guides/flax_gspmd.md index 424520f24b..50441f941a 100644 --- a/docs_nnx/guides/flax_gspmd.md +++ b/docs_nnx/guides/flax_gspmd.md @@ -10,15 +10,17 @@ jupytext: # Scale up on multiple devices -This guide demonstrates how to scale up [Flax NNX `Module`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) on multiple devices and hosts, such as GPUs, Google TPUs, and CPUs, using [JAX just-in-time compilation machinery (`jax.jit`)](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html). +This guide demonstrates how to scale up [Flax NNX `Module`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) on [multiple devices and hosts](Multi-host and multi-process environments) - such as GPUs, Google TPUs, and CPUs - using the [JAX just-in-time compilation machinery (`jax.jit`)](https://jax.readthedocs.io/en/latest/jit-compilation.html) and [`flax.nnx.spmd`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html). +++ ## Overview -Flax relies on [JAX](https://jax.readthedocs.io) for numeric computations and scaling the computations up across multiple devices, such as GPU and TPUs. At the core of scaling up is the [JAX just-in-time compiler `jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html). Throughout this guide, you will be using Flax’s own [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) which wraps around [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and works more conveniently with Flax NNX `Module`s. +Flax relies on [JAX](https://jax.readthedocs.io) for numeric computations and scaling the computations up across multiple devices, such as GPU and Google TPUs. At the core of scaling up is the [JAX just-in-time (`jax.jit`) compiler `jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html). Throughout this guide, you will be using Flax’s own [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) transform, which wraps around [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and works more conveniently with Flax NNX `Module`s. -JAX compilation follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm. This means you write Python code as if it runs only on one device, and [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) will automatically compile and run it on multiple devices. +> **Note:** To learn more about Flax’s transformations, such as `nnx.jit` and `nnx.vmap`, go to [Why Flax NNX? - Transforms](https://flax.readthedocs.io/en/latest/why.html#transforms), [Transformations](https://flax.readthedocs.io/en/latest/guides/transforms.html), and [Flax NNX vs JAX Transformations](https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html). + +JAX compilation follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm. This means you write Python code as if it runs only on one device, and [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) will [automatically compile](https://jax.readthedocs.io/en/latest/jit-compilation.html#jit-compilation) and [run it](https://jax.readthedocs.io/en/latest/sharded-computation.html) on [multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html). To ensure the compilation performance, you often need to instruct JAX how your model's variables need to be sharded across devices. This is where Flax NNX's Sharding Metadata API - [`flax.nnx.spmd`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) - comes in. It helps you annotate your model variables with this information. @@ -79,9 +81,9 @@ print(mesh) ## Define a model with specified sharding -Next, create an example layer called `DotReluDot` that subclasses Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module). This layer carries out two dot product multiplication upon the input `x`, and uses the `jax.nn.relu` (ReLU) activation function in-between. - -To annotate a model variable with their ideal sharding, you can use [`flax.nnx.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) to wrap over its initializer function. Essentially, this calls [`flax.nnx.with_metadata`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.with_metadata) which adds a `.sharding` attribute field to the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable). +Next, create an example layer called `DotReluDot` that subclasses Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module). +- This layer carries out two dot product multiplications upon the input `x`, and uses the `jax.nn.relu` (ReLU) activation function in-between. +- To annotate a model variable with their ideal sharding, you can use [`flax.nnx.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) to wrap over its initializer function. Essentially, this calls [`flax.nnx.with_metadata`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.with_metadata) which adds a `.sharding` attribute field to the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable). > **Note:** This annotation will be [preserved and adjusted accordingly across lifted transformations in Flax NNX](https://flax.readthedocs.io/en/latest/guides/transforms.html#axes-metadata). This means if you use sharding annotations along with any transform that modifies axes (like [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html), [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html)), you need to provide sharding of that additional axis via the `transform_metadata` arg. Check out the [Flax NNX transformations (transforms) guide](https://flax.readthedocs.io/en/latest/guides/transforms.html) to learn more. @@ -130,12 +132,12 @@ JAX's [Distributed arrays and automatic parallelization](https://jax.readthedocs ## Initialize a sharded model -Now, you have annotations attached to the [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), but the actual weights haven't been sharded yet. If you just go ahead and create this model, all JAX arrays are still stuck in device `0`. In practice, you'd want to avoid this, because a large model will "OOM" (will cause the device to run out of memory) in this situation, while all the other devices are not utilized. +Now, you have annotations attached to the Flax [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), but the actual weights have not been sharded yet. If you just go ahead and create this model, all [`jax.Array`s](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) are still stuck in device `0`. In practice, you'd want to avoid this, because a large model will "OOM" (will cause the device to run out of memory) in this situation, while all the other devices are not utilized. ```{code-cell} ipython3 unsharded_model = DotReluDot(1024, rngs=nnx.Rngs(0)) -# You have annotations sticked there, yay! +# You have annotations stuck there, yay! print(unsharded_model.dot1.kernel.sharding) # (None, 'model') print(unsharded_model.w2.sharding) # ('model', None) @@ -144,7 +146,7 @@ print(unsharded_model.dot1.kernel.value.sharding) # SingleDeviceSharding print(unsharded_model.w2.value.sharding) # SingleDeviceSharding ``` -Here, you should leverage JAX's compilation mechanism, via [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit), to create the sharded model. The key is to initialize a model and assign shardings upon the model state within a `jit`ted function: +Here, you should leverage JAX's compilation mechanism via Flax’s [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) to create the sharded model. The key is to initialize a model and assign shardings upon the model state within a `jit`ted function: 1. Use [`nnx.get_partition_spec`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) to strip out the `.sharding` annotations attached upon model variables. @@ -152,7 +154,7 @@ Here, you should leverage JAX's compilation mechanism, via [`nnx.jit`](https://f 1. Throw away the unsharded state and return the model based upon the sharded state. -1. Compile the whole function with `nnx.jit`, which allows the output to be a stateful NNX module. +1. Compile the whole function with `nnx.jit`, which allows the output to be a stateful Flax NNX `Module`. 1. Run it under a device mesh context so that JAX knows which devices to shard it to. @@ -184,7 +186,7 @@ assert sharded_model.w2.value.sharding.is_equivalent_to( ) ``` -You can view the sharding of any 1-D or 2-D array with `jax.debug.visualize_array_sharding`: +You can view the sharding of any 1-D or 2-D array with [`jax.debug.visualize_array_sharding`](https://jax.readthedocs.io/en/latest/_autosummary/jax.debug.visualize_array_sharding.html): ```{code-cell} ipython3 print("sharded_model.dot1.kernel (None, 'model') :") @@ -199,19 +201,19 @@ The key to shard a JAX array is to call [`jax.lax.with_sharding_constraint`](htt > **Note:** Both [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html) and [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) in the JAX documentation cover automatic parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html), and semi-automatic parallelization with `jax.jit` and [`jax.lax.with_sharding_constraint](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) in greater detail. -You may have noticed you also used [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) once in the model definition too, to constraint the sharding of an intermediate value. This is just to show that you can always use it orthogonally with the Flax NNX API, if you want to explicitly shard values that are not model variables. +You may have noticed you also used [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) once in the model definition to constraint the sharding of an intermediate value. This is just to show that you can always use it orthogonally with the Flax NNX API if you want to explicitly shard values that are not model variables. This brings a question: Why use the Flax NNX Annotation API then? Why not just add JAX sharding constraints inside the model definition? The most important reason is that you still need the explicit annotations to load a sharded model from an on-disk checkpoint. This is described in the next section. +++ -## Load sharded model from a checkpoint +## Load a sharded model from a checkpoint -Now you can initialize a sharded model without OOM, but what about loading it from a checkpoint on disk? JAX checkpointing libraries, such as [Orbax](https://orbax.readthedocs.io/en/latest/), usually support loading it sharded if a sharding pytree is given. +Now you learned how to initialize a sharded model without OOM, but what about loading it from a checkpoint on disk? JAX checkpointing libraries, such as [Orbax](https://orbax.readthedocs.io/en/latest/), usually support loading a model sharded if a sharding pytree is provided. -You can generate such as sharding pytree with [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding). To avoid any real memory allocation, use the [`nnx.eval_shape`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.eval_shape) transform to generate a model of abstract JAX arrays, and only use its `.sharding` annotations to obtain the sharding tree. +You can generate such a sharding pytree with Flax’s [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding). To avoid any real memory allocation, use the [`nnx.eval_shape`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.eval_shape) transform to generate a model of abstract JAX arrays, and only use its `.sharding` annotations to obtain the sharding tree. -Below is an example demonstration using Orbax's `StandardCheckpointer` API. Check out [Orbax website](https://orbax.readthedocs.io/en/latest/) to learn their latest updates and recommended APIs. +Below is an example that demonstrates using Orbax's `StandardCheckpointer` API. (Go to the [Orbax documentation site](https://orbax.readthedocs.io/en/latest/) to learn about their latest and most recommended APIs.) ```{code-cell} ipython3 import orbax.checkpoint as ocp @@ -239,26 +241,28 @@ jax.debug.visualize_array_sharding(loaded_sharded.w2.value) ## Compile the training loop -Now, from initialization or from checkpoint, you have a sharded model. To carry out the compiled, scaled up training, you need to shard the inputs as well. In this data parallelism example, the training data has its batch dimension sharded across the `data` device axis, so you should put your data in sharding `('data', None)`. You can use [`jax.device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html#jax.device_put) for this. +Now, after either initialization or loading the checkpoint, you have a sharded model. To carry out the compiled scaled up training, you need to shard the inputs as well. -Note that with the correct sharding for all inputs, the output will be sharded in the most natural way even without JIT compilation. In the example below even without [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) on the output `y`, it was still sharded as `('data', None)`. +- In the data parallelism example, the training data has its batch dimension sharded across the `data` device axis, so you should put your data in sharding `('data', None)`. You can use [`jax.device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html#jax.device_put) for this. +- Note that with the correct sharding for all inputs, the output will be sharded in the most natural way even without `jit` compilation. +- In the example below, even without [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) on the output `y`, it was still sharded as `('data', None)`. -> If you are interested in why: The second matmul of `DotReluDot.__call__` has two inputs of sharding `('data', 'model')` and `('model', None)`, in which both inputs' contraction axis are `model`. So a reduce-scatter matmul happened and will naturally shard the output as `('data', None)`. Check out the [JAX shard map collective guide](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#example-2-psum-scatter-the-result) and its examples if you want to learn mathematically how it happens at a low level. +> If you are interested in why: The second matmul of `DotReluDot.__call__` has two inputs of sharding `('data', 'model')` and `('model', None)`, in which both inputs' contraction axis are `model`. So a reduce-scatter matmul happened and will naturally shard the output as `('data', None)`. Check out the [JAX shard map collective guide](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#example-2-psum-scatter-the-result) and its examples if you want to learn mathematically how it happens at a low-level. ```{code-cell} ipython3 -# In data parallelism, the first dimension (batch) will be sharded on `data` axis. +# In data parallelism, the first dimension (batch) will be sharded on the `data` axis. data_sharding = NamedSharding(mesh, PartitionSpec('data', None)) input = jax.device_put(jnp.ones((8, 1024)), data_sharding) with mesh: output = sharded_model(input) print(output.shape) -jax.debug.visualize_array_sharding(output) # Also sharded as ('data', None) +jax.debug.visualize_array_sharding(output) # Also sharded as `('data', None)`. ``` -Now the rest of the training loop is pretty conventional - almost the same as the example in [NNX Basics](https://flax.readthedocs.io/en/latest/nnx_basics.html#transforms), except that the inputs and labels are also explicitly sharded. - -[`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) will adjust and automatically choose the best layout based on how its inputs are already sharded, so try out different shardings for your own model and inputs. +Now the rest of the training loop is pretty conventional - it is almost the same as the example in [Flax NNX Basics](https://flax.readthedocs.io/en/latest/nnx_basics.html#transforms): +- Except that the inputs and labels are also explicitly sharded. +- [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) will adjust and automatically choose the best layout based on how its inputs are already sharded, so try out different shardings for your own model and inputs. ```{code-cell} ipython3 optimizer = nnx.Optimizer(sharded_model, optax.adam(1e-3)) # reference sharing @@ -285,7 +289,7 @@ with mesh: ## Profiling -If you are running on a TPU pod or a pod slice, you can create a custom `block_all()` utility function, as defined below, to measure the performance: +If you are using a Google TPU pod or a pod slice, you can create a custom `block_all()` utility function, as defined below, to measure the performance: ```{code-cell} ipython3 %%timeit @@ -302,7 +306,7 @@ with mesh: JAX's [automatic](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) [SPMD]((https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD)) encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you have the option to annotate with more descriptive axis names (not just device mesh axis names like `'data'` and `'model'`), as long as you provide a mapping from your alias to the device mesh axes. -You can provide the mapping along with the annotation as another metadata of the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), or overwrite it at top-level. Check out the `LogicalDotReluDot` example below. +You can provide the mapping along with the annotation as another metadata of the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), or overwrite it at top-level. Check out the `LogicalDotReluDot()` example below. ```{code-cell} ipython3 # The mapping from alias annotation to the device mesh. @@ -337,7 +341,7 @@ class LogicalDotReluDot(nnx.Module): return z ``` -If you didn't provide all `sharding_rule` annotations in model definition, you can write a few lines to add it to [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) of the model, before the call of [`nnx.get_partition_spec`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) or [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding). +If you didn't provide all `sharding_rule` annotations in the model definition, you can write a few lines to add it to Flax’s [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) of the model, before the call of [`nnx.get_partition_spec`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_partition_spec) or [`nnx.get_named_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.get_named_sharding). ```{code-cell} ipython3 def add_sharding_rule(vs: nnx.VariableState) -> nnx.VariableState: @@ -384,6 +388,6 @@ Choosing when to use a device or logical axis depends on how much you want to co * For a simpler model, this can save you a few extra lines of code of converting the logical naming back to the device naming. - * Shardings of intermediate *activation* values can only be done via [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) and device mesh axis. So if you want super fine-grained control over your model's sharding, directly using device mesh axis names everywhere might be less confusing. + * Shardings of intermediate *activation* values can only be done via [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) and device mesh axis. Therefore, if you want super fine-grained control over your model's sharding, directly using device mesh axis names everywhere might be less confusing. -* **Logical naming**: Helpful if you want to experiment around and find the most optimal partition layout for your *model weights*. +* **Logical naming**: This is helpful if you want to experiment around and find the most optimal partition layout for your *model weights*.