Skip to content

Commit

Permalink
exclude nnx from mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Nov 7, 2023
1 parent 467162e commit c142e79
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 11 deletions.
23 changes: 13 additions & 10 deletions flax/experimental/nnx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

_**N**eural **N**etworks for JA**X**_

NNX is a Neural Networks library for JAX that provides a simple yet powerful module system that adheres to standard Python semantics. Its aim is to combine the robustness of [Flax](https://flax.readthedocs.io/en/latest/) with a simplified, Pythonic API akin to that of [PyTorch](https://pytorch.org/).
NNX is a Neural Networks library for JAX that provides a simple yet powerful module system that adheres to standard Python semantics. Its aim is to combine the robustness of Linen with a simplified, Pythonic API akin to that of [PyTorch](https://pytorch.org/).

* **Pythonic**: Modules are just regular python classes, they contain their own state, are fully mutable, and allow sharing references between Modules.
* **Compatible**: Easily convert back and forth between Modules and pytrees using the Functional API to integrate with any JAX API.
Expand Down Expand Up @@ -35,22 +35,25 @@ from flax.experimental import nnx
import jax
import jax.numpy as jnp

class Count(nnx.Variable): pass
class Count(nnx.Variable): pass # typed Variable collections

class Linear(nnx.Module):
def __init__(self, din, dout, *, rngs: nnx.Rngs): # <--- explicit RNG management
key = rngs() # <-------------------------------------------------|
self.w = nnx.Param(jax.random.uniform(key, (din, dout))) # typed Variable collections
self.b = nnx.Param(jnp.zeros((dout,))) # <--------------------------|
self.count = Count(0) # <------------------------------------------|
def __init__(self, din, dout, *, rngs: nnx.Rngs): # explicit RNG management
key = rngs()
# put dynamic state in Variable types
self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
self.count = Count(0)
# other types as treated as static
self.din = din
self.dout = dout

def __call__(self, x):
self.count += 1 # inplace stateful updates
return x @ self.w + self.b

model = Linear(din=12, dout=2, rngs=nnx.Rngs(0)) # no need for `init`

y = model(jnp.ones((8, 12))) # call module directly
model = Linear(din=12, dout=2, rngs=nnx.Rngs(0)) # no special `init` method
y = model(jnp.ones((8, 12))) # call methods directly

assert model.count == 1
```
Expand Down
2 changes: 1 addition & 1 deletion flax/experimental/nnx/docs/why.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
"\n",
"\n",
"model = CounterLinear(4, 4, rngs=nnx.Rngs(0)) # no special `init` method\n",
"y = model(jnp.ones((2, 4))) # call module directly\n",
"y = model(jnp.ones((2, 4))) # call methods directly\n",
"\n",
"print(f'{model = }')"
]
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ module = [
"yaml",
]
ignore_missing_imports = true
# exclude nnx
[[tool.mypy.overrides]]
module = "flax.experimental.nnx.*"
ignore_errors = true

[tool.pytest.ini_options]
filterwarnings = [
Expand Down

0 comments on commit c142e79

Please sign in to comment.