@@ -213,8 +213,14 @@ def __init__(
213
213
self .should_quantize = should_quantize
214
214
215
215
self ._grads : Dict [str , torch .Tensor ] = {}
216
+
217
+ # Used to save global parameters so that they can be restored in case
218
+ # commit fails
216
219
self .original_parameters : Dict [str , torch .Tensor ] = {}
217
220
221
+ # Used to mix the local and global parameters
222
+ self ._local_parameters : Dict [str , torch .Tensor ] = {}
223
+
218
224
for name , p in self ._model_fragment .named_parameters ():
219
225
if isinstance (p , DTensor ):
220
226
p = extract_local_tensor (p .data )
@@ -237,6 +243,14 @@ def save_parameters(self) -> None:
237
243
param_to_local = extract_local_tensor (p .data )
238
244
self .original_parameters [name ].copy_ (param_to_local , non_blocking = True )
239
245
246
+ def _save_local_parameters (self ) -> None :
247
+ """
248
+ Saves a copy of the model's parameters.
249
+ """
250
+ with torch .no_grad ():
251
+ for name , p in self ._model_fragment .named_parameters ():
252
+ self ._local_parameters [name ] = extract_local_tensor (p .data )
253
+
240
254
@torch .profiler .record_function ("torchft::local_sgd::restore_parameters" )
241
255
def restore_parameters (self ) -> None :
242
256
with torch .no_grad ():
@@ -290,6 +304,21 @@ def _set_grads(self) -> None:
290
304
else :
291
305
p .grad = self ._grads [name ]
292
306
307
+ def _clear_local_parameters (self ) -> None :
308
+ """
309
+ Clears the saved copy of the model's parameters
310
+ """
311
+ self ._local_parameters = {}
312
+
313
+ def _merge_parameters (self ) -> None :
314
+ """
315
+ Merges the local and global parameters.
316
+ """
317
+ for name , p in self ._model_fragment .named_parameters ():
318
+ torch .lerp (
319
+ p .data , self ._local_parameters [name ], 1 - self ._fragment_update_alpha
320
+ )
321
+
293
322
@torch .profiler .record_function ("torchft::local_sgd::wait" )
294
323
def wait (self ) -> None :
295
324
"""
@@ -373,6 +402,8 @@ def perform_sync(self) -> bool:
373
402
374
403
self .wait ()
375
404
405
+ # save the parameters so they can be used for merging
406
+ self ._save_local_parameters ()
376
407
# Restore the parameters back to the previous state
377
408
self .restore_parameters ()
378
409
@@ -394,8 +425,12 @@ def perform_sync(self) -> bool:
394
425
self ._set_grads ()
395
426
self ._outer_optimizer .step ()
396
427
self .save_parameters ()
428
+ self ._merge_parameters ()
397
429
self ._outer_optimizer .zero_grad ()
398
430
431
+ # free up memory
432
+ self ._clear_local_parameters ()
433
+
399
434
return should_commit
400
435
401
436
def _average_grads (self ) -> None :
@@ -547,12 +582,6 @@ def __init__(
547
582
if fragment_update_alpha < 0 or fragment_update_alpha > 1 :
548
583
raise ValueError ("fragment_update_alpha must be between 0 and 1" )
549
584
550
- # TODO: Support `fragment_update_alpha`
551
- if fragment_update_alpha != 0.0 :
552
- raise ValueError (
553
- "Merging local parameters with global parameters is not supported yet"
554
- )
555
-
556
585
super ().__init__ ()
557
586
self ._manager = manager
558
587
0 commit comments