Skip to content
Draft
Show file tree
Hide file tree
Changes from 66 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
96051cc
Add --moe-use-device-initiated-grouped-gemm to allow token_per_expert…
QiZhangNV Nov 3, 2025
92e247f
Initial change for packed offloading
vasunvidia Nov 17, 2025
fa8da97
Bug fix
Nov 17, 2025
31b2ba9
Mem Opt
vasunvidia Nov 17, 2025
78dfacd
Handle MXFP8Tensor offload
Nov 20, 2025
0c0a75e
Enable Packed offloading to CPU pinned memory with PACKED_OFFLOAD_CPU=1
Nov 20, 2025
6703445
Enable activation truncation for first step
Nov 21, 2025
955fbba
Overflow check and assert
Nov 22, 2025
cf7b68b
Check in temporary solution for detecing overflow in receiving buffer
nanz-nv Nov 22, 2025
dc4e973
Reconstruct the stash buffer into a 2D structure
nanz-nv Nov 23, 2025
683d283
Refactor the code to check overflow in HybridEP receiving buffer
nanz-nv Nov 24, 2025
9c65eea
Use CPU offloading context manager as a WAR for now to WAR the proble…
nanz-nv Nov 24, 2025
7c2aa7c
Add support for paged stashing
nanz-nv Nov 25, 2025
f44a426
Add the feature of speculative CE stashing
nanz-nv Nov 26, 2025
1bbaf54
Fix PP schedule
Nov 26, 2025
629bf22
Use common buffer across VP for paged stashing
vasunvidia Nov 26, 2025
50c6c17
Disable Packed Offloading for validation
Nov 27, 2025
32fbc15
Fixe perf issue in packed stash/pop kernels
nanz-nv Nov 27, 2025
bff7e8b
Minor fix for tensor allocation and padding requirement on budget
nanz-nv Dec 7, 2025
94c14bc
Packed/paged offloading is current not stream-safe. Need to put stash…
nanz-nv Dec 7, 2025
7b0ef46
add new hybrid ep
Autumn1998 Dec 9, 2025
6905e2c
Remove the overflow check in framework because it is now done by hybr…
nanz-nv Dec 10, 2025
9c056df
Fix one merge conflict
nanz-nv Dec 10, 2025
669d9f7
Code cleanup
vasunvidia Dec 11, 2025
66ebb1e
Add second autograd to avoid triple buffering
vasunvidia Dec 12, 2025
535b277
Avoid unnecessary wait_stream for reload in case of 1f1b
vasunvidia Dec 12, 2025
cb71c66
Check in dynamic-shape-aware SwiGLU triton kernel
nanz-nv Dec 18, 2025
c308899
Major cleanup and refactor
nanz-nv Dec 18, 2025
4c1b01b
Check in paged_stash.py that was omited in the previous commit
nanz-nv Dec 18, 2025
3536250
Remove d2d page feature for now
nanz-nv Dec 18, 2025
90b02d5
Update added arguments and add compatibility check
nanz-nv Dec 18, 2025
fb8fc21
refine overflow check
nanz-nv Dec 18, 2025
27352b5
Fixing lint issues
nanz-nv Dec 19, 2025
84ba8b8
Minor refactor
vasunvidia Jan 8, 2026
e32a28b
Add unit test for Paged Stashing
vasunvidia Jan 9, 2026
6ca4a01
1. allocate stashing buffer based on avg token count if STASH_BUFFER_…
nanz-nv Jan 22, 2026
e88df64
Reenable overlapping of stashing kernels
nanz-nv Jan 23, 2026
10ed85b
Remove a buggy/redundant reset
nanz-nv Feb 3, 2026
62ffb30
Cleanup moe-expert-rank-capacity-factor argument.
vasunvidia Feb 9, 2026
19b62d2
Update moe_use_device_initiated_grouped_gemm check for paged stashing…
vasunvidia Feb 21, 2026
7c868e9
Remove the WAR of running warmup on a side stream
nanz-nv Mar 17, 2026
b815f99
Fix for data_iterator type check in Paged Stashing fallback
vasunvidia Mar 18, 2026
ac42b99
Change to support eager-mode fallback for validation
vasunvidia Mar 18, 2026
5cff7a9
Revert "Check in dynamic-shape-aware SwiGLU triton kernel"
nanz-nv Mar 18, 2026
6dd213b
Fixed some minor issues
nanz-nv Mar 18, 2026
b28f812
Fix the unit test
nanz-nv Mar 18, 2026
2e92588
Initial commit for spill to cpu feature
nanz-nv Mar 14, 2026
58a97c1
Move paged stashing knobs from env vars to transformer_config knobs
nanz-nv Mar 18, 2026
79522cc
Refactor the knobs a bit so it is more intuitive
nanz-nv Mar 18, 2026
b3be4de
Use get_attr_wrapped_model util to access moe and mtp layers
vasunvidia Mar 18, 2026
3fc366e
Refactor the unit test for paged stashing
nanz-nv Mar 20, 2026
7a23c78
Clean up after rebase
nanz-nv Mar 21, 2026
b4e1e56
skip routed expert padding
zhongbozhu Mar 24, 2026
06d8a85
Refactor/clean-up logging
nanz-nv Mar 25, 2026
fb620fc
Resolve review feedback
nanz-nv Mar 25, 2026
09ef7af
Fix fallback data read for PP=1
vasunvidia Mar 25, 2026
f9a5fcf
Paged stashing refactor
vasunvidia Mar 26, 2026
25be640
Remove logical_shape check
vasunvidia Mar 26, 2026
1d3755a
Remove paged_stash_set_last_layer
vasunvidia Mar 26, 2026
f227fd6
Cleanup PadUnpadFunction
vasunvidia Mar 26, 2026
b5cb760
Remove stash modules and remove stashing code for non-fused grouped gemm
nanz-nv Mar 30, 2026
de6c6eb
Remove dead code
nanz-nv Mar 30, 2026
84d1803
Fix TE import problem in experts.py
nanz-nv Mar 31, 2026
b49c1a0
Fixed merge conflict
nanz-nv Mar 31, 2026
2617ff9
Address reviewer's comments
nanz-nv Mar 31, 2026
a037128
Review comments
vasunvidia Apr 2, 2026
05ea747
Add PagedStashRunner for overflow detection for pure M-LM training
vasunvidia Apr 2, 2026
a133251
Release stashing buffer before fallback to restore the memory
nanz-nv Apr 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions megatron/core/full_cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import logging

import gc
import torch

from megatron.core.tensor_parallel.random import get_all_rng_states
Expand Down Expand Up @@ -180,12 +181,10 @@ def __call__(self, *args, **kwargs):
torch.cuda.synchronize()
torch.distributed.barrier()
logger.info(f'CUDA graph capture done for {training_str}!!!')

if FullCudaGraphWrapper.cuda_graph[training_str] is None:
FullCudaGraphWrapper.result[training_str] = self.forward_backward_func(*args, **kwargs)
else:
FullCudaGraphWrapper.cuda_graph[training_str].replay()

self.next_iter(training_str)
return FullCudaGraphWrapper.result[training_str]

Expand All @@ -196,3 +195,19 @@ def curr_iter(self, stage):
def next_iter(self, stage):
"""Increment current training/validation iteration."""
FullCudaGraphWrapper.curr_iteration[stage] += 1

def reset_cuda_graph(self, stage=None):
"""Reset CUDA graph."""
if stage is None or stage == 'training':
if FullCudaGraphWrapper.cuda_graph['training'] is not None:
del FullCudaGraphWrapper.cuda_graph['training']
FullCudaGraphWrapper.cuda_graph['training'] = None
FullCudaGraphWrapper.result['training'] = None
FullCudaGraphWrapper.curr_iteration['training'] = 0
if stage is None or stage == 'validation':
if FullCudaGraphWrapper.cuda_graph['validation'] is not None:
del FullCudaGraphWrapper.cuda_graph['validation']
FullCudaGraphWrapper.cuda_graph['validation'] = None
FullCudaGraphWrapper.result['validation'] = None
FullCudaGraphWrapper.curr_iteration['validation'] = 0
gc.collect()
12 changes: 12 additions & 0 deletions megatron/core/models/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from megatron.core.tensor_parallel import gather_from_sequence_parallel_region
from megatron.core.transformer.enums import CudaGraphScope, ModelType
from megatron.core.transformer.linear_cross_entropy import LinearCrossEntropyModule
from megatron.core.transformer.moe.paged_stash import paged_stash_init_chunk_handler
from megatron.core.transformer.multi_token_prediction import (
MultiTokenPredictionBlock,
mtp_on_this_rank,
Expand Down Expand Up @@ -483,6 +484,12 @@ def preprocess_for_fine_grained_offloading(self):
off_interface.mark_not_offload(param)
self.disable_param_offloading = False

def preprocess_for_paged_stash(self):
"""Preprocess for paged stash."""
return paged_stash_init_chunk_handler(
vp_size=self.config.virtual_pipeline_model_parallel_size, vp_stage=self.vp_stage
)

def forward(
self,
input_ids: Tensor,
Expand Down Expand Up @@ -519,6 +526,9 @@ def forward(
if self.config.fine_grained_activation_offloading:
self.preprocess_for_fine_grained_offloading()

if self.config.moe_paged_stash:
self.preprocess_for_paged_stash()

inference_context = deprecate_inference_params(inference_context, inference_params)

preproc_output = self._preprocess(
Expand Down Expand Up @@ -823,6 +833,8 @@ def build_schedule_plan(

if self.config.fine_grained_activation_offloading:
self.preprocess_for_fine_grained_offloading()
if self.config.moe_paged_stash:
self.preprocess_for_paged_stash()

from ..common.model_chunk_schedule_plan import TransformerModelChunkSchedulePlan

Expand Down
10 changes: 10 additions & 0 deletions megatron/core/pipeline_parallel/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
FineGrainedActivationOffloadingInterface as off_interface,
)
from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicator
from megatron.core.transformer.moe.paged_stash import paged_stash_reset
from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator
from megatron.core.pipeline_parallel.utils import (
is_pp_first_stage,
Expand Down Expand Up @@ -638,6 +639,9 @@ def forward_backward_no_pipelining(
if config.timers is not None:
config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)

if config.moe_paged_stash:
paged_stash_reset(enabled=config.moe_paged_stash and not forward_only, config=config)

no_sync_func = config.no_sync_func
if no_sync_func is None:
no_sync_func = contextlib.nullcontext
Expand Down Expand Up @@ -1082,6 +1086,9 @@ def forward_backward_pipelining_with_interleaving(
adjust_tensor_shapes_fn is None
), "adjust_tensor_shapes_fn is not supported for interleaved pipeline parallelism"

if config.moe_paged_stash:
paged_stash_reset(enabled=config.moe_paged_stash and not forward_only, config=config)

if config.overlap_p2p_comm and config.batch_p2p_comm:
raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm")

Expand Down Expand Up @@ -2284,6 +2291,9 @@ def forward_backward_pipelining_without_interleaving(
if config.timers is not None:
config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)

if config.moe_paged_stash:
paged_stash_reset(enabled=config.moe_paged_stash and not forward_only, config=config)

# Disable async grad reductions
no_sync_func = config.no_sync_func
if no_sync_func is None:
Expand Down
53 changes: 39 additions & 14 deletions megatron/core/transformer/moe/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections.abc import Callable
from copy import deepcopy
from dataclasses import dataclass
from contextlib import nullcontext
from functools import partial
from itertools import chain
from math import ceil
Expand Down Expand Up @@ -41,6 +42,12 @@
from megatron.core.transformer.moe.moe_utils import (
ProcessGroupCollection,
get_align_size_for_quantization,
skip_routed_expert_padding,
)
from megatron.core.transformer.moe.paged_stash import (
get_paged_stash_context,
paged_stash_group_commit,
paged_stash_group_start,
)
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import (
Expand All @@ -51,6 +58,7 @@

if HAVE_TE:
from megatron.core.extensions.transformer_engine import Fp8Padding, Fp8Unpadding
import transformer_engine as te
else:
Fp8Padding, Fp8Unpadding = None, None

Expand Down Expand Up @@ -915,8 +923,7 @@ def _fused_forward(

# Apply padding if needed
unpadded_tokens_per_expert = None
if self.config.moe_router_padding_for_quantization:
# Padding has already been applied in router
if skip_routed_expert_padding(self.config):
pass
elif self.config.fp8 or self.config.fp4:
tokens_per_expert = tokens_per_expert.tolist()
Expand All @@ -931,19 +938,38 @@ def _fused_forward(
tokens_per_expert = torch.tensor(
tokens_per_expert, dtype=torch.int, device=permuted_probs.device
)

# Call fused impl
output = ops(
permuted_local_hidden_states,
tokens_per_expert, # FC1
permuted_probs, # Scaled SwiGLU
tokens_per_expert, # FC2
)

# if the number of tokens is 0, pad the hidden states to 256

if self.config.moe_paged_stash:
permuted_local_hidden_states = paged_stash_group_start(permuted_local_hidden_states)
max_num_tokens = permuted_local_hidden_states.shape[0]
# Average/expected tokens is a pre-padding estimate used by paged stashing heuristics.
# moe_expert_rank_capacity_factor is required when moe_paged_stash is enabled.
cap_factor = self.config.moe_expert_rank_capacity_factor
avg_num_tokens = (
int(max_num_tokens // cap_factor) if cap_factor is not None and cap_factor > 0 else None
)
stash_context = get_paged_stash_context(
name="grouped_mlp",
max_num_tokens=max_num_tokens,
num_tokens_tensor=tokens_per_expert.sum(),
avg_num_tokens=avg_num_tokens,
)
else:
stash_context = nullcontext()
with stash_context:
# Call fused impl
output = ops(
permuted_local_hidden_states,
tokens_per_expert, # FC1
permuted_probs, # Scaled SwiGLU
tokens_per_expert, # FC2
)
# Remove padding if needed
if unpadded_tokens_per_expert is not None:
output = self.quantization_unpadding(output, unpadded_tokens_per_expert)

if self.config.moe_paged_stash:
output = paged_stash_group_commit(output, name="grouped_mlp")
return output

def bias_act_func(self, intermediate_parallel, bias_parallel, permuted_probs):
Expand Down Expand Up @@ -1033,8 +1059,7 @@ def forward(
unpadded_tokens_per_expert = None
tokens_per_expert: list[int] = tokens_per_expert.tolist()
permuted_probs = permuted_probs.unsqueeze(-1)
if self.config.moe_router_padding_for_quantization:
# Padding has already been applied in router
if skip_routed_expert_padding(self.config):
pass
elif self.config.fp8 or self.config.fp4:
unpadded_tokens_per_expert = tokens_per_expert
Expand Down
16 changes: 16 additions & 0 deletions megatron/core/transformer/moe/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,6 +1321,22 @@ def get_align_size_for_quantization(config: TransformerConfig) -> int:
return 16


def skip_routed_expert_padding(config: TransformerConfig) -> bool:
"""Whether the expert module should skip quantization padding.

Returns True when padding is already applied by the router or the
HybridEP dispatcher.
"""
if config.moe_router_padding_for_quantization:
return True
if (
config.moe_token_dispatcher_type == "flex"
and config.moe_flex_dispatcher_backend == "hybridep"
):
return True
return False


# TODO(Hepteract): delete the usage of the global parallel_state.
# Initialize process groups with the global parallel_state.
def get_default_pg_collection() -> ProcessGroupCollection:
Expand Down
Loading