Skip to content

Commit f66e47d

Browse files
committed
fix - fix qwen35 moe mtp
1 parent 76480d9 commit f66e47d

File tree

1 file changed

+33
-2
lines changed

1 file changed

+33
-2
lines changed

rtp_llm/models/qwen3_next/qwen3_next_mtp.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from rtp_llm.model_factory_register import register_model
55
from rtp_llm.model_loader.model_weight_info import ModelWeightInfo
66
from rtp_llm.model_loader.weight_module import AtomicWeight, WeightModule
7-
from rtp_llm.models.qwen3_next.qwen3_next import Qwen3Next
7+
from rtp_llm.models.qwen3_next.qwen3_next import Qwen3Next, Qwen35Moe
88
from rtp_llm.models.qwen3_next.qwen3_next_weight import Qwen3NextWeight, plus_one
99
from rtp_llm.ops import HybridAttentionType
1010
from rtp_llm.utils.model_weight import CkptWeightInfo, W, identity, transpose
@@ -111,8 +111,39 @@ def get_weight_cls():
111111
return Qwen3NextMTPWeight
112112

113113

114-
class Qwen35MoeMTP(Qwen3NextMTP):
114+
class Qwen35MoeMTP(Qwen35Moe):
115115
@classmethod
116+
def _create_config(cls, ckpt_path: str) -> ModelConfig:
117+
config = super()._create_config(ckpt_path)
118+
# mtp model attention is mqa, not linear
119+
config.hybrid_attention_config.hybrid_attention_types = [
120+
HybridAttentionType.NONE
121+
]
122+
config.moe_layer_index = [0]
123+
config.num_layers = 1
124+
config.is_mtp = True
125+
return config
126+
127+
def _create_python_model(self) -> Optional[Any]:
128+
from rtp_llm.models_py.model_desc.qwen3_next_mtp import Qwen3NextMTPModel
129+
130+
model_config = self.model_config
131+
parallelism_config = self.parallelism_config
132+
fmha_config = self.fmha_config
133+
py_hw_kernel_config = self.hw_kernel_config
134+
moe_config = self.moe_config
135+
self.py_model = Qwen3NextMTPModel(
136+
model_config,
137+
parallelism_config,
138+
self.weight,
139+
max_generate_batch_size=self.max_generate_batch_size,
140+
moe_config=moe_config,
141+
fmha_config=fmha_config,
142+
py_hw_kernel_config=py_hw_kernel_config,
143+
device_resource_config=self.device_resource_config,
144+
)
145+
146+
@staticmethod
116147
def get_weight_cls():
117148
return Qwen35MoeMTPWeight
118149

0 commit comments

Comments
 (0)