Skip to content

Dense + while loop interaction #2922

Answered by cgarciae
gahdritz asked this question in Q&A
Discussion options

You must be logged in to vote

Hey @gahdritz, the error is a bit uninformative. Same as with nn.scan you have to pass broadcast_variables='params' to indicate this collection will be shared, and you should also pass split_rng for initialization. Here is the working example:

import jax
import jax.numpy as jnp
from flax import linen as nn

class MyModel(nn.Module):
  num_neurons: int

  @nn.compact
  def __call__(self, x, training: bool):

    dense = nn.Dense(self.num_neurons, name="dense")

    if self.is_initializing():
        dense(x)
    
    def fn(mdl, x):
        x = dense(x)
        return x

    body = lambda mdl, inp: (
        inp[0] + 1,
        jax.lax.stop_gradient(fn(mdl, inp[1])),
    )

    _, x = nn.w…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@gahdritz
Comment options

Answer selected by gahdritz
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants