diff --git a/README.md b/README.md
index 4633ce4..0aeb2e7 100644
--- a/README.md
+++ b/README.md
@@ -127,7 +127,7 @@ The following is the list of models supported by MCore-Bridge:
| Series | model_type |
| -------- | ------------------------------------------------------------ |
| Qwen | qwen2, qwen2_moe
qwen3, qwen3_moe, qwen3_next |
-| DeepSeek | deepseek_v3, deepseek_v32 |
+| DeepSeek | deepseek_v3, deepseek_v32, deepseek_v4 |
| GLM | glm4, glm4_moe, glm4_moe_lite
glm_moe_dsa |
| MiniMax | minimax_m2 |
| Kimi | kimi_k2, kimi_k25 |
diff --git a/README_zh.md b/README_zh.md
index 9542c31..1254356 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -123,7 +123,7 @@ uv pip install -e . --torch-backend=auto
| 系列 | model_type |
| -------- | ------------------------------------------------------------ |
| Qwen | qwen2, qwen2_moe
qwen3, qwen3_moe, qwen3_next |
-| DeepSeek | deepseek_v3, deepseek_v32 |
+| DeepSeek | deepseek_v3, deepseek_v32, deepseek_v4 |
| GLM | glm4, glm4_moe, glm4_moe_lite
glm_moe_dsa |
| MiniMax | minimax_m2 |
| Kimi | kimi_k2, kimi_k25 |
diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py
index 99092fd..b7f17b7 100644
--- a/src/mcore_bridge/bridge/gpt_bridge.py
+++ b/src/mcore_bridge/bridge/gpt_bridge.py
@@ -357,18 +357,18 @@ def _broadcast_ep_pp(self, tensor, is_expert):
dist.all_reduce(src_rank, group=pp_group)
src_rank = dist.get_global_rank(pp_group, src_rank.item())
meta_data = torch.zeros(10, dtype=torch.int64, device='cuda')
- dtype_mapping = {torch.float64: 0, torch.float32: 1, torch.float16: 2, torch.bfloat16: 3, torch.uint8: 4}
- dtype_mapping_r = {v: k for k, v in dtype_mapping.items()}
+ dtype_mapping = [torch.float64, torch.float32, torch.float16, torch.bfloat16, torch.uint8, torch.int32]
+ dtype_mapping_r = {v: k for k, v in enumerate(dtype_mapping)}
if tensor is None:
dist.broadcast(meta_data, src=src_rank, group=pp_group)
shape = meta_data[1:1 + meta_data[0]].tolist()
- dtype = dtype_mapping_r[meta_data[-1].item()]
+ dtype = dtype_mapping[meta_data[-1].item()]
tensor = torch.empty(shape, device='cuda', dtype=dtype)
dist.broadcast(tensor, src=src_rank, group=pp_group)
else:
meta_data[0] = tensor.ndim
meta_data[1:1 + tensor.ndim] = torch.tensor(tensor.shape, dtype=torch.int64, device='cuda')
- meta_data[-1] = dtype_mapping[tensor.dtype]
+ meta_data[-1] = dtype_mapping_r[tensor.dtype]
dist.broadcast(meta_data, src=src_rank, group=pp_group)
dist.broadcast(tensor, src=src_rank, group=pp_group)
return tensor
@@ -529,7 +529,7 @@ def _filter_prefix(state_dict, prefix: str):
return state_dict
return {k: v for k, v in state_dict.items() if k.startswith(prefix)}
- def _reduce_tensor_pp_group(self, tensor, to_mcore, dtype=torch.long, op=dist.ReduceOp.MAX):
+ def _reduce_tensor_pp_group(self, tensor, to_mcore, dtype=torch.bool, op=dist.ReduceOp.MAX):
if to_mcore:
return tensor
tensor = torch.tensor([tensor], dtype=dtype, device='cuda')
@@ -678,14 +678,17 @@ def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore, **kwargs):
self._set_state_dict(mg_attn, 'q_layernorm.weight', hf_state_dict, self.hf_q_norm_key, to_mcore)
self._set_state_dict(mg_attn, 'k_layernorm.weight', hf_state_dict, self.hf_k_norm_key, to_mcore)
- def _set_router(self, mg_mlp, hf_state_dict, to_mcore):
+ def _set_router(self, mg_mlp, hf_state_dict, to_mcore, **kwargs):
+ moe_router_enable_expert_bias = kwargs.get('moe_router_enable_expert_bias')
+ if moe_router_enable_expert_bias is None:
+ moe_router_enable_expert_bias = self.config.moe_router_enable_expert_bias
hf_gate_key = self.hf_gate_key
if self.llm_model_type == 'gpt_oss':
hf_gate_key = 'router.weight'
self._set_state_dict(mg_mlp, 'router.weight', hf_state_dict, hf_gate_key, to_mcore)
if self.config.add_bias_linear:
self._set_state_dict(mg_mlp, 'router.bias', hf_state_dict, hf_gate_key.replace('weight', 'bias'), to_mcore)
- if self.config.moe_router_enable_expert_bias:
+ if moe_router_enable_expert_bias:
self._set_state_dict(mg_mlp, 'router.expert_bias', hf_state_dict, self.hf_expert_bias_key, to_mcore)
def _set_moe_state(
@@ -746,7 +749,7 @@ def _get_hf_experts_attr(self, is_mtp: bool = False):
return True, True
if self.model_type in {'glm4v_moe', 'kimi_vl', 'qwen3_omni_moe', 'qwen3_5_moe'} or self.llm_model_type in {
'qwen2_moe', 'qwen3_moe', 'deepseek_v2', 'deepseek_v3', 'kimi_k2', 'dots1', 'ernie4_5_moe', 'glm4_moe',
- 'glm4_moe_lite', 'minimax_m2', 'olmoe', 'qwen3_next', 'glm_moe_dsa', 'deepseek_v32'
+ 'glm4_moe_lite', 'minimax_m2', 'olmoe', 'qwen3_next', 'glm_moe_dsa', 'deepseek_v32', 'deepseek_v4'
}:
return False, False
elif self.model_type in {'qwen3_vl_moe', 'llama4', 'gemma4'} or self.llm_model_type in {'gpt_oss'}:
@@ -1619,6 +1622,28 @@ def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool
self.hf_post_attention_layernorm_key, to_mcore)
return hf_state_dict
+ def _set_hyper_connection(self, mg_layer, hf_state_dict, layer_idx, to_mcore):
+
+ for key, hf_key in zip(['self_attention_hyper_connection', 'mlp_hyper_connection'], ['attn', 'ffn']):
+ hyper_connection = None if mg_layer is None else getattr(mg_layer, key)
+ self._set_state_dict(hyper_connection, 'mapping_proj.weight', hf_state_dict, f'hc_{hf_key}_fn', to_mcore)
+ self._set_state_dict(hyper_connection, 'bias', hf_state_dict, f'hc_{hf_key}_base', to_mcore)
+ has_hyper_connection = hyper_connection is not None
+ has_hyper_connection = self._reduce_tensor_pp_group(has_hyper_connection, to_mcore)
+ if has_hyper_connection:
+ if to_mcore:
+ alpha = hf_state_dict[f'hc_{hf_key}_scale'].load()
+ for i, alpha_suffix in enumerate(['pre', 'post', 'res']):
+ getattr(hyper_connection, f'alpha_{alpha_suffix}').data[:] = alpha[i]
+ else:
+ alpha = None
+ if hyper_connection is not None:
+ alpha = []
+ for i, alpha_suffix in enumerate(['pre', 'post', 'res']):
+ alpha.append(getattr(hyper_connection, f'alpha_{alpha_suffix}', None))
+ alpha = torch.concat(alpha, dim=0)
+ hf_state_dict[f'hc_{hf_key}_scale'] = self._get_weight(alpha, 'alpha')[0]
+
def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool):
hf_prefix = f'{hf_prefix}{layer_idx}.'
if to_mcore:
@@ -1627,6 +1652,9 @@ def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: i
hf_state_dict = {}
hf_state_dict.update(self._set_layer_attn(mg_layer, hf_state_dict, layer_idx, to_mcore))
hf_state_dict.update(self._set_layer_mlp(mg_layer, hf_state_dict, layer_idx, to_mcore))
+ if self.config.enable_hyper_connections:
+ self._set_hyper_connection(mg_layer, hf_state_dict, layer_idx, to_mcore)
+
if to_mcore:
hf_state_dict = {}
else:
@@ -1678,14 +1706,18 @@ def _convert_post_process(self, mg_model, hf_state_dict, hf_prefix: str, to_mcor
self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, hf_lm_head_key, to_mcore)
elif to_mcore and lm_model.output_layer.weight is not None:
self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, self.hf_embed_key, to_mcore)
- self._set_state_dict(lm_model, 'decoder.final_layernorm.weight', hf_state_dict, self.hf_final_layernorm_key,
- to_mcore)
+ self._set_final_layernorm(lm_model, hf_state_dict, to_mcore)
+
if to_mcore:
hf_state_dict = {}
else:
hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix)
return hf_state_dict
+ def _set_final_layernorm(self, lm_model, hf_state_dict, to_mcore):
+ self._set_state_dict(lm_model, 'decoder.final_layernorm.weight', hf_state_dict, self.hf_final_layernorm_key,
+ to_mcore)
+
def _convert_hf_state_dict(self, hf_state_dict, to_mcore):
res = {}
for k, v in hf_state_dict.items():
@@ -1713,6 +1745,7 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd
if to_mcore:
yield
else:
+ hf_state_dict = self._convert_hf_state_dict(hf_state_dict, to_mcore)
yield from list(self._add_prefix(hf_state_dict, hf_prefix).items())
hf_state_dict = {}
layer_idx = 0
@@ -1746,6 +1779,7 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd
if to_mcore:
yield
else:
+ res = self._convert_hf_state_dict(res, to_mcore)
yield from list(self._add_prefix(res, hf_prefix).items())
hf_state_dict = {}
@@ -1761,6 +1795,7 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd
if to_mcore:
yield
else:
+ res = self._convert_hf_state_dict(res, to_mcore)
yield from list(self._add_prefix(res, hf_prefix).items())
hf_state_dict = {}
if not to_mcore or is_pp_last_stage:
diff --git a/src/mcore_bridge/config/model_config.py b/src/mcore_bridge/config/model_config.py
index 09703e0..46faa08 100644
--- a/src/mcore_bridge/config/model_config.py
+++ b/src/mcore_bridge/config/model_config.py
@@ -197,7 +197,7 @@ class ModelConfig(TransformerConfig):
linear_decoupled_in_proj: bool = False
# dsa
- experimental_attention_variant: Optional[Literal['gated_delta_net', 'dsa']] = None
+ experimental_attention_variant: Optional[Literal['gated_delta_net', 'dsa', 'dsv4_hybrid']] = None
dsa_indexer_n_heads: Optional[int] = None
dsa_indexer_head_dim: Optional[int] = None
dsa_indexer_topk: Optional[int] = None
@@ -205,6 +205,18 @@ class ModelConfig(TransformerConfig):
dsa_indexer_use_sparse_loss: bool = False
dsa_indexer_rotary_interleaved: bool = False
+ # deepseek-v4
+ csa_window_size: int = 128
+ csa_compress_ratios: Optional[List[int]] = None
+ csa_compress_rotary_base: float = 40000.0
+ o_groups: int = 8
+ o_lora_rank: int = 1024
+ enable_hyper_connections: bool = False
+ num_residual_streams: int = 4
+ mhc_sinkhorn_iterations: int = 20
+ mhc_init_gating_factor: float = 0.01
+ moe_n_hash_layers: int = 0
+
# mtp
mtp_decoder_input_detach: bool = False
mtp_shared_weights: bool = False
@@ -290,6 +302,7 @@ def __post_init__(self):
if self.add_bias_linear:
self.add_qkv_bias = True
+ self.actual_vocab_size = self.padded_vocab_size
self.batch_p2p_comm = not self.overlap_p2p_comm
if self.swiglu:
self.activation_func = F.silu
@@ -315,6 +328,8 @@ def __post_init__(self):
self.mtp_num_layers = 1
else:
self.mtp_unroll_steps = self.mtp_num_layers
+ if self.csa_compress_ratios is not None and self.mtp_num_layers is not None:
+ self.csa_compress_ratios += [0] * self.mtp_num_layers
super().__post_init__()
self._check_npu()
diff --git a/src/mcore_bridge/config/parser.py b/src/mcore_bridge/config/parser.py
index cabea5a..c09734f 100644
--- a/src/mcore_bridge/config/parser.py
+++ b/src/mcore_bridge/config/parser.py
@@ -56,6 +56,15 @@
'dsa_indexer_head_dim': ['index_head_dim'],
'dsa_indexer_topk': ['index_topk'],
'dsa_indexer_rotary_interleaved': ['indexer_rope_interleave'],
+ # deepseek_v4
+ 'csa_compress_ratios': ['compress_rates'],
+ 'csa_compress_rotary_base': ['compress_rope_theta'],
+ 'o_groups': ['o_groups'],
+ 'o_lora_rank': ['o_lora_rank'],
+ 'num_residual_streams': ['hc_mult'],
+ 'mhc_sinkhorn_iterations': ['hc_sinkhorn_iters'],
+ 'moe_n_hash_layers': ['mlp_layer_types'],
+ 'activation_func_clamp_value': ['swiglu_limit'],
# other
'original_max_position_embeddings': ['original_max_position_embeddings'],
'partial_rotary_factor': ['partial_rotary_factor'],
@@ -88,7 +97,7 @@ def _convert_config(config, _internal_call=False) -> Dict[str, Any]:
else:
continue
else:
- if k == 'kv_lora_rank':
+ if k in {'q_lora_rank', 'kv_lora_rank'}:
megatron_config['multi_latent_attention'] = True
elif k == 'hf_model_type':
if _internal_call:
@@ -134,7 +143,7 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]:
res.pop('ffn_hidden_size', None)
if llm_model_type in {'qwen2_moe', 'qwen3_next'} or hf_model_type == 'qwen3_5_moe':
res['moe_shared_expert_gate'] = True
- if llm_model_type in {'deepseek', 'deepseek_v2', 'deepseek_v3', 'kimi_k2', 'deepseek_v32', 'dots1'
+ if llm_model_type in {'deepseek', 'deepseek_v2', 'deepseek_v3', 'kimi_k2', 'deepseek_v32', 'dots1', 'deepseek_v4'
} or hf_model_type == 'kimi_vl':
if llm_model_type != 'deepseek':
res['qk_layernorm'] = True
@@ -143,6 +152,17 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]:
res['moe_router_score_function'] = 'sigmoid'
elif llm_model_type == 'deepseek_v32':
res['experimental_attention_variant'] = 'dsa'
+ elif llm_model_type == 'deepseek_v4':
+ if 'v_head_dim' not in res:
+ res['v_head_dim'] = res['kv_channels']
+ res['experimental_attention_variant'] = 'dsv4_hybrid'
+ res['moe_router_enable_expert_bias'] = True
+ res['csa_window_size'] = window_size
+ res['enable_hyper_connections'] = True
+ csa_compress_ratios = res.pop('csa_compress_ratios', None)
+ res['csa_compress_ratios'] = [csa_compress_ratios.get(layer_type, 0) for layer_type in layer_types]
+ moe_n_hash_layers = res.pop('moe_n_hash_layers', None)
+ res['moe_n_hash_layers'] = len([layer for layer in moe_n_hash_layers if layer == 'hash_moe'])
elif llm_model_type == 'hunyuan':
# Since HunYuan’s attention applies RoPE before using q/k_layernorm,
# which is incompatible with megatron-core, support is not provided here.
diff --git a/src/mcore_bridge/model/constant.py b/src/mcore_bridge/model/constant.py
index 9b8dc1b..9708f6a 100644
--- a/src/mcore_bridge/model/constant.py
+++ b/src/mcore_bridge/model/constant.py
@@ -9,6 +9,7 @@ class LLMModelType:
minimax_m2 = 'minimax_m2'
hy_v3 = 'hy_v3'
bailing_moe = 'bailing_moe'
+ deepseek_v4 = 'deepseek_v4'
qwen3_emb = 'qwen3_emb'
diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py
index 35d89f8..617c622 100644
--- a/src/mcore_bridge/model/gpt_model.py
+++ b/src/mcore_bridge/model/gpt_model.py
@@ -66,19 +66,14 @@ def __init__(
):
vocab_size = math.ceil(
config.padded_vocab_size / config.tensor_model_parallel_size) * config.tensor_model_parallel_size
- hf_rope_scaling = config.rope_scaling
+ self.hf_rope_scaling = config.rope_scaling
if config.multi_latent_attention:
config.rope_type = 'rope' # use transformers implementation
# Set default value, the following content will not be used. (dummy)
config.mscale_all_dim = 0.
config.cache_mla_latents = False
config.rotary_scaling_factor = 40
- if hf_rope_scaling and hf_rope_scaling['rope_type'] == 'yarn':
- # softmax_scale
- config.mscale = hf_rope_scaling['mscale']
- config.mscale_all_dim = hf_rope_scaling['mscale_all_dim']
- config.rotary_scaling_factor = hf_rope_scaling['factor']
- self.hf_rope_scaling = hf_rope_scaling
+ self._init_mla_softmax_scale(config)
super().__init__(
config,
transformer_layer_spec,
@@ -136,6 +131,13 @@ def _value(self):
attention.config = copy.copy(attention.config)
attention.config.apply_rope_fusion = False
+ def _init_mla_softmax_scale(self, config):
+ if self.hf_rope_scaling and self.hf_rope_scaling['rope_type'] == 'yarn':
+ # softmax_scale
+ config.mscale = self.hf_rope_scaling['mscale']
+ config.mscale_all_dim = self.hf_rope_scaling['mscale_all_dim']
+ config.rotary_scaling_factor = self.hf_rope_scaling['factor']
+
def _preprocess(
self,
input_ids: torch.Tensor,
@@ -297,6 +299,10 @@ def forward(
assert padding_mask.shape[1] % tp_size == 0, f'padding_mask.shape: {padding_mask.shape}'
padding_mask = torch.chunk(padding_mask, tp_size, dim=1)[mpu.get_tensor_model_parallel_rank()]
extra_block_kwargs['padding_mask'] = padding_mask.contiguous()
+
+ if self.config.moe_n_hash_layers > 0:
+ extra_block_kwargs['input_ids'] = input_ids
+
# Run decoder.
hidden_states = self.decoder(
hidden_states=decoder_input,
diff --git a/src/mcore_bridge/model/gpts/__init__.py b/src/mcore_bridge/model/gpts/__init__.py
index 52b007f..6eb44db 100644
--- a/src/mcore_bridge/model/gpts/__init__.py
+++ b/src/mcore_bridge/model/gpts/__init__.py
@@ -1,2 +1,2 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
-from . import bailing_moe, glm4, hunyuan, llm, minimax_m2, olmoe, qwen3_emb, qwen3_next
+from . import bailing_moe, deepseek_v4, glm4, hunyuan, llm, minimax_m2, olmoe, qwen3_emb, qwen3_next
diff --git a/src/mcore_bridge/model/gpts/deepseek_v4.py b/src/mcore_bridge/model/gpts/deepseek_v4.py
new file mode 100644
index 0000000..f0f4b85
--- /dev/null
+++ b/src/mcore_bridge/model/gpts/deepseek_v4.py
@@ -0,0 +1,450 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import copy
+import torch
+from megatron.core import tensor_parallel
+from megatron.core.models.common.embeddings import apply_rotary_pos_emb
+from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
+from megatron.core.transformer.transformer_block import TransformerBlock as McoreTransformerBlock
+from typing import Optional
+
+from mcore_bridge.bridge import GPTBridge
+
+from ..constant import ModelType
+from ..gpt_model import GPTModel
+from ..register import ModelLoader, ModelMeta, register_model
+from ..rope import get_rope_inv_freq
+
+try:
+ from megatron.core.pipeline_parallel.fine_grained_activation_offload import \
+ FineGrainedActivationOffloadingInterface as off_interface
+ from megatron.core.transformer.experimental_attention_variant.deepseek_v4_hybrid_attention import \
+ DSv4HybridSelfAttention as McoreDSv4HybridSelfAttention
+ from megatron.core.transformer.experimental_attention_variant.deepseek_v4_hybrid_attention import _q_rms_norm
+ from megatron.core.typed_torch import apply_module
+except ImportError:
+ McoreDSv4HybridSelfAttention = object
+ _q_rms_norm = None
+ apply_module = None
+ off_interface = None
+
+
+class DSv4HybridSelfAttention(McoreDSv4HybridSelfAttention):
+
+ def __init__(self, config, *args, **kwargs):
+ assert McoreDSv4HybridSelfAttention is not object, (
+ 'Please install the Megatron-Core dev branch: '
+ '`pip install git+https://github.com/NVIDIA/Megatron-LM@dev`')
+ super().__init__(config, *args, **kwargs)
+ self.layer_type = self.config.hf_config.layer_types[self.layer_number - 1]
+ self.rope_layer_type = 'main' if self.layer_type == 'sliding_attention' else 'compress'
+
+ def get_query_key_value_tensors(
+ self,
+ hidden_states,
+ key_value_states=None,
+ position_ids=None,
+ packed_seq_params=None,
+ inference_context=None,
+ rotary_pos_emb=None,
+ *,
+ inference_params=None,
+ ):
+ """
+ Derives `query`, `key` and `value` tensors from `hidden_states`.
+ """
+ # s = sequence length, b = batch size, h = hidden size, n = num attention heads
+ # Attention heads [s, b, n*h]
+ assert (hidden_states.ndim == 3), f"hidden_states should be 3D, [s, b, n*h], got {hidden_states.ndim}D"
+ if packed_seq_params is not None:
+ assert (packed_seq_params.local_cp_size
+ is None), 'dynamic_context_parallel is not supported with MLA yet and is planned for future. \
+ Please disable dynamic_context_parallel.'
+
+ assert (inference_context is None
+ and inference_params is None), 'Inference is not supported for DSv4HybridSelfAttention.'
+
+ if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd':
+ cu_seqlens_q = packed_seq_params.cu_seqlens_q
+ cu_seqlens_kv = packed_seq_params.cu_seqlens_kv
+ else:
+ cu_seqlens_q = cu_seqlens_kv = None
+
+ # =========================================
+ # QKV down projection and layernorm
+ # =========================================
+ # q_compressed: [s, b, q_lora_rank]
+ q_compressed, _ = self.linear_q_down_proj(hidden_states)
+
+ kv_compressed = hidden_states
+
+ if packed_seq_params is not None:
+ # If sequence packing, TE expect [t, h, d] shaped qkv input.
+ # In Megatron-Core, the qkv shape is [t, 1, h, d].
+ # So we need to reshape qkv from [t, 1, h, d] to [t, h, d].
+ q_compressed = q_compressed.squeeze(1)
+
+ # =========================================
+ # Apply norm
+ # =========================================
+
+ if self.config.q_lora_rank is not None:
+ # q_compressed: [num_tokens, q_lora_rank]
+ q_compressed = apply_module(self.q_layernorm)(q_compressed)
+
+ # =========================================
+ # QKV up projection and RoPE apply
+ # =========================================
+
+ def qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, rotary_pos_emb):
+ """
+ Apply the up projection and RoPE to the query and key.
+ When sequence packing enabled, the input tensors adopt a packed shape of [t, ...];
+ otherwise, they maintain the unpacked shape [s, b, ...]. In subsequent code comments,
+ we uniformly use [num_tokens, ...] to denote [s, b, ...] or [t, ...] for two cases.
+ """
+ # q_compressed: [num_tokens, q_lora_rank]
+ # q: [num_tokens, n * (qk_head_dim + qk_pos_emb_head_dim)]
+ q, _ = self.linear_q_up_proj(q_compressed)
+
+ # q: [num_tokens, n, q_head_dim]
+ q = q.view(*q.size()[:-1], self.num_attention_heads_per_partition, self.q_head_dim)
+ q = _q_rms_norm(q, self.config.layernorm_epsilon)
+
+ kv, _ = self.linear_kv_proj(kv_compressed)
+ kv = self.kv_layernorm(kv)
+
+ # [num_tokens, qk_pos_emb_head_dim] -> [num_tokens, 1, qk_pos_emb_head_dim]
+ q_len = q.size()[0]
+ if packed_seq_params is None or self.config.context_parallel_size == 1:
+ # Shorten rotary_pos_emb to the sequence length when inference_params
+ # is not provided. This makes sure we can run forward directly with
+ # any sequence length. During training, the sequence length is always
+ # the full rotary_pos_emb length, except for sequence packing + CP.
+ # When sequence packing and context parallel are both enabled, the
+ # position embedding will not split rotary_pos_emb, so it may exceed
+ # the sequence length on this CP rank, but we need the full rotary_pos_emb
+ # to cover the full sequence, so we do not shorten it here.
+ rotary_pos_emb = rotary_pos_emb[0:q_len]
+
+ # q_no_pe: [num_tokens, n, qk_head_dim]
+ # q_pos_emb: [num_tokens, n, qk_pos_emb_head_dim]
+ pos_dim = self.config.qk_pos_emb_head_dim
+ q_no_pe, q_pos_emb = torch.split(q, [q.shape[-1] - pos_dim, pos_dim], dim=-1)
+
+ # RoPE and query (shared for wkv and latent)
+ # q_pos_emb: [num_tokens, n, qk_pos_emb_head_dim]
+ q_pos_emb = apply_rotary_pos_emb(
+ q_pos_emb,
+ rotary_pos_emb,
+ config=self.config,
+ cu_seqlens=cu_seqlens_q,
+ cp_group=self.pg_collection.cp,
+ mla_rotary_interleaved=True,
+ mla_output_remove_interleaving=True,
+ )
+ # query: [num_tokens, n, (qk_head_dim + v_head_dim)]
+ query = torch.cat([q_no_pe, q_pos_emb], dim=-1)
+
+ kv_no_pe, k_pos_emb = torch.split(kv, [kv.size(-1) - pos_dim, pos_dim], dim=-1)
+
+ # k_pos_emb:[num_tokens, 1, qk_pos_emb_head_dim]
+ k_pos_emb = apply_rotary_pos_emb(
+ k_pos_emb,
+ rotary_pos_emb,
+ config=self.config,
+ cu_seqlens=cu_seqlens_kv,
+ cp_group=self.pg_collection.cp,
+ mla_rotary_interleaved=True,
+ mla_output_remove_interleaving=True,
+ )
+
+ # Single head: key = value = [num_tokens, 1, v_head_dim]
+ kv = torch.cat([kv_no_pe, k_pos_emb], dim=-1).unsqueeze(-2)
+ key = kv
+ value = kv
+
+ query = query.contiguous()
+ key = key.contiguous()
+ value = value.contiguous()
+
+ return query, key, value
+
+ if self.recompute_up_proj:
+ quantization = self.config.fp8 or self.config.fp4
+ self.qkv_up_checkpoint = tensor_parallel.CheckpointWithoutOutput(fp8=quantization)
+ query, key, value = self.qkv_up_checkpoint.checkpoint(qkv_up_proj_and_rope_apply, q_compressed,
+ kv_compressed, rotary_pos_emb)
+ else:
+ query, key, value = qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, rotary_pos_emb)
+
+ return query, key, value, q_compressed, kv_compressed
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ key_value_states=None,
+ inference_context=None,
+ rotary_pos_emb=None,
+ rotary_pos_cos=None,
+ rotary_pos_sin=None,
+ rotary_pos_cos_sin=None,
+ attention_bias=None,
+ packed_seq_params=None,
+ position_ids=None,
+ sequence_len_offset=None,
+ *,
+ inference_params=None,
+ ):
+ """Forward pass for DeepSeek-v4 Hybrid Attention"""
+ rotary_pos_emb = rotary_pos_emb[self.rope_layer_type]
+ assert (attention_bias is None), 'Attention bias should not be passed into DSv4HybridAttention.'
+ assert (rotary_pos_cos is None
+ and rotary_pos_sin is None), 'DSv4HybridAttention does not support Flash Decoding'
+ assert (not rotary_pos_cos_sin), 'Flash-infer rope has not been tested with DSv4HybridAttention.'
+ assert (inference_context is None
+ and inference_params is None), 'Inference is not supported for DSv4HybridAttention.'
+
+ # =====================
+ # Query, Key, and Value
+ # =====================
+ # Get the query, key and value tensors based on the type of attention -
+ # self or cross attn.
+ query, key, value, q_compressed, kv_compressed = self.get_query_key_value_tensors(
+ hidden_states,
+ key_value_states,
+ position_ids,
+ packed_seq_params,
+ rotary_pos_emb=rotary_pos_emb,
+ inference_context=inference_context,
+ )
+
+ # TODO: Currently, TE can only accept contiguous tensors for MLA
+ query = query.contiguous()
+ key = key.contiguous()
+ value = value.contiguous()
+
+ # ==================================
+ # core attention computation
+ # ==================================
+ # Need corresponding TE change
+ core_attn_manager = off_interface(self.offload_core_attention and self.training, query, 'core_attn')
+ with core_attn_manager as query:
+ core_attn_out = self.core_attention(
+ query,
+ key,
+ value,
+ attention_mask,
+ packed_seq_params=packed_seq_params,
+ x=hidden_states,
+ qr=q_compressed,
+ )
+ core_attn_out = core_attn_manager.group_offload(core_attn_out, forced_released_tensors=[query, key, value])
+
+ if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd':
+ # reshape to same output shape as unpacked case
+ # (t, np, hn) -> (t, b=1, h=np*hn)
+ # t is the pack size = sum (sq_i)
+ # note that batch is a dummy dimension in the packed case
+ core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1)
+
+ if self.recompute_up_proj:
+ assert self.qkv_up_checkpoint is not None
+ self.qkv_up_checkpoint.discard_output_and_register_recompute(core_attn_out)
+ self.qkv_up_checkpoint = None
+
+ # inverse RoPE on last qk_pos_emb_head_dim of each head
+ seq_len = core_attn_out.size(0)
+ n_heads = self.num_attention_heads_per_partition
+ pos_dim = self.config.qk_pos_emb_head_dim
+ core_attn_out = core_attn_out.view(seq_len, core_attn_out.size(1), n_heads, -1)
+ packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd'
+ if packed_seq:
+ cu_seqlens_kv = (
+ packed_seq_params.cu_seqlens_kv_padded
+ if packed_seq_params.cu_seqlens_kv_padded is not None else packed_seq_params.cu_seqlens_kv)
+ else:
+ cu_seqlens_kv = None
+
+ content_part, rot_part = torch.split(core_attn_out, [core_attn_out.size(-1) - pos_dim, pos_dim], dim=-1)
+ rot_part = apply_rotary_pos_emb(
+ rot_part,
+ rotary_pos_emb,
+ self.config,
+ cu_seqlens=cu_seqlens_kv,
+ cp_group=self.pg_collection.cp,
+ mla_rotary_interleaved=True,
+ inverse=True,
+ mla_output_remove_interleaving=True,
+ )
+ core_attn_out = torch.cat([content_part, rot_part], dim=-1)
+ core_attn_out = core_attn_out.view(seq_len, core_attn_out.size(1), -1)
+
+ # Grouped output
+ core_attn_out = core_attn_out.view(core_attn_out.size(0), core_attn_out.size(1), self.o_local_groups, -1)
+ wo_a_weight = self.linear_o_group_proj.view(self.o_local_groups, self.config.o_lora_rank, -1)
+ core_attn_out = torch.einsum('...gd,grd->...gr', core_attn_out, wo_a_weight)
+ core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-2], -1)
+
+ # =================
+ # Output. [sq, b, h]
+ # =================
+ attn_proj_manager = off_interface(self.offload_attn_proj, core_attn_out, 'attn_proj')
+ with attn_proj_manager as core_attn_out:
+ output, bias = self.linear_proj(core_attn_out)
+ output = attn_proj_manager.group_offload(output, forced_released_tensors=[core_attn_out])
+
+ return output, bias
+
+
+class DeepseekV4GPTModel(GPTModel):
+
+ def _init_mla_softmax_scale(self, config):
+ pass
+
+ def _get_rotary_pos_emb(self, decoder_input, position_ids, packed_seq_params, inference_context=None):
+ rotary_seq_len = RotaryEmbedding.get_rotary_seq_len(self, inference_context, self.decoder, decoder_input,
+ self.config, packed_seq_params)
+ packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd'
+ rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq)
+ compress_rotary_pos_emb = self.compress_rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq)
+ rotary_pos_emb = {'main': rotary_pos_emb, 'compress': compress_rotary_pos_emb}
+ return rotary_pos_emb, None, None
+
+ def _set_inv_freq(self):
+ rope_scaling = self.config.rope_scaling
+ self.config.rope_scaling = rope_scaling['main']
+ new_inv_freq, attention_scaling = get_rope_inv_freq(self.config)
+ self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device)
+ self.config.attention_scaling = attention_scaling
+ # compress
+ self.compress_rotary_pos_emb = copy.copy(self.rotary_pos_emb)
+ self.config.rope_scaling = rope_scaling['compress']
+ new_inv_freq, attention_scaling = get_rope_inv_freq(self.config)
+ self.compress_rotary_pos_emb.inv_freq = new_inv_freq
+ self.config.compress_attention_scaling = attention_scaling
+
+ self.config.rope_scaling = rope_scaling
+
+
+class DeepseekV4Loader(ModelLoader):
+ model_cls = DeepseekV4GPTModel
+ transformer_block = McoreTransformerBlock
+
+ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None):
+ from megatron.core.models.gpt.experimental_attention_variant_module_specs import \
+ get_transformer_block_with_experimental_attention_variant_spec
+ transformer_layer_spec = get_transformer_block_with_experimental_attention_variant_spec(self.config, vp_stage)
+ for layer_spec in transformer_layer_spec.layer_specs:
+ layer_spec.submodules.self_attention.module = DSv4HybridSelfAttention
+ return transformer_layer_spec
+
+
+class DeepseekV4Bridge(GPTBridge):
+ hf_mtp_prefix = 'model.mtp'
+ hf_embed_key = 'model.embed.weight'
+ hf_attn_prefix = 'attn'
+ hf_mlp_prefix = 'ffn'
+ hf_lm_head_key = 'model.head.weight'
+ hf_score_key = 'model.score.weight'
+ hf_input_layernorm_key = 'attn_norm.weight'
+ hf_post_attention_layernorm_key = 'ffn_norm.weight'
+ hf_expert_bias_key = 'gate.bias'
+
+ def _convert_hf_state_dict(self, hf_state_dict, to_mcore):
+ res = super()._convert_hf_state_dict(hf_state_dict, to_mcore)
+ if to_mcore:
+ res = self._add_prefix(res, 'model.')
+ elif not to_mcore:
+ res = self._remove_prefix(res, 'model.')
+ return res
+
+ def _set_moe_state(
+ self,
+ mg_mlp,
+ hf_state_dict,
+ hf_prefix: str,
+ layer_idx: int,
+ to_mcore: bool,
+ is_mtp: bool = False,
+ ):
+ if to_mcore:
+ hf_state_dict = {
+ k.replace('.w1.', '.gate_proj.').replace('.w3.', '.up_proj.').replace('.w2.', '.down_proj.'): v
+ for k, v in hf_state_dict.items()
+ }
+ hf_state_dict = super()._set_moe_state(mg_mlp, hf_state_dict, hf_prefix, layer_idx, to_mcore, is_mtp)
+ if not to_mcore:
+ hf_state_dict = {
+ k.replace('.gate_proj.', '.w1.').replace('.up_proj.', '.w3.').replace('.down_proj.', '.w2.'): v
+ for k, v in hf_state_dict.items()
+ }
+ return hf_state_dict
+
+ def _set_mla_attn_state(
+ self,
+ mg_attn,
+ hf_state_dict,
+ hf_prefix: str,
+ layer_idx: int,
+ to_mcore: bool,
+ ):
+ if to_mcore:
+ hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
+ else:
+ hf_state_dict = {}
+ self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, 'wo_b.weight', to_mcore)
+ self._set_state_dict(mg_attn, 'linear_o_group_proj', hf_state_dict, 'wo_a.weight', to_mcore)
+ self._set_state_dict(mg_attn, 'linear_q_down_proj.weight', hf_state_dict, 'wq_a.weight', to_mcore)
+ self._set_state_dict(mg_attn, 'linear_q_up_proj.weight', hf_state_dict, 'wq_b.weight', to_mcore)
+ self._set_state_dict(mg_attn, 'linear_kv_proj.weight', hf_state_dict, 'wkv.weight', to_mcore)
+ self._set_state_dict(mg_attn, 'core_attention.attn_sink', hf_state_dict, 'attn_sink', to_mcore)
+ if self.config.qk_layernorm:
+ self._set_state_dict(mg_attn, 'q_layernorm.weight', hf_state_dict, 'q_norm.weight', to_mcore)
+ self._set_state_dict(mg_attn, 'kv_layernorm.weight', hf_state_dict, 'kv_norm.weight', to_mcore)
+ has_compressor = False if mg_attn is None else mg_attn.core_attention.compressor is not None
+ has_indexer = False if mg_attn is None else mg_attn.core_attention.indexer is not None
+ has_compressor = self._reduce_tensor_pp_group(has_compressor, to_mcore)
+ has_indexer = self._reduce_tensor_pp_group(has_indexer, to_mcore)
+ if has_compressor:
+ for mg_key, hf_key in zip(['ape', 'linear_wkv.weight', 'linear_wgate.weight', 'norm.weight'],
+ ['ape', 'wkv.weight', 'wgate.weight', 'norm.weight']):
+ self._set_state_dict(mg_attn, f'core_attention.compressor.{mg_key}', hf_state_dict,
+ f'compressor.{hf_key}', to_mcore)
+ if has_indexer:
+ for mg_key, hf_key in zip(['linear_wq_b.weight', 'linear_weights_proj.weight'],
+ ['wq_b.weight', 'weights_proj.weight']):
+ self._set_state_dict(mg_attn, f'core_attention.indexer.{mg_key}', hf_state_dict, f'indexer.{hf_key}',
+ to_mcore)
+ for mg_key, hf_key in zip(['ape', 'linear_wkv.weight', 'linear_wgate.weight', 'norm.weight'],
+ ['ape', 'wkv.weight', 'wgate.weight', 'norm.weight']):
+ self._set_state_dict(mg_attn, f'core_attention.indexer.compressor.{mg_key}', hf_state_dict,
+ f'indexer.compressor.{hf_key}', to_mcore)
+
+ if to_mcore:
+ hf_state_dict = {}
+ else:
+ hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix)
+ return hf_state_dict
+
+ def _set_final_layernorm(self, lm_model, hf_state_dict, to_mcore):
+ super()._set_final_layernorm(lm_model, hf_state_dict, to_mcore)
+ for key in ['hc_head_base', 'hc_head_fn', 'hc_head_scale']:
+ self._set_state_dict(lm_model, f'decoder.{key}', hf_state_dict, f'model.{key}', to_mcore)
+
+ def _set_router(self, mg_mlp, hf_state_dict, to_mcore, **kwargs):
+ is_hash_layer = False if mg_mlp is None else mg_mlp.router.is_hash_layer
+ is_hash_layer = self._reduce_tensor_pp_group(is_hash_layer, to_mcore)
+ if is_hash_layer:
+ self._set_state_dict(mg_mlp, 'router.tid2eid', hf_state_dict, 'gate.tid2eid', to_mcore)
+ kwargs['moe_router_enable_expert_bias'] = False
+ super()._set_router(mg_mlp, hf_state_dict, to_mcore, **kwargs)
+
+
+register_model(
+ ModelMeta(
+ ModelType.deepseek_v4,
+ ['deepseek_v4'],
+ bridge_cls=DeepseekV4Bridge,
+ loader=DeepseekV4Loader,
+ ))
diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py
index 23caa9b..504343c 100644
--- a/src/mcore_bridge/model/mm_gpts/gemma4.py
+++ b/src/mcore_bridge/model/mm_gpts/gemma4.py
@@ -448,7 +448,7 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs):
if use_alternative_attention else text_config.num_key_value_heads)
return super()._set_qkv(mg_attn, hf_state_dict, to_mcore, **kwargs)
- def _set_router(self, mg_mlp, hf_state_dict, to_mcore):
+ def _set_router(self, mg_mlp, hf_state_dict, to_mcore, **kwargs):
self._set_state_dict(mg_mlp, 'router.weight', hf_state_dict, 'router.proj.weight', to_mcore)
for key in ['per_expert_scale', 'scale']:
self._set_state_dict(mg_mlp, key, hf_state_dict, f'router.{key}', to_mcore)
diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py
index 75a3a46..342e4f1 100644
--- a/src/mcore_bridge/model/modules/transformer_layer.py
+++ b/src/mcore_bridge/model/modules/transformer_layer.py
@@ -227,8 +227,11 @@ def _apply_rotary_pos_emb_bshd(
t: torch.Tensor,
freqs: torch.Tensor,
rotary_interleaved: bool = False,
+ mla_rotary_interleaved: Optional[bool] = None,
multi_latent_attention: Optional[bool] = None,
mscale: float = 1.0,
+ inverse: bool = False,
+ mla_output_remove_interleaving: bool = False,
**kwargs,
) -> torch.Tensor:
"""Apply rotary positional embedding to input tensor T.
@@ -249,7 +252,9 @@ def _apply_rotary_pos_emb_bshd(
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
if multi_latent_attention is None:
multi_latent_attention = self.config.multi_latent_attention
- if multi_latent_attention:
+ if mla_rotary_interleaved is None:
+ mla_rotary_interleaved = multi_latent_attention
+ if mla_rotary_interleaved:
x1 = t[..., 0::2]
x2 = t[..., 1::2]
t = torch.cat((x1, x2), dim=-1)
@@ -258,8 +263,17 @@ def _apply_rotary_pos_emb_bshd(
# second part is sine component, need to change signs with _rotate_half method
cos_ = (torch.cos(freqs) * mscale).to(t.dtype)
sin_ = (torch.sin(freqs) * mscale).to(t.dtype)
+ if inverse:
+ sin_ = -sin_
t = (t * cos_) + (rope_utils._rotate_half(t, rotary_interleaved) * sin_)
+ # Fallback to original permutation
+ # DSv4 applies rope on V and O, so we need to uninterleave the tensor.
+ # The existing MLA code is safe because the dot product is permutation-invariant.
+ if mla_rotary_interleaved and mla_output_remove_interleaving:
+ x1, x2 = torch.chunk(t, 2, dim=-1)
+ t = torch.stack((x1, x2), dim=-1).flatten(start_dim=-2)
+
return torch.cat((t, t_pass), dim=-1)
rope_utils._apply_rotary_pos_emb_bshd = _apply_rotary_pos_emb_bshd
diff --git a/src/mcore_bridge/model/register.py b/src/mcore_bridge/model/register.py
index c6816a0..3ac468b 100644
--- a/src/mcore_bridge/model/register.py
+++ b/src/mcore_bridge/model/register.py
@@ -132,6 +132,8 @@ def _set_shared_expert_gate(self, transformer_layer_spec):
layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True}
def _set_transformer_layer(self, transformer_layer_spec):
+ if self.config.enable_hyper_connections:
+ return
for layer_spec in transformer_layer_spec.layer_specs:
if layer_spec.module is McoreTransformerLayer:
layer_spec.module = TransformerLayer
diff --git a/src/mcore_bridge/model/rope.py b/src/mcore_bridge/model/rope.py
index a514f43..0db60a0 100644
--- a/src/mcore_bridge/model/rope.py
+++ b/src/mcore_bridge/model/rope.py
@@ -25,11 +25,15 @@ def __init__(self, **kwargs):
def _get_dummy_config(config):
+ if config.multi_latent_attention and config.partial_rotary_factor is None:
+ head_dim = config.qk_pos_emb_head_dim
+ else:
+ head_dim = config.kv_channels
dummy_config = DummyConfig(
rope_scaling=config.rope_scaling,
rope_theta=config.rotary_base,
max_position_embeddings=config.max_position_embeddings,
- head_dim=config.qk_pos_emb_head_dim if config.multi_latent_attention else config.kv_channels,
+ head_dim=head_dim,
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
)