Why TrainState can be use as arguments of jited function without trigger error? #1858
-
Hi,
But TrainState can be use in jit-ed function without trigger any error, for example:
We have to mark |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
When you look at the definition of the class |
Beta Was this translation helpful? Give feedback.
When you look at the definition of the class
TrainState
, you can see that it extendsstruct.PyTreeNode
. This means that it is a dataclasses that acts act like a JAX pytree node, so JAX knows how to flatten/unflatten it for JAX transformations.