Force no split in make_rng
#3113
Unanswered
zaccharieramzi
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I have the following situation: I am using a
Dropout
layer multiple times without ann.scan
ornn.while_loop
, therefore I cannot usesplit_rngs={"dropout": False}
.However, I would still like to use the same dropout mask twice.
Is it possible to specify "no split" to make rng for certain collections?
If I just take the original dropout example I would like to do something like:
and still have
jnp.sum(y == 0.) / (3*4*3) == 0.5
approx.For more context I am actually trying to implement Deep Equilibrium Models using
jaxopt
andflax
, where the fixed point defining function uses dropout.I also tried to see if the
split_rngs
functionality could be extended tojaxopt
but I think it's going to be difficult.Beta Was this translation helpful? Give feedback.
All reactions