@@ -213,8 +213,14 @@ def __init__(
213
213
self .should_quantize = should_quantize
214
214
215
215
self ._grads : Dict [str , torch .Tensor ] = {}
216
+
217
+ # Used to save global parameters so that they can be restored in case
218
+ # commit fails
216
219
self .original_parameters : Dict [str , torch .Tensor ] = {}
217
220
221
+ # Used to mix the local and global parameters
222
+ self ._local_parameters : Dict [str , torch .Tensor ] = {}
223
+
218
224
for name , p in self ._model_fragment .named_parameters ():
219
225
if isinstance (p , DTensor ):
220
226
p = extract_local_tensor (p .data )
@@ -237,6 +243,14 @@ def save_parameters(self) -> None:
237
243
param_to_local = extract_local_tensor (p .data )
238
244
self .original_parameters [name ].copy_ (param_to_local , non_blocking = True )
239
245
246
+ def _save_local_parameters (self ) -> None :
247
+ """
248
+ Saves a copy of the model's parameters.
249
+ """
250
+ with torch .no_grad ():
251
+ for name , p in self ._model_fragment .named_parameters ():
252
+ self ._local_parameters [name ] = extract_local_tensor (p .data )
253
+
240
254
@torch .profiler .record_function ("torchft::local_sgd::restore_parameters" )
241
255
def restore_parameters (self ) -> None :
242
256
with torch .no_grad ():
@@ -293,6 +307,21 @@ def _set_grads(self) -> None:
293
307
# No longer needed
294
308
del self ._grads [name ]
295
309
310
+ def _clear_local_parameters (self ) -> None :
311
+ """
312
+ Clears the saved copy of the model's parameters
313
+ """
314
+ self ._local_parameters = {}
315
+
316
+ def _merge_parameters (self ) -> None :
317
+ """
318
+ Merges the local and global parameters.
319
+ """
320
+ for name , p in self ._model_fragment .named_parameters ():
321
+ torch .lerp (
322
+ p .data , self ._local_parameters [name ], 1 - self ._fragment_update_alpha
323
+ )
324
+
296
325
@torch .profiler .record_function ("torchft::local_sgd::wait" )
297
326
def wait (self ) -> None :
298
327
"""
@@ -376,6 +405,8 @@ def perform_sync(self) -> bool:
376
405
377
406
self .wait ()
378
407
408
+ # save the parameters so they can be used for merging
409
+ self ._save_local_parameters ()
379
410
# Restore the parameters back to the previous state
380
411
self .restore_parameters ()
381
412
@@ -397,8 +428,12 @@ def perform_sync(self) -> bool:
397
428
self ._set_grads ()
398
429
self ._outer_optimizer .step ()
399
430
self .save_parameters ()
431
+ self ._merge_parameters ()
400
432
self ._outer_optimizer .zero_grad ()
401
433
434
+ # free up memory
435
+ self ._clear_local_parameters ()
436
+
402
437
return should_commit
403
438
404
439
def _average_grads (self ) -> None :
@@ -550,12 +585,6 @@ def __init__(
550
585
if fragment_update_alpha < 0 or fragment_update_alpha > 1 :
551
586
raise ValueError ("fragment_update_alpha must be between 0 and 1" )
552
587
553
- # TODO: Support `fragment_update_alpha`
554
- if fragment_update_alpha != 0.0 :
555
- raise ValueError (
556
- "Merging local parameters with global parameters is not supported yet"
557
- )
558
-
559
588
super ().__init__ ()
560
589
self ._manager = manager
561
590
0 commit comments