Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] Add support for python container types #3486

Merged
merged 1 commit into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 40 additions & 39 deletions flax/experimental/nnx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ In this example `nnx.Rngs(0)` create a `random.key` for `params` with seed `0`,
The [Functional API](#functional-api) converts an NNX Module python semantics into pure pytree object with functional semantics. It is the recommended way to use NNX as it provides tight control over the state, allows you to use regular JAX transformations, and it minimizes overhead. In this example the model will be trained using Stochastic Gradient Descent (SGD).

```python
params, counts, moduledef = model.split(nnx.Param, Count)
params, counts, static = model.split(nnx.Param, Count)

@jax.jit
def train_step(params, counts, x, y):
def loss_fn(params):
model = moduledef.merge(params, counts)
model = static.merge(params, counts)
y_pred = model(x)
loss = jax.numpy.mean((y_pred - y) ** 2)
return loss, updates.extract(Count)
Expand All @@ -84,7 +84,7 @@ def train_step(params, counts, x, y):

# execute the training step
params, counts = train_step(params, counts, x, y)
model = moduledef.merge(params, counts)
model = static.merge(params, counts)
assert model.count == 2
```

Expand Down Expand Up @@ -141,51 +141,52 @@ NNX Modules are normal python classes, they obey regular python semantics such a

```python
class Foo(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
# node attributes
self.variable = nnx.Param(jnp.array(1))
self.implicit_param = jnp.array(3)
self.submodule = nnx.Linear(2, 4, rngs=rngs)
# static attributes
self.int = 1
self.float = 2.0
self.str = "hello"
self.list = [1, 2, 3]
def __init__(self, rngs: nnx.Rngs):
# node attributes
self.variable = nnx.Param(jnp.array(1))
self.submodule = nnx.Linear(2, 3, rngs=rngs)
self.container = [4, nnx.Linear(5, 6, rngs=rngs), 7]
# static attributes
self.int = 8
self.float = 9.0
self.str = 'hello'

model = Foo(din=12, dout=2, rngs=nnx.Rngs(0))
```
As shown above, python container types such as `list`, `tuple`, and `dict` are treated as static attributes, if similar functionality is needed, NNX provides the `Sequence` and `Dict` Modules.
As shown above, python container types such as `list`, `tuple`, and `dict` are treated as node attributes,
this means you can naturally have e.g. `list`s or `dict`s of Modules.

### Functional API

NNX Modules are not pytrees so they cannot be passed to JAX transformations. In order to interact with JAX, a Module must be partitioned into a `State` and `ModuleDef` objects. The `State` object is a flat dictionary-like pytree structure that contains all the deduplicated node attributes, and the `ModuleDef` contains the static attributes and structural information needed to reconstruct the Module.
NNX Modules are not pytrees so they cannot be passed to JAX transformations. In order to interact with JAX, a Module must be partitioned into a `State` and `GraphDef` objects. The `State` object is a flat dictionary-like pytree structure that contains all the deduplicated node attributes, and the `GraphDef` contains the static attributes and structural information needed to reconstruct the Module.

```python
state, moduledef = model.split()
state, static = model.split()
```
```
State({
'implicit_param',: Param(value=Array(3)),
'submodule/bias': Param(value=Array(...)),
'submodule/kernel': Param(value=Array(...)),
'variable': Param(value=Array(1))
state = State({
'variable': Array(1, dtype=int32, weak_type=True),
'submodule/kernel': Array(..., dtype=float32),
'submodule/bias': Array(..., dtype=float32),
'container/1/kernel': Array(..., dtype=float32),
'container/1/bias': Array(..., dtype=float32)
})
```

`State` and `ModuleDef` are pytrees so they can be passed to JAX transformations. More over, `ModuleDef` provides 2 very important methods: `merge` and `apply`. The `merge` method can be used to create a new `Module` from a `State` object:
`State` and `GraphDef` are pytrees so they can be passed to JAX transformations. More over, `GraphDef` provides 2 very important methods: `merge` and `apply`. The `merge` method can be used to create a new `Module` from a `State` object:

```python
model = moduledef.merge(state)
model = static.merge(state)
```
This can be use to e.g. recreate a module inside a JAX transformation. The `apply` provides a functional interface to the module, it can be used call any method or submodule and get the output and the updated state:

```python
# run __call__
y, (state, moduledef) = moduledef.apply(state)(x)
y, (state, static) = static.apply(state)(x)
# run some_method
y, (state, moduledef) = moduledef.apply(state).some_method(x)
y, (state, static) = static.apply(state).some_method(x)
# run submodule
y, (state, moduledef) = moduledef.apply(state).submodule(x)
y, (state, static) = static.apply(state).submodule(x)
```

`apply` can call any nested method or submodule as long as it can be accessed via the `.` or `[]` operators.
Expand All @@ -196,14 +197,14 @@ In NNX you can filter based on any node type, most commonly you will want to fil
Here are various examples of how you can use the `split` method to split a module into multiple substates:

```python
# split the module into the state with all the nodes and the moduledef
state, moduledef = model.split()
# split the module into the state with all the nodes and the static information
state, static = model.split()
# verify that the state contains only params, else raise an error
params, moduledef = model.split(nnx.Param)
params, static = model.split(nnx.Param)
# split the state into params and batch_stats, verify no nodes are left
params, batch_stats, moduledef = model.split(nnx.Param, nnx.BatchStat)
params, batch_stats, static = model.split(nnx.Param, nnx.BatchStat)
# if there are any nodes left, use the `...` filter to capture them
params, batch_stats, rest, moduledef = model.split(nnx.Param, nnx.BatchStat, ...)
params, batch_stats, rest, static = model.split(nnx.Param, nnx.BatchStat, ...)
# using `...` as the only filter is equivalent to not passing any filters
model.split(...) = model.split()
```
Expand All @@ -215,13 +216,13 @@ model.split(...) = model.split()
To reconstruct the module from a set of substates, you can use `merge` as usual but passing the substates as additional arguments:

```python
model = moduledef.merge(params, batch_stats, rest)
model = static.merge(params, batch_stats, rest)
```

The same is true for `apply`.

```python
y, (state, moduledef) = moduledef.apply(params, batch_stats, rest)(x)
y, (state, static) = static.apply(params, batch_stats, rest)(x)
```

Note that `apply` will return a single `state` object, if you need to `split` the state you can use `State`'s own `split` method:
Expand Down Expand Up @@ -295,8 +296,8 @@ State({
If you use the functional API to call the module instead, the `Intermediate` nodes will be present in the output `state`. To retrieve the `Intermediate` nodes and optionally separate them from the output `state` you can use `State.split`:

```python
state, moduledef = model.split()
y, (state, moduledef) = moduledef.apply(state)(jnp.ones((8, 12)))
state, static = model.split()
y, (state, static) = static.apply(state)(jnp.ones((8, 12)))
# "pop" the intermediates from the state
intermediates, state = state.split(nnx.Intermediate, ...)
```
Expand Down Expand Up @@ -344,10 +345,10 @@ class ScanMLP(nnx.Module):
def __init__(self, dim: int, *, n_layers: int, rngs: nnx.Rngs):
params_key = jax.random.split(rngs.params(), n_layers)
self.n_layers = n_layers
state, moduledef = jax.vmap(
state, static = jax.vmap(
lambda key: Block(dim, rngs=nnx.Rngs(params=key)).split()
)(params_key)
self.layers = moduledef.merge(state)
self.layers = static.merge(state)

```
Note that we split the `params` key into `n_layers` keys so each layer has different parameters.
Expand All @@ -359,11 +360,11 @@ apply it to the input `x`, passing the sliced `dropout_key` as part of the `Rngs
```python
def __call__(self, x: jax.Array, *, train: bool, rngs: nnx.Rngs) -> jax.Array:
dropout_key = jax.random.split(rngs.dropout(), self.n_layers)
params, moduledef = self.layers.split(nnx.Param)
params, static = self.layers.split(nnx.Param)

def scan_fn(x: inputs):
params, dropout_key = inputs
module = moduledef.merge(params)
module = static.merge(params)
x = module(x, train=train, rngs=nnx.Rngs(dropout=dropout_key))
return x, module.extract(nnx.Param)

Expand Down
8 changes: 5 additions & 3 deletions flax/experimental/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,23 @@
from flax.linen.pooling import pool as pool

from .nnx import compatibility as compatibility
from .nnx import graph_utils
from .nnx.dataclasses import dataclass as dataclass
from .nnx.dataclasses import field as field
from .nnx.dataclasses import param_field as param_field
from .nnx.dataclasses import treenode_field as treenode_field
from .nnx.dataclasses import variable_field as variable_field
from .nnx.errors import TraceContextError as TraceContextError
from .nnx.filterlib import All as All
from .nnx.filterlib import Not as Not
from .nnx.flaglib import flags as flags
from .nnx.graph_utils import GraphDef as GraphDef
from .nnx.helpers import Dict as Dict
from .nnx.helpers import Sequence as Sequence
from .nnx.helpers import TrainState as TrainState
from .nnx.module import GraphDef as GraphDef
from .nnx.module import M as M
from .nnx.module import Module as Module
from .nnx.module import ModuleDef as ModuleDef
from .nnx.module import merge as merge
from .nnx.nn import initializers as initializers
from .nnx.nn.activations import celu as celu
Expand Down Expand Up @@ -64,8 +68,6 @@
from .nnx.nn.normalization import BatchNorm as BatchNorm
from .nnx.nn.normalization import LayerNorm as LayerNorm
from .nnx.nn.stochastic import Dropout as Dropout
from .nnx.filterlib import All as All
from .nnx.filterlib import Not as Not
from .nnx.pytreelib import Pytree as Pytree
from .nnx.pytreelib import TreeNode as TreeNode
from .nnx.rnglib import Rngs as Rngs
Expand Down
8 changes: 4 additions & 4 deletions flax/experimental/nnx/docs/quick_start.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@
"\n",
"Now that we have a working model, lets see how to train it with `jax.jit` using NNX's Functional API. The `Module.split` method allows you to convert a Module into pytrees with functional semantics, this allows you to integrate with JAX's functional APIs like `jax.jit` and `jax.grad`.\n",
"\n",
"In this next example we will use the `.split` method to split the model into a `params: State` and `moduledef: ModuleDef` objects. We pass the `\"params\"` filter to check that the Module's state only contain `Variables` with the `params` collection. Having `params` and `moduledef` its pretty easy to implement a jitted `train_step` much like you would in Flax or Haiku. `ModuleDef` exposes an `apply` method which accepts some `State` and creates a function that runs the Module's `__call__` method. This function then returns the output of the Module along with the updated state."
"In this next example we will use the `.split` method to split the model into a `params: State` and `static: GraphDef` objects. We pass the `\"params\"` filter to check that the Module's state only contain `Variables` with the `params` collection. Having `params` and `static` its pretty easy to implement a jitted `train_step` much like you would in Flax or Haiku. `GraphDef` exposes an `apply` method which accepts some `State` and creates a function that runs the Module's `__call__` method. This function then returns the output of the Module along with the updated state."
]
},
{
Expand All @@ -344,13 +344,13 @@
"metadata": {},
"outputs": [],
"source": [
"params, moduledef = model.split(\"params\")\n",
"params, static = model.split(\"params\")\n",
"\n",
"\n",
"@jax.jit\n",
"def train_step(params: nnx.State, x, y):\n",
" def loss_fn(params):\n",
" logits, _updates = moduledef.apply(params)(x)\n",
" logits, _updates = static.apply(params)(x)\n",
" return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n",
"\n",
" loss, grads = jax.value_and_grad(loss_fn)(params)\n",
Expand Down Expand Up @@ -420,7 +420,7 @@
"outputs": [],
"source": [
"state = nnx.TrainState(\n",
" apply_fn=moduledef.apply,\n",
" apply_fn=static.apply,\n",
" params=params,\n",
" tx=optax.adam(0.001),\n",
")\n",
Expand Down
Loading
Loading