Skip to content

Commit 80fbcca

Browse files
committed
Fix Pangu aux-state indexing and apply ruff format
Signed-off-by: YoussefEssDS <[email protected]>
1 parent 5afb5b8 commit 80fbcca

File tree

2 files changed

+6
-13
lines changed

2 files changed

+6
-13
lines changed

tests/models/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def check_available_online(
371371
),
372372
"OuroForCausalLM": _HfExamplesInfo("ByteDance/Ouro-1.4B", trust_remote_code=True),
373373
"PanguEmbeddedForCausalLM": _HfExamplesInfo(
374-
"FreedomIntelligence/openPangu-Embedded-7B"
374+
"FreedomIntelligence/openPangu-Embedded-7B, trust_remote_code=True"
375375
),
376376
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
377377
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),

vllm/model_executor/models/pangu.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
"""Native OpenPangu Embedded model implementation."""
55

66
from collections.abc import Iterable
7-
from typing import Any
87

98
import torch
109
from torch import nn
@@ -22,14 +21,13 @@
2221
from vllm.model_executor.layers.quantization import QuantizationConfig
2322
from vllm.model_executor.layers.rotary_embedding import get_rope
2423
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
25-
from vllm.model_executor.models.llama import LlamaForCausalLM
2624
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
25+
from vllm.model_executor.models.llama import LlamaForCausalLM
2726
from vllm.model_executor.models.utils import (
2827
AutoWeightsLoader,
2928
PPMissingLayer,
3029
make_empty_intermediate_tensors_factory,
3130
make_layers,
32-
maybe_prefix,
3331
)
3432
from vllm.sequence import IntermediateTensors
3533

@@ -124,9 +122,7 @@ def __init__(
124122
rope_scaling.setdefault(
125123
"original_max_position_embeddings", original_max_position
126124
)
127-
max_position_embeddings = getattr(
128-
config, "max_position_embeddings", 2048
129-
)
125+
max_position_embeddings = getattr(config, "max_position_embeddings", 2048)
130126

131127
bias = getattr(config, "bias", False)
132128
self.q_proj = ColumnParallelLinear(
@@ -244,9 +240,7 @@ def forward(
244240
positions=positions,
245241
hidden_states=hidden_states,
246242
)
247-
hidden_states, residual = self.post_attention_layernorm(
248-
hidden_states, residual
249-
)
243+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
250244
hidden_states = self.mlp(hidden_states)
251245
return hidden_states, residual
252246

@@ -280,8 +274,7 @@ def __init__(
280274
self.vocab_size = config.vocab_size + lora_vocab
281275
self.org_vocab_size = config.vocab_size
282276
if get_pp_group().is_first_rank or (
283-
getattr(config, "tie_word_embeddings", True)
284-
and get_pp_group().is_last_rank
277+
getattr(config, "tie_word_embeddings", True) and get_pp_group().is_last_rank
285278
):
286279
self.embed_tokens = VocabParallelEmbedding(
287280
self.vocab_size,
@@ -333,7 +326,7 @@ def forward(
333326

334327
aux_hidden_states: list[torch.Tensor] = []
335328
for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]):
336-
if idx in self.aux_hidden_state_layers:
329+
if self.start_layer + idx in self.aux_hidden_state_layers:
337330
aux_hidden_states.append(hidden_states + residual)
338331
hidden_states, residual = layer(positions, hidden_states, residual)
339332

0 commit comments

Comments
 (0)