Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ The following is the list of models supported by MCore-Bridge:
| Series | model_type |
| -------- | ------------------------------------------------------------ |
| Qwen | qwen2, qwen2_moe<br />qwen3, qwen3_moe, qwen3_next |
| DeepSeek | deepseek_v3, deepseek_v32 |
| DeepSeek | deepseek_v3, deepseek_v32, deepseek_v4 |
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The PR adds deepseek_v4 to the list of supported models, but the actual implementation appears to be missing. The file src/mcore_bridge/model/gpts/deepseek_v4.py is empty in the provided context, and there are no changes to model registration or configuration logic to support this new model type. Please ensure the implementation is included or clarify if it relies on an existing model type.

| GLM | glm4, glm4_moe, glm4_moe_lite<br />glm_moe_dsa |
| MiniMax | minimax_m2 |
| Kimi | kimi_k2, kimi_k25 |
Expand Down
2 changes: 1 addition & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ uv pip install -e . --torch-backend=auto
| 系列 | model_type |
| -------- | ------------------------------------------------------------ |
| Qwen | qwen2, qwen2_moe<br />qwen3, qwen3_moe, qwen3_next |
| DeepSeek | deepseek_v3, deepseek_v32 |
| DeepSeek | deepseek_v3, deepseek_v32, deepseek_v4 |
| GLM | glm4, glm4_moe, glm4_moe_lite<br />glm_moe_dsa |
| MiniMax | minimax_m2 |
| Kimi | kimi_k2, kimi_k25 |
Expand Down
30 changes: 27 additions & 3 deletions src/mcore_bridge/bridge/gpt_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def _get_hf_experts_attr(self, is_mtp: bool = False):
return True, True
if self.model_type in {'glm4v_moe', 'kimi_vl', 'qwen3_omni_moe', 'qwen3_5_moe'} or self.llm_model_type in {
'qwen2_moe', 'qwen3_moe', 'deepseek_v2', 'deepseek_v3', 'kimi_k2', 'dots1', 'ernie4_5_moe', 'glm4_moe',
'glm4_moe_lite', 'minimax_m2', 'olmoe', 'qwen3_next', 'glm_moe_dsa', 'deepseek_v32'
'glm4_moe_lite', 'minimax_m2', 'olmoe', 'qwen3_next', 'glm_moe_dsa', 'deepseek_v32', 'deepseek_v4'
}:
return False, False
elif self.model_type in {'qwen3_vl_moe', 'llama4', 'gemma4'} or self.llm_model_type in {'gpt_oss'}:
Expand Down Expand Up @@ -1619,6 +1619,23 @@ def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool
self.hf_post_attention_layernorm_key, to_mcore)
return hf_state_dict

def _set_hyper_connection(self, mg_layer, hf_state_dict, layer_idx, to_mcore):

for key, hf_key in zip(['self_attention_hyper_connection', 'mlp_hyper_connection'], ['attn', 'ffn']):
hyper_connection = None if mg_layer is None else getattr(mg_layer, key)
Comment thread
Jintao-Huang marked this conversation as resolved.
self._set_state_dict(hyper_connection, 'mapping_proj.weight', hf_state_dict, f'hc_{hf_key}_fn', to_mcore)
self._set_state_dict(hyper_connection, 'bias', hf_state_dict, f'hc_{hf_key}_base', to_mcore)
if hyper_connection is not None:
if to_mcore:
alpha = hf_state_dict[f'hc_{hf_key}_scale'].load()
for i, alpha_suffix in enumerate(['pre', 'post', 'res']):
getattr(hyper_connection, f'alpha_{alpha_suffix}').data[:] = alpha[i]
else:
alpha = []
for i, alpha_suffix in enumerate(['pre', 'post', 'res']):
alpha.append(getattr(hyper_connection, f'alpha_{alpha_suffix}'))
hf_state_dict[f'hc_{hf_key}_scale'] = torch.concat(alpha)
Comment thread
Jintao-Huang marked this conversation as resolved.
Outdated

def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool):
hf_prefix = f'{hf_prefix}{layer_idx}.'
if to_mcore:
Expand All @@ -1627,6 +1644,9 @@ def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: i
hf_state_dict = {}
hf_state_dict.update(self._set_layer_attn(mg_layer, hf_state_dict, layer_idx, to_mcore))
hf_state_dict.update(self._set_layer_mlp(mg_layer, hf_state_dict, layer_idx, to_mcore))
if self.config.enable_hyper_connections:
self._set_hyper_connection(mg_layer, hf_state_dict, layer_idx, to_mcore)

if to_mcore:
hf_state_dict = {}
else:
Expand Down Expand Up @@ -1678,14 +1698,18 @@ def _convert_post_process(self, mg_model, hf_state_dict, hf_prefix: str, to_mcor
self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, hf_lm_head_key, to_mcore)
elif to_mcore and lm_model.output_layer.weight is not None:
self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, self.hf_embed_key, to_mcore)
self._set_state_dict(lm_model, 'decoder.final_layernorm.weight', hf_state_dict, self.hf_final_layernorm_key,
to_mcore)
self._set_final_layernorm(lm_model, hf_state_dict, to_mcore)

if to_mcore:
hf_state_dict = {}
else:
hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix)
return hf_state_dict

def _set_final_layernorm(self, lm_model, hf_state_dict, to_mcore):
self._set_state_dict(lm_model, 'decoder.final_layernorm.weight', hf_state_dict, self.hf_final_layernorm_key,
to_mcore)

def _convert_hf_state_dict(self, hf_state_dict, to_mcore):
res = {}
for k, v in hf_state_dict.items():
Expand Down
16 changes: 15 additions & 1 deletion src/mcore_bridge/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,27 @@ class ModelConfig(TransformerConfig):
linear_decoupled_in_proj: bool = False

# dsa
experimental_attention_variant: Optional[Literal['gated_delta_net', 'dsa']] = None
experimental_attention_variant: Optional[Literal['gated_delta_net', 'dsa', 'dsv4_hybrid']] = None
dsa_indexer_n_heads: Optional[int] = None
dsa_indexer_head_dim: Optional[int] = None
dsa_indexer_topk: Optional[int] = None
dsa_indexer_loss_coeff: Optional[float] = None
dsa_indexer_use_sparse_loss: bool = False
dsa_indexer_rotary_interleaved: bool = False

# deepseek-v4
csa_window_size: int = 128
csa_compress_ratios: Optional[List[int]] = None
csa_compress_rotary_base: float = 40000.0
o_groups: int = 8
o_lora_rank: int = 1024
enable_hyper_connections: bool = False
num_residual_streams: int = 4
mhc_sinkhorn_iterations: int = 20
mhc_init_gating_factor: float = 0.01
moe_n_hash_layers: int = 0
actual_vocab_size: Optional[int] = None

# mtp
mtp_decoder_input_detach: bool = False
mtp_shared_weights: bool = False
Expand Down Expand Up @@ -290,6 +303,7 @@ def __post_init__(self):

if self.add_bias_linear:
self.add_qkv_bias = True
self.actual_vocab_size = self.padded_vocab_size
self.batch_p2p_comm = not self.overlap_p2p_comm
if self.swiglu:
self.activation_func = F.silu
Expand Down
24 changes: 22 additions & 2 deletions src/mcore_bridge/config/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@
'dsa_indexer_head_dim': ['index_head_dim'],
'dsa_indexer_topk': ['index_topk'],
'dsa_indexer_rotary_interleaved': ['indexer_rope_interleave'],
# deepseek_v4
'csa_compress_ratios': ['compress_rates'],
'csa_compress_rotary_base': ['compress_rope_theta'],
'o_groups': ['o_groups'],
'o_lora_rank': ['o_lora_rank'],
'num_residual_streams': ['hc_mult'],
'mhc_sinkhorn_iterations': ['hc_sinkhorn_iters'],
'moe_n_hash_layers': ['mlp_layer_types'],
'activation_func_clamp_value': ['swiglu_limit'],
# other
'original_max_position_embeddings': ['original_max_position_embeddings'],
'partial_rotary_factor': ['partial_rotary_factor'],
Expand Down Expand Up @@ -88,7 +97,7 @@ def _convert_config(config, _internal_call=False) -> Dict[str, Any]:
else:
continue
else:
if k == 'kv_lora_rank':
if k in {'q_lora_rank', 'kv_lora_rank'}:
megatron_config['multi_latent_attention'] = True
elif k == 'hf_model_type':
if _internal_call:
Expand Down Expand Up @@ -134,7 +143,7 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]:
res.pop('ffn_hidden_size', None)
if llm_model_type in {'qwen2_moe', 'qwen3_next'} or hf_model_type == 'qwen3_5_moe':
res['moe_shared_expert_gate'] = True
if llm_model_type in {'deepseek', 'deepseek_v2', 'deepseek_v3', 'kimi_k2', 'deepseek_v32', 'dots1'
if llm_model_type in {'deepseek', 'deepseek_v2', 'deepseek_v3', 'kimi_k2', 'deepseek_v32', 'dots1', 'deepseek_v4'
} or hf_model_type == 'kimi_vl':
if llm_model_type != 'deepseek':
res['qk_layernorm'] = True
Expand All @@ -143,6 +152,17 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]:
res['moe_router_score_function'] = 'sigmoid'
elif llm_model_type == 'deepseek_v32':
res['experimental_attention_variant'] = 'dsa'
elif llm_model_type == 'deepseek_v4':
if 'v_head_dim' not in res:
res['v_head_dim'] = res['kv_channels']
res['experimental_attention_variant'] = 'dsv4_hybrid'
res['csa_window_size'] = window_size
res['enable_hyper_connections'] = True
res.pop('partial_rotary_factor', None)
csa_compress_ratios = res.pop('csa_compress_ratios', None)
res['csa_compress_ratios'] = [csa_compress_ratios.get(layer_type, 0) for layer_type in layer_types]
moe_n_hash_layers = res.pop('moe_n_hash_layers', None)
res['moe_n_hash_layers'] = len([layer for layer in moe_n_hash_layers if layer == 'hash_moe'])
Comment thread
Jintao-Huang marked this conversation as resolved.
elif llm_model_type == 'hunyuan':
# Since HunYuan’s attention applies RoPE before using q/k_layernorm,
# which is incompatible with megatron-core, support is not provided here.
Expand Down
1 change: 1 addition & 0 deletions src/mcore_bridge/model/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class LLMModelType:
minimax_m2 = 'minimax_m2'
hy_v3 = 'hy_v3'
bailing_moe = 'bailing_moe'
deepseek_v4 = 'deepseek_v4'

qwen3_emb = 'qwen3_emb'

Expand Down
23 changes: 15 additions & 8 deletions src/mcore_bridge/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,14 @@ def __init__(
):
vocab_size = math.ceil(
config.padded_vocab_size / config.tensor_model_parallel_size) * config.tensor_model_parallel_size
hf_rope_scaling = config.rope_scaling
self.hf_rope_scaling = config.rope_scaling
if config.multi_latent_attention:
config.rope_type = 'rope' # use transformers implementation
# Set default value, the following content will not be used. (dummy)
config.mscale_all_dim = 0.
config.cache_mla_latents = False
config.rotary_scaling_factor = 40
if hf_rope_scaling and hf_rope_scaling['rope_type'] == 'yarn':
# softmax_scale
config.mscale = hf_rope_scaling['mscale']
config.mscale_all_dim = hf_rope_scaling['mscale_all_dim']
config.rotary_scaling_factor = hf_rope_scaling['factor']
self.hf_rope_scaling = hf_rope_scaling
self._init_mla_softmax_scale(config)
super().__init__(
config,
transformer_layer_spec,
Expand Down Expand Up @@ -136,6 +131,13 @@ def _value(self):
attention.config = copy.copy(attention.config)
attention.config.apply_rope_fusion = False

def _init_mla_softmax_scale(self, config):
if self.hf_rope_scaling and self.hf_rope_scaling['rope_type'] == 'yarn':
# softmax_scale
config.mscale = self.hf_rope_scaling['mscale']
config.mscale_all_dim = self.hf_rope_scaling['mscale_all_dim']
config.rotary_scaling_factor = self.hf_rope_scaling['factor']

def _preprocess(
self,
input_ids: torch.Tensor,
Expand Down Expand Up @@ -297,6 +299,11 @@ def forward(
assert padding_mask.shape[1] % tp_size == 0, f'padding_mask.shape: {padding_mask.shape}'
padding_mask = torch.chunk(padding_mask, tp_size, dim=1)[mpu.get_tensor_model_parallel_rank()]
kwargs['padding_mask'] = padding_mask.contiguous()

extra_block_kwargs = extra_block_kwargs or {}
if self.config.moe_n_hash_layers > 0:
extra_block_kwargs['input_ids'] = input_ids

# Run decoder.
hidden_states = self.decoder(
hidden_states=decoder_input,
Expand All @@ -307,7 +314,7 @@ def forward(
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
**(extra_block_kwargs or {}),
**extra_block_kwargs,
**kwargs,
)

Expand Down
2 changes: 1 addition & 1 deletion src/mcore_bridge/model/gpts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from . import bailing_moe, glm4, hunyuan, llm, minimax_m2, olmoe, qwen3_emb, qwen3_next
from . import bailing_moe, deepseek_v4, glm4, hunyuan, llm, minimax_m2, olmoe, qwen3_emb, qwen3_next
Loading
Loading