How to use jax.lax.scan in Flax? #1283
-
(This question was asked by @marcosrdac on #16 , but I'm copying it here so it is easier accessible to users)
I plan to use Flax in my research on RNNs, but I'm struggling to understand some ideas behind Flax implementation for days. I really wanted to see an example of someone using the documented RNN cells, it would help me a lot! Anyway, I tried to use the lambda trick, but it ended up not working, I get an error like this: class LRNNCell(nn.Module):
@nn.compact
def __call__(self, h, x):
nh = h.shape[0]
Whx = nn.Dense(nh)
Whh = nn.Dense(nh, use_bias=False)
Wyh = nn.Dense(1)
h = nn.tanh(Whx(x) + Whh(h))
y = nn.tanh(Wyh(h))
return h, y
class LRNN(nn.Module):
ny: Any
nh: Any
@nn.compact
def __call__(self, x):
h = jnp.zeros(self.nh)
cell = LRNNCell()
h, y = jax.lax.scan(lambda h, x: cell(h, x), h, x)
return y[-self.ny:] What am I missing? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
In Linen, we implemented "lifted" JAX transforms that allow you to apply JAX transformations to Modules directly, so you can define class LRNNCell(nn.Module):
@functools.partial(
nn.transforms.scan,
variable_broadcast='params',
split_rngs={'params': False})
@nn.compact
def __call__(self, h, x):
nh = h.shape[0]
Whx = nn.Dense(nh)
Whh = nn.Dense(nh, use_bias=False)
Wyh = nn.Dense(1)
h = nn.tanh(Whx(x) + Whh(h))
y = nn.tanh(Wyh(h))
return h, y Both our sst2 and seq2seq are using the |
Beta Was this translation helpful? Give feedback.
In Linen, we implemented "lifted" JAX transforms that allow you to apply JAX transformations to Modules directly, so you can define
LRNNCell
using that as follows:Both our sst2 and seq2seq are using the
nn.transforms.scan
transforms, so that may be helpful to get a better understanding as well.