Pjit of train step using Flax' train_state object #1792
Unanswered
mattiasmar
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Background: The Flax'
train_state
object contains the model params and optimizer states whose parallelization is controlled bypjit
and have been successfully used by the pjit'edoptimizer.init
(as reviewed in discussion #1789).The last cell of this collab (the cell that begins with "
#Pjit of training step
") demonstrates a failing attempt to reuse those very same PartionSpec's for thein_axis_resources
&out_axis_resources
of the pjit'edtrain_epoch
method of the Flax MNIST example.Could you tell how to pjit the
train_epoch
method of the Flax MNIST example, without breaking thetrain_state
object (as my end goal is to create a generic way of pjit'ing models in flax)?Current error:
IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(None, Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
Beta Was this translation helpful? Give feedback.
All reactions