Skip to content

Commit

Permalink
replace Context => Rngs
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Oct 28, 2023
1 parent db0638d commit d5d34f7
Show file tree
Hide file tree
Showing 39 changed files with 2,696 additions and 2,718 deletions.
20 changes: 10 additions & 10 deletions docs/experimental/nnx/[REMOVE]why_nnx_examples/ex1_nnx.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from flax.experimental import nnx
import jax
import jax.numpy as jnp

from flax.experimental import nnx


class Count(nnx.Variable):
pass


class Linear(nnx.Module):

def __init__(self, din: int, dout: int, *, ctx: nnx.Context):
def __init__(self, din: int, dout: int, *, ctx: nnx.Rngs):
self.din = din
self.dout = dout
key = ctx.make_rng("params")
key = ctx.make_rng('params')
self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
self.count = Count(0) # track the number of calls
Expand All @@ -22,12 +22,12 @@ def __call__(self, x) -> jax.Array:
return x @ self.w + self.b


model = Linear(din=5, dout=2, ctx=nnx.context(0))
model = Linear(din=5, dout=2, ctx=nnx.Rngs(0))
x = jnp.ones((1, 5))
y = model(x)

print("\n NNX")
print(f"{model.count = }")
print(f"{model.w = }")
print(f"{model.b = }")
print(f"{model = }")
print('\n NNX')
print(f'{model.count = }')
print(f'{model.w = }')
print(f'{model.b = }')
print(f'{model = }')
13 changes: 6 additions & 7 deletions docs/experimental/nnx/[REMOVE]why_nnx_examples/ex2_nnx.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from ex1_nnx import Linear, nnx, jnp, jax

# ----------------
# example begin
# ----------------
import numpy as np
from ex1_nnx import Linear, jax, jnp, nnx

X = np.random.uniform(size=(1000, 1))
Y = 0.8 * X + 0.4 + np.random.normal(scale=0.1, size=(1000, 1))

model = Linear(1, 1, ctx=nnx.context(0))
model = Linear(1, 1, ctx=nnx.Rngs(0))


@nnx.jit
Expand Down Expand Up @@ -38,8 +37,8 @@ def eval_step(model: Linear, x, y):
if step % 100 == 0:
loss = eval_step(model, X, Y)

print(f"Step {step}: loss={loss:.4f}")
print(f'Step {step}: loss={loss:.4f}')

print(f"\n{model.w = }")
print(f"{model.b = }")
print(f"{model.count = }")
print(f'\n{model.w = }')
print(f'{model.b = }')
print(f'{model.count = }')
27 changes: 13 additions & 14 deletions docs/experimental/nnx/[REMOVE]why_nnx_examples/ex3_nnx.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from flax.experimental import nnx
import jax

from flax.experimental import nnx

class Block(nnx.Module):

def __init__(self, dim: int, *, ctx: nnx.Context):
class Block(nnx.Module):
def __init__(self, dim: int, *, ctx: nnx.Rngs):
self.linear = nnx.Linear(dim, dim, ctx=ctx)
self.bn = nnx.BatchNorm(dim, ctx=ctx)
self.dropout = nnx.Dropout(0.5)

def __call__(self, x: jax.Array, *, ctx: nnx.Context) -> jax.Array:
def __call__(self, x: jax.Array, *, ctx: nnx.Rngs) -> jax.Array:
x = self.linear(x)
x = self.bn(x, ctx=ctx)
x = self.dropout(x, ctx=ctx)
Expand All @@ -21,17 +21,16 @@ def __call__(self, x: jax.Array, *, ctx: nnx.Context) -> jax.Array:


class ScanMLP(nnx.Module):

def __init__(self, dim: int, *, n_layers: int, ctx: nnx.Context):
def __init__(self, dim: int, *, n_layers: int, ctx: nnx.Rngs):
self.n_layers = n_layers
# partition Context and split the `params` key
keys, ctxdef = ctx.partition()
params_key = jax.random.split(keys["params"], n_layers)
params_key = jax.random.split(keys['params'], n_layers)

@partial(jax.vmap, out_axes=((0, None), None))
def create_block(params_key):
# merge back Context using the sliced `params` key
ctx = ctxdef.merge({"params": params_key})
ctx = ctxdef.merge({'params': params_key})
# create Block instance and return its partitions
return Block(dim, ctx=ctx).split(nnx.Param, nnx.BatchStat)

Expand All @@ -40,22 +39,22 @@ def create_block(params_key):
# merge to get a lifted Block instance
self.layers = moduledef.merge(params, batch_stats)

def __call__(self, x: jax.Array, *, ctx: nnx.Context):
def __call__(self, x: jax.Array, *, ctx: nnx.Rngs):
# partition Context and split the `dropout` key
keys, ctxdef = ctx.partition()
dropout_key = jax.random.split(keys["dropout"], self.n_layers)
dropout_key = jax.random.split(keys['dropout'], self.n_layers)
# partition Module to get params + batch_stats
(params, batch_stats), moduledef = self.layers.split(
nnx.Param, nnx.BatchStat
nnx.Param, nnx.BatchStat
)

def scan_fn(
carry: tuple[jax.Array, nnx.State], inputs: tuple[nnx.State, jax.Array]
carry: tuple[jax.Array, nnx.State], inputs: tuple[nnx.State, jax.Array]
):
(x, batch_stats), (params, dropout_key) = carry, inputs
# merge back Module and Context
module = moduledef.merge(params, batch_stats)
ctx = ctxdef.merge({"dropout": dropout_key})
ctx = ctxdef.merge({'dropout': dropout_key})
# forward pass
x = module(x, ctx=ctx)
# partition state and return
Expand All @@ -64,7 +63,7 @@ def scan_fn(

# call scan passing (x, batch_stats) as the carry, and (params, dropout_key) as the input
(x, batch_stats), params = jax.lax.scan(
scan_fn, (x, batch_stats), (params, dropout_key)
scan_fn, (x, batch_stats), (params, dropout_key)
)
# update layers state and return
self.layers.update((params, batch_stats))
Expand Down
19 changes: 11 additions & 8 deletions docs/experimental/nnx/[REMOVE]why_nnx_examples/ex4_nnx.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import jax.numpy as jnp
import jax
import jax.numpy as jnp

#######
from flax.experimental import nnx


def load_pretrained_model():
ctx = nnx.context(0)
model = nnx.Sequence([
ctx = nnx.Rngs(0)
model = nnx.Sequence(
[
lambda x: x.reshape((x.shape[0], -1)),
nnx.Linear(784, 1024, ctx=ctx),
])
]
)
return model

class Classifier(nnx.Module):

def __init__(self, backbone: nnx.Sequence, *, ctx: nnx.Context):
class Classifier(nnx.Module):
def __init__(self, backbone: nnx.Sequence, *, ctx: nnx.Rngs):
self.backbone = backbone
self.head = nnx.Linear(1024, 10, ctx=ctx)

Expand All @@ -25,9 +28,9 @@ def __call__(self, x: jax.Array):


pretrained_model = load_pretrained_model()
model = Classifier(pretrained_model, ctx=nnx.context(0))
model = Classifier(pretrained_model, ctx=nnx.Rngs(0))

# forward pass
y = model(jnp.ones((1, 28, 28)))

print("state =", jax.tree_map(jnp.shape, model.get_state()))
print('state =', jax.tree_map(jnp.shape, model.get_state()))
78 changes: 39 additions & 39 deletions flax/experimental/nnx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ import jax.numpy as jnp
class Count(nnx.Variable): pass

class Linear(nnx.Module):
def __init__(self, din: int, dout: int, *, ctx: nnx.Context):
key = ctx.rngs.params()
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
key = rngs.params()
self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
self.count = Count(0) # track the number of calls
Expand All @@ -48,15 +48,15 @@ class Linear(nnx.Module):
self.count += 1
return x @ self.w + self.b

model = Linear(din=12, dout=2, ctx=nnx.context(0))
model = Linear(din=12, dout=2, rngs=nnx.Rngs(0))

# Forward pass and verify the call count
x = jnp.ones((8, 12))
y = model(x)
assert model.count == 1
```

In this example `nnx.context(0)` create a `PRNGKey` for `params` with seed `0`, this is used by `ctx.rngs.<rng-name>()` inside `__init__` to generate a random key to initialize the parameters.
In this example `nnx.Rngs(0)` create a `PRNGKey` for `params` with seed `0`, this is used by `rngs.<rng-name>()` inside `__init__` to generate a random key to initialize the parameters.

### Training with the Functional API

Expand Down Expand Up @@ -150,18 +150,18 @@ NNX Modules are normal python classes, they obey regular python semantics such a

```python
class Foo(nnx.Module):
def __init__(self, ctx: nnx.Context):
def __init__(self, rngs: nnx.Rngs):
# node attributes
self.variable = nnx.Param(jnp.array(1))
self.implicit_param = jnp.array(3)
self.submodule = nnx.Linear(2, 4, ctx=ctx)
self.submodule = nnx.Linear(2, 4, rngs=rngs)
# static attributes
self.int = 1
self.float = 2.0
self.str = "hello"
self.list = [1, 2, 3]

model = Foo(din=12, dout=2, ctx=nnx.context(0))
model = Foo(din=12, dout=2, rngs=nnx.Rngs(0))
```
As shown above, python container types such as `list`, `tuple`, and `dict` are treated as static attributes, if similar functionality is needed, NNX provides the `Sequence` and `Dict` Modules.

Expand Down Expand Up @@ -275,8 +275,8 @@ Here is an example of how to create a `Linear` module that captures its output i

```python
class Linear(nnx.Module):
def __init__(self, din: int, dout: int, *, ctx: nnx.Context):
key = ctx.rngs.params()
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
key = rngs.params()
self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))

Expand All @@ -285,7 +285,7 @@ class Linear(nnx.Module):
self.y = nnx.Intermediate(y)
return y

model = Linear(12, 2, ctx=nnx.context(0))
model = Linear(12, 2, rngs=nnx.Rngs(0))
```
Since `y` is only created when the module is called, it is not available upon initialization. However, once you call the module `y` will be created. It is recommended that you use `pop` to retrieve temporary collections like `Intermediate`:

Expand Down Expand Up @@ -317,7 +317,7 @@ Alternatively, you can use `State.extract` to retrieve the `Intermediate` nodes

NNX lifted transforms analogous versions of JAX transforms but they know how to work with Modules. They usually perform the following tasks:

* Handle the Module's substates and Context's RNG streams according to the transform's semantics.
* Handle the Module's substates and Rngs's RNG streams according to the transform's semantics.
* Properly propagating state in and out of the transform, including updating the input Module's state with updates that happen inside the transform.

Here's a diagram illustrating how lifted transformations work:
Expand All @@ -335,13 +335,13 @@ Here we will create an example of how to implement an MLP that uses "scan over l

```python
class Block(nnx.Module):
def __init__(self, dim: int, *, ctx: nnx.Context):
self.linear = nnx.Linear(dim, dim, ctx=ctx)
def __init__(self, dim: int, *, rngs: nnx.Rngs):
self.linear = nnx.Linear(dim, dim, rngs=rngs)
self.dropout = nnx.Dropout(0.5)

def __call__(self, x: jax.Array, *, train: bool, ctx: nnx.Context) -> jax.Array:
def __call__(self, x: jax.Array, *, train: bool, rngs: nnx.Rngs) -> jax.Array:
x = self.linear(x)
x = self.dropout(x, deterministic=not train, ctx=ctx)
x = self.dropout(x, deterministic=not train, rngs=rngs)
x = jax.nn.gelu(x)
return x
```
Expand All @@ -350,30 +350,30 @@ Now we will define `ScanMLP`. During `__init__`, instead of creating a list of `

```python
class ScanMLP(nnx.Module):
def __init__(self, dim: int, *, n_layers: int, ctx: nnx.Context):
params_key = jax.random.split(ctx.rngs.params(), n_layers)
def __init__(self, dim: int, *, n_layers: int, rngs: nnx.Rngs):
params_key = jax.random.split(rngs.params(), n_layers)
self.n_layers = n_layers
state, moduledef = jax.vmap(
lambda key: Block(dim, ctx=nnx.context(params=key)).split()
lambda key: Block(dim, rngs=nnx.Rngs(params=key)).split()
)(params_key)
self.layers = moduledef.merge(state)

```
Note that we split the `params` key into `n_layers` keys so each layer has different parameters.

Now we will define `__call__`. Here we need to split the `dropout` key into `n_layers` keys so each layer has a different dropout mask, and `split` the layers to get their `params`. Both `params` and `dropout_key` will be passed as inputs, `x` will be the carry value. Inside the `scan_fn` we will merge the `params` back into a `Block` module and
apply it to the input `x`, passing the sliced `dropout_key` as part of the `Context`.
apply it to the input `x`, passing the sliced `dropout_key` as part of the `Rngs`.


```python
def __call__(self, x: jax.Array, *, train: bool, ctx: nnx.Context) -> jax.Array:
dropout_key = jax.random.split(ctx.rngs.dropout(), self.n_layers)
def __call__(self, x: jax.Array, *, train: bool, rngs: nnx.Rngs) -> jax.Array:
dropout_key = jax.random.split(rngs.dropout(), self.n_layers)
params, moduledef = self.layers.split(nnx.Param)

def scan_fn(x: inputs):
params, dropout_key = inputs
module = moduledef.merge(params)
x = module(x, train=train, ctx=nnx.context(dropout=dropout_key))
x = module(x, train=train, rngs=nnx.Rngs(dropout=dropout_key))
return x, module.extract(nnx.Param)

x, params = jax.lax.scan(scan_fn, x, (params, dropout_key))
Expand All @@ -385,10 +385,10 @@ Finally we apply `jax.lax.scan`, update the `layers` state with the new `params`
Here is a simple way to test our `ScanMLP`:

```python
model = ScanMLP(10, n_layers=5, ctx=nnx.context(0))
model = ScanMLP(10, n_layers=5, rngs=nnx.Rngs(0))

x = jnp.ones((3, 10))
y = model(x, train=True, ctx=nnx.context(dropout=1))
y = model(x, train=True, rngs=nnx.Rngs(dropout=1))
```

For a more robust implementation with comments take a look at the [Scan over layers](https://github.com/cgarciae/nnx/blob/main/examples/06_scan_over_layers.py) example.
Expand All @@ -402,36 +402,36 @@ Here's an example of creating a module with shared state:

```python
class Block(nnx.Module):
def __init__(self, linear: nnx.Linear, *, ctx: nnx.Context):
def __init__(self, linear: nnx.Linear, *, rngs: nnx.Rngs):
self.linear = linear
self.bn = nnx.BatchNorm(2, ctx=ctx)
self.bn = nnx.BatchNorm(2, rngs=rngs)

def __call__(self, x, *, ctx: nnx.Context):
def __call__(self, x, *, rngs: nnx.Rngs):
x = self.linear(x)
x = self.bn(x, ctx=ctx)
x = self.bn(x, rngs=rngs)
x = nnx.relu(x)
return x

class Model(nnx.Module):
def __init__(self, *, ctx: nnx.Context):
shared = nnx.Linear(2, 2, ctx=ctx)
self.block1 = Block(shared, ctx=ctx)
self.block2 = Block(shared, ctx=ctx)

def __call__(self, x, *, ctx: nnx.Context):
x = self.block1(x, ctx=ctx)
x = self.block2(x, ctx=ctx)
def __init__(self, *, rngs: nnx.Rngs):
shared = nnx.Linear(2, 2, rngs=rngs)
self.block1 = Block(shared, rngs=rngs)
self.block2 = Block(shared, rngs=rngs)

def __call__(self, x):
x = self.block1(x)
x = self.block2(x)
return x
```

In this example, the `Model` module contains two instances of the `Block` module. Each instance shares the same `nnx.Linear` module. To run the model, you can use the Context `flags` argument to set the `use_running_average` flag for all `BatchNorm` modules.
In this example, the `Model` module contains two instances of the `Block` module. Each instance shares the same `nnx.Linear` module. To run the model, you can use the Rngs `flags` argument to set the `use_running_average` flag for all `BatchNorm` modules.

Here's an example of computing the loss for a `Model` instance:

```python
def loss_fn(model: Model, x: jax.Array, y: jax.Array):
ctx = nnx.context(flags=dict(use_running_average=True))
y_pred = model(x, ctx=ctx)
with nnx.flags(use_running_average=True):
y_pred = model(x)
return jnp.mean((y - y_pred) ** 2)
```

Expand Down
Loading

0 comments on commit d5d34f7

Please sign in to comment.