Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ th {
| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | ✅︎ | ✅︎ |
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ |
| `OuroForCausalLM` | ouro | `ByteDance/Ouro-1.4B`, `ByteDance/Ouro-2.6B`, etc. | ✅︎ | |
| `PanguEmbeddedForCausalLM` | openPangu Embedded |`FreedomIntelligence/openPangu-Embedded-7B` | ✅︎ | ✅︎ |
| `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ |
| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ |
| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ |
Expand Down
3 changes: 3 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,9 @@ def check_available_online(
"OrionStarAI/Orion-14B-Chat", trust_remote_code=True
),
"OuroForCausalLM": _HfExamplesInfo("ByteDance/Ouro-1.4B", trust_remote_code=True),
"PanguEmbeddedForCausalLM": _HfExamplesInfo(
"FreedomIntelligence/openPangu-Embedded-7B", trust_remote_code=True
),
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
Expand Down
370 changes: 370 additions & 0 deletions vllm/model_executor/models/pangu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,370 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""Native OpenPangu Embedded model implementation."""

from collections.abc import Iterable

import torch
from torch import nn
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionType
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
PPMissingLayer,
make_empty_intermediate_tensors_factory,
make_layers,
)
from vllm.sequence import IntermediateTensors


class PanguMLP(nn.Module):
"""Feed-forward network for PanguEmbedded layers."""

def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
*,
bias: bool,
quant_config: QuantizationConfig | None,
prefix: str,
) -> None:
super().__init__()
self.gate_proj = ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_proj",
)
self.up_proj = ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.act_fn = get_act_fn(hidden_act)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
gate, _ = self.gate_proj(hidden_states)
up, _ = self.up_proj(hidden_states)
hidden_states = self.act_fn(gate) * up
hidden_states, _ = self.down_proj(hidden_states)
return hidden_states


class PanguAttention(nn.Module):
"""Self-attention block with GQA."""

def __init__(
self,
config: PretrainedConfig,
*,
cache_config: CacheConfig | None,
quant_config: QuantizationConfig | None,
prefix: str,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.total_num_heads = config.num_attention_heads
self.total_num_kv_heads = getattr(
config, "num_key_value_heads", config.num_attention_heads
)
tp_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
if self.total_num_kv_heads >= tp_size:
assert self.total_num_kv_heads % tp_size == 0
else:
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = getattr(
config,
"head_dim",
self.hidden_size // self.total_num_heads,
)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5

rope_theta = getattr(config, "rope_theta", 10000.0)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None:
rope_scaling = dict(rope_scaling)
original_max_position = getattr(
config, "original_max_position_embeddings", None
)
if original_max_position is not None:
rope_scaling.setdefault(
"original_max_position_embeddings", original_max_position
)
max_position_embeddings = getattr(config, "max_position_embeddings", 2048)

bias = getattr(config, "bias", False)
self.q_proj = ColumnParallelLinear(
self.hidden_size,
self.total_num_heads * self.head_dim,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
)
self.k_proj = ColumnParallelLinear(
self.hidden_size,
self.total_num_kv_heads * self.head_dim,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.k_proj",
)
self.v_proj = ColumnParallelLinear(
self.hidden_size,
self.total_num_kv_heads * self.head_dim,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.v_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)

self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=True,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
attn_type=AttentionType.DECODER,
prefix=f"{prefix}.attn",
)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
q, _ = self.q_proj(hidden_states)
k, _ = self.k_proj(hidden_states)
v, _ = self.v_proj(hidden_states)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output


class PanguDecoderLayer(nn.Module):
"""Single decoder block for PanguEmbedded."""

def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
config: PretrainedConfig | None = None,
) -> None:
super().__init__()
config = config or vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = self.get_quant_config(vllm_config)

self.hidden_size = config.hidden_size
self.self_attn = PanguAttention(
config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = PanguMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
bias=getattr(config, "bias", False),
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(
config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-5)
)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-5)
)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)

hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual

def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
return vllm_config.quant_config


class PanguModel(nn.Module):
"""Backbone model for OpenPangu Embedded."""

def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[nn.Module] = PanguDecoderLayer,
) -> None:
super().__init__()

config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config

self.config = config
self.quant_config = quant_config
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
if get_pp_group().is_first_rank or (
getattr(config, "tie_word_embeddings", True) and get_pp_group().is_last_rank
):
self.embed_tokens = VocabParallelEmbedding(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The missing line prefix=f"{prefix}.embed_tokens", may disable quantization support. As @jeejeelee pointed out, this PR should be merged into #27521.

self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
)
else:
self.embed_tokens = PPMissingLayer()

self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
prefix=f"{prefix}.layers",
)

if get_pp_group().is_last_rank:
self.norm = RMSNorm(
config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-5)
)
else:
self.norm = PPMissingLayer()

self.aux_hidden_state_layers: tuple[int, ...] = ()
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)

def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]

aux_hidden_states: list[torch.Tensor] = []
for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]):
if residual is None:
aux_hidden_states.append(hidden_states)
else:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual)

if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)

hidden_states, _ = self.norm(hidden_states, residual)

if aux_hidden_states:
return hidden_states, aux_hidden_states
return hidden_states

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)


class PanguForCausalLM(LlamaForCausalLM, SupportsLoRA, SupportsPP):
"""Causal LM head for OpenPangu Embedded."""

packed_modules_mapping: dict[str, list[str]] = {}
mistral_mapping: dict[str, str] = {}

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(
vllm_config=vllm_config,
prefix=prefix,
layer_type=PanguDecoderLayer,
)

def _init_model(
self,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[nn.Module] = PanguDecoderLayer,
):
return PanguModel(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"OuroForCausalLM": ("ouro", "OuroForCausalLM"),
"PanguEmbeddedForCausalLM": ("pangu", "PanguForCausalLM"),
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
Expand Down