@@ -257,17 +257,35 @@ def restore_parameters(self) -> None:
257
257
else :
258
258
p .data .copy_ (self .original_parameters [name ], non_blocking = False )
259
259
260
+ def _save_grads (self ) -> None :
261
+ with torch .no_grad ():
262
+ for name , p in self ._model_fragment .named_parameters ():
263
+ if isinstance (p , DTensor ):
264
+ local_param = p .to_local ()
265
+ else :
266
+ local_param = p
267
+ pseudogradient = local_param - self .original_parameters [name ].to (
268
+ p .device
269
+ )
270
+ self ._grads [name ] = pseudogradient
271
+
260
272
def _set_grads (self ) -> None :
261
273
"""
262
274
Sets the gradients of the model fragment from the allreduce result
263
275
"""
264
- for name , p in self ._model_fragment .named_parameters ():
265
- if isinstance (p , DTensor ):
266
- p .grad ._local_tensor = self ._grads [name ]
267
- else :
268
- p .grad = self ._grads [name ]
269
-
270
- del self ._grads [name ]
276
+ with torch .no_grad ():
277
+ for name , p in self ._model_fragment .named_parameters ():
278
+ # avoid copying the gradient, it should be on the same device
279
+ if isinstance (p , DTensor ):
280
+ p .grad = DTensor .from_local (
281
+ self ._grads [name ],
282
+ p .device_mesh ,
283
+ p .placements ,
284
+ shape = p .shape ,
285
+ stride = p .stride (),
286
+ )
287
+ else :
288
+ p .grad = self ._grads [name ]
271
289
272
290
@torch .profiler .record_function ("torchft::local_sgd::wait" )
273
291
def wait (self ) -> None :
@@ -304,14 +322,9 @@ def prepare_sync(self) -> None:
304
322
Calculate the pseugradient, average them across the manager group and starts
305
323
allreduce on the pseudo-gradients but doesn't wait for it to finish.
306
324
"""
307
- # Set the .grad field of each parameter to its pseudogradient
308
- for name , p in self ._model_fragment .named_parameters ():
309
- local_param = extract_local_tensor (p .data )
310
- pseudogradient = local_param - self .original_parameters [name ].to (p .device )
311
- if isinstance (p , DTensor ):
312
- self ._grads [name ] = pseudogradient
313
- else :
314
- self ._grads [name ] = pseudogradient
325
+ self ._save_grads ()
326
+
327
+ assert len (self ._allreduce_futures ) == 0
315
328
316
329
# Make sure tensors are available to `_stream`
317
330
if self ._stream is not None :
@@ -371,18 +384,12 @@ def _allreduce_per_param(self) -> None:
371
384
"""Performs allreduce on each gradient tensor separately (original method)."""
372
385
for name , p in self ._model_fragment .named_parameters ():
373
386
# Perform allreduce on the pseudogradients
374
- assert p .grad is not None
375
- if isinstance (p , DTensor ):
376
- work = self ._manager .allreduce (
377
- self ._grads [name ], should_quantize = self .should_quantize
378
- )
379
- else :
380
- work = self ._manager .allreduce (
381
- self ._grads [name ], should_quantize = self .should_quantize
382
- )
387
+ work = self ._manager .allreduce (
388
+ self ._grads [name ], should_quantize = self .should_quantize
389
+ )
383
390
self ._allreduce_futures .append (work )
384
391
385
- def bucketize_and_allreduce (
392
+ def _bucketize_and_allreduce (
386
393
self ,
387
394
tensors : List [torch .Tensor ],
388
395
bucket_size_bytes : int ,
@@ -439,10 +446,9 @@ def _allreduce_bucketized(self) -> None:
439
446
"""
440
447
Averages gradients using bucketized allreduce with a fixed buffer.
441
448
"""
442
- grads = [
443
- p .grad for p in self ._model_fragment .parameters () if p .grad is not None
444
- ]
445
- self .bucketize_and_allreduce (
449
+ grads = list (self ._grads .values ())
450
+ assert len (grads ) > 0 , "No gradients to allreduce"
451
+ self ._bucketize_and_allreduce (
446
452
grads ,
447
453
bucket_size_bytes = self .bucket_cap_mb ,
448
454
)
0 commit comments