diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index 38adb42..d8a944f 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -257,17 +257,41 @@ def restore_parameters(self) -> None: else: p.data.copy_(self.original_parameters[name], non_blocking=False) + def _save_grads(self) -> None: + """ + Saves pseudo-gradients of the parameters + """ + with torch.no_grad(): + for name, p in self._model_fragment.named_parameters(): + if isinstance(p, DTensor): + local_param = p.to_local() + else: + local_param = p + pseudogradient = local_param - self.original_parameters[name].to( + p.device + ) + self._grads[name] = pseudogradient + def _set_grads(self) -> None: """ Sets the gradients of the model fragment from the allreduce result """ - for name, p in self._model_fragment.named_parameters(): - if isinstance(p, DTensor): - p.grad._local_tensor = self._grads[name] - else: - p.grad = self._grads[name] + with torch.no_grad(): + for name, p in self._model_fragment.named_parameters(): + # avoid copying the gradient, it should be on the same device + if isinstance(p, DTensor): + p.grad = DTensor.from_local( + self._grads[name], + p.device_mesh, + p.placements, + shape=p.shape, + stride=p.stride(), + ) + else: + p.grad = self._grads[name] - del self._grads[name] + # No longer needed + del self._grads[name] @torch.profiler.record_function("torchft::local_sgd::wait") def wait(self) -> None: @@ -304,14 +328,9 @@ def prepare_sync(self) -> None: Calculate the pseugradient, average them across the manager group and starts allreduce on the pseudo-gradients but doesn't wait for it to finish. """ - # Set the .grad field of each parameter to its pseudogradient - for name, p in self._model_fragment.named_parameters(): - local_param = extract_local_tensor(p.data) - pseudogradient = local_param - self.original_parameters[name].to(p.device) - if isinstance(p, DTensor): - self._grads[name] = pseudogradient - else: - self._grads[name] = pseudogradient + self._save_grads() + + assert len(self._allreduce_futures) == 0 # Make sure tensors are available to `_stream` if self._stream is not None: @@ -371,18 +390,12 @@ def _allreduce_per_param(self) -> None: """Performs allreduce on each gradient tensor separately (original method).""" for name, p in self._model_fragment.named_parameters(): # Perform allreduce on the pseudogradients - assert p.grad is not None - if isinstance(p, DTensor): - work = self._manager.allreduce( - self._grads[name], should_quantize=self.should_quantize - ) - else: - work = self._manager.allreduce( - self._grads[name], should_quantize=self.should_quantize - ) + work = self._manager.allreduce( + self._grads[name], should_quantize=self.should_quantize + ) self._allreduce_futures.append(work) - def bucketize_and_allreduce( + def _bucketize_and_allreduce( self, tensors: List[torch.Tensor], bucket_size_bytes: int, @@ -439,10 +452,9 @@ def _allreduce_bucketized(self) -> None: """ Averages gradients using bucketized allreduce with a fixed buffer. """ - grads = [ - p.grad for p in self._model_fragment.parameters() if p.grad is not None - ] - self.bucketize_and_allreduce( + grads = list(self._grads.values()) + assert len(grads) > 0, "No gradients to allreduce" + self._bucketize_and_allreduce( grads, bucket_size_bytes=self.bucket_cap_mb, ) diff --git a/torchft/local_sgd_test.py b/torchft/local_sgd_test.py index 035d955..26c0208 100644 --- a/torchft/local_sgd_test.py +++ b/torchft/local_sgd_test.py @@ -52,6 +52,16 @@ def _copy_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Ten return {name: value.clone().detach() for name, value in state_dict.items()} +class TinyModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.w1 = nn.Parameter(torch.tensor([1.0, 2.0])) + self.w2 = nn.Parameter(torch.tensor([3.0, 4.0, 5.0])) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x @ self.w1.unsqueeze(0).T + self.w2.sum() + + class LocalSGDTest(TestCase): def test_local_sgd_healthy(self) -> None: model = SimpleModel() @@ -216,24 +226,10 @@ def test_diloco_allreduce_call_efficiency( self.assertEqual(int(allreduce_calls), int(param_count)) def test_bucketization_correctness(self) -> None: - class TinyModel(nn.Module): - def __init__(self): - super().__init__() - self.w1 = nn.Parameter(torch.tensor([1.0, 2.0])) - self.w2 = nn.Parameter(torch.tensor([3.0, 4.0, 5.0])) - - def forward(self, x): - return x @ self.w1.unsqueeze(0).T + self.w2.sum() - model = TinyModel() inner_opt = torch.optim.SGD(model.parameters(), lr=0.1) outer_opt = torch.optim.SGD(model.parameters(), lr=0.1) - # Manually assign fake gradients - grads = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0, 5.0])] - for p, g in zip(model.parameters(), grads): - p.grad = g.clone() - manager = create_autospec(Manager) manager._use_async_quorum = False manager.should_commit.return_value = True @@ -254,10 +250,71 @@ def fake_allreduce( ) diloco._fragments[0].bucket_cap_mb = 10 * 1024 * 1024 + # Manually assign fake gradients + grads = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0, 5.0])] + for g, (name, param) in zip(grads, model.named_parameters()): + diloco._fragments[0]._grads[name] = g.clone() + # Run only bucketized logic diloco._fragments[0]._average_grads() + # The parameter gradients should not be set + for param in model.parameters(): + self.assertEqual(param.grad, None) + + diloco._fragments[0]._set_grads() + # Expect grads to have been doubled expected_grads = [g * 2 for g in grads] for param, expected in zip(model.parameters(), expected_grads): torch.testing.assert_close(param.grad, expected, rtol=1e-5, atol=1e-8) + + def test_gradient_correctness(self) -> None: + model = TinyModel() + inner_opt = torch.optim.SGD(model.parameters(), lr=0.1) + outer_opt = torch.optim.SGD(model.parameters(), lr=0.1) + + manager = create_autospec(Manager) + manager._use_async_quorum = False + manager.should_commit.return_value = True + + # Define fake allreduce: multiplies buffer by 2 + def fake_allreduce( + tensor: Tensor, should_quantize: bool + ) -> torch.futures.Future[Tensor]: + tensor.mul_(2) + fut = torch.futures.Future() # pyre-fixme[29]: not a function + fut.set_result(tensor) + return fut + + manager.allreduce.side_effect = fake_allreduce + + diloco = DiLoCo(manager, [model], inner_opt, outer_opt, sync_every=2) + + # save original parameters + diloco._fragments[0].save_parameters() + + # change the model's parameters + for p in model.parameters(): + p.data.add_(2) + + # calculate and set the gradients + diloco._fragments[0]._save_grads() + + # calculate + diloco._fragments[0]._average_grads() + + # The parameter gradients should not be set + for param in model.parameters(): + self.assertEqual(param.grad, None) + + diloco._fragments[0]._set_grads() + + # we added 2 to the parameters, then multiplied the gradients by 2 + # so we should expect the model's gradient to be 4 + expected_grad = 4 + for param in model.parameters(): + assert param.grad is not None + t = torch.empty_like(param.grad) + t.fill_(expected_grad) + torch.testing.assert_close(param.grad, t)