11
11
import logging
12
12
import math
13
13
import threading
14
+ from contextlib import nullcontext
14
15
from types import TracebackType
15
16
from typing import Any , Callable , Dict , Iterator , List , Optional , Tuple , Type
16
17
@@ -197,9 +198,10 @@ def __init__(
197
198
self ._outer_optimizer = outer_optimizer
198
199
199
200
# Stores pending all reduce
200
- self ._allreduce_futures : list [
201
- torch .futures .Future [None ] | torch .futures .Future [torch .Tensor ]
202
- ] = []
201
+ self ._allreduce_futures : list [torch .futures .Future [torch .Tensor ]] = []
202
+ self ._stream : Optional [torch .cuda .Stream ] = (
203
+ torch .cuda .Stream () if torch .cuda .is_available () else None
204
+ )
203
205
204
206
if bucket_cap_mb is not None :
205
207
self .bucket_cap_mb = int (bucket_cap_mb * 1024 * 1024 )
@@ -222,13 +224,15 @@ def __init__(
222
224
t = t .pin_memory ()
223
225
self .original_parameters [name ] = t
224
226
227
+ @torch .profiler .record_function ("torchft::local_sgd::save_parameters" )
225
228
def save_parameters (self ) -> None :
226
229
with torch .no_grad ():
227
230
# TODO: consider running copy on a separate stream
228
231
for name , p in self ._model_fragment .named_parameters ():
229
232
param_to_local = extract_local_tensor (p .data )
230
233
self .original_parameters [name ].copy_ (param_to_local , non_blocking = True )
231
234
235
+ @torch .profiler .record_function ("torchft::local_sgd::restore_parameters" )
232
236
def restore_parameters (self ) -> None :
233
237
with torch .no_grad ():
234
238
# TODO: consider running copy on a separate stream
@@ -248,6 +252,7 @@ def restore_parameters(self) -> None:
248
252
else :
249
253
p .data .copy_ (self .original_parameters [name ], non_blocking = False )
250
254
255
+ @torch .profiler .record_function ("torchft::local_sgd::wait" )
251
256
def wait (self ) -> None :
252
257
"""
253
258
Waits for the previously scheduled allreduce to finish
@@ -256,6 +261,9 @@ def wait(self) -> None:
256
261
for work in self ._allreduce_futures :
257
262
work .wait ()
258
263
264
+ if self ._stream is not None :
265
+ self ._stream .synchronize ()
266
+
259
267
self ._allreduce_futures = []
260
268
261
269
def should_prepare_fragment (self , step : int ) -> bool :
@@ -272,22 +280,31 @@ def should_sync_fragment(self, step: int) -> bool:
272
280
step_to_sync = step - self ._fragment_sync_offset - self ._fragment_sync_delay
273
281
return step_to_sync % self ._sync_every == 0
274
282
283
+ @torch .profiler .record_function ("torchft::local_sgd::prepare_sync" )
275
284
def prepare_sync (self ) -> None :
276
285
"""
277
286
Calculate the pseugradient, average them across the manager group and starts
278
287
allreduce on the pseudo-gradients but doesn't wait for it to finish.
279
288
"""
280
- # Set the .grad field of each parameter to its pseudogradient
281
- for name , p in self ._model_fragment .named_parameters ():
282
- local_param = extract_local_tensor (p .data )
283
- pseudogradient = local_param - self .original_parameters [name ].to (p .device )
284
- if isinstance (p , DTensor ):
285
- p .grad ._local_tensor = pseudogradient
286
- else :
287
- p .grad = pseudogradient
289
+ with (
290
+ torch .cuda .stream (self ._stream )
291
+ if self ._stream is not None
292
+ else nullcontext ()
293
+ ):
294
+ # Set the .grad field of each parameter to its pseudogradient
295
+ for name , p in self ._model_fragment .named_parameters ():
296
+ local_param = extract_local_tensor (p .data )
297
+ pseudogradient = local_param - self .original_parameters [name ].to (
298
+ p .device
299
+ )
300
+ if isinstance (p , DTensor ):
301
+ p .grad ._local_tensor = pseudogradient
302
+ else :
303
+ p .grad = pseudogradient
288
304
289
- self ._average_grads ()
305
+ self ._average_grads ()
290
306
307
+ @torch .profiler .record_function ("torchft::local_sgd::perform_sync" )
291
308
def perform_sync (self ) -> bool :
292
309
"""
293
310
Overrides the sync method to wait for the scheduled allreduce to finish and
@@ -467,16 +484,6 @@ def __init__(
467
484
if fragment_update_alpha < 0 or fragment_update_alpha > 1 :
468
485
raise ValueError ("fragment_update_alpha must be between 0 and 1" )
469
486
470
- # TODO: Support multiple fragments
471
- # This requires changing the manager to support `should_commit` for each
472
- # fragment separately.
473
- if len (model_fragments ) != 1 :
474
- raise ValueError ("Multiple fragments are not supported yet" )
475
-
476
- # TODO: Support `fragment_sync_delay`
477
- if fragment_sync_delay != 0 :
478
- raise ValueError ("Fragment synchronization delay is not supported yet" )
479
-
480
487
# TODO: Support `fragment_update_alpha`
481
488
if fragment_update_alpha != 0.0 :
482
489
raise ValueError (
@@ -522,6 +529,8 @@ def __init__(
522
529
use_bucketization ,
523
530
bucket_cap_mb ,
524
531
should_quantize ,
532
+ fragment_sync_delay ,
533
+ fragment_update_alpha ,
525
534
)
526
535
for i , model_fragment in enumerate (model_fragments )
527
536
]
@@ -606,16 +615,20 @@ def _step_post_hook(
606
615
step = self ._local_step
607
616
608
617
# Start sending fragments
609
- for fragment in self ._fragments :
618
+ for i , fragment in enumerate ( self ._fragments ) :
610
619
if not fragment .should_prepare_fragment (step ):
611
620
continue
612
621
622
+ logger .debug (f"preparing fragment { i } at step { step } " )
623
+
613
624
fragment .prepare_sync ()
614
625
615
- for fragment in self ._fragments :
626
+ for i , fragment in enumerate ( self ._fragments ) :
616
627
if not fragment .should_sync_fragment (step ):
617
628
continue
618
629
630
+ logger .debug (f"syncing fragment { i } at step { step } " )
631
+
619
632
if not fragment .perform_sync ():
620
633
# Cancel all the previously scheduled allreduce by simply
621
634
# waiting for them. They should have failed but lets be
@@ -655,3 +668,17 @@ def _step_post_hook(
655
668
# training data by looping here. Otherwise that training data goes to
656
669
# waste after recovery
657
670
self ._quorum_loop ()
671
+
672
+ # We need to set make sure `_local_step` is still
673
+ # the same across all replicas if `quorum_id` changed.
674
+ #
675
+ # We can't garuntee a majority of replicas in this new quorum
676
+ # has the latest `max_step`.
677
+ #
678
+ # TODO: This is garuntee is currently lacking
679
+ # in torchft unless `shrink_only` is set.
680
+ #
681
+ # After the quorum though, everyone will have the same
682
+ # `local_step` because replicas with the chosen
683
+ # `max_step` will have the same `local_step`. That is
684
+ # because we don't take additional steps after commit.
0 commit comments