Skip to content

How do I use static_argnums with flax.linen.checkpoint on a Module whose __call__ has a boolean control flag? #3631

Closed Answered by ihh
ihh asked this question in Q&A
Discussion options

You must be logged in to vote

This turned out to be a very basic Python error: since I was passing deterministic as a named argument, it did not get counted as a positional argument. The following works:

CheckpointedMLPWithDropout = nn.checkpoint (MLPWithDropout, static_argnums=2)

...

vars = model.init(rng, x, True)

print(model.apply(vars, x, False, rngs = {'dropout': rng}))

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by ihh
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant