Skip to content

Commit

Permalink
Merge pull request #4401 from 8bitmp3:update-performance-doc
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704889940
  • Loading branch information
Flax Authors committed Dec 11, 2024
2 parents a785bff + b69a46f commit 554b690
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 20 deletions.
35 changes: 25 additions & 10 deletions docs_nnx/guides/performance.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,31 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Performance Considerations\n",
"Currently `nnx.jit` traverses the object graph in pure Python, this is slow and adds overhead. To solve this in general we will be developing a Rust extension called `flaxlib` (see first steps in #4196) to speedup some of the traversal logic in `graph.py`, similar to how JAX solved the same issue with `jaxlib` for standard pytrees. However, there's two things to consider:\n",
"# Performance considerations\n",
"\n",
"* The overhead is only relevant for small models. See [Asynchronous dispatch](#asynchronous-dispatch).\n",
"* You can remove the overhead by using `jax.jit` + `nnx.split` / `nnx.merge` to stage out the traversal logic. See [Lowering the Python Overhead](#lowering-the-python-overhead).\n",
"Currently, Flax [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) traverses the object graph in pure Python, which is slow and adds overhead. This is why in order to solve this the Flax team will be developing a Rust extension called `flaxlib` to speed up some of the traversal logic in [`graph.py`](https://github.com/google/flax/blob/main/flax/nnx/graph.py). This will be similar to how the JAX team resolved a similar issue by introducing [`jaxlib`](https://jax.readthedocs.io/en/latest/installation.html#installation) for standard [JAX pytrees](https://jax.readthedocs.io/en/latest/key-concepts.html#pytrees) (refer to the first steps in [Flax PR #4196](https://github.com/google/flax/pull/4196)).\n",
"\n",
"However, there are two things to consider:\n",
"\n",
"* The overhead is only relevant for small models (refer to [Asynchronous dispatch](#asynchronous-dispatch).\n",
"* You can remove the overhead by using [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) + [`flax.nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) / [`flax.nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) to stage out the traversal logic (Refer to [Lowering the Python overhead](#lowering-the-python-overhead).\n",
"\n",
"\n",
"## Asynchronous dispatch\n",
"In [benchmarks/nnx_simple_training.py](https://github.com/google/flax/blob/main/benchmarks/nnx_simple_training.py) we are increasing the layer width (features per layer) and measuring the total training time for the same model trained both with `nnx.jit` and `jax.jit`. As you can see in the graph below, after a certain model size the time spent in the traversal is completely absorbed by async dispatch. This happens when Python is able to finish the current for loop step, and reach the next `train_step` and JAX is still not done with the previous `train_step`. \n",
"\n",
"In [benchmarks/nnx_simple_training.py](https://github.com/google/flax/blob/main/benchmarks/nnx_simple_training.py) we are increasing the layer width (features per layer) and measuring the total training time for the same model trained both with [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) and [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html).\n",
"\n",
"As demonstrated in the graph below, after a certain model size the time spent in the traversal is completely absorbed by async dispatch. This happens when Python is able to finish the current for loop step, and reach the next `train_step` and JAX is still not done with the previous `train_step`. \n",
"\n",
"![performance-graph](images/performance-graph.png)\n",
"\n",
"This means that you only need to worry about the `nnx.jit` overhead for small models. If you are working with a small model, check out the next section to see how you can remove the overhead.\n",
"\n",
"## Lowering the Python Overhead\n",
"To remove the python overhead you can use regular `jax.jit` in combination with `nnx.split` and `nnx.merge` to stage out the traversal logic. To learn how to do this, lets first create this simple model:"
"## Lowering the Python overhead\n",
"\n",
"To remove the Python overhead, you can use regular `jax.jit` in combination with `nnx.split` and `nnx.merge` to stage out the traversal logic.\n",
"\n",
"To learn how to do this, let’s first create the following simple `Model`:"
]
},
{
Expand Down Expand Up @@ -49,7 +58,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Lets say we have this `train_step` function that is using `nnx.jit` and takes in a `model`, `optimizer`, and `metrics`, all of which are Flax NNX objects:"
"Next, let’s create a `train_step()` function that uses [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit), taking in the `model`, `optimizer`, and `metrics`, all of which are Flax NNX objects:"
]
},
{
Expand Down Expand Up @@ -85,7 +94,11 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"To speed it up, before starting the training loop we can use `nnx.split` over the all the Flax NNX objects that are inputs to `train_step` to create a `graphdef` and `state` pytrees that are fast to traverse. Next we change `train_step` so accept `graphdef` and `state` and use `nnx.merge` and `nnx.split` at the beginning and end of `train_step` to switch back and forth between the objects and their pytree representations. Even though `nnx.split` and `nnx.merge` are slow it doesn't matter because they will only run once during tracing. With this in place, we can change the `train_step` function to use `jax.jit` instead of `nnx.jit`:"
"To speed this up, before starting the training loop we can use [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) over all the Flax NNX objects that are inputs to `train_step()` to create `graphdef` and `state` pytrees that are faster to traverse.\n",
"\n",
"Next, we change `train_step()` to accept `graphdef` and `state`, and use [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) and [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) at the beginning and the end of `train_step()` to switch back and forth between the objects and their pytree representations. And even though `nnx.split` and `nnx.merge` are slow, it doesn't matter because they will run only once during tracing.\n",
"\n",
"With this in place, we can change the `train_step()` function to use `jax.jit` instead of `nnx.jit`:"
]
},
{
Expand Down Expand Up @@ -133,7 +146,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Notice that we only do this for `jit`, you can still use other transforms like `nnx.value_and_grad` shown in the example since their overhead is already absorbed by the outer `jit`. Also, after the training loop is done (or whenever need) `nnx.update` can be used to update Flax NNX objects like `model`, `optimizer`, and `metrics` to a new `state`."
"Notice that we only do this for `jit`. You can still use other [Flax transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html#transformations) like [`nnx.value_and_grad`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.value_and_grad) shown in the above example since their overhead is already absorbed by the outer `jit`.\n",
"\n",
"And after the training loop is done (or whenever it is needed), we can use Flax [`nnx.update`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.update) to update Flax NNX objects like `model`, `optimizer`, and `metrics` to a new `state`."
]
}
],
Expand Down
35 changes: 25 additions & 10 deletions docs_nnx/guides/performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,31 @@ jupytext:
jupytext_version: 1.13.8
---

# Performance Considerations
Currently `nnx.jit` traverses the object graph in pure Python, this is slow and adds overhead. To solve this in general we will be developing a Rust extension called `flaxlib` (see first steps in #4196) to speedup some of the traversal logic in `graph.py`, similar to how JAX solved the same issue with `jaxlib` for standard pytrees. However, there's two things to consider:
# Performance considerations

* The overhead is only relevant for small models. See [Asynchronous dispatch](#asynchronous-dispatch).
* You can remove the overhead by using `jax.jit` + `nnx.split` / `nnx.merge` to stage out the traversal logic. See [Lowering the Python Overhead](#lowering-the-python-overhead).
Currently, Flax [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) traverses the object graph in pure Python, which is slow and adds overhead. This is why in order to solve this the Flax team will be developing a Rust extension called `flaxlib` to speed up some of the traversal logic in [`graph.py`](https://github.com/google/flax/blob/main/flax/nnx/graph.py). This will be similar to how the JAX team resolved a similar issue by introducing [`jaxlib`](https://jax.readthedocs.io/en/latest/installation.html#installation) for standard [JAX pytrees](https://jax.readthedocs.io/en/latest/key-concepts.html#pytrees) (refer to the first steps in [Flax PR #4196](https://github.com/google/flax/pull/4196)).

However, there are two things to consider:

* The overhead is only relevant for small models (refer to [Asynchronous dispatch](#asynchronous-dispatch).
* You can remove the overhead by using [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) + [`flax.nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) / [`flax.nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) to stage out the traversal logic (Refer to [Lowering the Python overhead](#lowering-the-python-overhead).


## Asynchronous dispatch
In [benchmarks/nnx_simple_training.py](https://github.com/google/flax/blob/main/benchmarks/nnx_simple_training.py) we are increasing the layer width (features per layer) and measuring the total training time for the same model trained both with `nnx.jit` and `jax.jit`. As you can see in the graph below, after a certain model size the time spent in the traversal is completely absorbed by async dispatch. This happens when Python is able to finish the current for loop step, and reach the next `train_step` and JAX is still not done with the previous `train_step`.

In [benchmarks/nnx_simple_training.py](https://github.com/google/flax/blob/main/benchmarks/nnx_simple_training.py) we are increasing the layer width (features per layer) and measuring the total training time for the same model trained both with [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) and [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html).

As demonstrated in the graph below, after a certain model size the time spent in the traversal is completely absorbed by async dispatch. This happens when Python is able to finish the current for loop step, and reach the next `train_step` and JAX is still not done with the previous `train_step`.

![performance-graph](images/performance-graph.png)

This means that you only need to worry about the `nnx.jit` overhead for small models. If you are working with a small model, check out the next section to see how you can remove the overhead.

## Lowering the Python Overhead
To remove the python overhead you can use regular `jax.jit` in combination with `nnx.split` and `nnx.merge` to stage out the traversal logic. To learn how to do this, lets first create this simple model:
## Lowering the Python overhead

To remove the Python overhead, you can use regular `jax.jit` in combination with `nnx.split` and `nnx.merge` to stage out the traversal logic.

To learn how to do this, let’s first create the following simple `Model`:

```{code-cell}
from flax import nnx
Expand All @@ -43,7 +52,7 @@ class Model(nnx.Module):
return self.linear_out(x)
```

Lets say we have this `train_step` function that is using `nnx.jit` and takes in a `model`, `optimizer`, and `metrics`, all of which are Flax NNX objects:
Next, let’s create a `train_step()` function that uses [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit), taking in the `model`, `optimizer`, and `metrics`, all of which are Flax NNX objects:

```{code-cell}
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization
Expand All @@ -69,7 +78,11 @@ for _ in range(10):
loss = train_step(model, optimizer, metrics, x, y)
```

To speed it up, before starting the training loop we can use `nnx.split` over the all the Flax NNX objects that are inputs to `train_step` to create a `graphdef` and `state` pytrees that are fast to traverse. Next we change `train_step` so accept `graphdef` and `state` and use `nnx.merge` and `nnx.split` at the beginning and end of `train_step` to switch back and forth between the objects and their pytree representations. Even though `nnx.split` and `nnx.merge` are slow it doesn't matter because they will only run once during tracing. With this in place, we can change the `train_step` function to use `jax.jit` instead of `nnx.jit`:
To speed this up, before starting the training loop we can use [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) over all the Flax NNX objects that are inputs to `train_step()` to create `graphdef` and `state` pytrees that are faster to traverse.

Next, we change `train_step()` to accept `graphdef` and `state`, and use [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) and [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) at the beginning and the end of `train_step()` to switch back and forth between the objects and their pytree representations. And even though `nnx.split` and `nnx.merge` are slow, it doesn't matter because they will run only once during tracing.

With this in place, we can change the `train_step()` function to use `jax.jit` instead of `nnx.jit`:

```{code-cell}
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization
Expand Down Expand Up @@ -107,4 +120,6 @@ for _ in range(10):
nnx.update((model, optimizer, metrics), state)
```

Notice that we only do this for `jit`, you can still use other transforms like `nnx.value_and_grad` shown in the example since their overhead is already absorbed by the outer `jit`. Also, after the training loop is done (or whenever need) `nnx.update` can be used to update Flax NNX objects like `model`, `optimizer`, and `metrics` to a new `state`.
Notice that we only do this for `jit`. You can still use other [Flax transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html#transformations) like [`nnx.value_and_grad`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.value_and_grad) shown in the above example since their overhead is already absorbed by the outer `jit`.

And after the training loop is done (or whenever it is needed), we can use Flax [`nnx.update`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.update) to update Flax NNX objects like `model`, `optimizer`, and `metrics` to a new `state`.

0 comments on commit 554b690

Please sign in to comment.