From a8e2ab1f2c7d8c8ee491a301e9844643b56edd8c Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Fri, 25 Oct 2024 10:36:06 +0100 Subject: [PATCH] [nnx] improve transforms guide --- docs_nnx/guides/transforms.ipynb | 100 ++++++++++++++++++++++--------- docs_nnx/guides/transforms.md | 36 ++++++++++- 2 files changed, 106 insertions(+), 30 deletions(-) diff --git a/docs_nnx/guides/transforms.ipynb b/docs_nnx/guides/transforms.ipynb index 28287fe7ec..4ad1e48af3 100644 --- a/docs_nnx/guides/transforms.ipynb +++ b/docs_nnx/guides/transforms.ipynb @@ -66,7 +66,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -78,7 +78,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -131,7 +131,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -143,7 +143,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -191,7 +191,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -203,7 +203,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -331,7 +331,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -343,7 +343,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -500,7 +500,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -512,7 +512,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -578,7 +578,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -590,7 +590,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -627,19 +627,69 @@ "nnx.display(weights)" ] }, + { + "cell_type": "markdown", + "id": "60eee7f9", + "metadata": {}, + "source": [ + "## Rules and limitations\n", + "In this section we will cover some rules and limitations apply when using Modules inside transformations.\n", + "\n", + "### Mutable Module cannot be passed by closure\n", + "\n", + "While Python allows for passing objects as closures to functions, this is generally not supported by Flax NNX transforms. The reason is that because Modules are mutable it is very easy to capture tracer into a Module created outside of the transform, this is silent error in JAX. To avoid this, Flax NNX checks that the Modules and Variables being mutated are passed as arguments to the transformed function.\n", + "\n", + "For example, if we a have stateful Module such as `Counter` that increments a counter every time it is called, and we try to pass it as a closure to a function decorated with `nnx.jit`, we would be leaking the tracer. However Flax NNX will raise an error instead to prevent this:" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "f8b95c03", + "metadata": {}, + "outputs": [], + "source": [ + "class Counter(nnx.Module):\n", + " def __init__(self):\n", + " self.count = nnx.Param(jnp.array(0))\n", + "\n", + " def increment(self):\n", + " self.count += jnp.array(1)\n", + "\n", + "counter = Counter()\n", + "\n", + "@nnx.jit\n", + "def f(x):\n", + " counter.increment()\n", + " return 2 * x\n", + "\n", + "try:\n", + " y = f(3)\n", + "except Exception as e:\n", + " print(e)" + ] + }, + { + "cell_type": "markdown", + "id": "6f37e23b", + "metadata": {}, + "source": [ + "To solve this issue pass all Module as arguments to the functions being transformed. In this case `f` should accept `counter` as an argument." + ] + }, { "cell_type": "markdown", "id": "75edf7a8", "metadata": {}, "source": [ - "## Consistent aliasing\n", + "### Consistent aliasing\n", "\n", "The main issue with allowing for reference semantics in transforms is that references can be shared across inputs and outputs. This can be problematic if it is not taken care of because it would lead to ill-defined or inconsistent behavior. In the example below you have a single `Weights` `nnx.Module` - `m` ` whose reference appears in multiple places in `arg1` and `arg2`. The problem here is that you also specify that you want to vectorize `arg1` in axis `0` and `arg2` in axis `1`. This would be fine in JAX because of referential transparency of pytrees. But this would be problematic in Flax NNX because you are trying to vectorize `m` in two different ways. Flax NNX will enforce consistency by raising an error." ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "id": "46b1cc25", "metadata": {}, "outputs": [ @@ -648,14 +698,10 @@ "output_type": "stream", "text": [ "Inconsistent aliasing detected. The following nodes have different prefixes:\n", - "Node: \n", + "Node: \n", " param: 0\n", " param: 0\n", - " param: 1\n", - "Node: \n", - " : 0\n", - " : 0\n", - " : 1\n" + " param: 1\n" ] } ], @@ -688,7 +734,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "id": "cca9cf31", "metadata": {}, "outputs": [ @@ -697,14 +743,10 @@ "output_type": "stream", "text": [ "Inconsistent aliasing detected. The following nodes have different prefixes:\n", - "Node: \n", + "Node: \n", " param: 0\n", " param: 0\n", - " param: 1\n", - "Node: \n", - " : 0\n", - " : 0\n", - " : 1\n" + " param: 1\n" ] } ], @@ -737,7 +779,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "id": "d85c772c", "metadata": {}, "outputs": [ @@ -781,7 +823,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "id": "358e51f7", "metadata": {}, "outputs": [ diff --git a/docs_nnx/guides/transforms.md b/docs_nnx/guides/transforms.md index 9df2c9abbf..0cd7046f3e 100644 --- a/docs_nnx/guides/transforms.md +++ b/docs_nnx/guides/transforms.md @@ -309,7 +309,41 @@ print(jnp.allclose(y1, y2)) nnx.display(weights) ``` -## Consistent aliasing +## Rules and limitations +In this section we will cover some rules and limitations apply when using Modules inside transformations. + +### Mutable Module cannot be passed by closure + +While Python allows for passing objects as closures to functions, this is generally not supported by Flax NNX transforms. The reason is that because Modules are mutable it is very easy to capture tracer into a Module created outside of the transform, this is silent error in JAX. To avoid this, Flax NNX checks that the Modules and Variables being mutated are passed as arguments to the transformed function. + +For example, if we a have stateful Module such as `Counter` that increments a counter every time it is called, and we try to pass it as a closure to a function decorated with `nnx.jit`, we would be leaking the tracer. However Flax NNX will raise an error instead to prevent this: + +```{code-cell} ipython3 +class Counter(nnx.Module): + def __init__(self): + self.count = nnx.Param(jnp.array(0)) + + def increment(self): + self.count += jnp.array(1) + +counter = Counter() + +@nnx.jit +def f(x): + counter.increment() + return 2 * x + +try: + y = f(3) +except Exception as e: + print(e) +``` + +To solve this issue pass all Module as arguments to the functions being transformed. In this case `f` should accept `counter` as an argument. + ++++ + +### Consistent aliasing The main issue with allowing for reference semantics in transforms is that references can be shared across inputs and outputs. This can be problematic if it is not taken care of because it would lead to ill-defined or inconsistent behavior. In the example below you have a single `Weights` `nnx.Module` - `m` ` whose reference appears in multiple places in `arg1` and `arg2`. The problem here is that you also specify that you want to vectorize `arg1` in axis `0` and `arg2` in axis `1`. This would be fine in JAX because of referential transparency of pytrees. But this would be problematic in Flax NNX because you are trying to vectorize `m` in two different ways. Flax NNX will enforce consistency by raising an error.