Skip to content

Commit 5dfdc7c

Browse files
committed
enable merging parameters for diloco
1 parent a19a0e3 commit 5dfdc7c

File tree

1 file changed

+35
-6
lines changed

1 file changed

+35
-6
lines changed

torchft/local_sgd.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,14 @@ def __init__(
213213
self.should_quantize = should_quantize
214214

215215
self._grads: Dict[str, torch.Tensor] = {}
216+
217+
# Used to save global parameters so that they can be restored in case
218+
# commit fails
216219
self.original_parameters: Dict[str, torch.Tensor] = {}
217220

221+
# Used to mix the local and global parameters
222+
self._local_parameters: Dict[str, torch.Tensor] = {}
223+
218224
for name, p in self._model_fragment.named_parameters():
219225
if isinstance(p, DTensor):
220226
p = extract_local_tensor(p.data)
@@ -237,6 +243,14 @@ def save_parameters(self) -> None:
237243
param_to_local = extract_local_tensor(p.data)
238244
self.original_parameters[name].copy_(param_to_local, non_blocking=True)
239245

246+
def _save_local_parameters(self) -> None:
247+
"""
248+
Saves a copy of the model's parameters.
249+
"""
250+
with torch.no_grad():
251+
for name, p in self._model_fragment.named_parameters():
252+
self._local_parameters[name] = extract_local_tensor(p.data)
253+
240254
@torch.profiler.record_function("torchft::local_sgd::restore_parameters")
241255
def restore_parameters(self) -> None:
242256
with torch.no_grad():
@@ -287,6 +301,21 @@ def _set_grads(self) -> None:
287301
else:
288302
p.grad = self._grads[name]
289303

304+
def _clear_local_parameters(self) -> None:
305+
"""
306+
Clears the saved copy of the model's parameters
307+
"""
308+
self._local_parameters = {}
309+
310+
def _merge_parameters(self) -> None:
311+
"""
312+
Merges the local and global parameters.
313+
"""
314+
for name, p in self._model_fragment.named_parameters():
315+
torch.lerp(
316+
p.data, self._local_parameters[name], 1 - self._fragment_update_alpha
317+
)
318+
290319
@torch.profiler.record_function("torchft::local_sgd::wait")
291320
def wait(self) -> None:
292321
"""
@@ -370,6 +399,8 @@ def perform_sync(self) -> bool:
370399

371400
self.wait()
372401

402+
# save the parameters so they can be used for merging
403+
self._save_local_parameters()
373404
# Restore the parameters back to the previous state
374405
self.restore_parameters()
375406

@@ -391,8 +422,12 @@ def perform_sync(self) -> bool:
391422
self._set_grads()
392423
self._outer_optimizer.step()
393424
self.save_parameters()
425+
self._merge_parameters()
394426
self._outer_optimizer.zero_grad()
395427

428+
# free up memory
429+
self._clear_local_parameters()
430+
396431
return should_commit
397432

398433
def _average_grads(self) -> None:
@@ -544,12 +579,6 @@ def __init__(
544579
if fragment_update_alpha < 0 or fragment_update_alpha > 1:
545580
raise ValueError("fragment_update_alpha must be between 0 and 1")
546581

547-
# TODO: Support `fragment_update_alpha`
548-
if fragment_update_alpha != 0.0:
549-
raise ValueError(
550-
"Merging local parameters with global parameters is not supported yet"
551-
)
552-
553582
super().__init__()
554583
self._manager = manager
555584

0 commit comments

Comments
 (0)