Skip to content

Why TrainState can be use as arguments of jited function without trigger error? #1858

Answered by marcvanzee
wztdream asked this question in General
Discussion options

You must be logged in to vote

When you look at the definition of the class TrainState, you can see that it extends struct.PyTreeNode. This means that it is a dataclasses that acts act like a JAX pytree node, so JAX knows how to flatten/unflatten it for JAX transformations.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by wztdream
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants