diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 21235e305db4..e3084195cd50 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -404,6 +404,8 @@ th { | `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | ✅︎ | ✅︎ | | `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | | `OuroForCausalLM` | ouro | `ByteDance/Ouro-1.4B`, `ByteDance/Ouro-2.6B`, etc. | ✅︎ | | +| `PanguEmbeddedForCausalLM` |openPangu-Embedded-7B | `FreedomIntelligence/openPangu-Embedded-7B-V1.1` | ✅︎ | ✅︎ | +| `PanguUltraMoEForCausalLM` |openpangu-ultra-moe-718b-model | `FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1` | ✅︎ | ✅︎ | | `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ | | `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | | `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index 00fe99980500..a4bcddc50c1a 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -363,6 +363,11 @@ def check_available_online( "Olmo2ForCausalLM": _HfExamplesInfo("allenai/OLMo-2-0425-1B"), "Olmo3ForCausalLM": _HfExamplesInfo("shanearora/2025-sep-a-base-model"), "OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"), + "OpenPanguMTPModel": _HfExamplesInfo( + "FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1", + trust_remote_code=True, + is_available_online=False, + ), "OPTForCausalLM": _HfExamplesInfo( "facebook/opt-125m", {"1b": "facebook/opt-iml-max-1.3b"} ), @@ -370,6 +375,14 @@ def check_available_online( "OrionStarAI/Orion-14B-Chat", trust_remote_code=True ), "OuroForCausalLM": _HfExamplesInfo("ByteDance/Ouro-1.4B", trust_remote_code=True), + "PanguEmbeddedForCausalLM": _HfExamplesInfo( + "FreedomIntelligence/openPangu-Embedded-7B-V1.1", trust_remote_code=True + ), + "PanguUltraMoEForCausalLM": _HfExamplesInfo( + "FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1", + trust_remote_code=True, + is_available_online=False, + ), "PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"), "PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"), "Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"), diff --git a/vllm/config/model.py b/vllm/config/model.py index 2e80df431103..17d3162695b5 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1231,6 +1231,8 @@ def is_deepseek_mla(self) -> bool: "kimi_k2", "kimi_linear", "longcat_flash", + "pangu_ultra_moe", + "pangu_ultra_moe_mtp", ): return self.hf_text_config.kv_lora_rank is not None elif self.hf_text_config.model_type == "eagle": @@ -1379,6 +1381,7 @@ def get_layers_start_end_indices( or self.hf_config.model_type == "glm4_moe_mtp" or self.hf_config.model_type == "ernie_mtp" or self.hf_config.model_type == "qwen3_next_mtp" + or self.hf_config.model_type == "pangu_ultra_moe_mtp" ): total_num_hidden_layers = getattr( self.hf_text_config, "num_nextn_predict_layers", 0 diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index af1d640f8acc..873dfd017069 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -41,6 +41,7 @@ "qwen3_next_mtp", "mimo_mtp", "longcat_flash_mtp", + "pangu_ultra_moe_mtp", "mtp", "suffix", ] @@ -51,6 +52,7 @@ "ernie_mtp", "qwen3_next_mtp", "longcat_flash_mtp", + "pangu_ultra_moe_mtp", ) @@ -179,6 +181,13 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: hf_config.update( {"n_predict": n_predict, "architectures": ["DeepSeekMTPModel"]} ) + if hf_config.model_type in ("pangu_ultra_moe"): + hf_config.model_type = "pangu_ultra_moe_mtp" + if hf_config.model_type == "pangu_ultra_moe_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update( + {"n_predict": n_predict, "architectures": ["OpenPanguMTPModel"]} + ) if hf_config.architectures[0] == "MiMoForCausalLM": hf_config.model_type = "mimo_mtp" diff --git a/vllm/model_executor/models/openpangu.py b/vllm/model_executor/models/openpangu.py new file mode 100644 index 000000000000..457498d995f8 --- /dev/null +++ b/vllm/model_executor/models/openpangu.py @@ -0,0 +1,1078 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing +from collections.abc import Callable, Iterable +from typing import Any + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ParallelConfig, VllmConfig +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_world_size, + get_tp_group, + tensor_model_parallel_all_gather, +) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.model_executor.models.interfaces import ( + MixtureOfExperts, + SupportsLoRA, + SupportsPP, +) +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, + sequence_parallel_chunk, +) +from vllm.sequence import IntermediateTensors + + +def check_ffn_act_fn(act_fn: str): + if act_fn != "silu": + raise ValueError( + f"Unsupported activation: {act_fn}. Only silu is supported for now." + ) + + +class OpenPanguMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: QuantizationConfig | None = None, + bias: bool = False, + reduce_results: bool = True, + is_sequence_parallel=False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + disable_tp=is_sequence_parallel, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=bias, + quant_config=quant_config, + reduce_results=reduce_results, + disable_tp=is_sequence_parallel, + prefix=f"{prefix}.down_proj", + ) + + check_ffn_act_fn(hidden_act) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act_fn(self.gate_up_proj(x)[0]))[0] + + +class OpenPanguMoE(nn.Module): + def __init__( + self, + config: PretrainedConfig, + parallel_config: ParallelConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tp_group().rank_in_group + + self.routed_scaling_factor = config.routed_scaling_factor + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts: int = config.n_routed_experts + self.n_shared_experts: int = config.n_shared_experts + + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe + check_ffn_act_fn(config.hidden_act) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + self.gate.e_score_correction_bias = None + + # Load balancing settings. + eplb_config = parallel_config.eplb_config + self.enable_eplb = parallel_config.enable_eplb + + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_logical_experts = self.n_routed_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) + + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = OpenPanguMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + is_sequence_parallel=self.is_sequence_parallel, + reduce_results=False, + prefix=f"{prefix}.shared_experts", + ) + else: + self.shared_experts = None + + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=1, + topk_group=1, + prefix=f"{prefix}.experts", + scoring_func="sigmoid", + # we do scaling outside, set factor to 1.0 to avoid double mul + routed_scaling_factor=1.0, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + ) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + + router_logits, _ = self.gate(hidden_states) + + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + + shared_output, final_hidden_states = fused_moe_out + if self.shared_experts is None: + assert shared_output is None + + if hidden_states.dtype != torch.float16: + final_hidden_states *= self.routed_scaling_factor + elif self.shared_experts is not None: + assert shared_output is not None + shared_output *= 1.0 / self.routed_scaling_factor + + if self.shared_experts is not None: + assert shared_output is not None + final_hidden_states += shared_output + + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0 + ) + final_hidden_states = final_hidden_states[:num_tokens] + elif self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states + ) + + return final_hidden_states.view(num_tokens, hidden_dim) + + +class OpenPanguMLAAttention(nn.Module): + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int | None, + kv_lora_rank: int, + rope_theta: float = 10000, + max_position_embeddings: int = 8192, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.rope_theta = rope_theta + + self.tp_size = get_tensor_model_parallel_world_size() + if num_heads % self.tp_size != 0: + raise ValueError( + f"num_heads {num_heads} is not divisible by tp_size {self.tp_size}." + ) + self.num_local_heads = num_heads // self.tp_size + + self.scaling = self.qk_head_dim**-0.5 + self.max_position_embeddings = max_position_embeddings + + self.prefix = prefix + + if self.q_lora_rank is not None: + self.fused_qkv_a_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.fused_qkv_a_proj", + disable_tp=True, + ) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj", + ) + else: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa", + ) + + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj", + ) + + self.o_proj = RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + # TODO: remove hard coding + rope_scaling = { + "beta_fast": 32, + "beta_slow": 1, + "factor": 1, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": max_position_embeddings, + "type": "yarn", + "rope_type": "deepseek_yarn", + } + self.rotary_emb = get_rope( + qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False, + ) + + mla_modules = MLAModules( + kv_a_layernorm=self.kv_a_layernorm, + kv_b_proj=self.kv_b_proj, + rotary_emb=self.rotary_emb, + o_proj=self.o_proj, + fused_qkv_a_proj=self.fused_qkv_a_proj + if self.q_lora_rank is not None + else None, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa + if self.q_lora_rank is None + else None, + q_a_layernorm=self.q_a_layernorm if self.q_lora_rank is not None else None, + q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None, + q_proj=self.q_proj if self.q_lora_rank is None else None, + indexer=None, + is_sparse=False, + topk_indices_buffer=None, + ) + + self.mla_attn = MultiHeadLatentAttentionWrapper( + self.hidden_size, + self.num_local_heads, + self.scaling, + self.qk_nope_head_dim, + self.qk_rope_head_dim, + self.v_head_dim, + self.q_lora_rank, + self.kv_lora_rank, + mla_modules, + cache_config, + quant_config, + prefix, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + return self.mla_attn(positions, hidden_states) + + +class OpenPanguEmbeddedAttention(nn.Module): + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: dict[str, Any] | None = None, + max_position_embeddings: int = 8192, + quant_config: QuantizationConfig | None = None, + bias: bool = False, + bias_o_proj: bool = False, + cache_config: CacheConfig | None = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ) -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + if self.total_num_heads % tp_size != 0: + raise ValueError( + f"total_num_heads {self.total_num_heads} " + f"is not divisible by tp_size {tp_size}." + ) + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads > tp_size and self.total_num_kv_heads % tp_size != 0: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel ranks. + raise ValueError( + "Number of KV heads is greater than TP size, " + f"but total_num_kv_heads {self.total_num_kv_heads} " + f"is not divisible by tp_size {tp_size}." + ) + elif ( + self.total_num_kv_heads < tp_size and tp_size % self.total_num_kv_heads != 0 + ): + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel ranks. + raise ValueError( + f"Number of KV heads is less than TP size, but tp_size {tp_size} " + f"is not divisible by total_num_kv_heads {self.total_num_kv_heads}." + ) + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + head_dim = getattr(config, "head_dim", None) + if head_dim is None: + head_dim = self.hidden_size // self.total_num_heads + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias_o_proj, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self._init_rotary_emb( + config, rope_scaling=rope_scaling, quant_config=quant_config + ) + + if hasattr(config, "interleaved_sliding_window"): + interleaved_sliding_window = config.interleaved_sliding_window + if isinstance(interleaved_sliding_window, int): + sliding_window = interleaved_sliding_window + elif isinstance(interleaved_sliding_window, list): + sw_idx = layer_idx % len(interleaved_sliding_window) + sliding_window = interleaved_sliding_window[sw_idx] + else: + raise ValueError( + f"{type(interleaved_sliding_window)} " + "for interleaved_sliding_window is not supported." + ) + else: + sliding_window = None + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=sliding_window, + attn_type=attn_type, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + def _init_rotary_emb( + self, + config: PretrainedConfig, + rope_scaling: dict[str, Any] | None, + quant_config: QuantizationConfig | None, + ) -> None: + is_neox_style = True + is_gguf = quant_config and quant_config.get_name() == "gguf" + if is_gguf and config.model_type == "PanguEmbedded": + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + ) + + +class OpenPanguDecoderLayer(nn.Module): + def __init__( + self, + config: PretrainedConfig, + prefix: str, + vllm_config: VllmConfig, + ) -> None: + super().__init__() + + if config is None: + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + + layer_idx = int(prefix.split(sep=".")[-1]) + self.layer_idx = layer_idx + + self.use_mla = ( + hasattr(config, "qk_nope_head_dim") + and hasattr(config, "qk_rope_head_dim") + and hasattr(config, "v_head_dim") + and hasattr(config, "kv_lora_rank") + ) + if self.use_mla: + self.self_attn = OpenPanguMLAAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=( + config.q_lora_rank if hasattr(config, "q_lora_rank") else None + ), + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + else: + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False + ) + bias_o_proj = attention_bias + if hasattr(config, "qkv_bias"): + attention_bias = config.qkv_bias + # By default, PanguEmbedded uses causal attention + # as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + self.self_attn = OpenPanguEmbeddedAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), + rope_theta=rope_theta, + rope_scaling=getattr(config, "rope_scaling", None), + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + bias_o_proj=bias_o_proj, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + ) + + if ( + getattr(config, "n_routed_experts", None) is not None + and layer_idx >= config.first_k_dense_replace + ): + self.mlp = OpenPanguMoE( + config=config, + parallel_config=parallel_config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = OpenPanguMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.routed_scaling_factor = getattr(config, "routed_scaling_factor", None) + self.num_hidden_layers = config.num_hidden_layers + self.first_k_dense_replace = getattr( + config, "first_k_dense_replace", self.num_hidden_layers + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.tp_group = get_tp_group().device_group + self.sandwich_norm = getattr(config, "sandwich_norm", False) + if self.sandwich_norm: + self.pre_mlp_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_mlp_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> torch.Tensor: + if residual is None: + residual = hidden_states.clone() + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + if ( + self.routed_scaling_factor is not None + and hidden_states.dtype == torch.float16 + ): + # Fix FP16 overflow + # We scale both hidden_states and residual before + # rmsnorm, and rmsnorm result would not affect by scale. + hidden_states *= 1.0 / self.routed_scaling_factor + if self.layer_idx == 0: + # The residual is shared by all layers, we only scale it on + # first layer. + residual *= 1.0 / self.routed_scaling_factor + + if self.sandwich_norm: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, residual = self.pre_mlp_layernorm(hidden_states, residual) + else: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) + + # Fully Connected + hidden_states = self.mlp(hidden_states) + + if ( + self.routed_scaling_factor is not None + and isinstance(self.mlp, OpenPanguMLP) + and hidden_states.dtype == torch.float16 + ): + hidden_states *= 1.0 / self.routed_scaling_factor + + if self.sandwich_norm: + hidden_states = self.post_mlp_layernorm(hidden_states) + + return hidden_states, residual + + +@support_torch_compile +class OpenPanguModel(nn.Module): + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + eplb_config = vllm_config.parallel_config.eplb_config + self.config = config + self.num_redundant_experts = eplb_config.num_redundant_experts + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: OpenPanguDecoderLayer(config, prefix, vllm_config), + prefix=f"{prefix}.layers", + ) + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_attn_mlp_weight( + self, + attn_mlp_replace_mapping: list[tuple[str, str, int]], + params_dict: dict[str, Any], + weight_name: str, + loaded_weight: torch.Tensor, + loaded_params: set[str], + ) -> bool: + for param_name, origin_name, shard_id in attn_mlp_replace_mapping: + if origin_name not in weight_name or ( + ("mlp.experts." in weight_name) and weight_name not in params_dict + ): + continue + weight_name_mapped = weight_name.replace(origin_name, param_name) + if ( + param_name == "fused_qkv_a_proj" + and weight_name_mapped not in params_dict + ): + continue + else: + weight_name = weight_name_mapped + if weight_name.endswith(".bias") and weight_name not in params_dict: + continue + if is_pp_missing_parameter(weight_name, self): + continue + + param = params_dict[weight_name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(weight_name) + return True + return False + + def load_expert_weight( + self, + expert_merge_mapping: list[tuple[str, str, int, str]], + params_dict: dict[str, Any], + weight_name: str, + loaded_weight: torch.Tensor, + loaded_params: set[str], + flag_dict: dict[str, bool], + ) -> bool: + for mapping in expert_merge_mapping: + param_name, origin_name, expert_id, shard_id = mapping + if origin_name not in weight_name: + continue + flag_dict["is_expert_weight"] = True + weight_name_mapped = weight_name.replace(origin_name, param_name) + if is_pp_missing_parameter(weight_name_mapped, self): + continue + param = params_dict[weight_name_mapped] + weight_loader = typing.cast(Callable[..., bool], param.weight_loader) + success = weight_loader( + param, + loaded_weight, + weight_name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + weight_name = weight_name_mapped + loaded_params.add(weight_name_mapped) + return True + return False + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + attn_mlp_replace_mapping = [ + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".fused_qkv_a_proj", ".q_a_proj", 0), + (".fused_qkv_a_proj", ".kv_a_proj_with_mqa", 1), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + has_experts = hasattr(self.config, "n_routed_experts") + if has_experts: + expert_merge_mapping = SharedFusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts, + num_redundant_experts=self.num_redundant_experts, + ) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + + if ( + "layers" in name + and hasattr(self.config, "num_nextn_predict_layers") + and (self.config.num_nextn_predict_layers > 0) + ): + layer_idx = int(name.split("layers.")[-1].split(".")[0]) + mtp_idx = layer_idx - self.config.num_hidden_layers + if mtp_idx >= 0 and mtp_idx < self.config.num_nextn_predict_layers: + continue # skip spec decode layers for main model + + flag_dict = {"is_expert_weight": False} + if ( + self.load_attn_mlp_weight( + attn_mlp_replace_mapping, + params_dict, + name, + loaded_weight, + loaded_params, + ) + or has_experts + and self.load_expert_weight( + expert_merge_mapping, + params_dict, + name, + loaded_weight, + loaded_params, + flag_dict, + ) + ): + continue + else: + if flag_dict["is_expert_weight"]: + continue + if name.endswith(".bias") and name not in params_dict: + continue + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class OpenPanguModelBase(nn.Module, SupportsPP, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + self.fuse_qkv_a_proj = ( + hasattr(config, "q_lora_rank") and config.q_lora_rank is not None + ) + if self.fuse_qkv_a_proj: + self.packed_modules_mapping["fused_qkv_a_proj"] = [ + "q_a_proj", + "kv_a_proj_with_mqa", + ] + + self.model = OpenPanguModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + else: + self.lm_head = PPMissingLayer() + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) + + +class OpenPanguMoEModel(OpenPanguModelBase, MixtureOfExperts): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + config = vllm_config.model_config.hf_config + + # Set MoE hyperparameters + self.expert_weights = [] + self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace + self.num_expert_groups = 1 + + self.moe_layers: list[SharedFusedMoE] = [] + example_moe = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, OpenPanguDecoderLayer) + if isinstance(layer.mlp, OpenPanguMoE): + # Pick last one layer since the first ones may be dense layers. + example_moe = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_moe is None: + raise RuntimeError("No MOE layer found in model.layers.") + + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.n_routed_experts = example_moe.n_routed_experts + self.n_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for layer in self.model.layers: + if isinstance(layer.mlp, OpenPanguMoE): + moe = layer.mlp + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + + +class OpenPanguEmbeddedModel(OpenPanguModelBase): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + +class PanguEmbeddedForCausalLM(OpenPanguEmbeddedModel): + pass + + +class PanguUltraMoEForCausalLM(OpenPanguMoEModel): + pass diff --git a/vllm/model_executor/models/openpangu_mtp.py b/vllm/model_executor/models/openpangu_mtp.py new file mode 100644 index 000000000000..f4049f2d3970 --- /dev/null +++ b/vllm/model_executor/models/openpangu_mtp.py @@ -0,0 +1,265 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from +# https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/deepseek_mtp.py +from collections.abc import Iterable + +import torch +import torch.nn as nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.deepseek_mtp import ( + DeepSeekMultiTokenPredictor, + DeepSeekMultiTokenPredictorLayer, + SharedHead, +) +from vllm.model_executor.models.utils import maybe_prefix +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .openpangu import OpenPanguDecoderLayer + + +class OpenPanguMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer): + def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: + nn.Module.__init__(self) + + config = vllm_config.speculative_config.draft_model_config.hf_config + self.config = config + quant_config = vllm_config.quant_config + + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) + self.shared_head = SharedHead( + config=config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "shared_head"), + ) + self.mtp_block = OpenPanguDecoderLayer(config, prefix, vllm_config) + + +class OpenPanguMultiTokenPredictor(DeepSeekMultiTokenPredictor): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + # to map the exact layer index from weights + self.layers = torch.nn.ModuleDict( + { + str(idx): OpenPanguMultiTokenPredictorLayer( + vllm_config, f"{prefix}.layers.{idx}" + ) + for idx in range( + self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers, + ) + } + ) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + + +@support_torch_compile +class OpenPanguMTP(nn.Module, SupportsPP): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + self.model = OpenPanguMultiTokenPredictor( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, + positions, + hidden_states, + inputs_embeds, + spec_step_idx, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + spec_step_idx: int = 0, + ) -> torch.Tensor | None: + return self.model.compute_logits(hidden_states, spec_step_idx) + + def get_spec_layer(self, name): + if ( + "layers" in name + and hasattr(self.config, "num_nextn_predict_layers") + and self.config.num_nextn_predict_layers > 0 + ): + layer_idx = int(name.split("layers.")[-1].split(".")[0]) + mtp_idx = layer_idx - self.config.num_hidden_layers + if mtp_idx >= 0 and mtp_idx < self.config.num_nextn_predict_layers: + return layer_idx + return None + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ("fused_qkv_a_proj", "q_a_proj", 0), + ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), + ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts, + ) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + spec_layer = self.get_spec_layer(name) + if spec_layer is None: + continue + + name = self._rewrite_spec_layer_name(spec_layer, name) + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name_mapped = name.replace(weight_name, param_name) + + # QKV fusion is optional, fall back to normal + # weight loading if it's not enabled + if ( + param_name == "fused_qkv_a_proj" + ) and name_mapped not in params_dict: + continue + else: + name = name_mapped + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if ( + spec_layer != self.model.mtp_start_layer_idx + and ".layers" not in name + ): + continue + + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: + """ + Rewrite the weight name to match the format of the original model. + Add .mtp_block for modules in transformer layer block for spec layer + and rename shared layer weights to be top level. + """ + spec_layer_weight_names = [ + "embed_tokens", + "enorm", + "hnorm", + "eh_proj", + "shared_head", + ] + shared_weight_names = ["embed_tokens"] + spec_layer_weight = False + shared_weight = False + for weight_name in spec_layer_weight_names: + if weight_name in name: + spec_layer_weight = True + if weight_name in shared_weight_names: + shared_weight = True + break + if not spec_layer_weight: + # treat rest weights as weights for transformer layer block + name = name.replace( + f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block." + ) + elif shared_weight: + # treat shared weights as top level weights + name = name.replace(f"model.layers.{spec_layer}.", "model.") + return name diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index d9299697fcb0..dddbc88069ef 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -149,6 +149,8 @@ "OPTForCausalLM": ("opt", "OPTForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"), "OuroForCausalLM": ("ouro", "OuroForCausalLM"), + "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"), + "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"), "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), @@ -406,6 +408,7 @@ "LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"), "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), "MedusaModel": ("medusa", "Medusa"), + "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"), "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"), # Temporarily disabled. # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1. diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 1e18eea2330a..75a4140fd655 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -316,7 +316,12 @@ def propose( positions = target_positions[:, last_token_indices] else: positions = target_positions[last_token_indices] - if self.method in ("deepseek_mtp", "ernie_mtp", "longcat_flash_mtp"): + if self.method in ( + "deepseek_mtp", + "ernie_mtp", + "longcat_flash_mtp", + "pangu_ultra_moe_mtp", + ): hidden_states = self.hidden_states[last_token_indices] else: hidden_states = hidden_states[last_token_indices]