|
4 | 4 | from rtp_llm.model_factory_register import register_model |
5 | 5 | from rtp_llm.model_loader.model_weight_info import ModelWeightInfo |
6 | 6 | from rtp_llm.model_loader.weight_module import AtomicWeight, WeightModule |
7 | | -from rtp_llm.models.qwen3_next.qwen3_next import Qwen3Next |
| 7 | +from rtp_llm.models.qwen3_next.qwen3_next import Qwen3Next, Qwen35Moe |
8 | 8 | from rtp_llm.models.qwen3_next.qwen3_next_weight import Qwen3NextWeight, plus_one |
9 | 9 | from rtp_llm.ops import HybridAttentionType |
10 | 10 | from rtp_llm.utils.model_weight import CkptWeightInfo, W, identity, transpose |
@@ -111,8 +111,39 @@ def get_weight_cls(): |
111 | 111 | return Qwen3NextMTPWeight |
112 | 112 |
|
113 | 113 |
|
114 | | -class Qwen35MoeMTP(Qwen3NextMTP): |
| 114 | +class Qwen35MoeMTP(Qwen35Moe): |
115 | 115 | @classmethod |
| 116 | + def _create_config(cls, ckpt_path: str) -> ModelConfig: |
| 117 | + config = super()._create_config(ckpt_path) |
| 118 | + # mtp model attention is mqa, not linear |
| 119 | + config.hybrid_attention_config.hybrid_attention_types = [ |
| 120 | + HybridAttentionType.NONE |
| 121 | + ] |
| 122 | + config.moe_layer_index = [0] |
| 123 | + config.num_layers = 1 |
| 124 | + config.is_mtp = True |
| 125 | + return config |
| 126 | + |
| 127 | + def _create_python_model(self) -> Optional[Any]: |
| 128 | + from rtp_llm.models_py.model_desc.qwen3_next_mtp import Qwen3NextMTPModel |
| 129 | + |
| 130 | + model_config = self.model_config |
| 131 | + parallelism_config = self.parallelism_config |
| 132 | + fmha_config = self.fmha_config |
| 133 | + py_hw_kernel_config = self.hw_kernel_config |
| 134 | + moe_config = self.moe_config |
| 135 | + self.py_model = Qwen3NextMTPModel( |
| 136 | + model_config, |
| 137 | + parallelism_config, |
| 138 | + self.weight, |
| 139 | + max_generate_batch_size=self.max_generate_batch_size, |
| 140 | + moe_config=moe_config, |
| 141 | + fmha_config=fmha_config, |
| 142 | + py_hw_kernel_config=py_hw_kernel_config, |
| 143 | + device_resource_config=self.device_resource_config, |
| 144 | + ) |
| 145 | + |
| 146 | + @staticmethod |
116 | 147 | def get_weight_cls(): |
117 | 148 | return Qwen35MoeMTPWeight |
118 | 149 |
|
|
0 commit comments