@@ -213,8 +213,14 @@ def __init__(
213
213
self .should_quantize = should_quantize
214
214
215
215
self ._grads : Dict [str , torch .Tensor ] = {}
216
+
217
+ # Used to save global parameters so that they can be restored in case
218
+ # commit fails
216
219
self .original_parameters : Dict [str , torch .Tensor ] = {}
217
220
221
+ # Used to mix the local and global parameters
222
+ self ._local_parameters : Dict [str , torch .Tensor ] = {}
223
+
218
224
for name , p in self ._model_fragment .named_parameters ():
219
225
if isinstance (p , DTensor ):
220
226
p = extract_local_tensor (p .data )
@@ -237,6 +243,14 @@ def save_parameters(self) -> None:
237
243
param_to_local = extract_local_tensor (p .data )
238
244
self .original_parameters [name ].copy_ (param_to_local , non_blocking = True )
239
245
246
+ def _save_local_parameters (self ) -> None :
247
+ """
248
+ Saves a copy of the model's parameters.
249
+ """
250
+ with torch .no_grad ():
251
+ for name , p in self ._model_fragment .named_parameters ():
252
+ self ._local_parameters [name ] = extract_local_tensor (p .data )
253
+
240
254
@torch .profiler .record_function ("torchft::local_sgd::restore_parameters" )
241
255
def restore_parameters (self ) -> None :
242
256
with torch .no_grad ():
@@ -287,6 +301,21 @@ def _set_grads(self) -> None:
287
301
else :
288
302
p .grad = self ._grads [name ]
289
303
304
+ def _clear_local_parameters (self ) -> None :
305
+ """
306
+ Clears the saved copy of the model's parameters
307
+ """
308
+ self ._local_parameters = {}
309
+
310
+ def _merge_parameters (self ) -> None :
311
+ """
312
+ Merges the local and global parameters.
313
+ """
314
+ for name , p in self ._model_fragment .named_parameters ():
315
+ torch .lerp (
316
+ p .data , self ._local_parameters [name ], 1 - self ._fragment_update_alpha
317
+ )
318
+
290
319
@torch .profiler .record_function ("torchft::local_sgd::wait" )
291
320
def wait (self ) -> None :
292
321
"""
@@ -370,6 +399,8 @@ def perform_sync(self) -> bool:
370
399
371
400
self .wait ()
372
401
402
+ # save the parameters so they can be used for merging
403
+ self ._save_local_parameters ()
373
404
# Restore the parameters back to the previous state
374
405
self .restore_parameters ()
375
406
@@ -391,8 +422,12 @@ def perform_sync(self) -> bool:
391
422
self ._set_grads ()
392
423
self ._outer_optimizer .step ()
393
424
self .save_parameters ()
425
+ self ._merge_parameters ()
394
426
self ._outer_optimizer .zero_grad ()
395
427
428
+ # free up memory
429
+ self ._clear_local_parameters ()
430
+
396
431
return should_commit
397
432
398
433
def _average_grads (self ) -> None :
@@ -544,12 +579,6 @@ def __init__(
544
579
if fragment_update_alpha < 0 or fragment_update_alpha > 1 :
545
580
raise ValueError ("fragment_update_alpha must be between 0 and 1" )
546
581
547
- # TODO: Support `fragment_update_alpha`
548
- if fragment_update_alpha != 0.0 :
549
- raise ValueError (
550
- "Merging local parameters with global parameters is not supported yet"
551
- )
552
-
553
582
super ().__init__ ()
554
583
self ._manager = manager
555
584
0 commit comments