-
Hi, I filed this as an issue but I think it might be better as a question, so am reposting here (will delete/answer own issue if it is solved). Hoping community can help. I have a subclassed flax.linen.Module that takes a boolean argument in its My hope was to use the standard pattern of passing a boolean flag into the The deterministic variable cannot be traced as it will affect control flow, so instead (when applying function transformations) it is usually flagged using static_argnums, as in this example from the jax.checkpoint documentation. Again, standard jax. What I want is to use this same pattern, but with the flax module. I take the use of "lifting" in the documentation of flax.linen.checkpoint to mean that it can be used on a Module the same way jax.checkpoint can be used on a function. And indeed the documentation says it has a static_argnums parameter which (I think) should intuitively be applied to the However, I can't get this to work. See for example the following code, directly modeled on the example in the flax.linen.checkpoint docs:
Running this, I hoped to get a scalar value (wrapped as a jnp array) as output. The documentation is a bit thin on this area, or at least I haven't been able to find relevant sections. Any tips would be appreciated. I have also been trying to look in the flax/core/lift.py source but there are rather a lot of layers of abstraction which slows my comprehension... again, any help appreciated. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
This turned out to be a very basic Python error: since I was passing
|
Beta Was this translation helpful? Give feedback.
This turned out to be a very basic Python error: since I was passing
deterministic
as a named argument, it did not get counted as a positional argument. The following works: