Sharding and training multiple models at once for a large scale reinforcement learning #13601
Replies: 2 comments 5 replies
-
🤯 this is an epic application! Haven't read this paper at all and from a quick skim seems really interesting, will read it in more depth. From what you've described, the model weights are successfully sharded and kept on all devices (may I ask, how many GPUs are being used?). It seems like your observation is that the activations are fairly large, and you'd like to try offloading/partitioning them (hence the From what I see, this is a case of using activation checkpointing to enable partitioning of activations, and potentially Also you should check out the really helpful guide for transformer models here. You can either pass in these arguments directly to the Strategy, or make your own custom config. |
Beta Was this translation helpful? Give feedback.
-
@thomfoster have you cracked this? I am working on the exact same problem with some friends: using PPO to make GPT-J better at conversations (our reward model is trained on a large dataset of user conversations from our app chai.ml). I got good results applying PPO to GPT2 as my initial policy but want to initialise it to GPT-J --- I've stuck to deepspeed so far but not getting speedups from increasing N GPUs so I must be doing something wrong. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hey lightning team (perhaps @SeanNaren would be best placed),
I'm currently replicating this paper by OpenAI, in which they run the PPO algorithm from reinforcement learning with large language models.
Whilst I have successfully used the lightning trainer and deepspeed implementation in the past to train LLMs up 20B parameters, I am struggling to get deepspeed to correctly shard the models in this case. This is because my LightningModule contains 4 transformers (actor, critic, reference and reward networks). Whilst only 2 of the transformers are being contained and require gradients, all the networks involved are multiple billions of parameters and do not fit onto a 40GB A100 chip without deepspeed.
As simple proof of concept that would be amazing to get up an running is below. The exact architectures / loss computation isn't important, I just want to prove out that I can do two forward passes with gradients and two forward passes without, and combine the outputs together.
As you can see, I am initialising the models inside setup after calling the
enable_transformers_pretrained_deepspeed_sharding
method. Whilst the models intialise successfully, we run out of CUDA memory attempting to do the second forward pass (to get r2).Perhaps one way to solve this would be to somehow mark "r2" as an activation in the training step so that it can be offloaded to RAM? (Althought right now I'm not even sure thats the issue and am struggling to debug haha!)
Any help would be much appreciated!
Best,
Thom
My trainer config is below:
Beta Was this translation helpful? Give feedback.
All reactions