Skip to content

Commit c0f1e62

Browse files
authored
Fixed reloading of sharded master weights (#9224)
1 parent ff57e53 commit c0f1e62

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

test/test_zero1.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ def test_zero1_load(self):
106106
torch.optim.SGD,
107107
lr=0.5,
108108
momentum=0.5,
109-
grad_clipping=True)
109+
grad_clipping=True,
110+
save_master_weights=True)
110111

111112
opt.step()
112113

@@ -127,14 +128,28 @@ def test_zero1_load(self):
127128
# is same as what is directly used here.
128129
reloaded_opt.load_state_dict(orig_opt_state)
129130

130-
self.assertEqual(reloaded_opt['param_groups'],
131+
reloaded_opt_state = reloaded_opt.state_dict()
132+
133+
self.assertEqual(reloaded_opt_state['param_groups'],
131134
orig_opt_state['param_groups'])
132135

133-
self.assertEqual(reloaded_opt['state'], orig_opt_state['state'])
136+
self.assertEqual(reloaded_opt_state['state'], orig_opt_state['state'])
137+
138+
self.assertEqual(reloaded_opt_state['base_state'],
139+
orig_opt_state['base_state'])
140+
141+
self.assertEqual(reloaded_opt_state['shape_info'],
142+
orig_opt_state['shape_info'])
143+
144+
sharded_master_weights = orig_opt_state['sharded_master_weights']
134145

135-
self.assertEqual(reloaded_opt['base_state'], orig_opt_state['base_state'])
146+
for param_group, loaded_param_groups in zip(
147+
reloaded_opt.base_optimizer.param_groups,
148+
orig_opt_state['param_groups']):
149+
for param, loaded_param_idx in zip(param_group['params'],
150+
loaded_param_groups['params']):
136151

137-
self.assertEqual(reloaded_opt['shape_info'], orig_opt_state['shape_info'])
152+
self.assertEqual(param, sharded_master_weights[loaded_param_idx])
138153

139154

140155
def _mp_fn(index):

torch_xla/distributed/zero_redundancy_optimizer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -523,12 +523,13 @@ def load_state_dict(self, state_dict):
523523
self.base_optimizer.load_state_dict(tmp)
524524
if 'sharded_master_weights' in state_dict:
525525
master_weights = state_dict['sharded_master_weights']
526-
index = 0
527-
for param_group, sharded_param_group in zip(
528-
self.param_groups, self.base_optimizer.param_groups):
529-
for param, shard in zip(param_group['params'],
530-
sharded_param_group['params']):
531-
shard.data.copy_(master_weights[index])
526+
for param_group, sharded_param_group, loaded_param_groups in zip(
527+
self.param_groups, self.base_optimizer.param_groups,
528+
state_dict['param_groups']):
529+
for param, shard, loaded_param_idx in zip(
530+
param_group['params'], sharded_param_group['params'],
531+
loaded_param_groups['params']):
532+
shard.data.copy_(master_weights[loaded_param_idx])
532533
# set dummy gradient for allgather to be triggered.
533534
if self.use_grad_acc_hook:
534535
# Create main gradients
@@ -540,7 +541,6 @@ def load_state_dict(self, state_dict):
540541
param.main_grad = shard.main_grad
541542
else:
542543
param.grad = torch.zeros_like(param.data)
543-
index += 1
544544
torch_xla.sync()
545545
# add `torch_xla.sync()` around allgather to avoid large number of
546546
# compilation

0 commit comments

Comments
 (0)