How save deepspeed stage 3 model with pickle or torch #8910
-
Hi, I'm trying to save a model trained using deepspeed stage 2 using this code:
With stage 2 it worked if I added this code:
But using stage=3 I get this error: Traceback (most recent call last): I also tried saving using torch.save, but got same error. I also tried both pytorch-lightning version 1.3.8 and 1.4.1 cc: @SeanNaren |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 3 replies
-
Hey @ViktorThink! Thanks for bringing this up, I think we can make this clearer in the documentation for next time. To save I recommend you using trainer = pl.Trainer(
gpus=4,
plugins=DeepSpeedPlugin(
stage=3,
cpu_offload=True,
partition_activations=True,),
precision=16,
accelerator="ddp",
)
trainer.fit(model, train_dataloader)
trainer.save_checkpoint('model.pt') Note when using DeepSpeed we save a directory not a single file. More information can be read in the documentation here: https://pytorch-lightning.readthedocs.io/en/latest/advanced/advanced_gpu.html#deepspeed |
Beta Was this translation helpful? Give feedback.
-
Same error with pytorch-lightning 1.4.9 and 1.5.0rc1 on python 3.7 and 3.8 (DeepSpeedPlugin version is 0.5.4) after evaluation phase, checkpoint callback tries to save shared model to disk, but torch can't pickle with Pytorch-lightning 1.4.9 I can use
|
Beta Was this translation helpful? Give feedback.
-
After some debugging with a user, I've come up with a final script to show how you can use the import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.plugins import DeepSpeedPlugin
from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
if __name__ == "__main__":
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
test_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
enable_model_summary=False,
strategy=DeepSpeedPlugin(stage=2),
precision=16,
gpus=2,
callbacks=ModelCheckpoint(dirpath='checkpoints', save_last=True)
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
# once saved via the model checkpoint callback,
# it saves a folder containing the deepspeed checkpoint rather than a single file
checkpoint_path = "checkpoints/last.ckpt/"
if trainer.is_global_zero:
single_ckpt_path = "single_model.pt"
# magically converts the folder into a single lightning loadable pytorch file (for ZeRO 1,2 and 3)
convert_zero_checkpoint_to_fp32_state_dict(checkpoint_path, single_ckpt_path)
loaded_parameters = BoringModel.load_from_checkpoint(single_ckpt_path).parameters()
model = model.cpu()
# Assert model parameters are identical after loading
for orig_param, saved_model_param in zip(model.parameters(), loaded_parameters):
if model.dtype == torch.half:
# moved model to float32 for comparison with single fp32 saved weights
saved_model_param = saved_model_param.half()
assert torch.equal(orig_param, saved_model_param) The above where we use the Trainer as an engine still works, but now you'd need to pass the checkpoint path like so |
Beta Was this translation helpful? Give feedback.
After some debugging with a user, I've come up with a final script to show how you can use the
convert_zero_checkpoint_to_fp32_state_dict
to generate a single file that can be loaded using pickle, or lightning.