From 21e9c2399f4a8b6550a6dd2e7051e2c3f4e5dc03 Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Mon, 30 Jun 2025 08:26:05 -0700 Subject: [PATCH] enable merging parameters for diloco Summary: - merge local and global parameters of the model after synchronization - add the "alpha" parameter to integration tests Test Plan: ``` pytest -vs ./torchft/local_sgd_integ_test.py ``` --- torchft/local_sgd.py | 39 ++++++++++++++++++++++++++++----- torchft/local_sgd_integ_test.py | 11 ++++++---- 2 files changed, 40 insertions(+), 10 deletions(-) diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index 1e409a7..d6bd78a 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -213,8 +213,14 @@ def __init__( self.should_quantize = should_quantize self._grads: Dict[str, torch.Tensor] = {} + + # Used to save global parameters so that they can be restored in case + # commit fails self.original_parameters: Dict[str, torch.Tensor] = {} + # Used to mix the local and global parameters + self._local_parameters: Dict[str, torch.Tensor] = {} + for name, p in self._model_fragment.named_parameters(): if isinstance(p, DTensor): p = extract_local_tensor(p.data) @@ -237,6 +243,14 @@ def save_parameters(self) -> None: param_to_local = extract_local_tensor(p.data) self.original_parameters[name].copy_(param_to_local, non_blocking=True) + def _save_local_parameters(self) -> None: + """ + Saves a copy of the model's parameters. + """ + with torch.no_grad(): + for name, p in self._model_fragment.named_parameters(): + self._local_parameters[name] = extract_local_tensor(p.data) + @torch.profiler.record_function("torchft::local_sgd::restore_parameters") def restore_parameters(self) -> None: with torch.no_grad(): @@ -293,6 +307,19 @@ def _set_grads(self) -> None: # No longer needed del self._grads[name] + def _clear_local_parameters(self) -> None: + """ + Clears the saved copy of the model's parameters + """ + self._local_parameters = {} + + def _merge_parameters(self) -> None: + """ + Merges the local and global parameters. + """ + for name, p in self._model_fragment.named_parameters(): + p.data.lerp(self._local_parameters[name], 1 - self._fragment_update_alpha) + @torch.profiler.record_function("torchft::local_sgd::wait") def wait(self) -> None: """ @@ -382,6 +409,8 @@ def perform_sync(self) -> bool: self.wait() + # save the parameters so they can be used for merging + self._save_local_parameters() # Restore the parameters back to the previous state self.restore_parameters() @@ -404,8 +433,12 @@ def perform_sync(self) -> bool: self._set_grads() self._outer_optimizer.step() self.save_parameters() + self._merge_parameters() self._outer_optimizer.zero_grad() + # free up memory + self._clear_local_parameters() + return should_commit def _average_grads(self) -> None: @@ -557,12 +590,6 @@ def __init__( if fragment_update_alpha < 0 or fragment_update_alpha > 1: raise ValueError("fragment_update_alpha must be between 0 and 1") - # TODO: Support `fragment_update_alpha` - if fragment_update_alpha != 0.0: - raise ValueError( - "Merging local parameters with global parameters is not supported yet" - ) - super().__init__() self._manager = manager diff --git a/torchft/local_sgd_integ_test.py b/torchft/local_sgd_integ_test.py index f0edcec..355e487 100644 --- a/torchft/local_sgd_integ_test.py +++ b/torchft/local_sgd_integ_test.py @@ -589,18 +589,19 @@ def test_streaming_diloco_recovery(self, use_cuda: bool) -> None: self.assertEqual(event_injectors[1].count[EventInjectorEvent.Failure], 1) - CONFIG: list[tuple[bool, int, int]] = [ - (use_cuda, n_fragments, fragment_sync_delay) + CONFIG: list[tuple[bool, int, int, float]] = [ + (use_cuda, n_fragments, fragment_sync_delay, alpha) for use_cuda in [False] for n_fragments in [1, 2] for fragment_sync_delay in [0, 1] + for alpha in [0.0, 0.5, 1.0] ] # pyre-fixme[56]: Pyre was not able to infer the type of argument @skipIf(sys.platform == "darwin", "not reliable on mac") @parameterized.expand(CONFIG) def test_streaming_diloco_upscale( - self, use_cuda: bool, n_fragments: int, fragment_sync_delay: int + self, use_cuda: bool, n_fragments: int, fragment_sync_delay: int, alpha: float ) -> None: # Skip the test if use_cuda is True and there are not enough GPUs if use_cuda and torch.cuda.device_count() < 2: @@ -642,6 +643,7 @@ def test_streaming_diloco_upscale( "diloco_args": { "fragment_sync_delay": fragment_sync_delay, "sync_every": 4, + "fragment_update_alpha": alpha, }, }, ) @@ -681,7 +683,7 @@ def test_streaming_diloco_upscale( @skipIf(sys.platform == "darwin", "not reliable on mac") @parameterized.expand(CONFIG) def test_streaming_diloco_commit_failure( - self, use_cuda: bool, n_fragments: int, fragment_sync_delay: int + self, use_cuda: bool, n_fragments: int, fragment_sync_delay: int, alpha: float ) -> None: # Skip the test if use_cuda is True and there are not enough GPUs if use_cuda and torch.cuda.device_count() < 2: @@ -719,6 +721,7 @@ def test_streaming_diloco_commit_failure( "diloco_args": { "fragment_sync_delay": fragment_sync_delay, "sync_every": 4, + "fragment_update_alpha": alpha, }, }, )