@@ -257,17 +257,30 @@ def restore_parameters(self) -> None:
257
257
else :
258
258
p .data .copy_ (self .original_parameters [name ], non_blocking = False )
259
259
260
+ def _save_grads (self ) -> None :
261
+ with torch .no_grad ():
262
+ for name , p in self ._model_fragment .named_parameters ():
263
+ local_param = extract_local_tensor (p .data )
264
+ pseudogradient = local_param - self .original_parameters [name ].to (p .device )
265
+ self ._grads [name ] = pseudogradient
266
+
260
267
def _set_grads (self ) -> None :
261
268
"""
262
269
Sets the gradients of the model fragment from the allreduce result
263
270
"""
264
- for name , p in self ._model_fragment .named_parameters ():
265
- if isinstance (p , DTensor ):
266
- p .grad ._local_tensor = self ._grads [name ]
267
- else :
268
- p .grad = self ._grads [name ]
269
-
270
- del self ._grads [name ]
271
+ with torch .no_grad ():
272
+ for name , p in self ._model_fragment .named_parameters ():
273
+ # avoid copying the gradient, it should be on the same device
274
+ if isinstance (p , DTensor ):
275
+ p .grad = DTensor .from_local (
276
+ self ._grads [name ],
277
+ p .device_mesh ,
278
+ p .placements ,
279
+ shape = p .shape ,
280
+ stride = p .stride (),
281
+ )
282
+ else :
283
+ p .grad = self ._grads [name ]
271
284
272
285
@torch .profiler .record_function ("torchft::local_sgd::wait" )
273
286
def wait (self ) -> None :
@@ -304,14 +317,9 @@ def prepare_sync(self) -> None:
304
317
Calculate the pseugradient, average them across the manager group and starts
305
318
allreduce on the pseudo-gradients but doesn't wait for it to finish.
306
319
"""
307
- # Set the .grad field of each parameter to its pseudogradient
308
- for name , p in self ._model_fragment .named_parameters ():
309
- local_param = extract_local_tensor (p .data )
310
- pseudogradient = local_param - self .original_parameters [name ].to (p .device )
311
- if isinstance (p , DTensor ):
312
- self ._grads [name ] = pseudogradient
313
- else :
314
- self ._grads [name ] = pseudogradient
320
+ self ._save_grads ()
321
+
322
+ assert len (self ._allreduce_futures ) == 0
315
323
316
324
# Make sure tensors are available to `_stream`
317
325
if self ._stream is not None :
@@ -371,18 +379,12 @@ def _allreduce_per_param(self) -> None:
371
379
"""Performs allreduce on each gradient tensor separately (original method)."""
372
380
for name , p in self ._model_fragment .named_parameters ():
373
381
# Perform allreduce on the pseudogradients
374
- assert p .grad is not None
375
- if isinstance (p , DTensor ):
376
- work = self ._manager .allreduce (
377
- self ._grads [name ], should_quantize = self .should_quantize
378
- )
379
- else :
380
- work = self ._manager .allreduce (
381
- self ._grads [name ], should_quantize = self .should_quantize
382
- )
382
+ work = self ._manager .allreduce (
383
+ self ._grads [name ], should_quantize = self .should_quantize
384
+ )
383
385
self ._allreduce_futures .append (work )
384
386
385
- def bucketize_and_allreduce (
387
+ def _bucketize_and_allreduce (
386
388
self ,
387
389
tensors : List [torch .Tensor ],
388
390
bucket_size_bytes : int ,
@@ -439,10 +441,9 @@ def _allreduce_bucketized(self) -> None:
439
441
"""
440
442
Averages gradients using bucketized allreduce with a fixed buffer.
441
443
"""
442
- grads = [
443
- p .grad for p in self ._model_fragment .parameters () if p .grad is not None
444
- ]
445
- self .bucketize_and_allreduce (
444
+ grads = list (self ._grads .values ())
445
+ assert len (grads ) > 0 , "No gradients to allreduce"
446
+ self ._bucketize_and_allreduce (
446
447
grads ,
447
448
bucket_size_bytes = self .bucket_cap_mb ,
448
449
)
0 commit comments