Incorrect pjit definition of train_state object #1789
Answered
by
jheek
mattiasmar
asked this question in
Q&A
-
Is it obvious to anyone in this forum what explains the inconsistency of this collab's pjit This is a duplicate of the discussion jax-ml/jax#9203.
|
Beta Was this translation helpful? Give feedback.
Answered by
jheek
Jan 17, 2022
Replies: 1 comment 2 replies
-
The eval_shape trick you used is clever but you fed the shapes as a tuple into eval_shape causing incorrect results:
Each shape here is a tuple which is itself a pytree you should simply pass model.params or Some minor issues:
|
Beta Was this translation helpful? Give feedback.
2 replies
Answer selected by
mattiasmar
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The eval_shape trick you used is clever but you fed the shapes as a tuple into eval_shape causing incorrect results:
Each shape here is a tuple which is itself a pytree you should simply pass model.params or
jax.ShapeDtypeStruct
instances to eval shape.Some minor issues: