diff --git a/README.md b/README.md index 4633ce4..0aeb2e7 100644 --- a/README.md +++ b/README.md @@ -127,7 +127,7 @@ The following is the list of models supported by MCore-Bridge: | Series | model_type | | -------- | ------------------------------------------------------------ | | Qwen | qwen2, qwen2_moe
qwen3, qwen3_moe, qwen3_next | -| DeepSeek | deepseek_v3, deepseek_v32 | +| DeepSeek | deepseek_v3, deepseek_v32, deepseek_v4 | | GLM | glm4, glm4_moe, glm4_moe_lite
glm_moe_dsa | | MiniMax | minimax_m2 | | Kimi | kimi_k2, kimi_k25 | diff --git a/README_zh.md b/README_zh.md index 9542c31..1254356 100644 --- a/README_zh.md +++ b/README_zh.md @@ -123,7 +123,7 @@ uv pip install -e . --torch-backend=auto | 系列 | model_type | | -------- | ------------------------------------------------------------ | | Qwen | qwen2, qwen2_moe
qwen3, qwen3_moe, qwen3_next | -| DeepSeek | deepseek_v3, deepseek_v32 | +| DeepSeek | deepseek_v3, deepseek_v32, deepseek_v4 | | GLM | glm4, glm4_moe, glm4_moe_lite
glm_moe_dsa | | MiniMax | minimax_m2 | | Kimi | kimi_k2, kimi_k25 | diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 99092fd..b7f17b7 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -357,18 +357,18 @@ def _broadcast_ep_pp(self, tensor, is_expert): dist.all_reduce(src_rank, group=pp_group) src_rank = dist.get_global_rank(pp_group, src_rank.item()) meta_data = torch.zeros(10, dtype=torch.int64, device='cuda') - dtype_mapping = {torch.float64: 0, torch.float32: 1, torch.float16: 2, torch.bfloat16: 3, torch.uint8: 4} - dtype_mapping_r = {v: k for k, v in dtype_mapping.items()} + dtype_mapping = [torch.float64, torch.float32, torch.float16, torch.bfloat16, torch.uint8, torch.int32] + dtype_mapping_r = {v: k for k, v in enumerate(dtype_mapping)} if tensor is None: dist.broadcast(meta_data, src=src_rank, group=pp_group) shape = meta_data[1:1 + meta_data[0]].tolist() - dtype = dtype_mapping_r[meta_data[-1].item()] + dtype = dtype_mapping[meta_data[-1].item()] tensor = torch.empty(shape, device='cuda', dtype=dtype) dist.broadcast(tensor, src=src_rank, group=pp_group) else: meta_data[0] = tensor.ndim meta_data[1:1 + tensor.ndim] = torch.tensor(tensor.shape, dtype=torch.int64, device='cuda') - meta_data[-1] = dtype_mapping[tensor.dtype] + meta_data[-1] = dtype_mapping_r[tensor.dtype] dist.broadcast(meta_data, src=src_rank, group=pp_group) dist.broadcast(tensor, src=src_rank, group=pp_group) return tensor @@ -529,7 +529,7 @@ def _filter_prefix(state_dict, prefix: str): return state_dict return {k: v for k, v in state_dict.items() if k.startswith(prefix)} - def _reduce_tensor_pp_group(self, tensor, to_mcore, dtype=torch.long, op=dist.ReduceOp.MAX): + def _reduce_tensor_pp_group(self, tensor, to_mcore, dtype=torch.bool, op=dist.ReduceOp.MAX): if to_mcore: return tensor tensor = torch.tensor([tensor], dtype=dtype, device='cuda') @@ -678,14 +678,17 @@ def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore, **kwargs): self._set_state_dict(mg_attn, 'q_layernorm.weight', hf_state_dict, self.hf_q_norm_key, to_mcore) self._set_state_dict(mg_attn, 'k_layernorm.weight', hf_state_dict, self.hf_k_norm_key, to_mcore) - def _set_router(self, mg_mlp, hf_state_dict, to_mcore): + def _set_router(self, mg_mlp, hf_state_dict, to_mcore, **kwargs): + moe_router_enable_expert_bias = kwargs.get('moe_router_enable_expert_bias') + if moe_router_enable_expert_bias is None: + moe_router_enable_expert_bias = self.config.moe_router_enable_expert_bias hf_gate_key = self.hf_gate_key if self.llm_model_type == 'gpt_oss': hf_gate_key = 'router.weight' self._set_state_dict(mg_mlp, 'router.weight', hf_state_dict, hf_gate_key, to_mcore) if self.config.add_bias_linear: self._set_state_dict(mg_mlp, 'router.bias', hf_state_dict, hf_gate_key.replace('weight', 'bias'), to_mcore) - if self.config.moe_router_enable_expert_bias: + if moe_router_enable_expert_bias: self._set_state_dict(mg_mlp, 'router.expert_bias', hf_state_dict, self.hf_expert_bias_key, to_mcore) def _set_moe_state( @@ -746,7 +749,7 @@ def _get_hf_experts_attr(self, is_mtp: bool = False): return True, True if self.model_type in {'glm4v_moe', 'kimi_vl', 'qwen3_omni_moe', 'qwen3_5_moe'} or self.llm_model_type in { 'qwen2_moe', 'qwen3_moe', 'deepseek_v2', 'deepseek_v3', 'kimi_k2', 'dots1', 'ernie4_5_moe', 'glm4_moe', - 'glm4_moe_lite', 'minimax_m2', 'olmoe', 'qwen3_next', 'glm_moe_dsa', 'deepseek_v32' + 'glm4_moe_lite', 'minimax_m2', 'olmoe', 'qwen3_next', 'glm_moe_dsa', 'deepseek_v32', 'deepseek_v4' }: return False, False elif self.model_type in {'qwen3_vl_moe', 'llama4', 'gemma4'} or self.llm_model_type in {'gpt_oss'}: @@ -1619,6 +1622,28 @@ def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool self.hf_post_attention_layernorm_key, to_mcore) return hf_state_dict + def _set_hyper_connection(self, mg_layer, hf_state_dict, layer_idx, to_mcore): + + for key, hf_key in zip(['self_attention_hyper_connection', 'mlp_hyper_connection'], ['attn', 'ffn']): + hyper_connection = None if mg_layer is None else getattr(mg_layer, key) + self._set_state_dict(hyper_connection, 'mapping_proj.weight', hf_state_dict, f'hc_{hf_key}_fn', to_mcore) + self._set_state_dict(hyper_connection, 'bias', hf_state_dict, f'hc_{hf_key}_base', to_mcore) + has_hyper_connection = hyper_connection is not None + has_hyper_connection = self._reduce_tensor_pp_group(has_hyper_connection, to_mcore) + if has_hyper_connection: + if to_mcore: + alpha = hf_state_dict[f'hc_{hf_key}_scale'].load() + for i, alpha_suffix in enumerate(['pre', 'post', 'res']): + getattr(hyper_connection, f'alpha_{alpha_suffix}').data[:] = alpha[i] + else: + alpha = None + if hyper_connection is not None: + alpha = [] + for i, alpha_suffix in enumerate(['pre', 'post', 'res']): + alpha.append(getattr(hyper_connection, f'alpha_{alpha_suffix}', None)) + alpha = torch.concat(alpha, dim=0) + hf_state_dict[f'hc_{hf_key}_scale'] = self._get_weight(alpha, 'alpha')[0] + def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): hf_prefix = f'{hf_prefix}{layer_idx}.' if to_mcore: @@ -1627,6 +1652,9 @@ def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: i hf_state_dict = {} hf_state_dict.update(self._set_layer_attn(mg_layer, hf_state_dict, layer_idx, to_mcore)) hf_state_dict.update(self._set_layer_mlp(mg_layer, hf_state_dict, layer_idx, to_mcore)) + if self.config.enable_hyper_connections: + self._set_hyper_connection(mg_layer, hf_state_dict, layer_idx, to_mcore) + if to_mcore: hf_state_dict = {} else: @@ -1678,14 +1706,18 @@ def _convert_post_process(self, mg_model, hf_state_dict, hf_prefix: str, to_mcor self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, hf_lm_head_key, to_mcore) elif to_mcore and lm_model.output_layer.weight is not None: self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, self.hf_embed_key, to_mcore) - self._set_state_dict(lm_model, 'decoder.final_layernorm.weight', hf_state_dict, self.hf_final_layernorm_key, - to_mcore) + self._set_final_layernorm(lm_model, hf_state_dict, to_mcore) + if to_mcore: hf_state_dict = {} else: hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict + def _set_final_layernorm(self, lm_model, hf_state_dict, to_mcore): + self._set_state_dict(lm_model, 'decoder.final_layernorm.weight', hf_state_dict, self.hf_final_layernorm_key, + to_mcore) + def _convert_hf_state_dict(self, hf_state_dict, to_mcore): res = {} for k, v in hf_state_dict.items(): @@ -1713,6 +1745,7 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd if to_mcore: yield else: + hf_state_dict = self._convert_hf_state_dict(hf_state_dict, to_mcore) yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) hf_state_dict = {} layer_idx = 0 @@ -1746,6 +1779,7 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd if to_mcore: yield else: + res = self._convert_hf_state_dict(res, to_mcore) yield from list(self._add_prefix(res, hf_prefix).items()) hf_state_dict = {} @@ -1761,6 +1795,7 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd if to_mcore: yield else: + res = self._convert_hf_state_dict(res, to_mcore) yield from list(self._add_prefix(res, hf_prefix).items()) hf_state_dict = {} if not to_mcore or is_pp_last_stage: diff --git a/src/mcore_bridge/config/model_config.py b/src/mcore_bridge/config/model_config.py index 09703e0..46faa08 100644 --- a/src/mcore_bridge/config/model_config.py +++ b/src/mcore_bridge/config/model_config.py @@ -197,7 +197,7 @@ class ModelConfig(TransformerConfig): linear_decoupled_in_proj: bool = False # dsa - experimental_attention_variant: Optional[Literal['gated_delta_net', 'dsa']] = None + experimental_attention_variant: Optional[Literal['gated_delta_net', 'dsa', 'dsv4_hybrid']] = None dsa_indexer_n_heads: Optional[int] = None dsa_indexer_head_dim: Optional[int] = None dsa_indexer_topk: Optional[int] = None @@ -205,6 +205,18 @@ class ModelConfig(TransformerConfig): dsa_indexer_use_sparse_loss: bool = False dsa_indexer_rotary_interleaved: bool = False + # deepseek-v4 + csa_window_size: int = 128 + csa_compress_ratios: Optional[List[int]] = None + csa_compress_rotary_base: float = 40000.0 + o_groups: int = 8 + o_lora_rank: int = 1024 + enable_hyper_connections: bool = False + num_residual_streams: int = 4 + mhc_sinkhorn_iterations: int = 20 + mhc_init_gating_factor: float = 0.01 + moe_n_hash_layers: int = 0 + # mtp mtp_decoder_input_detach: bool = False mtp_shared_weights: bool = False @@ -290,6 +302,7 @@ def __post_init__(self): if self.add_bias_linear: self.add_qkv_bias = True + self.actual_vocab_size = self.padded_vocab_size self.batch_p2p_comm = not self.overlap_p2p_comm if self.swiglu: self.activation_func = F.silu @@ -315,6 +328,8 @@ def __post_init__(self): self.mtp_num_layers = 1 else: self.mtp_unroll_steps = self.mtp_num_layers + if self.csa_compress_ratios is not None and self.mtp_num_layers is not None: + self.csa_compress_ratios += [0] * self.mtp_num_layers super().__post_init__() self._check_npu() diff --git a/src/mcore_bridge/config/parser.py b/src/mcore_bridge/config/parser.py index cabea5a..c09734f 100644 --- a/src/mcore_bridge/config/parser.py +++ b/src/mcore_bridge/config/parser.py @@ -56,6 +56,15 @@ 'dsa_indexer_head_dim': ['index_head_dim'], 'dsa_indexer_topk': ['index_topk'], 'dsa_indexer_rotary_interleaved': ['indexer_rope_interleave'], + # deepseek_v4 + 'csa_compress_ratios': ['compress_rates'], + 'csa_compress_rotary_base': ['compress_rope_theta'], + 'o_groups': ['o_groups'], + 'o_lora_rank': ['o_lora_rank'], + 'num_residual_streams': ['hc_mult'], + 'mhc_sinkhorn_iterations': ['hc_sinkhorn_iters'], + 'moe_n_hash_layers': ['mlp_layer_types'], + 'activation_func_clamp_value': ['swiglu_limit'], # other 'original_max_position_embeddings': ['original_max_position_embeddings'], 'partial_rotary_factor': ['partial_rotary_factor'], @@ -88,7 +97,7 @@ def _convert_config(config, _internal_call=False) -> Dict[str, Any]: else: continue else: - if k == 'kv_lora_rank': + if k in {'q_lora_rank', 'kv_lora_rank'}: megatron_config['multi_latent_attention'] = True elif k == 'hf_model_type': if _internal_call: @@ -134,7 +143,7 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]: res.pop('ffn_hidden_size', None) if llm_model_type in {'qwen2_moe', 'qwen3_next'} or hf_model_type == 'qwen3_5_moe': res['moe_shared_expert_gate'] = True - if llm_model_type in {'deepseek', 'deepseek_v2', 'deepseek_v3', 'kimi_k2', 'deepseek_v32', 'dots1' + if llm_model_type in {'deepseek', 'deepseek_v2', 'deepseek_v3', 'kimi_k2', 'deepseek_v32', 'dots1', 'deepseek_v4' } or hf_model_type == 'kimi_vl': if llm_model_type != 'deepseek': res['qk_layernorm'] = True @@ -143,6 +152,17 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]: res['moe_router_score_function'] = 'sigmoid' elif llm_model_type == 'deepseek_v32': res['experimental_attention_variant'] = 'dsa' + elif llm_model_type == 'deepseek_v4': + if 'v_head_dim' not in res: + res['v_head_dim'] = res['kv_channels'] + res['experimental_attention_variant'] = 'dsv4_hybrid' + res['moe_router_enable_expert_bias'] = True + res['csa_window_size'] = window_size + res['enable_hyper_connections'] = True + csa_compress_ratios = res.pop('csa_compress_ratios', None) + res['csa_compress_ratios'] = [csa_compress_ratios.get(layer_type, 0) for layer_type in layer_types] + moe_n_hash_layers = res.pop('moe_n_hash_layers', None) + res['moe_n_hash_layers'] = len([layer for layer in moe_n_hash_layers if layer == 'hash_moe']) elif llm_model_type == 'hunyuan': # Since HunYuan’s attention applies RoPE before using q/k_layernorm, # which is incompatible with megatron-core, support is not provided here. diff --git a/src/mcore_bridge/model/constant.py b/src/mcore_bridge/model/constant.py index 9b8dc1b..9708f6a 100644 --- a/src/mcore_bridge/model/constant.py +++ b/src/mcore_bridge/model/constant.py @@ -9,6 +9,7 @@ class LLMModelType: minimax_m2 = 'minimax_m2' hy_v3 = 'hy_v3' bailing_moe = 'bailing_moe' + deepseek_v4 = 'deepseek_v4' qwen3_emb = 'qwen3_emb' diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 35d89f8..617c622 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -66,19 +66,14 @@ def __init__( ): vocab_size = math.ceil( config.padded_vocab_size / config.tensor_model_parallel_size) * config.tensor_model_parallel_size - hf_rope_scaling = config.rope_scaling + self.hf_rope_scaling = config.rope_scaling if config.multi_latent_attention: config.rope_type = 'rope' # use transformers implementation # Set default value, the following content will not be used. (dummy) config.mscale_all_dim = 0. config.cache_mla_latents = False config.rotary_scaling_factor = 40 - if hf_rope_scaling and hf_rope_scaling['rope_type'] == 'yarn': - # softmax_scale - config.mscale = hf_rope_scaling['mscale'] - config.mscale_all_dim = hf_rope_scaling['mscale_all_dim'] - config.rotary_scaling_factor = hf_rope_scaling['factor'] - self.hf_rope_scaling = hf_rope_scaling + self._init_mla_softmax_scale(config) super().__init__( config, transformer_layer_spec, @@ -136,6 +131,13 @@ def _value(self): attention.config = copy.copy(attention.config) attention.config.apply_rope_fusion = False + def _init_mla_softmax_scale(self, config): + if self.hf_rope_scaling and self.hf_rope_scaling['rope_type'] == 'yarn': + # softmax_scale + config.mscale = self.hf_rope_scaling['mscale'] + config.mscale_all_dim = self.hf_rope_scaling['mscale_all_dim'] + config.rotary_scaling_factor = self.hf_rope_scaling['factor'] + def _preprocess( self, input_ids: torch.Tensor, @@ -297,6 +299,10 @@ def forward( assert padding_mask.shape[1] % tp_size == 0, f'padding_mask.shape: {padding_mask.shape}' padding_mask = torch.chunk(padding_mask, tp_size, dim=1)[mpu.get_tensor_model_parallel_rank()] extra_block_kwargs['padding_mask'] = padding_mask.contiguous() + + if self.config.moe_n_hash_layers > 0: + extra_block_kwargs['input_ids'] = input_ids + # Run decoder. hidden_states = self.decoder( hidden_states=decoder_input, diff --git a/src/mcore_bridge/model/gpts/__init__.py b/src/mcore_bridge/model/gpts/__init__.py index 52b007f..6eb44db 100644 --- a/src/mcore_bridge/model/gpts/__init__.py +++ b/src/mcore_bridge/model/gpts/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from . import bailing_moe, glm4, hunyuan, llm, minimax_m2, olmoe, qwen3_emb, qwen3_next +from . import bailing_moe, deepseek_v4, glm4, hunyuan, llm, minimax_m2, olmoe, qwen3_emb, qwen3_next diff --git a/src/mcore_bridge/model/gpts/deepseek_v4.py b/src/mcore_bridge/model/gpts/deepseek_v4.py new file mode 100644 index 0000000..f0f4b85 --- /dev/null +++ b/src/mcore_bridge/model/gpts/deepseek_v4.py @@ -0,0 +1,450 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import copy +import torch +from megatron.core import tensor_parallel +from megatron.core.models.common.embeddings import apply_rotary_pos_emb +from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from megatron.core.transformer.transformer_block import TransformerBlock as McoreTransformerBlock +from typing import Optional + +from mcore_bridge.bridge import GPTBridge + +from ..constant import ModelType +from ..gpt_model import GPTModel +from ..register import ModelLoader, ModelMeta, register_model +from ..rope import get_rope_inv_freq + +try: + from megatron.core.pipeline_parallel.fine_grained_activation_offload import \ + FineGrainedActivationOffloadingInterface as off_interface + from megatron.core.transformer.experimental_attention_variant.deepseek_v4_hybrid_attention import \ + DSv4HybridSelfAttention as McoreDSv4HybridSelfAttention + from megatron.core.transformer.experimental_attention_variant.deepseek_v4_hybrid_attention import _q_rms_norm + from megatron.core.typed_torch import apply_module +except ImportError: + McoreDSv4HybridSelfAttention = object + _q_rms_norm = None + apply_module = None + off_interface = None + + +class DSv4HybridSelfAttention(McoreDSv4HybridSelfAttention): + + def __init__(self, config, *args, **kwargs): + assert McoreDSv4HybridSelfAttention is not object, ( + 'Please install the Megatron-Core dev branch: ' + '`pip install git+https://github.com/NVIDIA/Megatron-LM@dev`') + super().__init__(config, *args, **kwargs) + self.layer_type = self.config.hf_config.layer_types[self.layer_number - 1] + self.rope_layer_type = 'main' if self.layer_type == 'sliding_attention' else 'compress' + + def get_query_key_value_tensors( + self, + hidden_states, + key_value_states=None, + position_ids=None, + packed_seq_params=None, + inference_context=None, + rotary_pos_emb=None, + *, + inference_params=None, + ): + """ + Derives `query`, `key` and `value` tensors from `hidden_states`. + """ + # s = sequence length, b = batch size, h = hidden size, n = num attention heads + # Attention heads [s, b, n*h] + assert (hidden_states.ndim == 3), f"hidden_states should be 3D, [s, b, n*h], got {hidden_states.ndim}D" + if packed_seq_params is not None: + assert (packed_seq_params.local_cp_size + is None), 'dynamic_context_parallel is not supported with MLA yet and is planned for future. \ + Please disable dynamic_context_parallel.' + + assert (inference_context is None + and inference_params is None), 'Inference is not supported for DSv4HybridSelfAttention.' + + if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': + cu_seqlens_q = packed_seq_params.cu_seqlens_q + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv + else: + cu_seqlens_q = cu_seqlens_kv = None + + # ========================================= + # QKV down projection and layernorm + # ========================================= + # q_compressed: [s, b, q_lora_rank] + q_compressed, _ = self.linear_q_down_proj(hidden_states) + + kv_compressed = hidden_states + + if packed_seq_params is not None: + # If sequence packing, TE expect [t, h, d] shaped qkv input. + # In Megatron-Core, the qkv shape is [t, 1, h, d]. + # So we need to reshape qkv from [t, 1, h, d] to [t, h, d]. + q_compressed = q_compressed.squeeze(1) + + # ========================================= + # Apply norm + # ========================================= + + if self.config.q_lora_rank is not None: + # q_compressed: [num_tokens, q_lora_rank] + q_compressed = apply_module(self.q_layernorm)(q_compressed) + + # ========================================= + # QKV up projection and RoPE apply + # ========================================= + + def qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, rotary_pos_emb): + """ + Apply the up projection and RoPE to the query and key. + When sequence packing enabled, the input tensors adopt a packed shape of [t, ...]; + otherwise, they maintain the unpacked shape [s, b, ...]. In subsequent code comments, + we uniformly use [num_tokens, ...] to denote [s, b, ...] or [t, ...] for two cases. + """ + # q_compressed: [num_tokens, q_lora_rank] + # q: [num_tokens, n * (qk_head_dim + qk_pos_emb_head_dim)] + q, _ = self.linear_q_up_proj(q_compressed) + + # q: [num_tokens, n, q_head_dim] + q = q.view(*q.size()[:-1], self.num_attention_heads_per_partition, self.q_head_dim) + q = _q_rms_norm(q, self.config.layernorm_epsilon) + + kv, _ = self.linear_kv_proj(kv_compressed) + kv = self.kv_layernorm(kv) + + # [num_tokens, qk_pos_emb_head_dim] -> [num_tokens, 1, qk_pos_emb_head_dim] + q_len = q.size()[0] + if packed_seq_params is None or self.config.context_parallel_size == 1: + # Shorten rotary_pos_emb to the sequence length when inference_params + # is not provided. This makes sure we can run forward directly with + # any sequence length. During training, the sequence length is always + # the full rotary_pos_emb length, except for sequence packing + CP. + # When sequence packing and context parallel are both enabled, the + # position embedding will not split rotary_pos_emb, so it may exceed + # the sequence length on this CP rank, but we need the full rotary_pos_emb + # to cover the full sequence, so we do not shorten it here. + rotary_pos_emb = rotary_pos_emb[0:q_len] + + # q_no_pe: [num_tokens, n, qk_head_dim] + # q_pos_emb: [num_tokens, n, qk_pos_emb_head_dim] + pos_dim = self.config.qk_pos_emb_head_dim + q_no_pe, q_pos_emb = torch.split(q, [q.shape[-1] - pos_dim, pos_dim], dim=-1) + + # RoPE and query (shared for wkv and latent) + # q_pos_emb: [num_tokens, n, qk_pos_emb_head_dim] + q_pos_emb = apply_rotary_pos_emb( + q_pos_emb, + rotary_pos_emb, + config=self.config, + cu_seqlens=cu_seqlens_q, + cp_group=self.pg_collection.cp, + mla_rotary_interleaved=True, + mla_output_remove_interleaving=True, + ) + # query: [num_tokens, n, (qk_head_dim + v_head_dim)] + query = torch.cat([q_no_pe, q_pos_emb], dim=-1) + + kv_no_pe, k_pos_emb = torch.split(kv, [kv.size(-1) - pos_dim, pos_dim], dim=-1) + + # k_pos_emb:[num_tokens, 1, qk_pos_emb_head_dim] + k_pos_emb = apply_rotary_pos_emb( + k_pos_emb, + rotary_pos_emb, + config=self.config, + cu_seqlens=cu_seqlens_kv, + cp_group=self.pg_collection.cp, + mla_rotary_interleaved=True, + mla_output_remove_interleaving=True, + ) + + # Single head: key = value = [num_tokens, 1, v_head_dim] + kv = torch.cat([kv_no_pe, k_pos_emb], dim=-1).unsqueeze(-2) + key = kv + value = kv + + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + return query, key, value + + if self.recompute_up_proj: + quantization = self.config.fp8 or self.config.fp4 + self.qkv_up_checkpoint = tensor_parallel.CheckpointWithoutOutput(fp8=quantization) + query, key, value = self.qkv_up_checkpoint.checkpoint(qkv_up_proj_and_rope_apply, q_compressed, + kv_compressed, rotary_pos_emb) + else: + query, key, value = qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, rotary_pos_emb) + + return query, key, value, q_compressed, kv_compressed + + def forward( + self, + hidden_states, + attention_mask, + key_value_states=None, + inference_context=None, + rotary_pos_emb=None, + rotary_pos_cos=None, + rotary_pos_sin=None, + rotary_pos_cos_sin=None, + attention_bias=None, + packed_seq_params=None, + position_ids=None, + sequence_len_offset=None, + *, + inference_params=None, + ): + """Forward pass for DeepSeek-v4 Hybrid Attention""" + rotary_pos_emb = rotary_pos_emb[self.rope_layer_type] + assert (attention_bias is None), 'Attention bias should not be passed into DSv4HybridAttention.' + assert (rotary_pos_cos is None + and rotary_pos_sin is None), 'DSv4HybridAttention does not support Flash Decoding' + assert (not rotary_pos_cos_sin), 'Flash-infer rope has not been tested with DSv4HybridAttention.' + assert (inference_context is None + and inference_params is None), 'Inference is not supported for DSv4HybridAttention.' + + # ===================== + # Query, Key, and Value + # ===================== + # Get the query, key and value tensors based on the type of attention - + # self or cross attn. + query, key, value, q_compressed, kv_compressed = self.get_query_key_value_tensors( + hidden_states, + key_value_states, + position_ids, + packed_seq_params, + rotary_pos_emb=rotary_pos_emb, + inference_context=inference_context, + ) + + # TODO: Currently, TE can only accept contiguous tensors for MLA + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # ================================== + # core attention computation + # ================================== + # Need corresponding TE change + core_attn_manager = off_interface(self.offload_core_attention and self.training, query, 'core_attn') + with core_attn_manager as query: + core_attn_out = self.core_attention( + query, + key, + value, + attention_mask, + packed_seq_params=packed_seq_params, + x=hidden_states, + qr=q_compressed, + ) + core_attn_out = core_attn_manager.group_offload(core_attn_out, forced_released_tensors=[query, key, value]) + + if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': + # reshape to same output shape as unpacked case + # (t, np, hn) -> (t, b=1, h=np*hn) + # t is the pack size = sum (sq_i) + # note that batch is a dummy dimension in the packed case + core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) + + if self.recompute_up_proj: + assert self.qkv_up_checkpoint is not None + self.qkv_up_checkpoint.discard_output_and_register_recompute(core_attn_out) + self.qkv_up_checkpoint = None + + # inverse RoPE on last qk_pos_emb_head_dim of each head + seq_len = core_attn_out.size(0) + n_heads = self.num_attention_heads_per_partition + pos_dim = self.config.qk_pos_emb_head_dim + core_attn_out = core_attn_out.view(seq_len, core_attn_out.size(1), n_heads, -1) + packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' + if packed_seq: + cu_seqlens_kv = ( + packed_seq_params.cu_seqlens_kv_padded + if packed_seq_params.cu_seqlens_kv_padded is not None else packed_seq_params.cu_seqlens_kv) + else: + cu_seqlens_kv = None + + content_part, rot_part = torch.split(core_attn_out, [core_attn_out.size(-1) - pos_dim, pos_dim], dim=-1) + rot_part = apply_rotary_pos_emb( + rot_part, + rotary_pos_emb, + self.config, + cu_seqlens=cu_seqlens_kv, + cp_group=self.pg_collection.cp, + mla_rotary_interleaved=True, + inverse=True, + mla_output_remove_interleaving=True, + ) + core_attn_out = torch.cat([content_part, rot_part], dim=-1) + core_attn_out = core_attn_out.view(seq_len, core_attn_out.size(1), -1) + + # Grouped output + core_attn_out = core_attn_out.view(core_attn_out.size(0), core_attn_out.size(1), self.o_local_groups, -1) + wo_a_weight = self.linear_o_group_proj.view(self.o_local_groups, self.config.o_lora_rank, -1) + core_attn_out = torch.einsum('...gd,grd->...gr', core_attn_out, wo_a_weight) + core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-2], -1) + + # ================= + # Output. [sq, b, h] + # ================= + attn_proj_manager = off_interface(self.offload_attn_proj, core_attn_out, 'attn_proj') + with attn_proj_manager as core_attn_out: + output, bias = self.linear_proj(core_attn_out) + output = attn_proj_manager.group_offload(output, forced_released_tensors=[core_attn_out]) + + return output, bias + + +class DeepseekV4GPTModel(GPTModel): + + def _init_mla_softmax_scale(self, config): + pass + + def _get_rotary_pos_emb(self, decoder_input, position_ids, packed_seq_params, inference_context=None): + rotary_seq_len = RotaryEmbedding.get_rotary_seq_len(self, inference_context, self.decoder, decoder_input, + self.config, packed_seq_params) + packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq) + compress_rotary_pos_emb = self.compress_rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq) + rotary_pos_emb = {'main': rotary_pos_emb, 'compress': compress_rotary_pos_emb} + return rotary_pos_emb, None, None + + def _set_inv_freq(self): + rope_scaling = self.config.rope_scaling + self.config.rope_scaling = rope_scaling['main'] + new_inv_freq, attention_scaling = get_rope_inv_freq(self.config) + self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) + self.config.attention_scaling = attention_scaling + # compress + self.compress_rotary_pos_emb = copy.copy(self.rotary_pos_emb) + self.config.rope_scaling = rope_scaling['compress'] + new_inv_freq, attention_scaling = get_rope_inv_freq(self.config) + self.compress_rotary_pos_emb.inv_freq = new_inv_freq + self.config.compress_attention_scaling = attention_scaling + + self.config.rope_scaling = rope_scaling + + +class DeepseekV4Loader(ModelLoader): + model_cls = DeepseekV4GPTModel + transformer_block = McoreTransformerBlock + + def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): + from megatron.core.models.gpt.experimental_attention_variant_module_specs import \ + get_transformer_block_with_experimental_attention_variant_spec + transformer_layer_spec = get_transformer_block_with_experimental_attention_variant_spec(self.config, vp_stage) + for layer_spec in transformer_layer_spec.layer_specs: + layer_spec.submodules.self_attention.module = DSv4HybridSelfAttention + return transformer_layer_spec + + +class DeepseekV4Bridge(GPTBridge): + hf_mtp_prefix = 'model.mtp' + hf_embed_key = 'model.embed.weight' + hf_attn_prefix = 'attn' + hf_mlp_prefix = 'ffn' + hf_lm_head_key = 'model.head.weight' + hf_score_key = 'model.score.weight' + hf_input_layernorm_key = 'attn_norm.weight' + hf_post_attention_layernorm_key = 'ffn_norm.weight' + hf_expert_bias_key = 'gate.bias' + + def _convert_hf_state_dict(self, hf_state_dict, to_mcore): + res = super()._convert_hf_state_dict(hf_state_dict, to_mcore) + if to_mcore: + res = self._add_prefix(res, 'model.') + elif not to_mcore: + res = self._remove_prefix(res, 'model.') + return res + + def _set_moe_state( + self, + mg_mlp, + hf_state_dict, + hf_prefix: str, + layer_idx: int, + to_mcore: bool, + is_mtp: bool = False, + ): + if to_mcore: + hf_state_dict = { + k.replace('.w1.', '.gate_proj.').replace('.w3.', '.up_proj.').replace('.w2.', '.down_proj.'): v + for k, v in hf_state_dict.items() + } + hf_state_dict = super()._set_moe_state(mg_mlp, hf_state_dict, hf_prefix, layer_idx, to_mcore, is_mtp) + if not to_mcore: + hf_state_dict = { + k.replace('.gate_proj.', '.w1.').replace('.up_proj.', '.w3.').replace('.down_proj.', '.w2.'): v + for k, v in hf_state_dict.items() + } + return hf_state_dict + + def _set_mla_attn_state( + self, + mg_attn, + hf_state_dict, + hf_prefix: str, + layer_idx: int, + to_mcore: bool, + ): + if to_mcore: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + else: + hf_state_dict = {} + self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, 'wo_b.weight', to_mcore) + self._set_state_dict(mg_attn, 'linear_o_group_proj', hf_state_dict, 'wo_a.weight', to_mcore) + self._set_state_dict(mg_attn, 'linear_q_down_proj.weight', hf_state_dict, 'wq_a.weight', to_mcore) + self._set_state_dict(mg_attn, 'linear_q_up_proj.weight', hf_state_dict, 'wq_b.weight', to_mcore) + self._set_state_dict(mg_attn, 'linear_kv_proj.weight', hf_state_dict, 'wkv.weight', to_mcore) + self._set_state_dict(mg_attn, 'core_attention.attn_sink', hf_state_dict, 'attn_sink', to_mcore) + if self.config.qk_layernorm: + self._set_state_dict(mg_attn, 'q_layernorm.weight', hf_state_dict, 'q_norm.weight', to_mcore) + self._set_state_dict(mg_attn, 'kv_layernorm.weight', hf_state_dict, 'kv_norm.weight', to_mcore) + has_compressor = False if mg_attn is None else mg_attn.core_attention.compressor is not None + has_indexer = False if mg_attn is None else mg_attn.core_attention.indexer is not None + has_compressor = self._reduce_tensor_pp_group(has_compressor, to_mcore) + has_indexer = self._reduce_tensor_pp_group(has_indexer, to_mcore) + if has_compressor: + for mg_key, hf_key in zip(['ape', 'linear_wkv.weight', 'linear_wgate.weight', 'norm.weight'], + ['ape', 'wkv.weight', 'wgate.weight', 'norm.weight']): + self._set_state_dict(mg_attn, f'core_attention.compressor.{mg_key}', hf_state_dict, + f'compressor.{hf_key}', to_mcore) + if has_indexer: + for mg_key, hf_key in zip(['linear_wq_b.weight', 'linear_weights_proj.weight'], + ['wq_b.weight', 'weights_proj.weight']): + self._set_state_dict(mg_attn, f'core_attention.indexer.{mg_key}', hf_state_dict, f'indexer.{hf_key}', + to_mcore) + for mg_key, hf_key in zip(['ape', 'linear_wkv.weight', 'linear_wgate.weight', 'norm.weight'], + ['ape', 'wkv.weight', 'wgate.weight', 'norm.weight']): + self._set_state_dict(mg_attn, f'core_attention.indexer.compressor.{mg_key}', hf_state_dict, + f'indexer.compressor.{hf_key}', to_mcore) + + if to_mcore: + hf_state_dict = {} + else: + hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) + return hf_state_dict + + def _set_final_layernorm(self, lm_model, hf_state_dict, to_mcore): + super()._set_final_layernorm(lm_model, hf_state_dict, to_mcore) + for key in ['hc_head_base', 'hc_head_fn', 'hc_head_scale']: + self._set_state_dict(lm_model, f'decoder.{key}', hf_state_dict, f'model.{key}', to_mcore) + + def _set_router(self, mg_mlp, hf_state_dict, to_mcore, **kwargs): + is_hash_layer = False if mg_mlp is None else mg_mlp.router.is_hash_layer + is_hash_layer = self._reduce_tensor_pp_group(is_hash_layer, to_mcore) + if is_hash_layer: + self._set_state_dict(mg_mlp, 'router.tid2eid', hf_state_dict, 'gate.tid2eid', to_mcore) + kwargs['moe_router_enable_expert_bias'] = False + super()._set_router(mg_mlp, hf_state_dict, to_mcore, **kwargs) + + +register_model( + ModelMeta( + ModelType.deepseek_v4, + ['deepseek_v4'], + bridge_cls=DeepseekV4Bridge, + loader=DeepseekV4Loader, + )) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 23caa9b..504343c 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -448,7 +448,7 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): if use_alternative_attention else text_config.num_key_value_heads) return super()._set_qkv(mg_attn, hf_state_dict, to_mcore, **kwargs) - def _set_router(self, mg_mlp, hf_state_dict, to_mcore): + def _set_router(self, mg_mlp, hf_state_dict, to_mcore, **kwargs): self._set_state_dict(mg_mlp, 'router.weight', hf_state_dict, 'router.proj.weight', to_mcore) for key in ['per_expert_scale', 'scale']: self._set_state_dict(mg_mlp, key, hf_state_dict, f'router.{key}', to_mcore) diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index 75a3a46..342e4f1 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -227,8 +227,11 @@ def _apply_rotary_pos_emb_bshd( t: torch.Tensor, freqs: torch.Tensor, rotary_interleaved: bool = False, + mla_rotary_interleaved: Optional[bool] = None, multi_latent_attention: Optional[bool] = None, mscale: float = 1.0, + inverse: bool = False, + mla_output_remove_interleaving: bool = False, **kwargs, ) -> torch.Tensor: """Apply rotary positional embedding to input tensor T. @@ -249,7 +252,9 @@ def _apply_rotary_pos_emb_bshd( t, t_pass = t[..., :rot_dim], t[..., rot_dim:] if multi_latent_attention is None: multi_latent_attention = self.config.multi_latent_attention - if multi_latent_attention: + if mla_rotary_interleaved is None: + mla_rotary_interleaved = multi_latent_attention + if mla_rotary_interleaved: x1 = t[..., 0::2] x2 = t[..., 1::2] t = torch.cat((x1, x2), dim=-1) @@ -258,8 +263,17 @@ def _apply_rotary_pos_emb_bshd( # second part is sine component, need to change signs with _rotate_half method cos_ = (torch.cos(freqs) * mscale).to(t.dtype) sin_ = (torch.sin(freqs) * mscale).to(t.dtype) + if inverse: + sin_ = -sin_ t = (t * cos_) + (rope_utils._rotate_half(t, rotary_interleaved) * sin_) + # Fallback to original permutation + # DSv4 applies rope on V and O, so we need to uninterleave the tensor. + # The existing MLA code is safe because the dot product is permutation-invariant. + if mla_rotary_interleaved and mla_output_remove_interleaving: + x1, x2 = torch.chunk(t, 2, dim=-1) + t = torch.stack((x1, x2), dim=-1).flatten(start_dim=-2) + return torch.cat((t, t_pass), dim=-1) rope_utils._apply_rotary_pos_emb_bshd = _apply_rotary_pos_emb_bshd diff --git a/src/mcore_bridge/model/register.py b/src/mcore_bridge/model/register.py index c6816a0..3ac468b 100644 --- a/src/mcore_bridge/model/register.py +++ b/src/mcore_bridge/model/register.py @@ -132,6 +132,8 @@ def _set_shared_expert_gate(self, transformer_layer_spec): layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} def _set_transformer_layer(self, transformer_layer_spec): + if self.config.enable_hyper_connections: + return for layer_spec in transformer_layer_spec.layer_specs: if layer_spec.module is McoreTransformerLayer: layer_spec.module = TransformerLayer diff --git a/src/mcore_bridge/model/rope.py b/src/mcore_bridge/model/rope.py index a514f43..0db60a0 100644 --- a/src/mcore_bridge/model/rope.py +++ b/src/mcore_bridge/model/rope.py @@ -25,11 +25,15 @@ def __init__(self, **kwargs): def _get_dummy_config(config): + if config.multi_latent_attention and config.partial_rotary_factor is None: + head_dim = config.qk_pos_emb_head_dim + else: + head_dim = config.kv_channels dummy_config = DummyConfig( rope_scaling=config.rope_scaling, rope_theta=config.rotary_base, max_position_embeddings=config.max_position_embeddings, - head_dim=config.qk_pos_emb_head_dim if config.multi_latent_attention else config.kv_channels, + head_dim=head_dim, hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, )