From 5afb5b84261ac0597219a9b2d95bb3d84280ecd9 Mon Sep 17 00:00:00 2001 From: YoussefEssDS Date: Sun, 2 Nov 2025 13:55:37 +0000 Subject: [PATCH 1/5] Add native OpenPangu Embedded backend to vLLM Signed-off-by: YoussefEssDS --- tests/models/registry.py | 3 + vllm/model_executor/models/pangu.py | 375 +++++++++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + 3 files changed, 379 insertions(+) create mode 100644 vllm/model_executor/models/pangu.py diff --git a/tests/models/registry.py b/tests/models/registry.py index 8e1dd4ba91f1..be3ec6a38f35 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -370,6 +370,9 @@ def check_available_online( "OrionStarAI/Orion-14B-Chat", trust_remote_code=True ), "OuroForCausalLM": _HfExamplesInfo("ByteDance/Ouro-1.4B", trust_remote_code=True), + "PanguEmbeddedForCausalLM": _HfExamplesInfo( + "FreedomIntelligence/openPangu-Embedded-7B" + ), "PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"), "PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"), "Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"), diff --git a/vllm/model_executor/models/pangu.py b/vllm/model_executor/models/pangu.py new file mode 100644 index 000000000000..9fe0e02770f4 --- /dev/null +++ b/vllm/model_executor/models/pangu.py @@ -0,0 +1,375 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Native OpenPangu Embedded model implementation.""" + +from collections.abc import Iterable +from typing import Any + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionType +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + PPMissingLayer, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) +from vllm.sequence import IntermediateTensors + + +class PanguMLP(nn.Module): + """Feed-forward network for PanguEmbedded layers.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + *, + bias: bool, + quant_config: QuantizationConfig | None, + prefix: str, + ) -> None: + super().__init__() + self.gate_proj = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_proj", + ) + self.up_proj = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = get_act_fn(hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + gate, _ = self.gate_proj(hidden_states) + up, _ = self.up_proj(hidden_states) + hidden_states = self.act_fn(gate) * up + hidden_states, _ = self.down_proj(hidden_states) + return hidden_states + + +class PanguAttention(nn.Module): + """Self-attention block with GQA.""" + + def __init__( + self, + config: PretrainedConfig, + *, + cache_config: CacheConfig | None, + quant_config: QuantizationConfig | None, + prefix: str, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.total_num_heads = config.num_attention_heads + self.total_num_kv_heads = getattr( + config, "num_key_value_heads", config.num_attention_heads + ) + tp_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = getattr( + config, + "head_dim", + self.hidden_size // self.total_num_heads, + ) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + rope_theta = getattr(config, "rope_theta", 10000.0) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None: + rope_scaling = dict(rope_scaling) + original_max_position = getattr( + config, "original_max_position_embeddings", None + ) + if original_max_position is not None: + rope_scaling.setdefault( + "original_max_position_embeddings", original_max_position + ) + max_position_embeddings = getattr( + config, "max_position_embeddings", 2048 + ) + + bias = getattr(config, "bias", False) + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.total_num_heads * self.head_dim, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) + self.k_proj = ColumnParallelLinear( + self.hidden_size, + self.total_num_kv_heads * self.head_dim, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.k_proj", + ) + self.v_proj = ColumnParallelLinear( + self.hidden_size, + self.total_num_kv_heads * self.head_dim, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.v_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=True, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + attn_type=AttentionType.DECODER, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + q, _ = self.q_proj(hidden_states) + k, _ = self.k_proj(hidden_states) + v, _ = self.v_proj(hidden_states) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class PanguDecoderLayer(nn.Module): + """Single decoder block for PanguEmbedded.""" + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + config: PretrainedConfig | None = None, + ) -> None: + super().__init__() + config = config or vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = self.get_quant_config(vllm_config) + + self.hidden_size = config.hidden_size + self.self_attn = PanguAttention( + config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = PanguMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + bias=getattr(config, "bias", False), + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm( + config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-5) + ) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-5) + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None: + return vllm_config.quant_config + + +class PanguModel(nn.Module): + """Backbone model for OpenPangu Embedded.""" + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = PanguDecoderLayer, + ) -> None: + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.quant_config = quant_config + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank or ( + getattr(config, "tie_word_embeddings", True) + and get_pp_group().is_last_rank + ): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix), + prefix=f"{prefix}.layers", + ) + + if get_pp_group().is_last_rank: + self.norm = RMSNorm( + config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-5) + ) + else: + self.norm = PPMissingLayer() + + self.aux_hidden_state_layers: tuple[int, ...] = () + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + aux_hidden_states: list[torch.Tensor] = [] + for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]): + if idx in self.aux_hidden_state_layers: + aux_hidden_states.append(hidden_states + residual) + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + if aux_hidden_states: + return hidden_states, aux_hidden_states + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + +class PanguForCausalLM(LlamaForCausalLM, SupportsLoRA, SupportsPP): + """Causal LM head for OpenPangu Embedded.""" + + packed_modules_mapping: dict[str, list[str]] = {} + mistral_mapping: dict[str, str] = {} + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__( + vllm_config=vllm_config, + prefix=prefix, + layer_type=PanguDecoderLayer, + ) + + def _init_model( + self, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = PanguDecoderLayer, + ): + return PanguModel(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 7eca1a09e536..c5f2dc02401d 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -149,6 +149,7 @@ "OPTForCausalLM": ("opt", "OPTForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"), "OuroForCausalLM": ("ouro", "OuroForCausalLM"), + "PanguEmbeddedForCausalLM": ("pangu", "PanguForCausalLM"), "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), From 80fbcca085455784344a362e329fa109761abf99 Mon Sep 17 00:00:00 2001 From: YoussefEssDS Date: Sun, 2 Nov 2025 16:09:00 +0000 Subject: [PATCH 2/5] Fix Pangu aux-state indexing and apply ruff format Signed-off-by: YoussefEssDS --- tests/models/registry.py | 2 +- vllm/model_executor/models/pangu.py | 17 +++++------------ 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index be3ec6a38f35..bdc8c74e0944 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -371,7 +371,7 @@ def check_available_online( ), "OuroForCausalLM": _HfExamplesInfo("ByteDance/Ouro-1.4B", trust_remote_code=True), "PanguEmbeddedForCausalLM": _HfExamplesInfo( - "FreedomIntelligence/openPangu-Embedded-7B" + "FreedomIntelligence/openPangu-Embedded-7B, trust_remote_code=True" ), "PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"), "PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"), diff --git a/vllm/model_executor/models/pangu.py b/vllm/model_executor/models/pangu.py index 9fe0e02770f4..ff35dd19b410 100644 --- a/vllm/model_executor/models/pangu.py +++ b/vllm/model_executor/models/pangu.py @@ -4,7 +4,6 @@ """Native OpenPangu Embedded model implementation.""" from collections.abc import Iterable -from typing import Any import torch from torch import nn @@ -22,14 +21,13 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding -from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP +from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.utils import ( AutoWeightsLoader, PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix, ) from vllm.sequence import IntermediateTensors @@ -124,9 +122,7 @@ def __init__( rope_scaling.setdefault( "original_max_position_embeddings", original_max_position ) - max_position_embeddings = getattr( - config, "max_position_embeddings", 2048 - ) + max_position_embeddings = getattr(config, "max_position_embeddings", 2048) bias = getattr(config, "bias", False) self.q_proj = ColumnParallelLinear( @@ -244,9 +240,7 @@ def forward( positions=positions, hidden_states=hidden_states, ) - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual - ) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -280,8 +274,7 @@ def __init__( self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size if get_pp_group().is_first_rank or ( - getattr(config, "tie_word_embeddings", True) - and get_pp_group().is_last_rank + getattr(config, "tie_word_embeddings", True) and get_pp_group().is_last_rank ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, @@ -333,7 +326,7 @@ def forward( aux_hidden_states: list[torch.Tensor] = [] for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]): - if idx in self.aux_hidden_state_layers: + if self.start_layer + idx in self.aux_hidden_state_layers: aux_hidden_states.append(hidden_states + residual) hidden_states, residual = layer(positions, hidden_states, residual) From 232c9c5fda4ced51ed78316185ea74e1b21eabb6 Mon Sep 17 00:00:00 2001 From: YoussefEssDS Date: Mon, 3 Nov 2025 03:30:47 +0000 Subject: [PATCH 3/5] Guard aux residual collection & update supported models docs Signed-off-by: YoussefEssDS --- docs/models/supported_models.md | 1 + tests/models/registry.py | 2 +- vllm/model_executor/models/pangu.py | 4 +++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index fd25647dce54..5459853068b7 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -676,6 +676,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `Ovis` | Ovis2, Ovis1.6 | T + I+ | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | | `Ovis2_5` | Ovis2.5 | T + I+ + V | `AIDC-AI/Ovis2.5-9B`, etc. | | | | `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + IE | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | +| `PanguEmbeddedForCausalLM` | openPangu Embedded | `FreedomIntelligence/openPangu-Embedded-7B` | | ✅︎ | | `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + IE+ | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | | `Phi4MMForCausalLM` | Phi-4-multimodal | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | | `Phi4MultimodalForCausalLM` | Phi-4-multimodal (HF Transformers) | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct` (with revision `refs/pr/70`), etc. | ✅︎ | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index bdc8c74e0944..9e89bea3c771 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -371,7 +371,7 @@ def check_available_online( ), "OuroForCausalLM": _HfExamplesInfo("ByteDance/Ouro-1.4B", trust_remote_code=True), "PanguEmbeddedForCausalLM": _HfExamplesInfo( - "FreedomIntelligence/openPangu-Embedded-7B, trust_remote_code=True" + "FreedomIntelligence/openPangu-Embedded-7B", trust_remote_code=True ), "PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"), "PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"), diff --git a/vllm/model_executor/models/pangu.py b/vllm/model_executor/models/pangu.py index ff35dd19b410..9fec985e23fe 100644 --- a/vllm/model_executor/models/pangu.py +++ b/vllm/model_executor/models/pangu.py @@ -326,7 +326,9 @@ def forward( aux_hidden_states: list[torch.Tensor] = [] for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]): - if self.start_layer + idx in self.aux_hidden_state_layers: + if residual is None: + aux_hidden_states.append(hidden_states) + else: aux_hidden_states.append(hidden_states + residual) hidden_states, residual = layer(positions, hidden_states, residual) From dd7cf363379de7fe79bca35318623f8ebaafa46e Mon Sep 17 00:00:00 2001 From: YoussefEssDS Date: Mon, 3 Nov 2025 04:13:46 +0000 Subject: [PATCH 4/5] Add missing doc entry Signed-off-by: YoussefEssDS --- docs/models/supported_models.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 5459853068b7..728fb4600548 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -676,7 +676,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `Ovis` | Ovis2, Ovis1.6 | T + I+ | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | | `Ovis2_5` | Ovis2.5 | T + I+ + V | `AIDC-AI/Ovis2.5-9B`, etc. | | | | `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + IE | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | -| `PanguEmbeddedForCausalLM` | openPangu Embedded | `FreedomIntelligence/openPangu-Embedded-7B` | | ✅︎ | +| `PanguEmbeddedForCausalLM` | openPangu Embedded | T | `FreedomIntelligence/openPangu-Embedded-7B` | ✅︎ | ✅︎ | | `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + IE+ | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | | `Phi4MMForCausalLM` | Phi-4-multimodal | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | | `Phi4MultimodalForCausalLM` | Phi-4-multimodal (HF Transformers) | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct` (with revision `refs/pr/70`), etc. | ✅︎ | ✅︎ | From 2ce7be45bf5ce3dafdb8e092182e30a15bd9a28e Mon Sep 17 00:00:00 2001 From: YoussefEssDS Date: Mon, 3 Nov 2025 04:53:02 +0000 Subject: [PATCH 5/5] Fix model placement in docs Signed-off-by: YoussefEssDS --- docs/models/supported_models.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 728fb4600548..d1484522c224 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -404,6 +404,7 @@ th { | `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | ✅︎ | ✅︎ | | `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | | `OuroForCausalLM` | ouro | `ByteDance/Ouro-1.4B`, `ByteDance/Ouro-2.6B`, etc. | ✅︎ | | +| `PanguEmbeddedForCausalLM` | openPangu Embedded |`FreedomIntelligence/openPangu-Embedded-7B` | ✅︎ | ✅︎ | | `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ | | `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | | `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | @@ -676,7 +677,6 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `Ovis` | Ovis2, Ovis1.6 | T + I+ | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | | `Ovis2_5` | Ovis2.5 | T + I+ + V | `AIDC-AI/Ovis2.5-9B`, etc. | | | | `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + IE | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | -| `PanguEmbeddedForCausalLM` | openPangu Embedded | T | `FreedomIntelligence/openPangu-Embedded-7B` | ✅︎ | ✅︎ | | `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + IE+ | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | | `Phi4MMForCausalLM` | Phi-4-multimodal | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | | `Phi4MultimodalForCausalLM` | Phi-4-multimodal (HF Transformers) | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct` (with revision `refs/pr/70`), etc. | ✅︎ | ✅︎ |