Skip to content

Commit 7aba771

Browse files
committed
fix gradient allreduce
Summary: - fix setting `_local_tensor` of a dtensor directly - fix allreduce bucketized to not use `parameter.grad` - simplify some code
1 parent a86ed0d commit 7aba771

File tree

2 files changed

+38
-29
lines changed

2 files changed

+38
-29
lines changed

torchft/local_sgd.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -257,17 +257,35 @@ def restore_parameters(self) -> None:
257257
else:
258258
p.data.copy_(self.original_parameters[name], non_blocking=False)
259259

260+
def _save_grads(self) -> None:
261+
with torch.no_grad():
262+
for name, p in self._model_fragment.named_parameters():
263+
if isinstance(p, DTensor):
264+
local_param = p.to_local()
265+
else:
266+
local_param = p
267+
pseudogradient = local_param - self.original_parameters[name].to(
268+
p.device
269+
)
270+
self._grads[name] = pseudogradient
271+
260272
def _set_grads(self) -> None:
261273
"""
262274
Sets the gradients of the model fragment from the allreduce result
263275
"""
264-
for name, p in self._model_fragment.named_parameters():
265-
if isinstance(p, DTensor):
266-
p.grad._local_tensor = self._grads[name]
267-
else:
268-
p.grad = self._grads[name]
269-
270-
del self._grads[name]
276+
with torch.no_grad():
277+
for name, p in self._model_fragment.named_parameters():
278+
# avoid copying the gradient, it should be on the same device
279+
if isinstance(p, DTensor):
280+
p.grad = DTensor.from_local(
281+
self._grads[name],
282+
p.device_mesh,
283+
p.placements,
284+
shape=p.shape,
285+
stride=p.stride(),
286+
)
287+
else:
288+
p.grad = self._grads[name]
271289

272290
@torch.profiler.record_function("torchft::local_sgd::wait")
273291
def wait(self) -> None:
@@ -304,14 +322,9 @@ def prepare_sync(self) -> None:
304322
Calculate the pseugradient, average them across the manager group and starts
305323
allreduce on the pseudo-gradients but doesn't wait for it to finish.
306324
"""
307-
# Set the .grad field of each parameter to its pseudogradient
308-
for name, p in self._model_fragment.named_parameters():
309-
local_param = extract_local_tensor(p.data)
310-
pseudogradient = local_param - self.original_parameters[name].to(p.device)
311-
if isinstance(p, DTensor):
312-
self._grads[name] = pseudogradient
313-
else:
314-
self._grads[name] = pseudogradient
325+
self._save_grads()
326+
327+
assert len(self._allreduce_futures) == 0
315328

316329
# Make sure tensors are available to `_stream`
317330
if self._stream is not None:
@@ -371,18 +384,12 @@ def _allreduce_per_param(self) -> None:
371384
"""Performs allreduce on each gradient tensor separately (original method)."""
372385
for name, p in self._model_fragment.named_parameters():
373386
# Perform allreduce on the pseudogradients
374-
assert p.grad is not None
375-
if isinstance(p, DTensor):
376-
work = self._manager.allreduce(
377-
self._grads[name], should_quantize=self.should_quantize
378-
)
379-
else:
380-
work = self._manager.allreduce(
381-
self._grads[name], should_quantize=self.should_quantize
382-
)
387+
work = self._manager.allreduce(
388+
self._grads[name], should_quantize=self.should_quantize
389+
)
383390
self._allreduce_futures.append(work)
384391

385-
def bucketize_and_allreduce(
392+
def _bucketize_and_allreduce(
386393
self,
387394
tensors: List[torch.Tensor],
388395
bucket_size_bytes: int,
@@ -439,10 +446,9 @@ def _allreduce_bucketized(self) -> None:
439446
"""
440447
Averages gradients using bucketized allreduce with a fixed buffer.
441448
"""
442-
grads = [
443-
p.grad for p in self._model_fragment.parameters() if p.grad is not None
444-
]
445-
self.bucketize_and_allreduce(
449+
grads = list(self._grads.values())
450+
assert len(grads) > 0, "No gradients to allreduce"
451+
self._bucketize_and_allreduce(
446452
grads,
447453
bucket_size_bytes=self.bucket_cap_mb,
448454
)

torchft/local_sgd_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,9 @@ def fake_allreduce(
255255
diloco._fragments[0].bucket_cap_mb = 10 * 1024 * 1024
256256

257257
# Run only bucketized logic
258+
for name, param in model.named_parameters():
259+
assert param.grad is not None
260+
diloco._fragments[0]._grads[name] = param.grad
258261
diloco._fragments[0]._average_grads()
259262

260263
# Expect grads to have been doubled

0 commit comments

Comments
 (0)