Skip to content

Commit e9c0fd6

Browse files
committed
fix gradient allreduce
Summary: - fix setting `_local_tensor` of a dtensor directly - fix allreduce bucketized to not use `parameter.grad` - simplify some code
1 parent 078b6c0 commit e9c0fd6

File tree

1 file changed

+30
-29
lines changed

1 file changed

+30
-29
lines changed

torchft/local_sgd.py

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -257,17 +257,30 @@ def restore_parameters(self) -> None:
257257
else:
258258
p.data.copy_(self.original_parameters[name], non_blocking=False)
259259

260+
def _save_grads(self) -> None:
261+
with torch.no_grad():
262+
for name, p in self._model_fragment.named_parameters():
263+
local_param = extract_local_tensor(p.data)
264+
pseudogradient = local_param - self.original_parameters[name].to(p.device)
265+
self._grads[name] = pseudogradient
266+
260267
def _set_grads(self) -> None:
261268
"""
262269
Sets the gradients of the model fragment from the allreduce result
263270
"""
264-
for name, p in self._model_fragment.named_parameters():
265-
if isinstance(p, DTensor):
266-
p.grad._local_tensor = self._grads[name]
267-
else:
268-
p.grad = self._grads[name]
269-
270-
del self._grads[name]
271+
with torch.no_grad():
272+
for name, p in self._model_fragment.named_parameters():
273+
# avoid copying the gradient, it should be on the same device
274+
if isinstance(p, DTensor):
275+
p.grad = DTensor.from_local(
276+
self._grads[name],
277+
p.device_mesh,
278+
p.placements,
279+
shape=p.shape,
280+
stride=p.stride(),
281+
)
282+
else:
283+
p.grad = self._grads[name]
271284

272285
@torch.profiler.record_function("torchft::local_sgd::wait")
273286
def wait(self) -> None:
@@ -304,14 +317,9 @@ def prepare_sync(self) -> None:
304317
Calculate the pseugradient, average them across the manager group and starts
305318
allreduce on the pseudo-gradients but doesn't wait for it to finish.
306319
"""
307-
# Set the .grad field of each parameter to its pseudogradient
308-
for name, p in self._model_fragment.named_parameters():
309-
local_param = extract_local_tensor(p.data)
310-
pseudogradient = local_param - self.original_parameters[name].to(p.device)
311-
if isinstance(p, DTensor):
312-
self._grads[name] = pseudogradient
313-
else:
314-
self._grads[name] = pseudogradient
320+
self._save_grads()
321+
322+
assert len(self._allreduce_futures) == 0
315323

316324
# Make sure tensors are available to `_stream`
317325
if self._stream is not None:
@@ -371,18 +379,12 @@ def _allreduce_per_param(self) -> None:
371379
"""Performs allreduce on each gradient tensor separately (original method)."""
372380
for name, p in self._model_fragment.named_parameters():
373381
# Perform allreduce on the pseudogradients
374-
assert p.grad is not None
375-
if isinstance(p, DTensor):
376-
work = self._manager.allreduce(
377-
self._grads[name], should_quantize=self.should_quantize
378-
)
379-
else:
380-
work = self._manager.allreduce(
381-
self._grads[name], should_quantize=self.should_quantize
382-
)
382+
work = self._manager.allreduce(
383+
self._grads[name], should_quantize=self.should_quantize
384+
)
383385
self._allreduce_futures.append(work)
384386

385-
def bucketize_and_allreduce(
387+
def _bucketize_and_allreduce(
386388
self,
387389
tensors: List[torch.Tensor],
388390
bucket_size_bytes: int,
@@ -439,10 +441,9 @@ def _allreduce_bucketized(self) -> None:
439441
"""
440442
Averages gradients using bucketized allreduce with a fixed buffer.
441443
"""
442-
grads = [
443-
p.grad for p in self._model_fragment.parameters() if p.grad is not None
444-
]
445-
self.bucketize_and_allreduce(
444+
grads = list(self._grads.values())
445+
assert len(grads) > 0, "No gradients to allreduce"
446+
self._bucketize_and_allreduce(
446447
grads,
447448
bucket_size_bytes=self.bucket_cap_mb,
448449
)

0 commit comments

Comments
 (0)