Model Parallelism with TrainState #1988
-
Hello, In HuggingFace library, they used the "TrainState" for training non-parallel models, while they couldn't do it when they used model parallelism. Is there a reason for not being able to use TrainState with model parallelism? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 9 replies
-
There is no reason in principle why TrainState could not be used with model parallelism. The reason why this doesn't work well (yet) in HuggingFace is, I believe, particular to their API and only applies to very large models. They are actually thinking about this in huggingface/transformers#15766. If you are looking for a simple example that uses TrainState with pjit, this example may be useful: https://colab.sandbox.google.com/github/marcvanzee/flax/blob/pjit-example/examples/siren/siren.ipynb |
Beta Was this translation helpful? Give feedback.
-
Hi @marcvanzee , I am trying to run your Colab example, but it doesn't work. The following code block :
Gives the following error:
I have tried different versions of JAX but it always gives the same error. Any idea what is the problem here and how to fix it ? |
Beta Was this translation helpful? Give feedback.
There is no reason in principle why TrainState could not be used with model parallelism.
The reason why this doesn't work well (yet) in HuggingFace is, I believe, particular to their API and only applies to very large models. They are actually thinking about this in huggingface/transformers#15766.
If you are looking for a simple example that uses TrainState with pjit, this example may be useful: https://colab.sandbox.google.com/github/marcvanzee/flax/blob/pjit-example/examples/siren/siren.ipynb
@patrickvonplaten