Skip to content

How to use jax.lax.scan in Flax? #1283

Answered by marcvanzee
marcvanzee asked this question in Q&A
Discussion options

You must be logged in to vote

In Linen, we implemented "lifted" JAX transforms that allow you to apply JAX transformations to Modules directly, so you can define LRNNCell using that as follows:

class LRNNCell(nn.Module):
  @functools.partial(
      nn.transforms.scan,
      variable_broadcast='params',
      split_rngs={'params': False})
  @nn.compact
  def __call__(self, h, x):
      nh = h.shape[0]
      Whx = nn.Dense(nh)
      Whh = nn.Dense(nh, use_bias=False)
      Wyh = nn.Dense(1)

      h = nn.tanh(Whx(x) + Whh(h))
      y = nn.tanh(Wyh(h))
      return h, y

Both our sst2 and seq2seq are using the nn.transforms.scan transforms, so that may be helpful to get a better understanding as well.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@marcosrdac
Comment options

Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants