Replies: 2 comments
-
Hey @jlperla! Please take a look at The Flax Philosophy. We avoid the use of combinators as much as possible, only combinators we have are are |
Beta Was this translation helpful? Give feedback.
-
Thanks @cgarciae I think I see the philsophy. As a JAX outsider, though, the mental model of performance in JAX is much more opaque than you may realize so it might be useful to give some more guidance. The other thing is that random number generation in nnx looks cool, but it is unclear to me how to use it properly. Along those lines, can you do a critique of this implementation and tell me things that would have led to hidden sub-optimal performance, incorrect use of random numbers, etc.? If helpful, I could clean this up as an exmaple for the docs if you think that is valuable? import typing as tp
import jax
import jax.numpy as jnp
from flax import nnx
from flax.nnx.nnx import rnglib
from flax.typing import Dtype, PrecisionLike
class MLP(nnx.Module):
def __init__(
self,
in_features: int,
out_features: int,
*,
width: int,
depth: int,
activation: tp.Callable,
rngs: rnglib.Rngs,
use_bias: bool = True,
use_final_bias: bool = True,
final_activation: tp.Optional[tp.Callable] = None,
dtype: tp.Optional[Dtype] = None,
param_dtype: Dtype = jnp.float32,
precision: PrecisionLike = None,
):
self.in_features = in_features
self.out_features = out_features
self.width = width
self.depth = depth
self.use_bias = use_bias
self.use_final_bias = use_final_bias
self.activation = activation
self.final_activation = final_activation
assert depth > 0 # skipping specialization of no hidden layers
self.layers = []
self.layers.append(
nnx.Linear(
in_features,
width,
use_bias=use_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
)
for i in range(self.depth - 1):
self.layers.append(
nnx.Linear(
width,
width,
use_bias=self.use_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
)
self.layers.append(self.activation)
self.layers.append(
nnx.Linear(
width,
out_features,
use_bias=self.use_final_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
)
if self.final_activation is not None:
self.layers.append(self.final_activation)
def __call__(self, x: jax.Array) -> jax.Array:
for layer in self.layers:
x = layer(x)
return x
if __name__ == "__main__":
rngs = nnx.Rngs(0)
n_in = 2
n_out = 1
depth = 3
width = 128
activation = nnx.relu
model = MLP(n_in, n_out, width=width, depth=depth, activation=activation, rngs=rngs)
# NOT SURE HOW THE nnx rngs can be split, etc.?
x = jax.random.normal(rngs.next(), (n_in,))
model(x)
random_inputs = jax.random.normal(rngs.next(), (5, n_in))
@nnx.jit
def loss(f, batch):
return jnp.mean(jax.vmap(f)(batch))
val = loss(model, random_inputs)
my_batch = jax.random.normal(rngs.next(), (20, n_in))
@nnx.jit
def loss_closure(f):
return jnp.mean(jax.vmap(f)(my_batch))
loss_val, loss_grad = nnx.value_and_grad(loss_closure)(model)
print(val) |
Beta Was this translation helpful? Give feedback.
-
Is there a coding pattern (or even class itself in progress) for a flexible MLP-style wrapper in NNX? Thinking along the lines of https://docs.kidger.site/equinox/api/nn/mlp/ or https://pytorch.org/vision/main/generated/torchvision.ops.MLP.html with a few extra features
To me, at least, the key parameters are:
I don't mind trying to see if an RA can do a PR for this as practice with NNX if you give some hints. But in the meantime is there a coding pattern which will work well with JAX/NNX.
Note that the pytorch ones don't quite have these features but the equinox one does, after I put in a PR to it.
Beta Was this translation helpful? Give feedback.
All reactions