Skip to content

Incorrect pjit definition of train_state object #1789

Answered by jheek
mattiasmar asked this question in Q&A
Discussion options

You must be logged in to vote

The eval_shape trick you used is clever but you fed the shapes as a tuple into eval_shape causing incorrect results:

params_shapes = jax.tree_map(lambda x: x.shape, model.params)
state_shapes = jax.eval_shape(get_initial_state, params_shapes)

Each shape here is a tuple which is itself a pytree you should simply pass model.params or jax.ShapeDtypeStruct instances to eval shape.

Some minor issues:

  1. In the get_opt_spec you return a FrozenDict that comes from set_partitions while the leaf is just a dict
  2. bias is rank 1 but you return a rank 2 partition spec P(None, None) instead of P(None)

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@mattiasmar
Comment options

@mattiasmar
Comment options

Answer selected by mattiasmar
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants