Skip to content

Paged Stashing#2690

Draft
nanz-nv wants to merge 68 commits intoNVIDIA:devfrom
vasunvidia:paged_offloading
Draft

Paged Stashing#2690
nanz-nv wants to merge 68 commits intoNVIDIA:devfrom
vasunvidia:paged_offloading

Conversation

@nanz-nv
Copy link
Copy Markdown
Contributor

@nanz-nv nanz-nv commented Dec 17, 2025

Main contributors (Equal Contribution, sorted alphabetically): Nan Zheng (@nanz-nv), Vasudevan Rengasamy (@vasunvidia)
Other contributors (sorted alphabetically): Dennis Liu(@Victarry), Hongbin Liu(@lhb8125), Qi Zhang(@QiZhangNV), Robin Zhang(@buptzyb), Tong Liu(@Autumn1998), Zijie Yan(@yanring)

Background

In token-dropless MoE training, the number of tokens received by each expert might vary, resulting in dynamic shaped tensors. Dynamic shaped tensors are naturally supported by PyTorch, thanks to its eager mode nature. This is done by creating a tensor lazily when the shape of the tensor is known at run-time. Albeit working well in eager mode, dynamic shaped tensor poses challenges for CUDA graphs because the the size of a tensor cannot be dynamically adjusted at runtime without the intervene of the host. In order to remove the sync and enable CUDA graph, one solution is to oversize the buffer in the expert part. This however causes significantly higher memory consumption compared to the eager-mode baseline through the form of memory fragmentation.

image

Idea overview

To address this problem, paged stashing decouples the need of oversized buffers for compute and the need of a properly sized buffer for storing activations for the backward pass. Paged stashing achieves this through adding one level of indirection: stashing and restoring. The stash operation copies the activation from the oversized static buffer to a pre-allocated stashing buffer after the forward for that module is done, and the restore operation does the reverse operation during the backward pass.

image

The key of saving memory lies in the fact that the stash operation packs the variable-size activation into a contiguous stashing buffer to reduce memory fragmentation. For simple scheduling where the activation allocation and deallocation follows a first-in-last-out pattern, stash and restore can be done easily in a bump-allocation manner. To accommodate complicated scheduling schedules, e.g. pipeline parallel, paging can be used, hence the name paged stashing.

page management

To accomodate complex scheduling such as that needed in pipeline parallelism, activations are partitioned into pages and a light-weight memory management kernel is in charge of allocate and deallocate pages for stashing. Pages are managed by lightweight GPU memory management kernels that can be fused with the stash/restore GPU kernels. It maintains a freelist which is implemented as a circular buffer. Each freelist keeps track of one type of pages.

CPU offloading

Paged stashing naturally supports offloading. When the stashing buffer is a pinned CPU tensor, the activation is offloaded to the host memory during forward and is reloaded to the GPU during backward.
Furthermore, one can easily extend the paging management system to accommodate partial offloading or on-demand offloading. This feature is currently WIP.

scheduling

Overlapping stashing and restore operations with compute can be implemented by inserting two autograd functions before and after the expert compute layer: pre-scheduler and post-scheduler that schedules stash and restore operations. The roles of these autograd functions are enumerated below:

  • Pre-scheduler forward: Wait for previous stash op. to complete, free the max-capacity sized temporary activations for the completed stash op. The wait is performed here instead of Post-scheduler forward to reduce the peak memory usage since the following expert compute layer will allocate another set of max-capacity sized temporary activations.
  • Post-scheduler forward: Since this is after experts compute, stashing operations for the current layer activations are scheduled here. If the next layer in the execution is a backward pass layer, schedule restore operations for the next layer.
    Additionally, in case of pipeline parallelism, this can be used to record the pipeline schedule during the first iteration.
  • Post-scheduler backward: Wait for previous stash op. to complete, free the max-capacity sized temporary activations for the completed stash op. The wait is performed here instead of Pre-scheduler backward to reduce the peak memory usage since the following expert compute BPROP layer will allocate another set of max-capacity sized temporary activations.
    Wait for restore operation for the current layer to complete. Additionally, in case of pipeline parallelism, this can be used to record the pipeline schedule during the first iteration.
  • Pre-scheduler backward: If the next layer in the execution is a backward pass layer, schedule restore operations for the next layer.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Dec 17, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@Victarry
Copy link
Copy Markdown
Contributor

/ok to test 3e8c042

@github-actions
Copy link
Copy Markdown
Contributor

Thank you for your contribution!

NVIDIA Megatron-LM is currently transitioning to development on Github. We will aim to review your PR after we complete our transition and stabilize our Github development process.

Thank you for your understanding.

@nanz-nv nanz-nv force-pushed the paged_offloading branch 3 times, most recently from 3cd7a47 to b5b19b0 Compare March 23, 2026 05:46
@yanring yanring requested a review from buptzyb March 24, 2026 07:20
if is_te_min_version("2.10.0"):
assert os.getenv("NVTE_CPU_OFFLOAD_V1", "0") == "1", \
"For fine-grained activation offloading with TE >= 2.10.0, NVTE_CPU_OFFLOAD_V1 should be set to 1 to avoid offloading weights."
assert not args.moe_paged_stash, "Fine-grained activation offloading and paged stash cannot be enabled at the same time"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this assertion added?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was there due to historical reasons. Just removed it. Thanks for catching that.


def paged_stash_group_commit(tensor, name=None):
"""Mark the end of a layer group and prepare for stash/reload."""
rank = torch.distributed.get_rank()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Unused Variable] rank = torch.distributed.get_rank() is computed but never used.

This is called on every expert layer forward pass. While the overhead is small, it's unnecessary dead code.

Suggestion: Remove this line.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

Comment on lines +720 to +723
count = 0
for item in self.paged_tensors_to_reload:
if len(self.paged_tensors_to_reload[item]) > 0:
count += 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Unused Variable] count is computed by iterating over all reload queues but is never read or used.

Suggestion: Remove these 4 lines (720-723).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed


def get_schedule_layer(self, vp_stage, layer_no, microbatch_no):
"""Get the schedule layer."""
return vp_stage * 1000000 + layer_no * 1000 + microbatch_no
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Add an assertion to validate the ranges:

assert layer_no < 1000 and microbatch_no < 1000, "Schedule encoding overflow"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

mhc_is_last_in_recompute_block[l_no]
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] This line changed from an empty line to a line with trailing whitespace — unintentional diff noise. Consider reverting to keep the diff clean.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

Comment on lines +982 to +999
def __enter__(self):
from megatron.core.extensions.transformer_engine import cpu_offload

if cpu_offload is not None:
cpu_offload.CPUOffloadEnabled = True
# Call the underlying context manager's __enter__
result = self.saved_tensors_context.__enter__()

# Add more custom logic after entering if needed
return result

def __exit__(self, *args: Any):
# Call the underlying context manager's __exit__
result = self.saved_tensors_context.__exit__(*args)
from megatron.core.extensions.transformer_engine import cpu_offload

if cpu_offload is not None:
cpu_offload.CPUOffloadEnabled = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] PagedStashContext unconditionally sets CPUOffloadEnabled = True on enter and False on exit without saving/restoring the original value.

Given that transformer_config already asserts moe_paged_stash cannot coexist with cpu_offloading, the intent of toggling CPUOffloadEnabled here is unclear. If this is needed for TE internal behavior, a comment explaining why would help. If not, consider removing it to avoid accidentally enabling CPU offload in unexpected contexts.

If it must stay, save and restore the original value:

def __enter__(self):
    if cpu_offload is not None:
        self._prev_offload = cpu_offload.CPUOffloadEnabled
        cpu_offload.CPUOffloadEnabled = True
    ...
def __exit__(self, *args):
    if cpu_offload is not None:
        cpu_offload.CPUOffloadEnabled = self._prev_offload

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

int(max_num_tokens // cap_factor) if cap_factor is not None and cap_factor > 0 else None
)
stash_context = get_paged_stash_context(
name="expert_fc1_fused",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the activation save for fc2 is also included in the stash context?
Such that we could change the context name expert_fc1_fused to grouped_mlp

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. Renamed

Comment on lines +663 to +664
while len(self.paged_tensors_to_stash) > 0:
paged_tensor = self.paged_tensors_to_stash.pop(0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
while len(self.paged_tensors_to_stash) > 0:
paged_tensor = self.paged_tensors_to_stash.pop(0)
self.paged_tensors_to_stash.clear()

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@buptzyb
Copy link
Copy Markdown
Contributor

buptzyb commented Apr 1, 2026

Does paged stashing support partial cudagraph? We have many cases where attention has dynamic shapes (varlen, dynamic cp, KDA, ...), so only the moe part is capturable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.