Skip to content

Commit 5afb5b8

Browse files
committed
Add native OpenPangu Embedded backend to vLLM
Signed-off-by: YoussefEssDS <[email protected]>
1 parent 6c317a6 commit 5afb5b8

File tree

3 files changed

+379
-0
lines changed

3 files changed

+379
-0
lines changed

tests/models/registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,9 @@ def check_available_online(
370370
"OrionStarAI/Orion-14B-Chat", trust_remote_code=True
371371
),
372372
"OuroForCausalLM": _HfExamplesInfo("ByteDance/Ouro-1.4B", trust_remote_code=True),
373+
"PanguEmbeddedForCausalLM": _HfExamplesInfo(
374+
"FreedomIntelligence/openPangu-Embedded-7B"
375+
),
373376
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
374377
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
375378
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
Lines changed: 375 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,375 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
"""Native OpenPangu Embedded model implementation."""
5+
6+
from collections.abc import Iterable
7+
from typing import Any
8+
9+
import torch
10+
from torch import nn
11+
from transformers import PretrainedConfig
12+
13+
from vllm.attention import Attention, AttentionType
14+
from vllm.config import CacheConfig, VllmConfig
15+
from vllm.distributed import (
16+
get_pp_group,
17+
get_tensor_model_parallel_world_size,
18+
)
19+
from vllm.model_executor.layers.activation import get_act_fn
20+
from vllm.model_executor.layers.layernorm import RMSNorm
21+
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
22+
from vllm.model_executor.layers.quantization import QuantizationConfig
23+
from vllm.model_executor.layers.rotary_embedding import get_rope
24+
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
25+
from vllm.model_executor.models.llama import LlamaForCausalLM
26+
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
27+
from vllm.model_executor.models.utils import (
28+
AutoWeightsLoader,
29+
PPMissingLayer,
30+
make_empty_intermediate_tensors_factory,
31+
make_layers,
32+
maybe_prefix,
33+
)
34+
from vllm.sequence import IntermediateTensors
35+
36+
37+
class PanguMLP(nn.Module):
38+
"""Feed-forward network for PanguEmbedded layers."""
39+
40+
def __init__(
41+
self,
42+
hidden_size: int,
43+
intermediate_size: int,
44+
hidden_act: str,
45+
*,
46+
bias: bool,
47+
quant_config: QuantizationConfig | None,
48+
prefix: str,
49+
) -> None:
50+
super().__init__()
51+
self.gate_proj = ColumnParallelLinear(
52+
hidden_size,
53+
intermediate_size,
54+
bias=False,
55+
quant_config=quant_config,
56+
prefix=f"{prefix}.gate_proj",
57+
)
58+
self.up_proj = ColumnParallelLinear(
59+
hidden_size,
60+
intermediate_size,
61+
bias=False,
62+
quant_config=quant_config,
63+
prefix=f"{prefix}.up_proj",
64+
)
65+
self.down_proj = RowParallelLinear(
66+
intermediate_size,
67+
hidden_size,
68+
bias=False,
69+
quant_config=quant_config,
70+
prefix=f"{prefix}.down_proj",
71+
)
72+
self.act_fn = get_act_fn(hidden_act)
73+
74+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
75+
gate, _ = self.gate_proj(hidden_states)
76+
up, _ = self.up_proj(hidden_states)
77+
hidden_states = self.act_fn(gate) * up
78+
hidden_states, _ = self.down_proj(hidden_states)
79+
return hidden_states
80+
81+
82+
class PanguAttention(nn.Module):
83+
"""Self-attention block with GQA."""
84+
85+
def __init__(
86+
self,
87+
config: PretrainedConfig,
88+
*,
89+
cache_config: CacheConfig | None,
90+
quant_config: QuantizationConfig | None,
91+
prefix: str,
92+
) -> None:
93+
super().__init__()
94+
self.hidden_size = config.hidden_size
95+
self.total_num_heads = config.num_attention_heads
96+
self.total_num_kv_heads = getattr(
97+
config, "num_key_value_heads", config.num_attention_heads
98+
)
99+
tp_size = get_tensor_model_parallel_world_size()
100+
assert self.total_num_heads % tp_size == 0
101+
self.num_heads = self.total_num_heads // tp_size
102+
if self.total_num_kv_heads >= tp_size:
103+
assert self.total_num_kv_heads % tp_size == 0
104+
else:
105+
assert tp_size % self.total_num_kv_heads == 0
106+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
107+
self.head_dim = getattr(
108+
config,
109+
"head_dim",
110+
self.hidden_size // self.total_num_heads,
111+
)
112+
self.q_size = self.num_heads * self.head_dim
113+
self.kv_size = self.num_kv_heads * self.head_dim
114+
self.scaling = self.head_dim**-0.5
115+
116+
rope_theta = getattr(config, "rope_theta", 10000.0)
117+
rope_scaling = getattr(config, "rope_scaling", None)
118+
if rope_scaling is not None:
119+
rope_scaling = dict(rope_scaling)
120+
original_max_position = getattr(
121+
config, "original_max_position_embeddings", None
122+
)
123+
if original_max_position is not None:
124+
rope_scaling.setdefault(
125+
"original_max_position_embeddings", original_max_position
126+
)
127+
max_position_embeddings = getattr(
128+
config, "max_position_embeddings", 2048
129+
)
130+
131+
bias = getattr(config, "bias", False)
132+
self.q_proj = ColumnParallelLinear(
133+
self.hidden_size,
134+
self.total_num_heads * self.head_dim,
135+
bias=bias,
136+
quant_config=quant_config,
137+
prefix=f"{prefix}.q_proj",
138+
)
139+
self.k_proj = ColumnParallelLinear(
140+
self.hidden_size,
141+
self.total_num_kv_heads * self.head_dim,
142+
bias=bias,
143+
quant_config=quant_config,
144+
prefix=f"{prefix}.k_proj",
145+
)
146+
self.v_proj = ColumnParallelLinear(
147+
self.hidden_size,
148+
self.total_num_kv_heads * self.head_dim,
149+
bias=bias,
150+
quant_config=quant_config,
151+
prefix=f"{prefix}.v_proj",
152+
)
153+
self.o_proj = RowParallelLinear(
154+
self.total_num_heads * self.head_dim,
155+
self.hidden_size,
156+
bias=bias,
157+
quant_config=quant_config,
158+
prefix=f"{prefix}.o_proj",
159+
)
160+
161+
self.rotary_emb = get_rope(
162+
self.head_dim,
163+
rotary_dim=self.head_dim,
164+
max_position=max_position_embeddings,
165+
base=rope_theta,
166+
rope_scaling=rope_scaling,
167+
is_neox_style=True,
168+
)
169+
self.attn = Attention(
170+
self.num_heads,
171+
self.head_dim,
172+
self.scaling,
173+
num_kv_heads=self.num_kv_heads,
174+
cache_config=cache_config,
175+
quant_config=quant_config,
176+
attn_type=AttentionType.DECODER,
177+
prefix=f"{prefix}.attn",
178+
)
179+
180+
def forward(
181+
self,
182+
positions: torch.Tensor,
183+
hidden_states: torch.Tensor,
184+
) -> torch.Tensor:
185+
q, _ = self.q_proj(hidden_states)
186+
k, _ = self.k_proj(hidden_states)
187+
v, _ = self.v_proj(hidden_states)
188+
q, k = self.rotary_emb(positions, q, k)
189+
attn_output = self.attn(q, k, v)
190+
output, _ = self.o_proj(attn_output)
191+
return output
192+
193+
194+
class PanguDecoderLayer(nn.Module):
195+
"""Single decoder block for PanguEmbedded."""
196+
197+
def __init__(
198+
self,
199+
*,
200+
vllm_config: VllmConfig,
201+
prefix: str = "",
202+
config: PretrainedConfig | None = None,
203+
) -> None:
204+
super().__init__()
205+
config = config or vllm_config.model_config.hf_config
206+
cache_config = vllm_config.cache_config
207+
quant_config = self.get_quant_config(vllm_config)
208+
209+
self.hidden_size = config.hidden_size
210+
self.self_attn = PanguAttention(
211+
config,
212+
cache_config=cache_config,
213+
quant_config=quant_config,
214+
prefix=f"{prefix}.self_attn",
215+
)
216+
self.mlp = PanguMLP(
217+
hidden_size=self.hidden_size,
218+
intermediate_size=config.intermediate_size,
219+
hidden_act=config.hidden_act,
220+
bias=getattr(config, "bias", False),
221+
quant_config=quant_config,
222+
prefix=f"{prefix}.mlp",
223+
)
224+
self.input_layernorm = RMSNorm(
225+
config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-5)
226+
)
227+
self.post_attention_layernorm = RMSNorm(
228+
config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-5)
229+
)
230+
231+
def forward(
232+
self,
233+
positions: torch.Tensor,
234+
hidden_states: torch.Tensor,
235+
residual: torch.Tensor | None,
236+
) -> tuple[torch.Tensor, torch.Tensor]:
237+
if residual is None:
238+
residual = hidden_states
239+
hidden_states = self.input_layernorm(hidden_states)
240+
else:
241+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
242+
243+
hidden_states = self.self_attn(
244+
positions=positions,
245+
hidden_states=hidden_states,
246+
)
247+
hidden_states, residual = self.post_attention_layernorm(
248+
hidden_states, residual
249+
)
250+
hidden_states = self.mlp(hidden_states)
251+
return hidden_states, residual
252+
253+
def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
254+
return vllm_config.quant_config
255+
256+
257+
class PanguModel(nn.Module):
258+
"""Backbone model for OpenPangu Embedded."""
259+
260+
def __init__(
261+
self,
262+
*,
263+
vllm_config: VllmConfig,
264+
prefix: str = "",
265+
layer_type: type[nn.Module] = PanguDecoderLayer,
266+
) -> None:
267+
super().__init__()
268+
269+
config = vllm_config.model_config.hf_config
270+
quant_config = vllm_config.quant_config
271+
lora_config = vllm_config.lora_config
272+
273+
self.config = config
274+
self.quant_config = quant_config
275+
lora_vocab = (
276+
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
277+
if lora_config
278+
else 0
279+
)
280+
self.vocab_size = config.vocab_size + lora_vocab
281+
self.org_vocab_size = config.vocab_size
282+
if get_pp_group().is_first_rank or (
283+
getattr(config, "tie_word_embeddings", True)
284+
and get_pp_group().is_last_rank
285+
):
286+
self.embed_tokens = VocabParallelEmbedding(
287+
self.vocab_size,
288+
config.hidden_size,
289+
org_num_embeddings=config.vocab_size,
290+
quant_config=quant_config,
291+
)
292+
else:
293+
self.embed_tokens = PPMissingLayer()
294+
295+
self.start_layer, self.end_layer, self.layers = make_layers(
296+
config.num_hidden_layers,
297+
lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
298+
prefix=f"{prefix}.layers",
299+
)
300+
301+
if get_pp_group().is_last_rank:
302+
self.norm = RMSNorm(
303+
config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-5)
304+
)
305+
else:
306+
self.norm = PPMissingLayer()
307+
308+
self.aux_hidden_state_layers: tuple[int, ...] = ()
309+
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
310+
["hidden_states", "residual"], config.hidden_size
311+
)
312+
313+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
314+
return self.embed_tokens(input_ids)
315+
316+
def forward(
317+
self,
318+
input_ids: torch.Tensor | None,
319+
positions: torch.Tensor,
320+
intermediate_tensors: IntermediateTensors | None,
321+
inputs_embeds: torch.Tensor | None = None,
322+
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
323+
if get_pp_group().is_first_rank:
324+
if inputs_embeds is not None:
325+
hidden_states = inputs_embeds
326+
else:
327+
hidden_states = self.get_input_embeddings(input_ids)
328+
residual = None
329+
else:
330+
assert intermediate_tensors is not None
331+
hidden_states = intermediate_tensors["hidden_states"]
332+
residual = intermediate_tensors["residual"]
333+
334+
aux_hidden_states: list[torch.Tensor] = []
335+
for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]):
336+
if idx in self.aux_hidden_state_layers:
337+
aux_hidden_states.append(hidden_states + residual)
338+
hidden_states, residual = layer(positions, hidden_states, residual)
339+
340+
if not get_pp_group().is_last_rank:
341+
return IntermediateTensors(
342+
{"hidden_states": hidden_states, "residual": residual}
343+
)
344+
345+
hidden_states, _ = self.norm(hidden_states, residual)
346+
347+
if aux_hidden_states:
348+
return hidden_states, aux_hidden_states
349+
return hidden_states
350+
351+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
352+
loader = AutoWeightsLoader(self)
353+
return loader.load_weights(weights)
354+
355+
356+
class PanguForCausalLM(LlamaForCausalLM, SupportsLoRA, SupportsPP):
357+
"""Causal LM head for OpenPangu Embedded."""
358+
359+
packed_modules_mapping: dict[str, list[str]] = {}
360+
mistral_mapping: dict[str, str] = {}
361+
362+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
363+
super().__init__(
364+
vllm_config=vllm_config,
365+
prefix=prefix,
366+
layer_type=PanguDecoderLayer,
367+
)
368+
369+
def _init_model(
370+
self,
371+
vllm_config: VllmConfig,
372+
prefix: str = "",
373+
layer_type: type[nn.Module] = PanguDecoderLayer,
374+
):
375+
return PanguModel(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@
149149
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
150150
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
151151
"OuroForCausalLM": ("ouro", "OuroForCausalLM"),
152+
"PanguEmbeddedForCausalLM": ("pangu", "PanguForCausalLM"),
152153
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
153154
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
154155
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),

0 commit comments

Comments
 (0)