Skip to content

Commit

Permalink
Merge pull request #4421 from google:nnx-fix-transforms-guide
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704863491
  • Loading branch information
Flax Authors committed Dec 10, 2024
2 parents 47122f2 + 10f3460 commit f09d105
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions docs_nnx/guides/transforms.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@
"source": [
"### Consistent aliasing\n",
"\n",
"The main issue with allowing for reference semantics in transforms is that references can be shared across inputs and outputs. This can be problematic if it is not taken care of because it would lead to ill-defined or inconsistent behavior. In the example below you have a single `Weights` `nnx.Module` - `m` ` whose reference appears in multiple places in `arg1` and `arg2`. The problem here is that you also specify that you want to vectorize `arg1` in axis `0` and `arg2` in axis `1`. This would be fine in JAX because of referential transparency of pytrees. But this would be problematic in Flax NNX because you are trying to vectorize `m` in two different ways. Flax NNX will enforce consistency by raising an error."
"The main issue with allowing for reference semantics in transforms is that references can be shared across inputs and outputs. This can be problematic if it is not taken care of because it would lead to ill-defined or inconsistent behavior. In the example below you have a single `Weights` Module `m` whose reference appears in multiple places in `arg1` and `arg2`. The problem here is that you also specify that you want to vectorize `arg1` in axis `0` and `arg2` in axis `1`. This would be fine in JAX because of referential transparency of pytrees. But this would be problematic in Flax NNX because you are trying to vectorize `m` in two different ways. Flax NNX will enforce consistency by raising an error."
]
},
{
Expand Down Expand Up @@ -863,7 +863,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion docs_nnx/guides/transforms.md
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ To solve this issue pass all Module as arguments to the functions being transfor

### Consistent aliasing

The main issue with allowing for reference semantics in transforms is that references can be shared across inputs and outputs. This can be problematic if it is not taken care of because it would lead to ill-defined or inconsistent behavior. In the example below you have a single `Weights` `nnx.Module` - `m` ` whose reference appears in multiple places in `arg1` and `arg2`. The problem here is that you also specify that you want to vectorize `arg1` in axis `0` and `arg2` in axis `1`. This would be fine in JAX because of referential transparency of pytrees. But this would be problematic in Flax NNX because you are trying to vectorize `m` in two different ways. Flax NNX will enforce consistency by raising an error.
The main issue with allowing for reference semantics in transforms is that references can be shared across inputs and outputs. This can be problematic if it is not taken care of because it would lead to ill-defined or inconsistent behavior. In the example below you have a single `Weights` Module `m` whose reference appears in multiple places in `arg1` and `arg2`. The problem here is that you also specify that you want to vectorize `arg1` in axis `0` and `arg2` in axis `1`. This would be fine in JAX because of referential transparency of pytrees. But this would be problematic in Flax NNX because you are trying to vectorize `m` in two different ways. Flax NNX will enforce consistency by raising an error.

```{code-cell} ipython3
class Weights(nnx.Module):
Expand Down

0 comments on commit f09d105

Please sign in to comment.