diff --git a/flax/experimental/nnx/README.md b/flax/experimental/nnx/README.md index a142bd12ca..54e4a5e08a 100644 --- a/flax/experimental/nnx/README.md +++ b/flax/experimental/nnx/README.md @@ -4,7 +4,7 @@ _**N**eural **N**etworks for JA**X**_ -NNX is a Neural Networks library for JAX that provides a simple yet powerful module system that adheres to standard Python semantics. Its aim is to combine the robustness of [Flax](https://flax.readthedocs.io/en/latest/) with a simplified, Pythonic API akin to that of [PyTorch](https://pytorch.org/). +NNX is a Neural Networks library for JAX that provides a simple yet powerful module system that adheres to standard Python semantics. Its aim is to combine the robustness of Linen with a simplified, Pythonic API akin to that of [PyTorch](https://pytorch.org/). * **Pythonic**: Modules are just regular python classes, they contain their own state, are fully mutable, and allow sharing references between Modules. * **Compatible**: Easily convert back and forth between Modules and pytrees using the Functional API to integrate with any JAX API. @@ -35,22 +35,25 @@ from flax.experimental import nnx import jax import jax.numpy as jnp -class Count(nnx.Variable): pass +class Count(nnx.Variable): pass # typed Variable collections class Linear(nnx.Module): - def __init__(self, din, dout, *, rngs: nnx.Rngs): # <--- explicit RNG management - key = rngs() # <-------------------------------------------------| - self.w = nnx.Param(jax.random.uniform(key, (din, dout))) # typed Variable collections - self.b = nnx.Param(jnp.zeros((dout,))) # <--------------------------| - self.count = Count(0) # <------------------------------------------| + def __init__(self, din, dout, *, rngs: nnx.Rngs): # explicit RNG management + key = rngs() + # put dynamic state in Variable types + self.w = nnx.Param(jax.random.uniform(key, (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + self.count = Count(0) + # other types as treated as static + self.din = din + self.dout = dout def __call__(self, x): self.count += 1 # inplace stateful updates return x @ self.w + self.b -model = Linear(din=12, dout=2, rngs=nnx.Rngs(0)) # no need for `init` - -y = model(jnp.ones((8, 12))) # call module directly +model = Linear(din=12, dout=2, rngs=nnx.Rngs(0)) # no special `init` method +y = model(jnp.ones((8, 12))) # call methods directly assert model.count == 1 ``` diff --git a/flax/experimental/nnx/docs/why.ipynb b/flax/experimental/nnx/docs/why.ipynb index 67f6b0d552..01b9c394a6 100644 --- a/flax/experimental/nnx/docs/why.ipynb +++ b/flax/experimental/nnx/docs/why.ipynb @@ -90,7 +90,7 @@ "\n", "\n", "model = CounterLinear(4, 4, rngs=nnx.Rngs(0)) # no special `init` method\n", - "y = model(jnp.ones((2, 4))) # call module directly\n", + "y = model(jnp.ones((2, 4))) # call methods directly\n", "\n", "print(f'{model = }')" ] diff --git a/pyproject.toml b/pyproject.toml index 14c7f31ede..4ef5eb1218 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,6 +112,10 @@ module = [ "yaml", ] ignore_missing_imports = true +# exclude nnx +[[tool.mypy.overrides]] +module = "flax.experimental.nnx.*" +ignore_errors = true [tool.pytest.ini_options] filterwarnings = [