How to pass a Flax dataclass to a JITed function where some fields are static and some are dynamic? #1217
Answered
by
marcvanzee
marcvanzee
asked this question in
Q&A
-
Original question by @malmaud. |
Beta Was this translation helpful? Give feedback.
Answered by
marcvanzee
Apr 8, 2021
Replies: 1 comment
-
This works: class Params(flax.struct.PyTreeNode):
a: bool = flax.struct.field(pytree_node=False) # static
b: int # dynamic Note that you can also use |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
marcvanzee
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This works:
Note that you can also use
@flax.struct.dataclass
to decorateParams
rather than subclassingflax.struct.PyTreeNode
, but the later currently gives better PyType support. PyType sometimes complains about either not understanding the constructor or the replace method when using the annotation.