From 5687088ba2626188ac9959c613943008d4c816f9 Mon Sep 17 00:00:00 2001 From: Max Kovalenko Date: Thu, 21 Nov 2024 16:38:13 +0200 Subject: [PATCH 1/2] Stage3: Use new torch grad accumulation hooks API * This commit addresses an issue reported in: https://github.com/microsoft/DeepSpeed/issues/6718 * The existing code has been using the grad_acc node hook to reduce params grads. The constructs such as param.data = replicated_tensor.data used in allgather_params(..) are compiled into param.set() causing the hook assigned to the grad_acc node not being called. * This is a known torch issue https://github.com/pytorch/pytorch/issues/139742. * The above caused accuracy issues and could be temporarily solved by simply disabling the torch compile when activation checkpointing is used. * This commit provides a clean solution by replacing the hook on a grad_acc node to a hook using a new and robust hook API on a param itself: param.register_post_accumulate_grad_hook(..) --- deepspeed/runtime/zero/stage3.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 99a5ecf41a2f..c45a10b77c8b 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1156,7 +1156,6 @@ def overlapping_partition_gradients_reduce_epilogue(self): def create_reduce_and_remove_grad_hooks(self): print_rank_0(f'[Begin] Create gradient reduction hooks') - self.grad_accs = [] self.leaf_parameters = defaultdict(list) for i, param_group in enumerate(self.fp16_groups): for param in param_group: @@ -1169,15 +1168,13 @@ def create_reduce_and_remove_grad_hooks(self): #print(f"After all gather {param.device}, {param.shape}") def wrapper(param): - param_tmp = param.expand_as(param) - grad_acc = param_tmp.grad_fn.next_functions[0][0] @instrument_w_nvtx def reduce_partition_and_remove_grads(*notneeded): self.reduce_ready_partitions_and_remove_grads(param) - self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads)) - self.grad_accs.append(grad_acc) + self._grad_acc_hooks.append( + param.register_post_accumulate_grad_hook(reduce_partition_and_remove_grads)) #print(f"param grad fn {param.expand_as(param).grad_fn}") if z3_leaf_parameter(param): From 1f2dcedd86406a583fd006cb9a5971fe3191c4b9 Mon Sep 17 00:00:00 2001 From: Max Kovalenko Date: Sun, 15 Dec 2024 14:40:30 +0200 Subject: [PATCH 2/2] Stage3: Use torch grad accumulation hooks API per torch version. Reject compile on torch < 2.1. --- deepspeed/runtime/compiler.py | 3 ++- deepspeed/runtime/zero/stage3.py | 4 ++-- deepspeed/utils/torch.py | 9 +++++++++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/compiler.py b/deepspeed/runtime/compiler.py index fa9220f4fcd0..be778b83f8bb 100644 --- a/deepspeed/runtime/compiler.py +++ b/deepspeed/runtime/compiler.py @@ -4,6 +4,7 @@ # DeepSpeed Team import torch +from deepspeed.utils.torch import required_torch_version try: from torch.compiler import is_compiling as torch_is_compiling @@ -16,7 +17,7 @@ def is_compile_supported(): - return hasattr(torch, "compiler") and hasattr(torch.nn.Module, "compile") + return required_torch_version(min_version=2.1) def disable(func): diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 89d29879978c..28f91cb9b3ab 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -16,6 +16,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.utils import logger +from deepspeed.utils.torch import register_grad_hook from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item @@ -1176,8 +1177,7 @@ def wrapper(param): def reduce_partition_and_remove_grads(*notneeded): self.reduce_ready_partitions_and_remove_grads(param) - self._grad_acc_hooks.append( - param.register_post_accumulate_grad_hook(reduce_partition_and_remove_grads)) + self._grad_acc_hooks.append(register_grad_hook(param, reduce_partition_and_remove_grads)) #print(f"param grad fn {param.expand_as(param).grad_fn}") if z3_leaf_parameter(param): diff --git a/deepspeed/utils/torch.py b/deepspeed/utils/torch.py index eb22d3561035..1d32775fe64a 100644 --- a/deepspeed/utils/torch.py +++ b/deepspeed/utils/torch.py @@ -20,3 +20,12 @@ def required_torch_version(min_version=None, max_version=None): return False return True + + +def register_grad_hook(param, hook): + if required_torch_version(min_version=2.1): + return param.register_post_accumulate_grad_hook(hook) + else: + param_tmp = param.expand_as(param) + grad_acc = param_tmp.grad_fn.next_functions[0][0] + return grad_acc.register_hook(hook)