diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py index 41c21d93d7d..ef80f72d6fe 100644 --- a/megatron/core/dist_checkpointing/strategies/common.py +++ b/megatron/core/dist_checkpointing/strategies/common.py @@ -86,7 +86,7 @@ def load_common(self, checkpoint_dir: Union[str, Path]): msc = MultiStorageClientFeature.import_package() return msc.torch.load(load_path, map_location='cpu') else: - return torch.load(load_path, map_location='cpu') + return torch.load(load_path, map_location='cpu', weights_only=False) except FileNotFoundError as e: err_msg = f'Common file {load_path} does not exist' if MultiStorageClientFeature.is_enabled(): diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py index a5b6c009ba4..22794d7e60d 100644 --- a/megatron/core/dist_checkpointing/strategies/torch.py +++ b/megatron/core/dist_checkpointing/strategies/torch.py @@ -503,10 +503,12 @@ def __init__( def _validate_global_shapes(self, metadata, sharded_tensors): for sh_ten in sharded_tensors: if sh_ten.key not in metadata.state_dict_metadata: - raise KeyError( - f"{sh_ten.key} from model not in state dict:" - f" {sorted(metadata.state_dict_metadata.keys())}" - ) + # raise KeyError( + # f"{sh_ten.key} from model not in state dict:" + # f" {sorted(metadata.state_dict_metadata.keys())}" + # ) + print(f"{sh_ten.key} from model not in state dict, will skip") + continue loaded_shape = metadata.state_dict_metadata[sh_ten.key].size expected_shape = sh_ten.global_shape if loaded_shape != expected_shape: @@ -530,7 +532,7 @@ def _temporarily_bypass_shape_validation(self): tensor_metadata = self.metadata.state_dict_metadata metadata_with_sizes = [ (tensor_metadata[key], tensor_metadata[key].size, sharded_tensor) - for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() + for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata ] try: # Temporarily set sizes to expected shapes @@ -802,6 +804,7 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St planner=MCoreLoadPlanner( shapes_validation_sharded_tensors=flexible_shape_sharded_tensors, allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors, + allow_partial_load=True, ), ) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index ef8527e9e5e..bfa50d49345 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -639,6 +639,7 @@ def __init__( self.te_quant_params: Optional[TEQuantizationParams] = None for param in self.parameters(): + setattr(param, "parallel_mode", parallel_mode) if is_expert: # Reduce the gradient on the expert_data_parallel group for expert linear layers setattr(param, "allreduce", not self.expert_parallel) @@ -1455,6 +1456,61 @@ def sharded_state_dict( if HAVE_TE and is_te_min_version("1.9.0.dev0"): + def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + class _FakeInt4QuantizationSTE(torch.autograd.Function): + @staticmethod + def forward(ctx, x, group_size): + m, n = x.shape + block_size_m, block_size_n = 1, group_size + + + m_padded = ceil_div(m, block_size_m) * block_size_m + n_padded = ceil_div(n, block_size_n) * block_size_n + + x_padded = torch.zeros( + (m_padded, n_padded), + dtype=x.dtype, device=x.device + ) + x_padded[:m, :n] = x + + x_view = x_padded.view( + m_padded // block_size_m, + block_size_m, + n_padded // block_size_n, + block_size_n + ) + + x_max = x_view.abs().float().amax(dim=(1, 3), keepdim=True) + q_max = 7 + x_scale = x_max / q_max + + x_scale = x_scale.clamp(min=1e-5) + + x_div = x_view / x_scale + x_round = torch.round(x_div) + + x_q_clamped = x_round.clamp(-q_max, q_max) + + x_dequant_view = x_q_clamped * x_scale + + x_dequant_full = x_dequant_view.view_as(x_padded) + x_out = x_dequant_full[:m, :n].contiguous().to(x.dtype) + + return x_out + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + def fake_int4_quantization_ste(x, group_size): + x_out = _FakeInt4QuantizationSTE.apply(x, group_size) + + if hasattr(x, 'main_grad'): + x_out.main_grad = x.main_grad + + return x_out class TEGroupedLinear(te.pytorch.GroupedLinear): """ @@ -1671,6 +1727,20 @@ def forward(self, x, m_splits): return out return out, None + def _get_weight_tensors(self): + """Get the weight tensors of the module.""" + weight_tensors = super()._get_weight_tensors() + + if os.getenv("OPEN_TRAINING_INT4_FAKE_QAT_FLAG", "0") == "1": + group_size = int(os.getenv("OPEN_TRAINING_INT4_GROUP_SIZE", "128")) + + weight_tensors = [ + fake_int4_quantization_ste(w, group_size) + for w in weight_tensors + ] + + return weight_tensors + def _encode_extra_state(self, state): # TE 2.0 changed the format of extra_state to be a byte tensor if is_te_min_version("2.0.0"): diff --git a/megatron/core/fusions/fused_mla_yarn_rope_apply.py b/megatron/core/fusions/fused_mla_yarn_rope_apply.py index 1fd5dcfae37..c9aeef1f076 100644 --- a/megatron/core/fusions/fused_mla_yarn_rope_apply.py +++ b/megatron/core/fusions/fused_mla_yarn_rope_apply.py @@ -385,6 +385,7 @@ def rotary_fwd_kv_kernel( SIN, emb_dim: tl.constexpr, k_dim: tl.constexpr, + k_dim_ceil: tl.constexpr, v_dim: tl.constexpr, head_num: tl.constexpr, batch_size, @@ -434,21 +435,27 @@ def rotary_fwd_kv_kernel( cos_right = tl.load(COS + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) sin_right = tl.load(SIN + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) - KV_ptr = KV + pid_m * stride_kv_seq + pid_head * BLOCK_H * stride_kv_nheads - kv_off = tl.arange(0, BLOCK_H)[:, None] * stride_kv_nheads - mask = kv_off < head_num * stride_kv_nheads - k_in_off = kv_off + tl.arange(0, k_dim)[None, :] - v_in_off = kv_off + k_dim + tl.arange(0, v_dim)[None, :] - k = tl.load(KV_ptr + k_in_off, mask=mask) - v = tl.load(KV_ptr + v_in_off, mask=mask) + KV_ptr = KV + pid_m * stride_kv_seq # + pid_head * BLOCK_H * stride_kv_nheads + ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H + kj_range = tl.arange(0, k_dim_ceil)[None, :] + mask_k = (ki_range < head_num) & (kj_range < k_dim) + mask_v = ki_range < head_num + k_off = ki_range * stride_kv_nheads + kj_range + if v_dim > 0: + v_off = ki_range * stride_kv_nheads + k_dim + tl.arange(0, v_dim)[None, :] + v = tl.load(KV_ptr + v_off, mask=mask_v) + else: + v = tl.zeros((BLOCK_H, 1), dtype=KV.dtype.element_ty) + k = tl.load(KV_ptr + k_off, mask=mask_k) - K_ptr = O_KEY + pid_m * stride_k_seq + pid_head * BLOCK_H * stride_k_nheads - V_ptr = O_VALUE + pid_m * stride_v_seq + pid_head * BLOCK_H * stride_v_nheads + K_ptr = O_KEY + pid_m * stride_k_seq # + pid_head * BLOCK_H * stride_k_nheads + V_ptr = O_VALUE + pid_m * stride_v_seq # + pid_head * BLOCK_H * stride_v_nheads - k_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads + tl.arange(0, k_dim)[None, :] - v_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_v_nheads + tl.arange(0, v_dim)[None, :] - tl.store(K_ptr + k_out_off, k, mask=mask) - tl.store(V_ptr + v_out_off, v, mask=mask) + k_out_off = ki_range * stride_k_nheads + kj_range + tl.store(K_ptr + k_out_off, k, mask=mask_k) + if v_dim > 0: + v_out_off = ki_range * stride_v_nheads + tl.arange(0, v_dim)[None, :] + tl.store(V_ptr + v_out_off, v, mask=mask_v) EMB = K_POS_EMB + pid_m * stride_emb_seq # x1 = t[..., 0::2], x2 = t[..., 1::2] @@ -460,14 +467,16 @@ def rotary_fwd_kv_kernel( x_left = x_left.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) x_right = x_right.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) + x_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H + mask_x = x_range < head_num x_left_off = ( - tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads + x_range * stride_k_nheads + k_dim + tl.arange(0, emb_dim // 2)[None, :] ) x_right_off = x_left_off + emb_dim // 2 - tl.store(K_ptr + x_left_off, x_left, mask=mask) - tl.store(K_ptr + x_right_off, x_right, mask=mask) + tl.store(K_ptr + x_left_off, x_left, mask=mask_x) + tl.store(K_ptr + x_right_off, x_right, mask=mask_x) @triton.autotune( @@ -493,6 +502,7 @@ def rotary_bwd_kv_kernel( SIN, emb_dim: tl.constexpr, k_dim: tl.constexpr, + k_dim_ceil: tl.constexpr, v_dim: tl.constexpr, head_num: tl.constexpr, batch_size, @@ -533,27 +543,32 @@ def rotary_bwd_kv_kernel( else: token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size) - dKV_ptr = dKV + pid_m * stride_dkv_seq + pid_head * BLOCK_H * stride_dkv_nheads - dkv_off = tl.arange(0, BLOCK_H)[:, None] * stride_dkv_nheads - mask = dkv_off < head_num * stride_dkv_nheads - dk_out_off = dkv_off + tl.arange(0, k_dim)[None, :] - dv_out_off = dkv_off + k_dim + tl.arange(0, v_dim)[None, :] - - dK_ptr = dK + pid_m * stride_dk_seq + pid_head * BLOCK_H * stride_dk_nheads - dV_ptr = dV + pid_m * stride_dv_seq + pid_head * BLOCK_H * stride_dv_nheads - dk_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + tl.arange(0, k_dim)[None, :] - dv_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dv_nheads + tl.arange(0, v_dim)[None, :] - dk = tl.load(dK_ptr + dk_in_off, mask=mask) - dv = tl.load(dV_ptr + dv_in_off, mask=mask) - tl.store(dKV_ptr + dk_out_off, dk, mask=mask) - tl.store(dKV_ptr + dv_out_off, dv, mask=mask) + dKV_ptr = dKV + pid_m * stride_dkv_seq # + pid_head * BLOCK_H * stride_dkv_nheads + ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H + kj_range = tl.arange(0, k_dim_ceil)[None, :] + mask_k = (ki_range < head_num) & (kj_range < k_dim) + mask_v = ki_range < head_num + dk_out_off = ki_range * stride_dkv_nheads + kj_range + + dK_ptr = dK + pid_m * stride_dk_seq # + pid_head * BLOCK_H * stride_dk_nheads + dV_ptr = dV + pid_m * stride_dv_seq # + pid_head * BLOCK_H * stride_dv_nheads + dk_in_off = ki_range * stride_dk_nheads + kj_range + + dk = tl.load(dK_ptr + dk_in_off, mask=mask_k) + tl.store(dKV_ptr + dk_out_off, dk, mask=mask_k) + + if v_dim > 0: + dv_out_off = ki_range * stride_dkv_nheads + k_dim + tl.arange(0, v_dim)[None, :] + dv_in_off = ki_range * stride_dv_nheads + tl.arange(0, v_dim)[None, :] + dv = tl.load(dV_ptr + dv_in_off, mask=mask_v) + tl.store(dKV_ptr + dv_out_off, dv, mask=mask_v) if pid_head == 0: x_left_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) x_right_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) for i in tl.static_range(triton.cdiv(head_num, BLOCK_H)): - dK_ptr = dK + pid_m * stride_dk_seq + i * BLOCK_H * stride_dk_nheads - x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim + dK_ptr = dK + pid_m * stride_dk_seq # + i * BLOCK_H * stride_dk_nheads + x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim + i * BLOCK_H * stride_dk_nheads mask = x_off < head_num * stride_dk_nheads x_left_off = x_off + tl.arange(0, emb_dim // 2)[None, :] x_right_off = x_left_off + emb_dim // 2 @@ -632,6 +647,7 @@ def forward( o_key = kv.new_empty(total_seqlen, nheads, emb_dim + k_dim) o_value = kv.new_empty(total_seqlen, nheads, v_dim) + k_dim_ceil = triton.next_power_of_2(k_dim) grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) rotary_fwd_kv_kernel[grid]( @@ -643,6 +659,7 @@ def forward( sin, emb_dim, k_dim, + k_dim_ceil, v_dim, nheads, batch_size, @@ -700,6 +717,7 @@ def backward(ctx, dk, dv): d_kv = dk.new_empty(total_seqlen, nheads, ctx.k_dim + ctx.v_dim) d_emb = dk.new_empty(total_seqlen, 1, ctx.emb_dim) + k_dim_ceil = triton.next_power_of_2(ctx.k_dim) grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) rotary_bwd_kv_kernel[grid]( @@ -711,6 +729,7 @@ def backward(ctx, dk, dv): sin, ctx.emb_dim, ctx.k_dim, + k_dim_ceil, ctx.v_dim, nheads, batch_size, diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py index b0fa6126b63..931cc6c7892 100644 --- a/megatron/core/models/common/language_module/language_module.py +++ b/megatron/core/models/common/language_module/language_module.py @@ -1,7 +1,7 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import logging import os -from typing import Optional, Tuple +from typing import Any, Dict, Literal, Optional, Tuple import torch from torch import Tensor @@ -14,6 +14,10 @@ except: te_parallel_cross_entropy = None from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy +try: + from cut_cross_entropy import linear_cross_entropy +except ImportError: + linear_cross_entropy = None from megatron.core.pipeline_parallel.utils import ( is_pp_first_stage, is_pp_last_stage, @@ -125,6 +129,76 @@ def check_and_set_env_variable( check_and_set_env_variable("NVTE_FUSED_ATTN", 1, AttnBackend.auto) check_and_set_env_variable("NVTE_UNFUSED_ATTN", 1, AttnBackend.auto) + def compute_output_layer_and_language_model_loss( + self, + hidden: Tensor, + labels: Optional[Tensor], + weight: Tensor = None, + sequence_parallel_enabled: bool = False, + column_parallel_linear: torch.nn.Module = None, + col_linear_kwargs: Dict[str, Any] = {}, + reduction: Literal["none", "sum", "mean"] = "none", + ignore_index: int = -100, + ) -> Tensor: + """Computes the language model logits and loss (Cross entropy across vocabulary) + + Args: + hidden (Tensor): The hidden states from the transformer model + labels (Optional[Tensor]): The labels of dimension [batch size, seq length] + weight (Tensor): The weight tensor of shape [vocab size, hidden size]. + Required if using fused linear cross entropy. + column_parallel_linear (torch.nn.Module): The column parallel linear + layer to use for computing logits when not using fused linear cross entropy. + col_linear_kwargs (Dict[str, Any]): Additional kwargs for column parallel linear layer + reduction (Optional[str]): The reduction method. Defaults to "none", and can be + one of "none", "sum", "mean". + ignore_index (Optional[int]): The index to ignore in the loss calculation. + Defaults to -100. + + Returns: + Tensor: Loss tensor of dimensions [batch size, sequence_length]. + """ + if ( + self.config.cross_entropy_loss_fusion + and self.config.cross_entropy_fusion_impl == 'linear' + ): + assert ( + weight is not None + ), "weight cannot be None when using fused linear cross entropy." + assert ( + labels is not None + ), "labels cannot be None when using fused linear cross entropy." + # [b s] => [s b] + labels = labels.transpose(0, 1).contiguous() + loss = linear_cross_entropy( + hidden, + weight, + labels, + tp_group=self.pg_collection.tp, + sequence_parallel=sequence_parallel_enabled, + reduction=reduction, + ignore_index=ignore_index, + ) + + # [s b] => [b, s] + loss = loss.view_as(labels).transpose(0, 1).contiguous() + return loss + else: + assert ( + column_parallel_linear is not None + ), "column_parallel_linear cannot be None when not using fused linear cross entropy." + # output + output_layer_params = {k: v.detach() for k, v in column_parallel_linear.named_parameters()} + output_layer_buffers = dict(column_parallel_linear.named_buffers()) + logits, _ = torch.func.functional_call( + column_parallel_linear, + {**output_layer_params, **output_layer_buffers}, + (hidden,), + col_linear_kwargs, + ) + + return self.compute_language_model_loss(labels, logits) + def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor: """Computes the language model loss (Cross entropy across vocabulary) diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index 49501ee54eb..ddaae617139 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -181,6 +181,9 @@ def get_gpt_layer_with_transformer_engine_spec( use_te_activation_func: bool = False, use_kitchen_attention: bool = False, kitchen_attention_backend: str = "sdpa", + fallback_to_eager_attn: bool = False, + post_self_attn_layernorm: bool = False, + post_mlp_layernorm: bool = False, ) -> ModuleSpec: """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). @@ -363,6 +366,87 @@ def get_gpt_layer_local_spec( moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, ) + return get_transformer_layer_spec_for_backend( + backend=backend, + attention=attention, + mlp=mlp, + sharded_state_dict_keys_map=sharded_state_dict_keys_map, + normalization=normalization, + ) + + +def get_transformer_layer_spec_for_backend( + backend: BackendSpecProvider, + attention: ModuleSpec, + mlp: ModuleSpec, + sharded_state_dict_keys_map: Optional[dict] = None, + normalization: Optional[str] = None, + post_self_attn_layernorm: bool = False, + post_mlp_layernorm: bool = False, +) -> ModuleSpec: + """Helper function to get module spec for TransformerLayer""" + + rms_norm = normalization == "RMSNorm" + + input_layernorm = ( + IdentityOp + if attention.metainfo["fuse_input_layernorm"] + else backend.layer_norm(rms_norm=rms_norm, for_qk=False) + ) + pre_mlp_layernorm = ( + IdentityOp + if mlp.metainfo["fuse_pre_mlp_layernorm"] + else backend.layer_norm(rms_norm=rms_norm, for_qk=False) + ) + + transformer_layer = ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=input_layernorm, + self_attention=attention, + self_attn_bda=get_bias_dropout_add, + post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, + pre_mlp_layernorm=pre_mlp_layernorm, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, + sharded_state_dict_keys_map=sharded_state_dict_keys_map, + ), + ) + return transformer_layer + + +def get_attention_module_spec_for_backend( + backend: BackendSpecProvider, + sharded_state_dict_keys_map: dict, + experimental_attention_variant: Optional[str] = None, + qk_layernorm: Optional[bool] = False, + qk_l2_norm: Optional[bool] = False, + multi_latent_attention: Optional[bool] = False, + mla_down_proj_use_column_parallel: Optional[bool] = False, + normalization: Optional[str] = None, + fallback_to_eager_attn: Optional[bool] = False, +) -> ModuleSpec: + """Helper function to get module spec for Attention""" + + if experimental_attention_variant is not None: + return get_experimental_attention_variant_module_spec_for_backend( + backend, + sharded_state_dict_keys_map, + experimental_attention_variant, + qk_layernorm, + qk_l2_norm, + multi_latent_attention, + mla_down_proj_use_column_parallel, + normalization, + fallback_to_eager_attn, + ) + + # Adjust for RMS norm. + rms_norm = normalization == "RMSNorm" + qk_norm = backend.layer_norm(rms_norm=rms_norm, for_qk=True) + + core_attention = backend.core_attention() if not fallback_to_eager_attn else DotProductAttention if multi_latent_attention: assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA." return ModuleSpec( diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index e287344c13d..170072027ab 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -482,6 +482,7 @@ def forward( inference_params: Optional[BaseInferenceContext] = None, loss_mask: Optional[Tensor] = None, padding_mask: Optional[Tensor] = None, + mtp_kwargs: Optional[dict] = {}, ) -> Tensor: """Forward function of the GPT Model This function passes the input tensors through the embedding layer, and then the decoder and finally into the post @@ -554,6 +555,7 @@ def forward( runtime_gather_output=runtime_gather_output, extra_block_kwargs=extra_block_kwargs, inference_context=inference_context, + mtp_kwargs=mtp_kwargs, ) def _postprocess( @@ -575,6 +577,7 @@ def _postprocess( runtime_gather_output=None, extra_block_kwargs=None, inference_context=None, + mtp_kwargs={}, ): """Postprocesses decoder hidden states to generate logits or compute loss. @@ -589,7 +592,8 @@ def _postprocess( output_weight = None if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() - if mtp_in_postprocess: + + if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None: hidden_states = self.mtp( input_ids=input_ids, position_ids=position_ids, @@ -615,6 +619,9 @@ def _postprocess( if loss_mask is None: # if loss_mask is not provided, use all ones as loss_mask loss_mask = torch.ones_like(mtp_labels) + else: + # Otherwise, roll the loss_mask to keep up with the mtp_labels + loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) for mtp_layer_number in range(self.config.mtp_num_layers): # output mtp_logits, _ = self.output_layer( diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index 4192b0bb73c..248634a9e68 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -680,6 +680,8 @@ def state_dict(self): # TE FusedAdam will not accumulate step for empty param groups, so we need to # align the step across param groups. param_group["step"] = int(step) + if "step" in param_group and param_group["step"] is None: + del param_group["step"] # Grad scaler state. if self.grad_scaler: @@ -1662,9 +1664,6 @@ def sharded_param_state_dp_reshardable( tensors[key] = LocalNonpersistentObject(tensors[key]) continue if key == 'step': - # The optimizer state of STEP is a 0-dim tensor and is handled - # separately via param_groups, not as part of the gradient buffer. - tensors[key] = LocalNonpersistentObject(tensors[key]) continue assert tensors[key].shape == (gbuf_local_end - gbuf_local_start,), ( tensors[key].shape, diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index 0bcea5687a4..a23bc5bc6ca 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -11,6 +11,7 @@ import numpy as np import torch +import torch.distributed as dist from .utils import GlobalMemoryBuffer, GlobalSymmetricMemoryBuffer, is_torch_min_version diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py index ac839c21f18..f18309217c3 100644 --- a/megatron/core/pipeline_parallel/p2p_communication.py +++ b/megatron/core/pipeline_parallel/p2p_communication.py @@ -26,22 +26,22 @@ def _batched_p2p_ops( ops = [] if tensor_send_prev is not None: send_prev_op = torch.distributed.P2POp( - torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group + torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, ) ops.append(send_prev_op) if tensor_recv_prev is not None: recv_prev_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group + torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, ) ops.append(recv_prev_op) if tensor_send_next is not None: send_next_op = torch.distributed.P2POp( - torch.distributed.isend, tensor_send_next, next_pipeline_rank, group + torch.distributed.isend, tensor_send_next, next_pipeline_rank, ) ops.append(send_next_op) if tensor_recv_next is not None: recv_next_op = torch.distributed.P2POp( - torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group + torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, ) ops.append(recv_next_op) if len(ops) > 0: diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index 4be97401748..2f1b323c01a 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -205,6 +205,12 @@ def __init__( self.router_replay = None if self.config.moe_enable_routing_replay: self.router_replay = RouterReplay() + # SiRL adapter: register with sirl's RoutingReplay for actor.py + try: + from sirl.utils.routing_replay import register_routing_replay + register_routing_replay(self) + except ImportError: + pass def _maintain_float32_expert_bias(self): """ diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index 2edb652bfc6..40fbd186c17 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -7,6 +7,7 @@ import torch from torch import Tensor +import warnings from megatron.core import InferenceParams, parallel_state, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedStateDict @@ -709,17 +710,19 @@ def _get_embeddings( cp_group=self.cp_group, packed_seq_params=packed_seq_params, ) - position_ids, _ = roll_tensor( - position_ids, - shifts=-1, - dims=-1, - cp_group=self.cp_group, - packed_seq_params=packed_seq_params, - ) + if position_ids is not None: + position_ids, _ = roll_tensor( + position_ids, + shifts=-1, + dims=-1, + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) # embedding decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) + decoder_input = decoder_input.detach() - hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=False) return input_ids, position_ids, decoder_input, hidden_states @@ -821,6 +824,51 @@ def _postprocess(self, hidden_states: torch.Tensor): return hidden_states def _checkpointed_forward(self, forward_func, *args, **kwargs): + """Wrap `forward_func` with activation checkpointing while only passing tensors. + + Non-tensor arguments (e.g., configuration objects, None) are captured via closure so + that checkpoint implementations never receive them directly, avoiding save_for_backward + issues with non-tensor inputs. + """ + + # TODO(jiajun): Is there any better implementation here? + positional_specs = [] + kw_specs = [] + tensor_args: List[torch.Tensor] = [] + + for arg in args: + if torch.is_tensor(arg): + positional_specs.append(('tensor', len(tensor_args))) + tensor_args.append(arg) + else: + positional_specs.append(('const', arg)) + + for key, value in kwargs.items(): + if torch.is_tensor(value): + kw_specs.append((key, ('tensor', len(tensor_args)))) + tensor_args.append(value) + else: + kw_specs.append((key, ('const', value))) + + def run(*flat_tensor_args): + rebuilt_args = [] + for spec_type, payload in positional_specs: + if spec_type == 'tensor': + rebuilt_args.append(flat_tensor_args[payload]) + else: + rebuilt_args.append(payload) + + rebuilt_kwargs = {} + for key, (spec_type, payload) in kw_specs: + if spec_type == 'tensor': + rebuilt_kwargs[key] = flat_tensor_args[payload] + else: + rebuilt_kwargs[key] = payload + + return forward_func(*rebuilt_args, **rebuilt_kwargs) + + tensor_args_tuple = tuple(tensor_args) + def checkpoint_handler(): """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" if self.config.fp8: @@ -831,12 +879,11 @@ def checkpoint_handler(): self.config.distribute_saved_activations, tensor_parallel.random.get_cuda_rng_tracker, parallel_state.get_tensor_model_parallel_group(), - *args, - **kwargs, + *tensor_args_tuple, ) else: return tensor_parallel.checkpoint( - forward_func, self.config.distribute_saved_activations, *args, *kwargs.values() + run, self.config.distribute_saved_activations, *tensor_args_tuple ) if self.config.recompute_method == 'uniform': diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index eaae585905e..0f01f6bc055 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -227,6 +227,9 @@ class TransformerConfig(ModelParallelConfig): attention_output_gate: bool = False """Whether to apply output gate to the attention layers.""" + post_self_attn_layernorm: bool = False + post_mlp_layernorm: bool = False + test_mode: bool = False """Whether to run real-time tests.""" diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index a5eaec92866..3b19403ce64 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -224,6 +224,7 @@ class TransformerLayerSubmodules: input_layernorm: Union[ModuleSpec, type] = IdentityOp self_attention: Union[ModuleSpec, type] = IdentityOp self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp + post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp cross_attention: Union[ModuleSpec, type] = IdentityOp @@ -232,6 +233,7 @@ class TransformerLayerSubmodules: pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp mlp: Union[ModuleSpec, type] = IdentityOp mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp + post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) @@ -311,6 +313,13 @@ def __init__( # [Module 3: BiasDropoutFusion] self.self_attn_bda = build_module(submodules.self_attn_bda) + self.post_self_attn_layernorm = build_module( + submodules.post_self_attn_layernorm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + # [Module 4: Post SelfAttention] Optional Layernorm after self-attn self.pre_cross_attn_layernorm = build_module( submodules.pre_cross_attn_layernorm, @@ -376,6 +385,13 @@ def __init__( self.is_moe_layer = isinstance(self.mlp, MoELayer) + self.post_mlp_layernorm = build_module( + submodules.post_mlp_layernorm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon + ) + self.recompute_input_layernorm = False self.recompute_pre_mlp_layernorm = False self.recompute_mlp = False @@ -615,6 +631,10 @@ def _forward_attention( attention_output_with_bias[0] ) + attention_output, attention_output_bias = attention_output_with_bias + attention_output = self.post_self_attn_layernorm(attention_output) + attention_output_with_bias = (attention_output, attention_output_bias) + # TODO: could we move `bias_dropout_add_exec_handler` itself # inside the module provided in the `bias_dropout_add_spec` module? nvtx_range_push(suffix="self_attn_bda") @@ -755,6 +775,20 @@ def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None) self._set_fc2_residual(residual) mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output, padding_mask=padding_mask) + mlp_output, mlp_output_bias = mlp_output_with_bias + mlp_output = self.post_mlp_layernorm(mlp_output) + mlp_output_with_bias = (mlp_output, mlp_output_bias) + + mlp_output, mlp_output_bias = mlp_output_with_bias + mlp_output = self.post_mlp_layernorm(mlp_output) + mlp_output_with_bias = (mlp_output, mlp_output_bias) + + if self.recompute_pre_mlp_layernorm: + # discard the output of the pre-mlp layernorm and register the recompute + # as a gradient hook of mlp_output_with_bias[0] + self.pre_mlp_norm_checkpoint.discard_output_and_register_recompute( + mlp_output_with_bias[0] + ) nvtx_range_pop(suffix="mlp") if ( diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 46f3c28b1da..ab9ecb95591 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1365,6 +1365,9 @@ def core_transformer_config_from_args(args, config_class=None): kw_args['inference_sampling_seed'] = args.seed + kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm + kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm + # handle quantization config # NOTE: Kitchen arguments are only added to the namespace when # Kitchen library is available. @@ -1678,6 +1681,10 @@ def _add_network_size_args(parser): group.add_argument('--make-vocab-size-divisible-by', type=int, default=128, help='Pad the vocab size to be divisible by this value.' 'This is added for computational efficieny reasons.') + # All slime-patch args that overlap with ArgumentGroupFactory(TransformerConfig) + # are removed. Only add args that are truly new: + group.add_argument('--use-gated-attention', action='store_true', + help='If set, use gated attention as in Qwen3Next') group.add_argument('--openai-gelu', action='store_true', help='Use OpenAIs GeLU implementation. This option' 'should not be used unless for backward compatibility' diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py index 33340a5e978..0cbfd801892 100644 --- a/megatron/training/tokenizer/tokenizer.py +++ b/megatron/training/tokenizer/tokenizer.py @@ -143,8 +143,8 @@ def __init__(self, pretrained_model_name_or_path, trust_remote_code=False, **kwa # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there self._tokenizer = transformers.AutoTokenizer.from_pretrained( pretrained_model_name_or_path=pretrained_model_name_or_path, - trust_remote_code=trust_remote_code, - **kwargs + trust_remote_code=True, + **kwargs, ) self._vocab = self._tokenizer.get_vocab() self._inv_vocab = {token_id: token for token, token_id in self._vocab.items()}