diff --git a/examples/eagle/convert_checkpoint.py b/examples/eagle/convert_checkpoint.py index 1632e1e218..b1eff4b1a9 100644 --- a/examples/eagle/convert_checkpoint.py +++ b/examples/eagle/convert_checkpoint.py @@ -5,7 +5,7 @@ from pathlib import Path from tqdm import tqdm -from transformers import LlamaConfig +from transformers import AutoConfig import tensorrt_llm from tensorrt_llm.mapping import Mapping @@ -14,7 +14,6 @@ from tensorrt_llm.models.model_weights_loader import ModelWeightsLoader from tensorrt_llm.quantization import QuantAlgo - def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--model_dir', type=str, default=None) @@ -280,7 +279,7 @@ def copy(tensors): hf_config = None eagle_model_dir = args.model_dir if args.eagle_model_dir is None else args.eagle_model_dir if args.model_dir is not None: - hf_config = LlamaConfig.from_pretrained(args.model_dir) + hf_config = AutoConfig.from_pretrained(args.model_dir) args.model_type = hf_config.model_type args.n_head = hf_config.num_attention_heads @@ -322,7 +321,7 @@ def copy(tensors): else: args.head_size_eagle = args.head_dim_eagle else: - hf_config_eagle = LlamaConfig.from_pretrained(args.eagle_model_dir) + hf_config_eagle = AutoConfig.from_pretrained(args.eagle_model_dir) args.n_head_eagle = hf_config_eagle.num_attention_heads args.inter_size_eagle = hf_config_eagle.intermediate_size args.n_layer_eagle = hf_config_eagle.num_hidden_layers @@ -368,7 +367,7 @@ def copy(tensors): args.rotary_scaling = rotary_scaling eagle_net_config = { - 'architecture': "LlamaForCausalLM", + 'architecture': "Qwen2ForCausalLM", 'dtype': args.dtype, 'logits_dtype': 'float32', 'num_hidden_layers': args.n_layer_eagle, @@ -395,7 +394,9 @@ def copy(tensors): 'use_parallel_embedding': args.use_parallel_embedding, 'embedding_sharding_dim': args.embedding_sharding_dim, 'head_dim': args.head_dim_eagle, - 'head_size': args.head_size_eagle + 'head_size': args.head_size_eagle, + "qwen_type":"qwen2", + "seq_length":8192 } config = { @@ -430,7 +431,9 @@ def copy(tensors): 'max_non_leaves_per_layer': args.max_non_leaves_per_layer, 'eagle_net_config': eagle_net_config, 'head_dim': args.head_dim, - 'head_size': args.head_size + 'head_size': args.head_size, + "qwen_type":"qwen2", + "seq_length":8192 } assert args.max_draft_len <= 256, "args.max_draft_len > 256 is not supported" @@ -487,4 +490,4 @@ def copy(tensors): tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - print(f'Total time of converting checkpoints: {t}') + print(f'Total time of converting checkpoints: {t}') \ No newline at end of file diff --git a/tensorrt_llm/models/eagle/config.py b/tensorrt_llm/models/eagle/config.py index 6124884f72..06c578db1b 100644 --- a/tensorrt_llm/models/eagle/config.py +++ b/tensorrt_llm/models/eagle/config.py @@ -16,15 +16,15 @@ import json from typing import Optional, Union -from transformers import LlamaConfig +from transformers import AutoConfig as HFModelConfig from ...mapping import Mapping from ..convert_utils import infer_dtype -from ..llama.config import LLaMAConfig +from ..qwen.config import QWenConfig as TRTCModelconfig from ..modeling_utils import QuantAlgo, QuantConfig -class EagleConfig(LLaMAConfig): +class EagleConfig(TRTCModelconfig): def __init__(self, *, @@ -35,7 +35,7 @@ def __init__(self, self.num_eagle_layers = num_eagle_layers self.max_non_leaves_per_layer = max_non_leaves_per_layer self.max_draft_len = max_draft_len - self.eagle_net_config = LLaMAConfig.from_dict( + self.eagle_net_config = TRTCModelconfig.from_dict( kwargs["eagle_net_config"]) del kwargs["eagle_net_config"] super().__init__(**kwargs) @@ -73,7 +73,7 @@ def from_hugging_face( hf_config = None hf_config_or_dir if speculative_config_or_dir is None else speculative_config_or_dir if hf_config_or_dir is not None: - hf_config = LlamaConfig.from_pretrained(hf_config_or_dir) + hf_config = HFModelConfig.from_pretrained(hf_config_or_dir) hf_config.model_type n_head = hf_config.num_attention_heads @@ -91,7 +91,7 @@ def from_hugging_face( if hasattr(hf_config, 'head_dim'): head_dim = hf_config.head_dim else: - head_dim = hf_config.n_embd // hf_config.n_head + head_dim = n_embd // n_head if hasattr(hf_config, 'head_size'): head_size = hf_config.head_size else: @@ -107,7 +107,7 @@ def from_hugging_face( rms_norm_eps_eagle = hf_config_eagle['rms_norm_eps'] n_positions_eagle = hf_config_eagle['max_position_embeddings'] else: - hf_config_eagle = LlamaConfig.from_pretrained( + hf_config_eagle = HFModelConfig.from_pretrained( speculative_config_or_dir) n_head_eagle = hf_config_eagle.num_attention_heads inter_size_eagle = hf_config_eagle.intermediate_size @@ -125,7 +125,7 @@ def from_hugging_face( rotary_scaling = rotary_scaling eagle_net_config = { - 'architecture': "LlamaForCausalLM", + 'architecture': "Qwen2ForCausalLM", 'dtype': dtype, 'logits_dtype': 'float32', 'num_hidden_layers': n_layer_eagle, @@ -152,7 +152,9 @@ def from_hugging_face( 'use_parallel_embedding': kwargs['use_parallel_embedding'], 'embedding_sharding_dim': kwargs['embedding_sharding_dim'], 'head_dim': head_dim, - 'head_size': head_size + 'head_size': head_size, + "qwen_type":"qwen2", + "seq_length":8192 } config = { @@ -185,7 +187,9 @@ def from_hugging_face( 'num_eagle_layers': kwargs['speculative_config'].num_eagle_layers, 'max_non_leaves_per_layer': kwargs['speculative_config'].max_non_leaves_per_layer, - 'eagle_net_config': eagle_net_config + 'eagle_net_config': eagle_net_config, + "qwen_type":"qwen2", + "seq_length":8192 } if quant_config: config['quantization']['quant_algo'] = quant_config.quant_algo @@ -218,4 +222,4 @@ def from_hugging_face( except IOError: pass - return cls.from_dict(config) + return cls.from_dict(config) \ No newline at end of file diff --git a/tensorrt_llm/models/eagle/model.py b/tensorrt_llm/models/eagle/model.py index 01b3905ba7..19f1ba5138 100644 --- a/tensorrt_llm/models/eagle/model.py +++ b/tensorrt_llm/models/eagle/model.py @@ -22,7 +22,9 @@ from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.generation_mixin import GenerationMixin -from tensorrt_llm.models.llama.model import LLaMAForCausalLM, LLaMAModel +from tensorrt_llm.models.qwen.model import QWenModel as TRTModel +from tensorrt_llm.models.qwen.model import QWenForCausalLM as TRTModelForCausalLM + from tensorrt_llm.models.model_weights_loader import ModelWeightsLoader from ..._common import default_net, default_trtnet @@ -531,7 +533,7 @@ class EagleNet(Module): def __init__(self, config, logits_dtype): super().__init__() - self.drafter = LLaMAModel(config) + self.drafter = TRTModel(config) self.config = config self.logits_dtype = logits_dtype @@ -575,7 +577,7 @@ def forward(self, return None, hidden_states, cache -class EagleForCausalLM(LLaMAForCausalLM): +class EagleForCausalLM(TRTModelForCausalLM): config_class = EagleConfig def __init__(self, config: EagleConfig): @@ -1324,4 +1326,4 @@ def copy(tensors): eagle_loader.load("eagle_nets." + tllm_key)) base_loader.fill(tllm_weights) - return model + return model \ No newline at end of file diff --git a/tensorrt_llm/models/qwen/config.py b/tensorrt_llm/models/qwen/config.py index 5ab3648938..64daa0a1b4 100644 --- a/tensorrt_llm/models/qwen/config.py +++ b/tensorrt_llm/models/qwen/config.py @@ -33,6 +33,7 @@ def __init__(self, moe: Optional[Union[MoeConfig, dict]] = None, num_labels: int = 1, **kwargs): + self.mlp_bias = mlp_bias self.attn_bias = attn_bias self.rotary_base = rotary_base @@ -40,6 +41,8 @@ def __init__(self, self.disable_weight_only_quant_plugin = disable_weight_only_quant_plugin self.num_labels = num_labels self.use_logn_attn = use_logn_attn + self.fc_after_embed = False + if moe is None: # Legacy MOE config fields moe = MoeConfig(num_experts=kwargs.pop('moe_num_experts', 0), @@ -65,6 +68,7 @@ def to_dict(self): 'disable_weight_only_quant_plugin'] = self.disable_weight_only_quant_plugin output['use_logn_attn'] = self.use_logn_attn output['moe'] = self.moe.to_dict() + output['fc_after_embed'] = self.fc_after_embed return output @classmethod @@ -105,7 +109,7 @@ def from_hugging_face(cls, hf_config.architectures = ['Qwen2ForCausalLM'] valid_types = ('qwen', 'qwen2', 'qwen2_moe', 'qwen2_llava_onevision', - 'qwen2_vl', 'qwen2_audio') + 'qwen2_vl', 'qwen2_audio','qwen3') assert qwen_type in valid_types, f"Unsupported Qwen type: {qwen_type}, only {valid_types} are acceptable." num_key_value_heads = getattr(hf_config, "num_key_value_heads", hf_config.num_attention_heads) @@ -114,7 +118,12 @@ def from_hugging_face(cls, hidden_act = getattr(hf_config, "hidden_act", "silu") if qwen_type == "qwen2_moe": hidden_act = "swiglu" + attn_bias = True # All existing Qwen models have attn bias + if qwen_type == 'qwen3': + attn_bias = getattr(hf_config,"attention_bias",attn_bias) # qwen3 dense attn_bias is false + head_size = getattr(hf_config,"head_dim", head_size) # qwen3 dense config contains head_dim + rotary_scaling = getattr(hf_config, "rope_scaling", None) seq_length = getattr(hf_config, "seq_length", 8192) use_logn_attn = getattr(hf_config, "use_logn_attn", False) @@ -187,4 +196,4 @@ def from_hugging_face(cls, quantization=quant_config, num_labels=num_labels, tie_word_embeddings=tie_word_embeddings, - **kwargs) + **kwargs) \ No newline at end of file diff --git a/tensorrt_llm/models/qwen/model.py b/tensorrt_llm/models/qwen/model.py index 60a2f8b38a..68fe2f14a7 100644 --- a/tensorrt_llm/models/qwen/model.py +++ b/tensorrt_llm/models/qwen/model.py @@ -21,7 +21,7 @@ from tqdm import tqdm from ..._utils import pad_vocab_size -from ...functional import Tensor, recv, send +from ...functional import Tensor, recv, send,LayerNormType,concat from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear, Embedding, GatedMLP, RmsNorm, SharedMoE) from ...layers.moe import MOEWeightWrapper @@ -50,10 +50,17 @@ def __init__(self, config: QWenConfig, layer_idx: int): self.tp_group = config.mapping.tp_group self.tp_size = config.mapping.tp_size - self.input_layernorm = RmsNorm(normalized_shape=config.hidden_size, + self.fc_after_embed = getattr(config, 'fc_after_embed', False) # 默认值设为None + + # Eagle 权重 无此部分 + if not self.fc_after_embed: + self.input_layernorm = RmsNorm(normalized_shape=config.hidden_size, eps=config.norm_epsilon, dtype=dtype) + layernorm_type = LayerNormType.RmsNorm + qk_layernorm = True if config.qwen_type == "qwen3" else False + layers_range = config.mapping.pp_layers(config.num_hidden_layers) local_layer_idx = layer_idx - layers_range[0] self.attention = Attention( @@ -78,7 +85,10 @@ def __init__(self, config: QWenConfig, layer_idx: int): cp_group=config.mapping.cp_group, quant_mode=config.quant_mode, use_logn_scaling=config.use_logn_attn, - dense_bias=False) + dense_bias=False, + qk_layernorm=qk_layernorm, + layernorm_type=layernorm_type, + eps = config.norm_epsilon) if config.moe.has_moe(): mlp_kwargs = {'moe_config': config.moe, 'mapping': config.mapping} @@ -127,7 +137,9 @@ def forward( mrope_params=None, ): residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + if not self.fc_after_embed: + hidden_states = self.input_layernorm(hidden_states) + attention_output = self.attention( hidden_states, attention_mask=attention_mask, @@ -167,11 +179,24 @@ def __init__(self, config: QWenConfig) -> None: dtype=config.dtype) self.layers = DecoderLayerList(QWenDecoderLayer, config) - - if self.mapping.is_last_pp_rank(): - self.ln_f = RmsNorm(normalized_shape=config.hidden_size, - eps=config.norm_epsilon, - dtype=config.dtype) + self.fc_after_embed = getattr(config, 'fc_after_embed', False) # 默认值设为None + + # Eagle 权重 无此部分 + if self.mapping.is_last_pp_rank(): + if not self.fc_after_embed: + self.ln_f = RmsNorm(normalized_shape=config.hidden_size, + eps=config.norm_epsilon, + dtype=config.dtype) + + + if self.fc_after_embed: + self.fc = ColumnLinear(2 * config.hidden_size, + config.hidden_size, + bias=False, + dtype=config.dtype, + tp_group=config.mapping.tp_group, + tp_size=config.mapping.tp_size, + gather_output=True) def forward(self, input_ids: Tensor, @@ -183,6 +208,7 @@ def forward(self, attention_params=None, mrope_params=None, hidden_states=None, + hidden_states_for_embed=None, prompt_embedding_table: Optional[Tensor] = None, prompt_tasks: Optional[Tensor] = None, prompt_vocab_size: Optional[Tensor] = None, @@ -197,6 +223,11 @@ def forward(self, else: hidden_states = recv(hidden_states, self.mapping.prev_pp_rank()) + if hidden_states_for_embed is not None: + hidden_states = concat([hidden_states, hidden_states_for_embed], + dim=-1) + hidden_states = self.fc(hidden_states) + hidden_states = self.layers.forward( hidden_states, use_cache=use_cache, @@ -210,8 +241,10 @@ def forward(self, if use_cache: hidden_states, presents = hidden_states + if self.mapping.is_last_pp_rank(): - hidden_states = self.ln_f(hidden_states) + if not self.fc_after_embed: + hidden_states = self.ln_f(hidden_states) else: hidden_states = send(hidden_states, self.mapping.next_pp_rank()) @@ -336,7 +369,7 @@ def from_hugging_face( "mlp.shared_expert_gate": "mlp.shared_expert_gate", "fc": ["up_proj", "gate_proj"], } - elif config.qwen_type in {"qwen2", "qwen2_vl" + elif config.qwen_type in {"qwen2", "qwen2_vl","qwen3" } and config.tie_word_embeddings: custom_dict = {"lm_head": "model.embed_tokens"} elif config.architecture == "Qwen2ForSequenceClassification": @@ -353,6 +386,18 @@ def from_hugging_face( "transformer": "language_model.model", "lm_head": "language_model.lm_head", } + + if config.tie_word_embeddings: + config.share_embedding_table = True + config.use_parallel_embedding = True + + # q k layernorm is required in qwen3 dense model + if config.qwen_type == "qwen3": + custom_dict.update({ + "q_layernorm":"q_norm", + "k_layernorm":"k_norm" + }) + loader = ModelWeightsLoader(hf_model_dir, custom_dict) model = cls(config) if config.qwen_type == "qwen" and model.config.mapping.has_tp(): @@ -515,4 +560,4 @@ def quantize( ) def use_lora(self, lora_config: LoraConfig): - use_lora(self, lora_config, self.trtllm_modules_to_hf_modules) + use_lora(self, lora_config, self.trtllm_modules_to_hf_modules) \ No newline at end of file