Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 161 additions & 7 deletions src/mcore_bridge/model/modules/mtp_layer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import torch
import transformer_engine
import warnings
from contextlib import nullcontext
from functools import partial
from megatron.core import InferenceParams
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core.enums import Fp8Recipe
from megatron.core.fp8_utils import get_fp8_context
Comment on lines +7 to +9
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To fully support FP4 quantization recomputation as referenced in the _checkpointed_forward logic, you should also import get_fp4_context.

Suggested change
from megatron.core.enums import Fp8Recipe
from megatron.core.fp8_utils import get_fp8_context
from megatron.core.enums import Fp8Recipe
from megatron.core.fp4_utils import get_fp4_context
from megatron.core.fp8_utils import get_fp8_context

from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.tensor_parallel.mappings import (gather_from_sequence_parallel_region,
gather_from_tensor_model_parallel_region,
Expand Down Expand Up @@ -94,11 +96,6 @@ def forward(
rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-layer_number, dims=0)
if self.config.recompute_granularity == 'full' and self.training:
hidden_states = self._checkpointed_forward(
partial(
self._proj_and_transformer_layer,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
),
hidden_states=hidden_states,
decoder_input=decoder_input,
attention_mask=attention_mask,
Expand All @@ -109,6 +106,8 @@ def forward(
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
else:
hidden_states = self._proj_and_transformer_layer(
Expand All @@ -127,6 +126,161 @@ def forward(
)
return hidden_states, input_ids, position_ids, decoder_input

# Code borrowed from NVIDIA/Megatron-LM
def _checkpointed_forward(
self,
hidden_states: torch.Tensor,
decoder_input: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
context_mask: Optional[torch.Tensor] = None,
rotary_pos_emb: Optional[torch.Tensor] = None,
rotary_pos_cos: Optional[torch.Tensor] = None,
rotary_pos_sin: Optional[torch.Tensor] = None,
attention_bias: Optional[torch.Tensor] = None,
inference_params: Optional[InferenceParams] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
sequence_len_offset: Optional[torch.Tensor] = None,
):
"""Forward through ``_proj_and_transformer_layer`` with activation
recomputation.

Mirrors ``transformer_block._checkpointed_forward``:

* Non-tensor objects (``attention_bias``, ``inference_params``,
``packed_seq_params``) are captured by the ``custom_forward``
closure; only tensor / ``None`` arguments flow positionally
through the underlying checkpoint primitive. This is required
by both backends: ``tensor_parallel.checkpoint`` because its
``save_for_backward`` only accepts tensors and ``None``, and
``te_checkpoint`` because its reentrant implementation only
tracks positional tensor inputs as checkpoint inputs (kwarg
tensors are not represented in the recompute backward path).
Comment on lines +165 to +173
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring incorrectly identifies attention_bias as a non-tensor object. Furthermore, to improve memory efficiency, tensors that do not require gradients (like attention_mask, rotary_pos_emb, and sequence_len_offset) should be captured in the closure rather than passed positionally to the checkpoint function, which avoids unnecessary stashing. This aligns with the implementation in transformer_block.py.

Suggested change
* Non-tensor objects (``attention_bias``, ``inference_params``,
``packed_seq_params``) are captured by the ``custom_forward``
closure; only tensor / ``None`` arguments flow positionally
through the underlying checkpoint primitive. This is required
by both backends: ``tensor_parallel.checkpoint`` because its
``save_for_backward`` only accepts tensors and ``None``, and
``te_checkpoint`` because its reentrant implementation only
tracks positional tensor inputs as checkpoint inputs (kwarg
tensors are not represented in the recompute backward path).
* Non-gradient tensors and objects (``attention_mask``, ``rotary_pos_emb``,
``attention_bias``, ``inference_params``, ``packed_seq_params``,
``sequence_len_offset``) are captured by the ``custom_forward``
closure; only tensors requiring gradients (``hidden_states``,
``decoder_input``) flow positionally through the underlying
checkpoint primitive. This is required by both backends:
``tensor_parallel.checkpoint`` because its ``save_for_backward``
only accepts tensors and ``None``, and ``te_checkpoint`` because
its reentrant implementation only tracks positional tensor inputs
as checkpoint inputs.

* Quantized recipes (fp8, fp4) route through ``te_checkpoint``;
everything else uses ``tensor_parallel.checkpoint``.
* Only ``fp8 + delayed scaling`` needs an outer quantization
context entered before ``te_checkpoint``; see the
``outer_quantization_context`` block below.
"""

def custom_forward(
hidden_states,
decoder_input,
attention_mask,
context,
context_mask,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
sequence_len_offset,
):
return self._proj_and_transformer_layer(
hidden_states=hidden_states,
decoder_input=decoder_input,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The custom_forward implementation is missing the inner quantization context logic for FP8 and FP4, which is necessary for correct activation recomputation. Additionally, the signature should be simplified to only take tensors that require gradients (like hidden_states and decoder_input), capturing the rest in the closure for better memory efficiency.

        def custom_forward(hidden_states, decoder_input):
            # Get appropriate inner quantization context
            if self.config.fp8:
                inner_quantization_context = get_fp8_context(self.config, self.layer_number - 1)
            elif self.config.fp4:
                inner_quantization_context = get_fp4_context(self.config, self.layer_number - 1)
            else:
                inner_quantization_context = nullcontext()

            with inner_quantization_context:
                return self._proj_and_transformer_layer(
                    hidden_states=hidden_states,
                    decoder_input=decoder_input,
                    attention_mask=attention_mask,
                    context=context,
                    context_mask=context_mask,
                    rotary_pos_emb=rotary_pos_emb,
                    rotary_pos_cos=rotary_pos_cos,
                    rotary_pos_sin=rotary_pos_sin,
                    attention_bias=attention_bias,
                    inference_params=inference_params,
                    packed_seq_params=packed_seq_params,
                    sequence_len_offset=sequence_len_offset,
                )


# Decide the outer quantization context, matching
# ``transformer_block._checkpointed_forward``. Only ``fp8 + delayed
# scaling`` needs an active context at the ``te_checkpoint`` entry
# point: TE's ``_CheckpointFunction.forward`` samples
# ``FP8GlobalStateManager.is_fp8_enabled()`` there to gate the
# phase-1 amax-buffer stash that phase-2 backward looks up via
# ``global_fp8_buffer_pos_fwd_recompute``. With fp8 only entered
# *inside* ``_proj_and_transformer_layer``, TE samples fp8 as off,
# phase-1 skips the stash, and phase-2 raises ``KeyError``.
# Non-delayed fp8 recipes (MXFP8BlockScaling, Float8CurrentScaling)
# and fp4 (NVFP4BlockScaling) treat the stash/lookup as a noop, so
# the inner context entered inside ``_proj_and_transformer_layer``
# is sufficient.
if self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed:
outer_quantization_context = get_fp8_context(self.config)
else:
outer_quantization_context = nullcontext()

def checkpoint_handler():
"""Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`"""
# fp4 quantization is internally implemented via TE's
# ``fp8_autocast`` (see ``fp4_utils.get_fp4_context``), so
# quantized recompute on either fp8 or fp4 must go through
# ``te_checkpoint``. Matches ``transformer_block``'s policy.
if self.config.fp8 or self.config.fp4:
from megatron.core.extensions.transformer_engine import te_checkpoint

return te_checkpoint(
custom_forward,
self.config.distribute_saved_activations,
tensor_parallel.random.get_cuda_rng_tracker,
parallel_state.get_tensor_model_parallel_group(),
hidden_states,
decoder_input,
attention_mask,
context,
context_mask,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
sequence_len_offset,
)
else:
# tensor_parallel.checkpoint stashes args via autograd's
# ``save_for_backward``, which only accepts tensors and ``None``.
# Pass tensor / ``None`` args positionally and capture the
# non-tensor objects (``attention_bias``, ``inference_params``,
# ``packed_seq_params``) via the ``custom_forward`` closure.
return tensor_parallel.checkpoint(
custom_forward,
self.config.distribute_saved_activations,
hidden_states,
decoder_input,
attention_mask,
context,
context_mask,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
sequence_len_offset,
)
Comment on lines +234 to +265
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update the checkpoint calls to match the simplified custom_forward signature. Only hidden_states and decoder_input (and potentially context/context_mask if they were supported and required gradients) should be passed positionally.

                return te_checkpoint(
                    custom_forward,
                    self.config.distribute_saved_activations,
                    tensor_parallel.random.get_cuda_rng_tracker,
                    parallel_state.get_tensor_model_parallel_group(),
                    hidden_states,
                    decoder_input,
                )
            else:
                return tensor_parallel.checkpoint(
                    custom_forward,
                    self.config.distribute_saved_activations,
                    hidden_states,
                    decoder_input,
                )


if self.config.recompute_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
assert (self.config.recompute_num_layers == 1), 'recompute_num_layers must be 1 for MTP recompute'
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This assertion might be too restrictive if recompute_num_layers is set to a value greater than 1 globally. It is safer to check if it is at least 1, as MTP only contains a single layer to checkpoint.

Suggested change
assert (self.config.recompute_num_layers == 1), 'recompute_num_layers must be 1 for MTP recompute'
assert (self.config.recompute_num_layers >= 1), 'recompute_num_layers must be at least 1 for MTP recompute'

with outer_quantization_context:
outputs = checkpoint_handler()
elif self.config.recompute_method == 'block':
# TODO: implement block-based recompute for MTP
warnings.warn("recompute_method == 'block' is not supported for MTP yet."
' Skipping recompute.')
outputs = self._proj_and_transformer_layer(
hidden_states=hidden_states,
decoder_input=decoder_input,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
Comment on lines +274 to +292
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The 'block' recompute method can be supported for MTP by simply checking if recompute_num_layers >= 1. Since MTP wraps a single transformer layer, this is equivalent to checkpointing the entire module.

Suggested change
elif self.config.recompute_method == 'block':
# TODO: implement block-based recompute for MTP
warnings.warn("recompute_method == 'block' is not supported for MTP yet."
' Skipping recompute.')
outputs = self._proj_and_transformer_layer(
hidden_states=hidden_states,
decoder_input=decoder_input,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
elif self.config.recompute_method == 'block':
if self.config.recompute_num_layers >= 1:
with outer_quantization_context:
outputs = checkpoint_handler()
else:
outputs = self._proj_and_transformer_layer(
hidden_states=hidden_states,
decoder_input=decoder_input,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)

else:
raise ValueError('Invalid activation recompute method.')

return outputs

def _concat_embeddings(self, hidden_states: torch.Tensor, decoder_input: torch.Tensor):
"""
Concatenate the tokens before sending to transformer layer.
Expand Down
Loading