diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 1a87e50..35d89f8 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -283,7 +283,7 @@ def forward( input_tensor = self.get_input_tensor() input_tensor, mtp_decoder_input = input_tensor.chunk(2, dim=0) self.set_input_tensor(input_tensor) - kwargs = {} + extra_block_kwargs = extra_block_kwargs or {} full_attention_mask = attention_mask if isinstance(full_attention_mask, dict): full_attention_mask = full_attention_mask['full_attention'] @@ -296,7 +296,7 @@ def forward( if self.config.sequence_parallel and tp_size > 1: assert padding_mask.shape[1] % tp_size == 0, f'padding_mask.shape: {padding_mask.shape}' padding_mask = torch.chunk(padding_mask, tp_size, dim=1)[mpu.get_tensor_model_parallel_rank()] - kwargs['padding_mask'] = padding_mask.contiguous() + extra_block_kwargs['padding_mask'] = padding_mask.contiguous() # Run decoder. hidden_states = self.decoder( hidden_states=decoder_input, @@ -307,8 +307,7 @@ def forward( rotary_pos_sin=rotary_pos_sin, packed_seq_params=packed_seq_params, sequence_len_offset=sequence_len_offset, - **(extra_block_kwargs or {}), - **kwargs, + **extra_block_kwargs, ) # MTP: https://github.com/NVIDIA/Megatron-LM/issues/1661 diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 4a8f729..23caa9b 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -152,12 +152,15 @@ def __init__( if self.use_alternative_attention else text_config.num_key_value_heads) # Shared KV across the trailing layers self.num_kv_shared_layers = getattr(text_config, 'num_kv_shared_layers', 0) - first_kv_shared_layer_idx = config.num_layers - self.num_kv_shared_layers - self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 - prev_layers = text_config.layer_types[:first_kv_shared_layer_idx] - self.store_full_length_kv = not self.is_kv_shared_layer and layer_idx == len( - prev_layers) - 1 - prev_layers[::-1].index(text_config.layer_types[layer_idx]) - + if self.num_kv_shared_layers: + first_kv_shared_layer_idx = config.num_layers - self.num_kv_shared_layers + self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 + prev_layers = text_config.layer_types[:first_kv_shared_layer_idx] + self.store_full_length_kv = not self.is_kv_shared_layer and layer_idx == len( + prev_layers) - 1 - prev_layers[::-1].index(text_config.layer_types[layer_idx]) + else: + self.is_kv_shared_layer = False + self.store_full_length_kv = False orig_kv_channels = config.kv_channels orig_num_query_groups = config.num_query_groups orig_k_layernorm = submodules.k_layernorm @@ -698,6 +701,7 @@ class Gemma4TransformerLayer(TransformerLayer): def __init__(self, config, submodules, *args, **kwargs): super().__init__(config, submodules, *args, **kwargs) text_config = config.hf_config.text_config + self.layer_type = text_config.layer_types[self.layer_number - 1] self.enable_moe_block = text_config.enable_moe_block if self.enable_moe_block: self.experts_mlp = self._build_mlp(submodules.experts_mlp) @@ -748,6 +752,8 @@ def __init__(self, config, submodules, *args, **kwargs): TENorm, hidden_size=hidden_size, config=self.config, eps=eps) def _forward_attention(self, hidden_states: Tensor, **kwargs): + kwargs['rotary_pos_emb'] = kwargs['rotary_pos_emb'][self.layer_type] + kwargs['attention_mask'] = kwargs['attention_mask'][self.layer_type] context = kwargs.pop('context', None) residual = hidden_states input_layernorm_output = self.input_layernorm(hidden_states) @@ -805,13 +811,9 @@ class Gemma4GPTModel(MultimodalGPTModel): class Gemma4TransformerBlock(TransformerBlock): def _layer_forward(self, layer, hidden_states, **kwargs): - layer_number = layer.layer_number - 1 per_layer_inputs = kwargs.pop('per_layer_inputs', None) if per_layer_inputs is not None: - kwargs['per_layer_input'] = per_layer_inputs[:, :, layer_number] - layer_type = self.config.hf_config.text_config.layer_types[layer_number] - kwargs['rotary_pos_emb'] = kwargs['rotary_pos_emb'][layer_type] - kwargs['attention_mask'] = kwargs['attention_mask'][layer_type] + kwargs['per_layer_input'] = per_layer_inputs[:, :, layer.layer_number - 1] return super()._layer_forward(layer, hidden_states, **kwargs) diff --git a/src/mcore_bridge/model/modules/mtp_layer.py b/src/mcore_bridge/model/modules/mtp_layer.py index 5398b71..587ac7b 100644 --- a/src/mcore_bridge/model/modules/mtp_layer.py +++ b/src/mcore_bridge/model/modules/mtp_layer.py @@ -1,9 +1,12 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import torch import transformer_engine +import warnings from contextlib import nullcontext -from functools import partial -from megatron.core import InferenceParams +from megatron.core import InferenceParams, parallel_state, tensor_parallel +from megatron.core.enums import Fp8Recipe +from megatron.core.extensions.transformer_engine import te_checkpoint +from megatron.core.fp8_utils import get_fp8_context from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.tensor_parallel.mappings import (gather_from_sequence_parallel_region, gather_from_tensor_model_parallel_region, @@ -16,6 +19,8 @@ from mcore_bridge.utils import roll_tensor +from .transformer_block import _checkpoint_flatten, _checkpoint_unflatten, _TensorIdx + try: from megatron.core.typed_torch import apply_module except ImportError: @@ -72,6 +77,7 @@ def forward( embedding=None, decoder_input=None, layer_number: Optional[int] = None, + **kwargs, ): assert context is None, 'multi token prediction + cross attention is not yet supported.' if layer_number is None: @@ -88,17 +94,20 @@ def forward( packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' if self.config.position_embedding_type == 'rope' and packed_seq: assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' - rotary_pos_emb = rotary_pos_emb[position_ids[0]] + if isinstance(rotary_pos_emb, dict): + for k, v in rotary_pos_emb.items(): + rotary_pos_emb[k] = v[position_ids[0]] + else: + rotary_pos_emb = rotary_pos_emb[position_ids[0]] else: # mrope or not packed_seq - rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-layer_number, dims=0) + if isinstance(rotary_pos_emb, dict): + for k, v in rotary_pos_emb.items(): + rotary_pos_emb[k] = torch.roll(v, shifts=-layer_number, dims=0) + else: + rotary_pos_emb = torch.roll(rotary_pos_emb, shifts=-layer_number, dims=0) if self.config.recompute_granularity == 'full' and self.training: hidden_states = self._checkpointed_forward( - partial( - self._proj_and_transformer_layer, - packed_seq_params=packed_seq_params, - sequence_len_offset=sequence_len_offset, - ), hidden_states=hidden_states, decoder_input=decoder_input, attention_mask=attention_mask, @@ -109,6 +118,9 @@ def forward( rotary_pos_sin=rotary_pos_sin, attention_bias=attention_bias, inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + **kwargs, ) else: hidden_states = self._proj_and_transformer_layer( @@ -124,9 +136,241 @@ def forward( inference_params=inference_params, packed_seq_params=packed_seq_params, sequence_len_offset=sequence_len_offset, + **kwargs, ) return hidden_states, input_ids, position_ids, decoder_input + # Code borrowed from NVIDIA/Megatron-LM + def _checkpointed_forward( + self, + hidden_states: torch.Tensor, + decoder_input: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + rotary_pos_emb: Optional[torch.Tensor] = None, + rotary_pos_cos: Optional[torch.Tensor] = None, + rotary_pos_sin: Optional[torch.Tensor] = None, + attention_bias: Optional[torch.Tensor] = None, + inference_params: Optional[InferenceParams] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[torch.Tensor] = None, + **kwargs, + ): + """Forward through ``_proj_and_transformer_layer`` with activation + recomputation. + + Mirrors ``transformer_block._checkpointed_forward``: + + * Non-tensor objects (``attention_bias``, ``inference_params``, + ``packed_seq_params``) are captured by the ``custom_forward`` + closure; only tensor / ``None`` arguments flow positionally + through the underlying checkpoint primitive. This is required + by both backends: ``tensor_parallel.checkpoint`` because its + ``save_for_backward`` only accepts tensors and ``None``, and + ``te_checkpoint`` because its reentrant implementation only + tracks positional tensor inputs as checkpoint inputs (kwarg + tensors are not represented in the recompute backward path). + * Quantized recipes (fp8, fp4) route through ``te_checkpoint``; + everything else uses ``tensor_parallel.checkpoint``. + * Only ``fp8 + delayed scaling`` needs an outer quantization + context entered before ``te_checkpoint``; see the + ``outer_quantization_context`` block below. + """ + + # Variables that don't require gradients can be captured via closure. + _ckpt_attention_mask = attention_mask + _ckpt_rotary_pos_emb = rotary_pos_emb + extra_kwargs_keys = tuple(kwargs.keys()) + _extra_flat_tensors = [] + _extra_schemas = [_checkpoint_flatten(v, _extra_flat_tensors) for v in kwargs.values()] + + def custom_forward(hidden_states, decoder_input, context, context_mask, rotary_pos_cos, rotary_pos_sin, + sequence_len_offset, *extra_flat): + rebuilt = [_checkpoint_unflatten(s, extra_flat) for s in _extra_schemas] + extra_kwargs = dict(zip(extra_kwargs_keys, rebuilt)) + return self._proj_and_transformer_layer( + hidden_states=hidden_states, + decoder_input=decoder_input, + attention_mask=_ckpt_attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=_ckpt_rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + **extra_kwargs, + ) + + # Decide the outer quantization context, matching + # ``transformer_block._checkpointed_forward``. Only ``fp8 + delayed + # scaling`` needs an active context at the ``te_checkpoint`` entry + # point: TE's ``_CheckpointFunction.forward`` samples + # ``FP8GlobalStateManager.is_fp8_enabled()`` there to gate the + # phase-1 amax-buffer stash that phase-2 backward looks up via + # ``global_fp8_buffer_pos_fwd_recompute``. With fp8 only entered + # *inside* ``_proj_and_transformer_layer``, TE samples fp8 as off, + # phase-1 skips the stash, and phase-2 raises ``KeyError``. + # Non-delayed fp8 recipes (MXFP8BlockScaling, Float8CurrentScaling) + # and fp4 (NVFP4BlockScaling) treat the stash/lookup as a noop, so + # the inner context entered inside ``_proj_and_transformer_layer`` + # is sufficient. + if self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed: + outer_quantization_context = get_fp8_context(self.config) + else: + outer_quantization_context = nullcontext() + + def checkpoint_handler(): + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + # fp4 quantization is internally implemented via TE's + # ``fp8_autocast`` (see ``fp4_utils.get_fp4_context``), so + # quantized recompute on either fp8 or fp4 must go through + # ``te_checkpoint``. Matches ``transformer_block``'s policy. + if self.config.fp8 or self.config.fp4: + + return te_checkpoint( + custom_forward, + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + parallel_state.get_tensor_model_parallel_group(), + hidden_states, + decoder_input, + context, + context_mask, + rotary_pos_cos, + rotary_pos_sin, + sequence_len_offset, + *_extra_flat_tensors, + ) + else: + # tensor_parallel.checkpoint stashes args via autograd's + # ``save_for_backward``, which only accepts tensors and ``None``. + # Pass tensor / ``None`` args positionally and capture the + # non-tensor objects (``attention_bias``, ``inference_params``, + # ``packed_seq_params``) via the ``custom_forward`` closure. + return tensor_parallel.checkpoint( + custom_forward, + self.config.distribute_saved_activations, + hidden_states, + decoder_input, + context, + context_mask, + rotary_pos_cos, + rotary_pos_sin, + sequence_len_offset, + *_extra_flat_tensors, + ) + + if self.config.recompute_method == 'uniform': + # Uniformly divide the total number of Transformer layers and checkpoint + # the input activation of each divided chunk. + # A method to further reduce memory usage reducing checkpoints. + assert (self.config.recompute_num_layers == 1), 'recompute_num_layers must be 1 for MTP recompute' + with outer_quantization_context: + outputs = checkpoint_handler() + elif self.config.recompute_method == 'block': + # TODO: implement block-based recompute for MTP + warnings.warn("recompute_method == 'block' is not supported for MTP yet." + ' Skipping recompute.') + outputs = self._proj_and_transformer_layer( + hidden_states=hidden_states, + decoder_input=decoder_input, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + **kwargs, + ) + else: + raise ValueError('Invalid activation recompute method.') + + return outputs + + def _proj_and_transformer_layer( + self, + hidden_states: torch.Tensor, + decoder_input: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + rotary_pos_emb: Optional[torch.Tensor] = None, + rotary_pos_cos: Optional[torch.Tensor] = None, + rotary_pos_sin: Optional[torch.Tensor] = None, + attention_bias: Optional[torch.Tensor] = None, + inference_params: Optional[InferenceParams] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Concatenates embeddings with hidden states and then applies transformer layer forward. + """ + padding_mask = kwargs.pop('padding_mask', None) + if padding_mask is not None: + kwargs['padding_mask'] = padding_mask + if self.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + # Unlike transformer_block.py which needs to support mixed-precision in + # different layers,currently MTP only use global fp8 context. + if self.config.fp8: + fp8_context = get_fp8_context(self.config) + transformer_layer_fp8_context = get_fp8_context(self.config) + else: + fp8_context = nullcontext() + transformer_layer_fp8_context = nullcontext() + + # TODO: currently ignoring FP4 in MTP layers because we need more numerical validation + with rng_context: + with fp8_context: + hidden_states = self._concat_embeddings(hidden_states, decoder_input) + + # Use a separate fp8 context for the transformer layer. This is to ensure that when the + # transformer layer is cudagraphed, the FP8GlobalStateManager.is_first_fp8_module() is + # True so that the fp8 weight caching can be triggered correctly. + with transformer_layer_fp8_context: + if getattr(self, 'mtp_layer_pattern', None) is not None: + hidden_states = self.transformer_layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + rotary_pos_emb=rotary_pos_emb, + inference_context=inference_params, + packed_seq_params=packed_seq_params, + **kwargs, + ) + else: + # GPT path: single TransformerLayer + hidden_states, _ = self.transformer_layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + **kwargs, + ) + + if not getattr(self, 'mhc_enabled', False): + hidden_states = self._postprocess(hidden_states) + + return hidden_states + def _concat_embeddings(self, hidden_states: torch.Tensor, decoder_input: torch.Tensor): """ Concatenate the tokens before sending to transformer layer.