Skip to content

Control Flow in Module methods #1754

Answered by jheek
stefano-1981 asked this question in Q&A
Discussion options

You must be logged in to vote

Parameter initialization is lazy in Flax. In this case for Dense the input features are derived from the input. However, due to the control flow not all layers are initialized during init that are used during apply.

There are a bunch of ways to fix this. I would avoid using where because it means you will always pay the cost of both Dense layers.
One fix is to call both layers in both branches but discard one output (the dummy call will be optimized away when using jit or pmap).

You could also use a dedicated init function to make sure everything is initialized like this:


def init_fn(m):
  x = jax.random.randint(jax.random.PRNGKey(1), (7, 11), 1, 4)
  y = jax.random.randint(jax.random.PR…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@stefano-1981
Comment options

Answer selected by stefano-1981
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants