Skip to content

Understanding vmap + make_rng #3393

Answered by chiamp
JamesAllingham asked this question in General
Discussion options

You must be logged in to vote

With nn.vmap, you could do something like this:

class RandomModule(nn.Module):
  @nn.compact
  def __call__(self, x):
    return jax.random.normal(self.make_rng("my_key"), x.shape)
class VmapRandomModule(nn.Module):
  @nn.compact
  def __call__(self, x):
    batch_random_module = nn.vmap(
        RandomModule,
        split_rngs={'my_key': True})
    return batch_random_module()(x)
x = jnp.arange(3)
a = VmapRandomModule()
a.apply({}, x, rngs={'my_key': jax.random.PRNGKey(0)})
>>> Array([1.6090478 , 0.16792756, 0.1328707 ], dtype=float32)

Also I think you had a typo in your code, since the variable x isn't defined:

class MyModule(nn.Module):
    @nn.compact
    def __call__(self):
       …

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@JamesAllingham
Comment options

Answer selected by JamesAllingham
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants