-
This is very similar to #2920, where I was trying the same thing but with Dropout. The following toy example, where I'm trying to run a Dense layer several times (with weight sharing) in a while loop, fails:
It crashes with the following error:
Any tips? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hey @gahdritz, the error is a bit uninformative. Same as with import jax
import jax.numpy as jnp
from flax import linen as nn
class MyModel(nn.Module):
num_neurons: int
@nn.compact
def __call__(self, x, training: bool):
dense = nn.Dense(self.num_neurons, name="dense")
if self.is_initializing():
dense(x)
def fn(mdl, x):
x = dense(x)
return x
body = lambda mdl, inp: (
inp[0] + 1,
jax.lax.stop_gradient(fn(mdl, inp[1])),
)
_, x = nn.while_loop(
lambda _, x: x[0] < 3,
body,
self,
(0, x),
broadcast_variables='params',
split_rngs={'params': False},
)
return x
root_key = jax.random.PRNGKey(seed=0)
main_key, params_key = jax.random.split(key=root_key, num=2)
my_model = MyModel(num_neurons=4)
x = jax.random.normal(key=main_key, shape=(3, 4, 4))
variables = my_model.init(params_key, x, training=False)
params = variables['params']
y = my_model.apply({'params': params}, x, training=True)
print(y) One thing that surprised me is that I had to condition on |
Beta Was this translation helpful? Give feedback.
Hey @gahdritz, the error is a bit uninformative. Same as with
nn.scan
you have to passbroadcast_variables='params'
to indicate this collection will be shared, and you should also passsplit_rng
for initialization. Here is the working example: