-
Not sure if this is a bug, but I'm having trouble using
Here, I'm just trying to use dropout repeatedly in a while loop. This crashes with the following (truncated) error:
It seems like the while loop severs some connection between the Module, which has access to the RNG collection, and the dropout module trying to use it. Any suggestions? Is this a bug or am I doing something wrong? If it is a bug, are there any workarounds? If it helps, I don't need to backprop through the while loop in my actual application. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
Hey @gahdritz, the issue is that you need to use the split_rngs={'dropout': True} # each iteration gets a different key
# or
split_rngs={'dropout': False} # each iteration gets the same key Here is the full example fixed (also replace jnp.empty with some random value as I was getting zeros). import jax
import jax.numpy as jnp
from flax import linen as nn
class MyModel(nn.Module):
num_neurons: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.num_neurons)(x)
def fn(mdl, x):
# jax.debug.print("x: {x}", x=x)
return nn.Dropout(rate=0.5, deterministic=not training)(x)
body = lambda mdl, inp: (
inp[0] + 1,
fn(mdl, inp[1]),
)
num_iter = 3
_, x = nn.while_loop(
lambda _, x: x[0] < num_iter,
body,
self,
(0, x),
split_rngs={'dropout': False},
)
return x
root_key = jax.random.PRNGKey(seed=0)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)
my_model = MyModel(num_neurons=3)
x = jax.random.uniform(main_key, shape=(3, 4, 4))
# Dropout is disabled with `training=False` (that is, `deterministic=True`).
variables = my_model.init(params_key, x, training=False)
params = variables['params']
# Dropout is enabled with `training=True` (that is, `deterministic=False`).
y = my_model.apply({'params': params}, x, training=True, rngs={'dropout': dropout_key})
print(y) |
Beta Was this translation helpful? Give feedback.
Hey @gahdritz, the issue is that you need to use the
split_rngs
argument to decide how the different rng streams will behave insidenn.while_loop
, e.g.Here is the full example fixed (also replace jnp.empty with some random value as I was getting zeros).