diff --git a/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py b/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py index 28487c3b367..2c3123e229a 100644 --- a/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +++ b/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py @@ -4,6 +4,18 @@ import torch +# SparseRL-Sync integration: sparse_diff_context wraps the param.copy_() so the +# attached SparseManager can snapshot pre-state, then diff against post-state to +# build per-param sparse-update indices. Falls back to nullcontext when the +# sparse_update package is not installed so the upstream behavior is unchanged. +try: + from sparse_update import sparse_diff_context +except ImportError: + from contextlib import nullcontext + + def sparse_diff_context(*args, **kwargs): + return nullcontext() + def _param_generator(cpu_optimizer): for group in cpu_optimizer.param_groups: @@ -121,7 +133,8 @@ def param_copy_back_gpu_hook(optimizer, args, kwargs): with torch.cuda.stream(self._h2d_stream): for param in _param_generator(optimizer): gpu_param = self.cpu_copys_map_gpu_param[param] - gpu_param.data.copy_(param.data, non_blocking=True) + with sparse_diff_context(gpu_param, param): + gpu_param.data.copy_(param.data, non_blocking=True) self._d2h_stream.record_event().wait(torch.cuda.current_stream()) return param_copy_back_gpu_hook @@ -137,7 +150,8 @@ def fp32_param_copy_back_gpu_hook(optimizer, args, kwargs): if param in self.param_to_fp32_param: fp32_param = self.param_to_fp32_param[param] - param.data.copy_(fp32_param.data) + with sparse_diff_context(param, fp32_param): + param.data.copy_(fp32_param.data) return fp32_param_copy_back_gpu_hook diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index eac21a3ea8e..7b35417b6bc 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -53,6 +53,22 @@ from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper, param_group_identifier_keys from .optimizer_config import OptimizerConfig +# SparseRL-Sync integration: init_sparse_manager binds the optimizer's shard +# views to a SparseManager that tracks per-rollout dp-local diff indices. +# sparse_diff_context wraps the in-place shard_model_param.copy_() that +# realises the new training weights. Both fall back to no-ops when +# sparse_update is not importable so the upstream behaviour is preserved. +try: + from sparse_update import init_sparse_manager, sparse_diff_context +except ImportError: + from contextlib import nullcontext + + def sparse_diff_context(*args, **kwargs): + return nullcontext() + + def init_sparse_manager(*args, **kwargs): + return None + logger = getLogger(__name__) @@ -413,6 +429,15 @@ def _build_model_and_main_param_groups( shard_float16_params_this_group.append(shard_model_param) shard_fp32_from_float16_params_this_group.append(shard_main_param) + # SparseRL-Sync: bind the shard views to a SparseManager so + # later sparse_diff_context() calls can recover the owner. + init_sparse_manager( + model_param=model_param, + shard_model_weight=shard_model_param, + shard_main_weight=shard_main_param, + param_range=param_range, + ) + # fp32 params. elif model_param.type() == 'torch.cuda.FloatTensor': shard_model_param = model_param.view(-1)[param_range.start : param_range.end] @@ -424,6 +449,15 @@ def _build_model_and_main_param_groups( if hasattr(model_param, 'shared'): shard_model_param.shared = model_param.shared + # SparseRL-Sync: fp32 params share the same shard view for + # both "model" and "main" sides; pass it on both slots. + init_sparse_manager( + model_param=model_param, + shard_model_weight=shard_model_param, + shard_main_weight=shard_model_param, + param_range=param_range, + ) + else: raise TypeError( 'Wrapped parameters must be one of ' @@ -2456,7 +2490,11 @@ def copy_group_params(shard_main_groups, model_groups): # FP8 params are quantized in the above "quantize_param_shard" function. continue else: - shard_model_param.data.copy_(shard_main_param) + # SparseRL-Sync: capture the pre/post state of the + # shard copy so the per-param diff indices can be + # computed by the manager. + with sparse_diff_context(shard_model_param, shard_main_param): + shard_model_param.data.copy_(shard_main_param) # Copy shard groups to model groups. copy_group_params(self.shard_fp32_from_float16_groups, self.model_float16_groups) diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py index f18309217c3..ac839c21f18 100644 --- a/megatron/core/pipeline_parallel/p2p_communication.py +++ b/megatron/core/pipeline_parallel/p2p_communication.py @@ -26,22 +26,22 @@ def _batched_p2p_ops( ops = [] if tensor_send_prev is not None: send_prev_op = torch.distributed.P2POp( - torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, + torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group ) ops.append(send_prev_op) if tensor_recv_prev is not None: recv_prev_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, + torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group ) ops.append(recv_prev_op) if tensor_send_next is not None: send_next_op = torch.distributed.P2POp( - torch.distributed.isend, tensor_send_next, next_pipeline_rank, + torch.distributed.isend, tensor_send_next, next_pipeline_rank, group ) ops.append(send_next_op) if tensor_recv_next is not None: recv_next_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, + torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group ) ops.append(recv_next_op) if len(ops) > 0: