Distributed training with multiple optimizers #10241
-
Hi, I have a model with multiple models inside the object (similar to a GAN) except that i want to use a single loss function with multiple optimizers. I am disabling automatic_optimization using the automatic_optimization flag. The code has been implemented and it works for a single GPU configuration. In order to accelerate the training process, i need to use DDP and across 3 GPU devices (1 node, multiple devices). The distributed training succeeds but i am not sure if it is working the way it is supposed to. Is there a way to figure out things? My psuedo code is something like below, cant share the full code due to NDAs: My query is, for distributed training with multiple optimizers, will the above code work in the INTENDED way? What should training_step_end function contain then? and how does multiple optimizers update across different devices? Thank you EDIT by @akihironitta provided scriptimport <standard imports>
class some_model1(nn.Module):
def __init__(self):
some layers
def forward(self, x, labels):
some calcuation on x and labels
return feat1
class some_model2(nn.Module):
def __init__(self):
some layers
def forward(self, x, labels):
some calcuation on x and labels
return feat2
class some_model3(nn.Module):
def __init__(self):
some layers
def forward(self, x, labels):
some calcuation on x and labels
return feat3
class some_model4(nn.Module):
def __init__(self):
some layers
def forward(self, x, labels):
some calcuation on x and labels
return loss
class Model(pl.LightningModule):
def __init__(self):
super(Model, self).__init__()
self.model1 = some_model1()
self.model2 = some_model2()
self.model3 = some_model3()
self.sub_model = some_model4()
self.common = nn.Sequential(some_layers)
self.sofmax_layer = nn.Softmax(dim=-1)
self.automatic_optimization=False
def forward(self, input_dict, output_dict, *args, **kwargs):
data1, data2, data3 = input_dict['data1'], input_dict['data2'],input_dict['data3']
feat1 = self.model1(data1)
pred1 = self.common(feat1)
feat2 = self.model2(data2)
pred3 = self.common(feat2)
feat3 = self.model3(data3)
pred3 = self.common(feat3)
combined_metrics2 = some metric_calculation based on output_dict
return {'combined_metrics1':combined_metrics1, 'combined_metrics2':combined_metrics2}
def training_step(self, batch, batch_idx, *args, **kwargs):
input_dict, output_dict = batch
self.optimizers()[0].zero_grad()
self.optimizers()[1].zero_grad()
combined_metrics1, combined_metrics2 = self.forward(input_dict, output_dict)
final_loss = some_loss_fn(combined_metrics1, combined_metrics2, output_dict)
self.optimizers()[0].zero_grad()
self.optimizers()[1].zero_grad()
self.manual_backward(final_loss)
self.optimizers()[0].step()
self.optimizers()[1].step()
return combined_metrics1, combined_metrics2
def training_step_end(self, outputs):
return ??
def configure_optimizers(self):
normal_params = list(self.model1.parameters()) + list(self.model2.parameters()) + list(self.model3.parameters())
self.normal_opt = optim.Adam(
normal_params,
lr = 0.001
)
self.opt2 = optim.Adam(
self.sub_model.parameters(),
lr= 0.001
)
return self.normal_opt, self.opt2
if __name__ == '__main__':
from xyz import TrainDataloader
from xyz import TestDataloader
train_set = TrainDataloader(params, partition='train')
test_set = TestDataloader(params, partition='test')
data_loader_loader = torch.utils.data.DataLoader(train_set, batch_size=16, shuffle=True,num_workers=1)
data_loader_test = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=True,num_workers=1)
class DataModule(pl.LightningDataModule):
def __init__(self, batch_size: int = 32):
super().__init__()
self.batch_size = batch_size
def setup(self, stage=None):
print('Setting up the data loader')
def train_dataloader(self):
return data_loader_loader
def val_dataloader(self):
return data_loader_test
def test_dataloader(self):
return data_loader_test
def teardown(self):
# Used to clean-up when the run is finished
...
dm = DataModule()
from pytorch_lightning.loggers import TensorBoardLogger
logger = TensorBoardLogger("logs", name="Custom")
model = Model(params)
trainer = pl.Trainer(max_epochs=50, log_every_n_steps=2, gpus=2, accelerator='ddp', logger=logger)
trainer.fit(model, datamodule=dm) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Your code looks good to me.
No need to do anything if you don't need to run anything at the end of
DDP syncs gradients across different devices overlapping backprop, and each device updates the weights with gradients synced across devices. See the PyTorch documentation for details: https://pytorch.org/docs/1.12/notes/ddp.html |
Beta Was this translation helpful? Give feedback.
Your code looks good to me.
No need to do anything if you don't need to run anything at the end of
training_step
.https://pytorch-lightning.readthedocs.io/en/1.6.5/common/lightning_module.html#training-step-end
DDP syncs gradients across different devices overlapping backprop, and each device updates the weights with gradients synced across devices. See the PyTorch documentation for details: https://pytorch.org/docs/1.12/notes/ddp.html