Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/mcore_bridge/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def forward(
input_tensor = self.get_input_tensor()
input_tensor, mtp_decoder_input = input_tensor.chunk(2, dim=0)
self.set_input_tensor(input_tensor)
kwargs = {}
extra_block_kwargs = extra_block_kwargs or {}
full_attention_mask = attention_mask
if isinstance(full_attention_mask, dict):
full_attention_mask = full_attention_mask['full_attention']
Expand All @@ -296,7 +296,7 @@ def forward(
if self.config.sequence_parallel and tp_size > 1:
assert padding_mask.shape[1] % tp_size == 0, f'padding_mask.shape: {padding_mask.shape}'
padding_mask = torch.chunk(padding_mask, tp_size, dim=1)[mpu.get_tensor_model_parallel_rank()]
kwargs['padding_mask'] = padding_mask.contiguous()
extra_block_kwargs['padding_mask'] = padding_mask.contiguous()
# Run decoder.
hidden_states = self.decoder(
hidden_states=decoder_input,
Expand All @@ -307,8 +307,7 @@ def forward(
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
**(extra_block_kwargs or {}),
**kwargs,
**extra_block_kwargs,
)

# MTP: https://github.com/NVIDIA/Megatron-LM/issues/1661
Expand Down
24 changes: 13 additions & 11 deletions src/mcore_bridge/model/mm_gpts/gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,15 @@ def __init__(
if self.use_alternative_attention else text_config.num_key_value_heads)
# Shared KV across the trailing layers
self.num_kv_shared_layers = getattr(text_config, 'num_kv_shared_layers', 0)
first_kv_shared_layer_idx = config.num_layers - self.num_kv_shared_layers
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
prev_layers = text_config.layer_types[:first_kv_shared_layer_idx]
self.store_full_length_kv = not self.is_kv_shared_layer and layer_idx == len(
prev_layers) - 1 - prev_layers[::-1].index(text_config.layer_types[layer_idx])

if self.num_kv_shared_layers:
first_kv_shared_layer_idx = config.num_layers - self.num_kv_shared_layers
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
prev_layers = text_config.layer_types[:first_kv_shared_layer_idx]
self.store_full_length_kv = not self.is_kv_shared_layer and layer_idx == len(
prev_layers) - 1 - prev_layers[::-1].index(text_config.layer_types[layer_idx])
else:
self.is_kv_shared_layer = False
self.store_full_length_kv = False
orig_kv_channels = config.kv_channels
orig_num_query_groups = config.num_query_groups
orig_k_layernorm = submodules.k_layernorm
Expand Down Expand Up @@ -698,6 +701,7 @@ class Gemma4TransformerLayer(TransformerLayer):
def __init__(self, config, submodules, *args, **kwargs):
super().__init__(config, submodules, *args, **kwargs)
text_config = config.hf_config.text_config
self.layer_type = text_config.layer_types[self.layer_number - 1]
self.enable_moe_block = text_config.enable_moe_block
if self.enable_moe_block:
self.experts_mlp = self._build_mlp(submodules.experts_mlp)
Expand Down Expand Up @@ -748,6 +752,8 @@ def __init__(self, config, submodules, *args, **kwargs):
TENorm, hidden_size=hidden_size, config=self.config, eps=eps)

def _forward_attention(self, hidden_states: Tensor, **kwargs):
kwargs['rotary_pos_emb'] = kwargs['rotary_pos_emb'][self.layer_type]
kwargs['attention_mask'] = kwargs['attention_mask'][self.layer_type]
context = kwargs.pop('context', None)
residual = hidden_states
input_layernorm_output = self.input_layernorm(hidden_states)
Expand Down Expand Up @@ -805,13 +811,9 @@ class Gemma4GPTModel(MultimodalGPTModel):
class Gemma4TransformerBlock(TransformerBlock):

def _layer_forward(self, layer, hidden_states, **kwargs):
layer_number = layer.layer_number - 1
per_layer_inputs = kwargs.pop('per_layer_inputs', None)
if per_layer_inputs is not None:
kwargs['per_layer_input'] = per_layer_inputs[:, :, layer_number]
layer_type = self.config.hf_config.text_config.layer_types[layer_number]
kwargs['rotary_pos_emb'] = kwargs['rotary_pos_emb'][layer_type]
kwargs['attention_mask'] = kwargs['attention_mask'][layer_type]
kwargs['per_layer_input'] = per_layer_inputs[:, :, layer.layer_number - 1]
return super()._layer_forward(layer, hidden_states, **kwargs)


Expand Down
262 changes: 253 additions & 9 deletions src/mcore_bridge/model/modules/mtp_layer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import torch
import transformer_engine
import warnings
from contextlib import nullcontext
from functools import partial
from megatron.core import InferenceParams
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core.enums import Fp8Recipe
from megatron.core.fp8_utils import get_fp8_context
Comment on lines +7 to +9
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To fully support FP4 quantization recomputation as referenced in the _checkpointed_forward logic, you should also import get_fp4_context.

Suggested change
from megatron.core.enums import Fp8Recipe
from megatron.core.fp8_utils import get_fp8_context
from megatron.core.enums import Fp8Recipe
from megatron.core.fp4_utils import get_fp4_context
from megatron.core.fp8_utils import get_fp8_context

from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.tensor_parallel.mappings import (gather_from_sequence_parallel_region,
gather_from_tensor_model_parallel_region,
Expand All @@ -16,6 +18,8 @@

from mcore_bridge.utils import roll_tensor

from .transformer_block import _checkpoint_flatten, _checkpoint_unflatten, _TensorIdx

try:
from megatron.core.typed_torch import apply_module
except ImportError:
Expand Down Expand Up @@ -72,6 +76,7 @@ def forward(
embedding=None,
decoder_input=None,
layer_number: Optional[int] = None,
**kwargs,
):
assert context is None, 'multi token prediction + cross attention is not yet supported.'
if layer_number is None:
Expand All @@ -88,17 +93,20 @@ def forward(
packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd'
if self.config.position_embedding_type == 'rope' and packed_seq:
assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}'
rotary_pos_emb = rotary_pos_emb[position_ids[0]]
if isinstance(rotary_pos_emb, dict):
for k, v in rotary_pos_emb.items():
rotary_pos_emb[k] = v[position_ids[0]]
else:
rotary_pos_emb = rotary_pos_emb[position_ids[0]]
else:
# mrope or not packed_seq
rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-layer_number, dims=0)
if isinstance(rotary_pos_emb, dict):
for k, v in rotary_pos_emb.items():
rotary_pos_emb[k] = torch.roll(v, shifts=-layer_number, dims=0)
else:
rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-layer_number, dims=0)
if self.config.recompute_granularity == 'full' and self.training:
hidden_states = self._checkpointed_forward(
partial(
self._proj_and_transformer_layer,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
),
hidden_states=hidden_states,
decoder_input=decoder_input,
attention_mask=attention_mask,
Expand All @@ -109,6 +117,9 @@ def forward(
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
**kwargs,
)
else:
hidden_states = self._proj_and_transformer_layer(
Expand All @@ -124,9 +135,242 @@ def forward(
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
**kwargs,
)
return hidden_states, input_ids, position_ids, decoder_input

# Code borrowed from NVIDIA/Megatron-LM
def _checkpointed_forward(
self,
hidden_states: torch.Tensor,
decoder_input: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
context_mask: Optional[torch.Tensor] = None,
rotary_pos_emb: Optional[torch.Tensor] = None,
rotary_pos_cos: Optional[torch.Tensor] = None,
rotary_pos_sin: Optional[torch.Tensor] = None,
attention_bias: Optional[torch.Tensor] = None,
inference_params: Optional[InferenceParams] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
sequence_len_offset: Optional[torch.Tensor] = None,
**kwargs,
):
"""Forward through ``_proj_and_transformer_layer`` with activation
recomputation.

Mirrors ``transformer_block._checkpointed_forward``:

* Non-tensor objects (``attention_bias``, ``inference_params``,
``packed_seq_params``) are captured by the ``custom_forward``
closure; only tensor / ``None`` arguments flow positionally
through the underlying checkpoint primitive. This is required
by both backends: ``tensor_parallel.checkpoint`` because its
``save_for_backward`` only accepts tensors and ``None``, and
``te_checkpoint`` because its reentrant implementation only
tracks positional tensor inputs as checkpoint inputs (kwarg
tensors are not represented in the recompute backward path).
Comment on lines +165 to +173
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring incorrectly identifies attention_bias as a non-tensor object. Furthermore, to improve memory efficiency, tensors that do not require gradients (like attention_mask, rotary_pos_emb, and sequence_len_offset) should be captured in the closure rather than passed positionally to the checkpoint function, which avoids unnecessary stashing. This aligns with the implementation in transformer_block.py.

Suggested change
* Non-tensor objects (``attention_bias``, ``inference_params``,
``packed_seq_params``) are captured by the ``custom_forward``
closure; only tensor / ``None`` arguments flow positionally
through the underlying checkpoint primitive. This is required
by both backends: ``tensor_parallel.checkpoint`` because its
``save_for_backward`` only accepts tensors and ``None``, and
``te_checkpoint`` because its reentrant implementation only
tracks positional tensor inputs as checkpoint inputs (kwarg
tensors are not represented in the recompute backward path).
* Non-gradient tensors and objects (``attention_mask``, ``rotary_pos_emb``,
``attention_bias``, ``inference_params``, ``packed_seq_params``,
``sequence_len_offset``) are captured by the ``custom_forward``
closure; only tensors requiring gradients (``hidden_states``,
``decoder_input``) flow positionally through the underlying
checkpoint primitive. This is required by both backends:
``tensor_parallel.checkpoint`` because its ``save_for_backward``
only accepts tensors and ``None``, and ``te_checkpoint`` because
its reentrant implementation only tracks positional tensor inputs
as checkpoint inputs.

* Quantized recipes (fp8, fp4) route through ``te_checkpoint``;
everything else uses ``tensor_parallel.checkpoint``.
* Only ``fp8 + delayed scaling`` needs an outer quantization
context entered before ``te_checkpoint``; see the
``outer_quantization_context`` block below.
"""

# Variables that don't require gradients can be captured via closure.
_ckpt_attention_mask = attention_mask
_ckpt_rotary_pos_emb = rotary_pos_emb
extra_kwargs_keys = tuple(kwargs.keys())
_extra_flat_tensors = []
_extra_schemas = [_checkpoint_flatten(v, _extra_flat_tensors) for v in kwargs.values()]

def custom_forward(hidden_states, decoder_input, context, context_mask, rotary_pos_cos, rotary_pos_sin,
sequence_len_offset, *extra_flat):
rebuilt = [_checkpoint_unflatten(s, extra_flat) for s in _extra_schemas]
extra_kwargs = dict(zip(extra_kwargs_keys, rebuilt))
return self._proj_and_transformer_layer(
hidden_states=hidden_states,
decoder_input=decoder_input,
attention_mask=_ckpt_attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=_ckpt_rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
**extra_kwargs,
)

# Decide the outer quantization context, matching
# ``transformer_block._checkpointed_forward``. Only ``fp8 + delayed
# scaling`` needs an active context at the ``te_checkpoint`` entry
# point: TE's ``_CheckpointFunction.forward`` samples
# ``FP8GlobalStateManager.is_fp8_enabled()`` there to gate the
# phase-1 amax-buffer stash that phase-2 backward looks up via
# ``global_fp8_buffer_pos_fwd_recompute``. With fp8 only entered
# *inside* ``_proj_and_transformer_layer``, TE samples fp8 as off,
# phase-1 skips the stash, and phase-2 raises ``KeyError``.
# Non-delayed fp8 recipes (MXFP8BlockScaling, Float8CurrentScaling)
# and fp4 (NVFP4BlockScaling) treat the stash/lookup as a noop, so
# the inner context entered inside ``_proj_and_transformer_layer``
# is sufficient.
if self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed:
outer_quantization_context = get_fp8_context(self.config)
else:
outer_quantization_context = nullcontext()

def checkpoint_handler():
"""Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`"""
# fp4 quantization is internally implemented via TE's
# ``fp8_autocast`` (see ``fp4_utils.get_fp4_context``), so
# quantized recompute on either fp8 or fp4 must go through
# ``te_checkpoint``. Matches ``transformer_block``'s policy.
if self.config.fp8 or self.config.fp4:
from megatron.core.extensions.transformer_engine import te_checkpoint

return te_checkpoint(
custom_forward,
self.config.distribute_saved_activations,
tensor_parallel.random.get_cuda_rng_tracker,
parallel_state.get_tensor_model_parallel_group(),
hidden_states,
decoder_input,
context,
context_mask,
rotary_pos_cos,
rotary_pos_sin,
sequence_len_offset,
*_extra_flat_tensors,
)
else:
# tensor_parallel.checkpoint stashes args via autograd's
# ``save_for_backward``, which only accepts tensors and ``None``.
# Pass tensor / ``None`` args positionally and capture the
# non-tensor objects (``attention_bias``, ``inference_params``,
# ``packed_seq_params``) via the ``custom_forward`` closure.
return tensor_parallel.checkpoint(
custom_forward,
self.config.distribute_saved_activations,
hidden_states,
decoder_input,
context,
context_mask,
rotary_pos_cos,
rotary_pos_sin,
sequence_len_offset,
*_extra_flat_tensors,
)
Comment on lines +234 to +265
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update the checkpoint calls to match the simplified custom_forward signature. Only hidden_states and decoder_input (and potentially context/context_mask if they were supported and required gradients) should be passed positionally.

                return te_checkpoint(
                    custom_forward,
                    self.config.distribute_saved_activations,
                    tensor_parallel.random.get_cuda_rng_tracker,
                    parallel_state.get_tensor_model_parallel_group(),
                    hidden_states,
                    decoder_input,
                )
            else:
                return tensor_parallel.checkpoint(
                    custom_forward,
                    self.config.distribute_saved_activations,
                    hidden_states,
                    decoder_input,
                )


if self.config.recompute_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
assert (self.config.recompute_num_layers == 1), 'recompute_num_layers must be 1 for MTP recompute'
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This assertion might be too restrictive if recompute_num_layers is set to a value greater than 1 globally. It is safer to check if it is at least 1, as MTP only contains a single layer to checkpoint.

Suggested change
assert (self.config.recompute_num_layers == 1), 'recompute_num_layers must be 1 for MTP recompute'
assert (self.config.recompute_num_layers >= 1), 'recompute_num_layers must be at least 1 for MTP recompute'

with outer_quantization_context:
outputs = checkpoint_handler()
elif self.config.recompute_method == 'block':
# TODO: implement block-based recompute for MTP
warnings.warn("recompute_method == 'block' is not supported for MTP yet."
' Skipping recompute.')
outputs = self._proj_and_transformer_layer(
hidden_states=hidden_states,
decoder_input=decoder_input,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
**kwargs,
)
Comment on lines +274 to +292
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The 'block' recompute method can be supported for MTP by simply checking if recompute_num_layers >= 1. Since MTP wraps a single transformer layer, this is equivalent to checkpointing the entire module.

Suggested change
elif self.config.recompute_method == 'block':
# TODO: implement block-based recompute for MTP
warnings.warn("recompute_method == 'block' is not supported for MTP yet."
' Skipping recompute.')
outputs = self._proj_and_transformer_layer(
hidden_states=hidden_states,
decoder_input=decoder_input,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
elif self.config.recompute_method == 'block':
if self.config.recompute_num_layers >= 1:
with outer_quantization_context:
outputs = checkpoint_handler()
else:
outputs = self._proj_and_transformer_layer(
hidden_states=hidden_states,
decoder_input=decoder_input,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)

else:
raise ValueError('Invalid activation recompute method.')

return outputs

def _proj_and_transformer_layer(
self,
hidden_states: torch.Tensor,
decoder_input: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
context_mask: Optional[torch.Tensor] = None,
rotary_pos_emb: Optional[torch.Tensor] = None,
rotary_pos_cos: Optional[torch.Tensor] = None,
rotary_pos_sin: Optional[torch.Tensor] = None,
attention_bias: Optional[torch.Tensor] = None,
inference_params: Optional[InferenceParams] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
sequence_len_offset: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Concatenates embeddings with hidden states and then applies transformer layer forward.
"""
padding_mask = kwargs.pop('padding_mask', None)
if padding_mask is not None:
kwargs['padding_mask'] = padding_mask
if self.config.sequence_parallel:
rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
else:
rng_context = nullcontext()

# Unlike transformer_block.py which needs to support mixed-precision in
# different layers,currently MTP only use global fp8 context.
if self.config.fp8:
fp8_context = get_fp8_context(self.config)
transformer_layer_fp8_context = get_fp8_context(self.config)
else:
fp8_context = nullcontext()
transformer_layer_fp8_context = nullcontext()

# TODO: currently ignoring FP4 in MTP layers because we need more numerical validation
with rng_context:
with fp8_context:
hidden_states = self._concat_embeddings(hidden_states, decoder_input)

# Use a separate fp8 context for the transformer layer. This is to ensure that when the
# transformer layer is cudagraphed, the FP8GlobalStateManager.is_first_fp8_module() is
# True so that the fp8 weight caching can be triggered correctly.
with transformer_layer_fp8_context:
if getattr(self, 'mtp_layer_pattern', None) is not None:
hidden_states = self.transformer_layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
rotary_pos_emb=rotary_pos_emb,
inference_context=inference_params,
packed_seq_params=packed_seq_params,
**kwargs,
)
else:
# GPT path: single TransformerLayer
hidden_states, _ = self.transformer_layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
**kwargs,
)

if not getattr(self, 'mhc_enabled', False):
hidden_states = self._postprocess(hidden_states)

return hidden_states

def _concat_embeddings(self, hidden_states: torch.Tensor, decoder_input: torch.Tensor):
"""
Concatenate the tokens before sending to transformer layer.
Expand Down
Loading