Current best practices to initialize massive (50B parameter+) models #16944
-
Hi, I am working with GPT-style models and need to intitialize a model at the GPT-3 scale. Unfortunately, this means the model will run out of memory during initialization on CPU (or take an eternity to initialize layer-by-layer on cpu before shipping to GPU). In vanilla Pytorch I solved this using FSDP by initializing my models on the "meta" device, with full initialization on GPU afterward. What is the current best, most performant method to accomplish this with lightning? Note: I found this, which references an init_meta_context(), but my pytorch-lightning (v1.9.0) has no such functionality: |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 1 reply
-
Hi! The from torchdistx.deferred_init import deferred_init
model = deferred_init(YourLightningModule) And we'll materialize it for you in the Trainer. This is very experimental, and you might encounter installation issues. In the long term, we'll adopt the fake tensor mode from PyTorch: #16448. Otherwise, for a stable(r) solution, you can use the DeepSpeed integration: https://pytorch-lightning.readthedocs.io/en/stable/advanced/model_parallel.html#deepspeed-zero-stage-3 |
Beta Was this translation helpful? Give feedback.
-
Hey, in addition to @carmocca response, for deepspeed you need to initialize the model within |
Beta Was this translation helpful? Give feedback.
-
Thank you, I got this working with the Deepspeed integration--which was failing before because of my own errors. Just a small follow-up: does reloading the model from a checkpoint still function the same and utilize the |
Beta Was this translation helpful? Give feedback.
Hi! The
init_meta_context
functionality was replaced with atorchdistx
integration in #13868. You can do the following:And we'll materialize it for you in the Trainer. This is very experimental, and you might encounter installation issues.
In the long term, we'll adopt the fake tensor mode from PyTorch: #16448.
Otherwise, for a stable(r) solution, you can use the DeepSpeed integration: https://pytorch-lightning.readthedocs.io/en/stable/advanced/model_parallel.html#deepspeed-zero-stage-3