Skip to content

feat(eagle):support qwen in eagle1/2 #5352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions examples/eagle/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pathlib import Path

from tqdm import tqdm
from transformers import LlamaConfig
from transformers import AutoConfig

import tensorrt_llm
from tensorrt_llm.mapping import Mapping
Expand All @@ -14,7 +14,6 @@
from tensorrt_llm.models.model_weights_loader import ModelWeightsLoader
from tensorrt_llm.quantization import QuantAlgo


def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', type=str, default=None)
Expand Down Expand Up @@ -280,7 +279,7 @@ def copy(tensors):
hf_config = None
eagle_model_dir = args.model_dir if args.eagle_model_dir is None else args.eagle_model_dir
if args.model_dir is not None:
hf_config = LlamaConfig.from_pretrained(args.model_dir)
hf_config = AutoConfig.from_pretrained(args.model_dir)

args.model_type = hf_config.model_type
args.n_head = hf_config.num_attention_heads
Expand Down Expand Up @@ -322,7 +321,7 @@ def copy(tensors):
else:
args.head_size_eagle = args.head_dim_eagle
else:
hf_config_eagle = LlamaConfig.from_pretrained(args.eagle_model_dir)
hf_config_eagle = AutoConfig.from_pretrained(args.eagle_model_dir)
args.n_head_eagle = hf_config_eagle.num_attention_heads
args.inter_size_eagle = hf_config_eagle.intermediate_size
args.n_layer_eagle = hf_config_eagle.num_hidden_layers
Expand Down Expand Up @@ -368,7 +367,7 @@ def copy(tensors):
args.rotary_scaling = rotary_scaling

eagle_net_config = {
'architecture': "LlamaForCausalLM",
'architecture': "Qwen2ForCausalLM",
'dtype': args.dtype,
'logits_dtype': 'float32',
'num_hidden_layers': args.n_layer_eagle,
Expand All @@ -395,7 +394,9 @@ def copy(tensors):
'use_parallel_embedding': args.use_parallel_embedding,
'embedding_sharding_dim': args.embedding_sharding_dim,
'head_dim': args.head_dim_eagle,
'head_size': args.head_size_eagle
'head_size': args.head_size_eagle,
"qwen_type":"qwen2",
"seq_length":8192
}

config = {
Expand Down Expand Up @@ -430,7 +431,9 @@ def copy(tensors):
'max_non_leaves_per_layer': args.max_non_leaves_per_layer,
'eagle_net_config': eagle_net_config,
'head_dim': args.head_dim,
'head_size': args.head_size
'head_size': args.head_size,
"qwen_type":"qwen2",
"seq_length":8192
}

assert args.max_draft_len <= 256, "args.max_draft_len > 256 is not supported"
Expand Down Expand Up @@ -487,4 +490,4 @@ def copy(tensors):

tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
print(f'Total time of converting checkpoints: {t}')
print(f'Total time of converting checkpoints: {t}')
26 changes: 15 additions & 11 deletions tensorrt_llm/models/eagle/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
import json
from typing import Optional, Union

from transformers import LlamaConfig
from transformers import AutoConfig as HFModelConfig

from ...mapping import Mapping
from ..convert_utils import infer_dtype
from ..llama.config import LLaMAConfig
from ..qwen.config import QWenConfig as TRTCModelconfig
from ..modeling_utils import QuantAlgo, QuantConfig


class EagleConfig(LLaMAConfig):
class EagleConfig(TRTCModelconfig):

def __init__(self,
*,
Expand All @@ -35,7 +35,7 @@ def __init__(self,
self.num_eagle_layers = num_eagle_layers
self.max_non_leaves_per_layer = max_non_leaves_per_layer
self.max_draft_len = max_draft_len
self.eagle_net_config = LLaMAConfig.from_dict(
self.eagle_net_config = TRTCModelconfig.from_dict(
kwargs["eagle_net_config"])
del kwargs["eagle_net_config"]
super().__init__(**kwargs)
Expand Down Expand Up @@ -73,7 +73,7 @@ def from_hugging_face(
hf_config = None
hf_config_or_dir if speculative_config_or_dir is None else speculative_config_or_dir
if hf_config_or_dir is not None:
hf_config = LlamaConfig.from_pretrained(hf_config_or_dir)
hf_config = HFModelConfig.from_pretrained(hf_config_or_dir)

hf_config.model_type
n_head = hf_config.num_attention_heads
Expand All @@ -91,7 +91,7 @@ def from_hugging_face(
if hasattr(hf_config, 'head_dim'):
head_dim = hf_config.head_dim
else:
head_dim = hf_config.n_embd // hf_config.n_head
head_dim = n_embd // n_head
if hasattr(hf_config, 'head_size'):
head_size = hf_config.head_size
else:
Expand All @@ -107,7 +107,7 @@ def from_hugging_face(
rms_norm_eps_eagle = hf_config_eagle['rms_norm_eps']
n_positions_eagle = hf_config_eagle['max_position_embeddings']
else:
hf_config_eagle = LlamaConfig.from_pretrained(
hf_config_eagle = HFModelConfig.from_pretrained(
speculative_config_or_dir)
n_head_eagle = hf_config_eagle.num_attention_heads
inter_size_eagle = hf_config_eagle.intermediate_size
Expand All @@ -125,7 +125,7 @@ def from_hugging_face(
rotary_scaling = rotary_scaling

eagle_net_config = {
'architecture': "LlamaForCausalLM",
'architecture': "Qwen2ForCausalLM",
'dtype': dtype,
'logits_dtype': 'float32',
'num_hidden_layers': n_layer_eagle,
Expand All @@ -152,7 +152,9 @@ def from_hugging_face(
'use_parallel_embedding': kwargs['use_parallel_embedding'],
'embedding_sharding_dim': kwargs['embedding_sharding_dim'],
'head_dim': head_dim,
'head_size': head_size
'head_size': head_size,
"qwen_type":"qwen2",
"seq_length":8192
}

config = {
Expand Down Expand Up @@ -185,7 +187,9 @@ def from_hugging_face(
'num_eagle_layers': kwargs['speculative_config'].num_eagle_layers,
'max_non_leaves_per_layer':
kwargs['speculative_config'].max_non_leaves_per_layer,
'eagle_net_config': eagle_net_config
'eagle_net_config': eagle_net_config,
"qwen_type":"qwen2",
"seq_length":8192
}
if quant_config:
config['quantization']['quant_algo'] = quant_config.quant_algo
Expand Down Expand Up @@ -218,4 +222,4 @@ def from_hugging_face(
except IOError:
pass

return cls.from_dict(config)
return cls.from_dict(config)
10 changes: 6 additions & 4 deletions tensorrt_llm/models/eagle/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@

from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.generation_mixin import GenerationMixin
from tensorrt_llm.models.llama.model import LLaMAForCausalLM, LLaMAModel
from tensorrt_llm.models.qwen.model import QWenModel as TRTModel
from tensorrt_llm.models.qwen.model import QWenForCausalLM as TRTModelForCausalLM

from tensorrt_llm.models.model_weights_loader import ModelWeightsLoader

from ..._common import default_net, default_trtnet
Expand Down Expand Up @@ -531,7 +533,7 @@ class EagleNet(Module):

def __init__(self, config, logits_dtype):
super().__init__()
self.drafter = LLaMAModel(config)
self.drafter = TRTModel(config)
self.config = config
self.logits_dtype = logits_dtype

Expand Down Expand Up @@ -575,7 +577,7 @@ def forward(self,
return None, hidden_states, cache


class EagleForCausalLM(LLaMAForCausalLM):
class EagleForCausalLM(TRTModelForCausalLM):
config_class = EagleConfig

def __init__(self, config: EagleConfig):
Expand Down Expand Up @@ -1324,4 +1326,4 @@ def copy(tensors):
eagle_loader.load("eagle_nets." + tllm_key))
base_loader.fill(tllm_weights)

return model
return model
13 changes: 11 additions & 2 deletions tensorrt_llm/models/qwen/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ def __init__(self,
moe: Optional[Union[MoeConfig, dict]] = None,
num_labels: int = 1,
**kwargs):

self.mlp_bias = mlp_bias
self.attn_bias = attn_bias
self.rotary_base = rotary_base
self.rotary_scaling = rotary_scaling
self.disable_weight_only_quant_plugin = disable_weight_only_quant_plugin
self.num_labels = num_labels
self.use_logn_attn = use_logn_attn
self.fc_after_embed = False

if moe is None:
# Legacy MOE config fields
moe = MoeConfig(num_experts=kwargs.pop('moe_num_experts', 0),
Expand All @@ -65,6 +68,7 @@ def to_dict(self):
'disable_weight_only_quant_plugin'] = self.disable_weight_only_quant_plugin
output['use_logn_attn'] = self.use_logn_attn
output['moe'] = self.moe.to_dict()
output['fc_after_embed'] = self.fc_after_embed
return output

@classmethod
Expand Down Expand Up @@ -105,7 +109,7 @@ def from_hugging_face(cls,
hf_config.architectures = ['Qwen2ForCausalLM']

valid_types = ('qwen', 'qwen2', 'qwen2_moe', 'qwen2_llava_onevision',
'qwen2_vl', 'qwen2_audio')
'qwen2_vl', 'qwen2_audio','qwen3')
assert qwen_type in valid_types, f"Unsupported Qwen type: {qwen_type}, only {valid_types} are acceptable."
num_key_value_heads = getattr(hf_config, "num_key_value_heads",
hf_config.num_attention_heads)
Expand All @@ -114,7 +118,12 @@ def from_hugging_face(cls,
hidden_act = getattr(hf_config, "hidden_act", "silu")
if qwen_type == "qwen2_moe":
hidden_act = "swiglu"

attn_bias = True # All existing Qwen models have attn bias
if qwen_type == 'qwen3':
attn_bias = getattr(hf_config,"attention_bias",attn_bias) # qwen3 dense attn_bias is false
head_size = getattr(hf_config,"head_dim", head_size) # qwen3 dense config contains head_dim

rotary_scaling = getattr(hf_config, "rope_scaling", None)
seq_length = getattr(hf_config, "seq_length", 8192)
use_logn_attn = getattr(hf_config, "use_logn_attn", False)
Expand Down Expand Up @@ -187,4 +196,4 @@ def from_hugging_face(cls,
quantization=quant_config,
num_labels=num_labels,
tie_word_embeddings=tie_word_embeddings,
**kwargs)
**kwargs)
Loading