Skip to content

Dropout + while loop interaction #2920

Answered by cgarciae
gahdritz asked this question in Q&A
Discussion options

You must be logged in to vote

Hey @gahdritz, the issue is that you need to use the split_rngs argument to decide how the different rng streams will behave inside nn.while_loop, e.g.

split_rngs={'dropout': True} # each iteration gets a different key
# or
split_rngs={'dropout': False} # each iteration gets the same key

Here is the full example fixed (also replace jnp.empty with some random value as I was getting zeros).

import jax
import jax.numpy as jnp
from flax import linen as nn

class MyModel(nn.Module):
  num_neurons: int

  @nn.compact
  def __call__(self, x, training: bool):
    x = nn.Dense(self.num_neurons)(x)
    
    def fn(mdl, x):
        # jax.debug.print("x: {x}", x=x)
        return nn.Dropout(rate=0.5, d…

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by gahdritz
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants
Converted from issue

This discussion was converted from issue #2918 on March 02, 2023 16:05.