Skip to content

How to pass a Flax dataclass to a JITed function where some fields are static and some are dynamic? #1217

Answered by marcvanzee
marcvanzee asked this question in Q&A
Discussion options

You must be logged in to vote

This works:

class Params(flax.struct.PyTreeNode):
  a: bool = flax.struct.field(pytree_node=False)  # static
  b: int  # dynamic

Note that you can also use @flax.struct.dataclass to decorate Params rather than subclassing flax.struct.PyTreeNode, but the later currently gives better PyType support. PyType sometimes complains about either not understanding the constructor or the replace method when using the annotation.

Replies: 1 comment

Comment options

marcvanzee
Apr 8, 2021
Maintainer Author

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant