diff --git a/test/test_zero1.py b/test/test_zero1.py index 22986757b43..e3dc5738f84 100644 --- a/test/test_zero1.py +++ b/test/test_zero1.py @@ -106,7 +106,8 @@ def test_zero1_load(self): torch.optim.SGD, lr=0.5, momentum=0.5, - grad_clipping=True) + grad_clipping=True, + save_master_weights=True) opt.step() @@ -127,14 +128,28 @@ def test_zero1_load(self): # is same as what is directly used here. reloaded_opt.load_state_dict(orig_opt_state) - self.assertEqual(reloaded_opt['param_groups'], + reloaded_opt_state = reloaded_opt.state_dict() + + self.assertEqual(reloaded_opt_state['param_groups'], orig_opt_state['param_groups']) - self.assertEqual(reloaded_opt['state'], orig_opt_state['state']) + self.assertEqual(reloaded_opt_state['state'], orig_opt_state['state']) + + self.assertEqual(reloaded_opt_state['base_state'], + orig_opt_state['base_state']) + + self.assertEqual(reloaded_opt_state['shape_info'], + orig_opt_state['shape_info']) + + sharded_master_weights = orig_opt_state['sharded_master_weights'] - self.assertEqual(reloaded_opt['base_state'], orig_opt_state['base_state']) + for param_group, loaded_param_groups in zip( + reloaded_opt.base_optimizer.param_groups, + orig_opt_state['param_groups']): + for param, loaded_param_idx in zip(param_group['params'], + loaded_param_groups['params']): - self.assertEqual(reloaded_opt['shape_info'], orig_opt_state['shape_info']) + self.assertEqual(param, sharded_master_weights[loaded_param_idx]) def _mp_fn(index): diff --git a/torch_xla/distributed/zero_redundancy_optimizer.py b/torch_xla/distributed/zero_redundancy_optimizer.py index 0714fa42106..25cfdd97740 100644 --- a/torch_xla/distributed/zero_redundancy_optimizer.py +++ b/torch_xla/distributed/zero_redundancy_optimizer.py @@ -523,12 +523,13 @@ def load_state_dict(self, state_dict): self.base_optimizer.load_state_dict(tmp) if 'sharded_master_weights' in state_dict: master_weights = state_dict['sharded_master_weights'] - index = 0 - for param_group, sharded_param_group in zip( - self.param_groups, self.base_optimizer.param_groups): - for param, shard in zip(param_group['params'], - sharded_param_group['params']): - shard.data.copy_(master_weights[index]) + for param_group, sharded_param_group, loaded_param_groups in zip( + self.param_groups, self.base_optimizer.param_groups, + state_dict['param_groups']): + for param, shard, loaded_param_idx in zip( + param_group['params'], sharded_param_group['params'], + loaded_param_groups['params']): + shard.data.copy_(master_weights[loaded_param_idx]) # set dummy gradient for allgather to be triggered. if self.use_grad_acc_hook: # Create main gradients @@ -540,7 +541,6 @@ def load_state_dict(self, state_dict): param.main_grad = shard.main_grad else: param.grad = torch.zeros_like(param.data) - index += 1 torch_xla.sync() # add `torch_xla.sync()` around allgather to avoid large number of # compilation