Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,18 @@

import torch

# SparseRL-Sync integration: sparse_diff_context wraps the param.copy_() so the
# attached SparseManager can snapshot pre-state, then diff against post-state to
# build per-param sparse-update indices. Falls back to nullcontext when the
# sparse_update package is not installed so the upstream behavior is unchanged.
try:
from sparse_update import sparse_diff_context
except ImportError:
from contextlib import nullcontext

def sparse_diff_context(*args, **kwargs):
return nullcontext()


def _param_generator(cpu_optimizer):
for group in cpu_optimizer.param_groups:
Expand Down Expand Up @@ -121,7 +133,8 @@ def param_copy_back_gpu_hook(optimizer, args, kwargs):
with torch.cuda.stream(self._h2d_stream):
for param in _param_generator(optimizer):
gpu_param = self.cpu_copys_map_gpu_param[param]
gpu_param.data.copy_(param.data, non_blocking=True)
with sparse_diff_context(gpu_param, param):
gpu_param.data.copy_(param.data, non_blocking=True)
self._d2h_stream.record_event().wait(torch.cuda.current_stream())

return param_copy_back_gpu_hook
Expand All @@ -137,7 +150,8 @@ def fp32_param_copy_back_gpu_hook(optimizer, args, kwargs):

if param in self.param_to_fp32_param:
fp32_param = self.param_to_fp32_param[param]
param.data.copy_(fp32_param.data)
with sparse_diff_context(param, fp32_param):
param.data.copy_(fp32_param.data)

return fp32_param_copy_back_gpu_hook

Expand Down
40 changes: 39 additions & 1 deletion megatron/core/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,22 @@
from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper, param_group_identifier_keys
from .optimizer_config import OptimizerConfig

# SparseRL-Sync integration: init_sparse_manager binds the optimizer's shard
# views to a SparseManager that tracks per-rollout dp-local diff indices.
# sparse_diff_context wraps the in-place shard_model_param.copy_() that
# realises the new training weights. Both fall back to no-ops when
# sparse_update is not importable so the upstream behaviour is preserved.
try:
from sparse_update import init_sparse_manager, sparse_diff_context
except ImportError:
from contextlib import nullcontext

def sparse_diff_context(*args, **kwargs):
return nullcontext()

def init_sparse_manager(*args, **kwargs):
return None

logger = getLogger(__name__)


Expand Down Expand Up @@ -413,6 +429,15 @@ def _build_model_and_main_param_groups(
shard_float16_params_this_group.append(shard_model_param)
shard_fp32_from_float16_params_this_group.append(shard_main_param)

# SparseRL-Sync: bind the shard views to a SparseManager so
# later sparse_diff_context() calls can recover the owner.
init_sparse_manager(
model_param=model_param,
shard_model_weight=shard_model_param,
shard_main_weight=shard_main_param,
param_range=param_range,
)

# fp32 params.
elif model_param.type() == 'torch.cuda.FloatTensor':
shard_model_param = model_param.view(-1)[param_range.start : param_range.end]
Expand All @@ -424,6 +449,15 @@ def _build_model_and_main_param_groups(
if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared

# SparseRL-Sync: fp32 params share the same shard view for
# both "model" and "main" sides; pass it on both slots.
init_sparse_manager(
model_param=model_param,
shard_model_weight=shard_model_param,
shard_main_weight=shard_model_param,
param_range=param_range,
)

else:
raise TypeError(
'Wrapped parameters must be one of '
Expand Down Expand Up @@ -2456,7 +2490,11 @@ def copy_group_params(shard_main_groups, model_groups):
# FP8 params are quantized in the above "quantize_param_shard" function.
continue
else:
shard_model_param.data.copy_(shard_main_param)
# SparseRL-Sync: capture the pre/post state of the
# shard copy so the per-param diff indices can be
# computed by the manager.
with sparse_diff_context(shard_model_param, shard_main_param):
shard_model_param.data.copy_(shard_main_param)

# Copy shard groups to model groups.
copy_group_params(self.shard_fp32_from_float16_groups, self.model_float16_groups)
Expand Down
8 changes: 4 additions & 4 deletions megatron/core/pipeline_parallel/p2p_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,22 @@ def _batched_p2p_ops(
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend, tensor_send_prev, prev_pipeline_rank,
torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group
)
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank,
torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group
)
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor_send_next, next_pipeline_rank,
torch.distributed.isend, tensor_send_next, next_pipeline_rank, group
)
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_next, next_pipeline_rank,
torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group
)
ops.append(recv_next_op)
if len(ops) > 0:
Expand Down