3
3
from lightllm .utils .dist_utils import get_world_size , get_rank
4
4
import threading
5
5
from lightllm .common .quantization import vLLMFP8w8a8QuantizationMethod
6
+ import os
6
7
7
8
try :
8
9
HAS_VLLM = True
@@ -28,6 +29,8 @@ def __init__(
28
29
self .tp_rank_ = get_rank ()
29
30
self .experts_up_projs = [None ] * self .n_routed_experts
30
31
self .experts_gate_projs = [None ] * self .n_routed_experts
32
+ self .expert_gate_up_proj_etp = None
33
+ self .expert_down_proj_etp = None
31
34
self .w2_list = [None ] * self .n_routed_experts
32
35
self .quant_method = None
33
36
self .lock = threading .Lock ()
@@ -36,9 +39,10 @@ def set_quant_method(self, quant_method):
36
39
if isinstance (quant_method , vLLMFP8w8a8QuantizationMethod ):
37
40
self .quant_method = quant_method
38
41
if self .quant_method is not None :
39
- self .quant_method .is_moe = True
42
+ self .quant_method .is_moe = True
40
43
41
44
def experts (self , input_tensor , router_logits , top_k , renormalize , use_grouped_topk , topk_group , num_expert_group ):
45
+
42
46
topk_weights , topk_ids = FusedMoE .select_experts (
43
47
hidden_states = input_tensor ,
44
48
router_logits = router_logits ,
@@ -95,27 +99,90 @@ def _fuse(self):
95
99
delattr (self , "experts_up_projs" )
96
100
delattr (self , "experts_gate_projs" )
97
101
102
+
103
+ def _load_hf_weights_etp (self , weights ):
104
+ world_size_ = get_world_size ()
105
+ assert self .n_routed_experts % world_size_ == 0
106
+ n_expert_ep = self .n_routed_experts // world_size_
107
+
108
+ #tp to ep here
109
+ expert_gate_up_proj_last = None
110
+ expert_down_proj_last = None
111
+
112
+ for i_experts_ep in range (n_expert_ep ):
113
+ expert_up_proj = None
114
+ expert_gate_proj = None
115
+ expert_gate_up_proj = None
116
+ expert_down_proj = None
117
+ i_experts = i_experts_ep + n_expert_ep * self .tp_rank_
118
+
119
+ if f"{ self .weight_prefix } .{ i_experts } .up_proj.weight" in weights :
120
+ expert_up_proj = weights [f"{ self .weight_prefix } .{ i_experts } .up_proj.weight" ]
121
+
122
+ #self.experts_up_proj[i_experts] = expert_up_proj
123
+
124
+ if f"{ self .weight_prefix } .{ i_experts } .gate_proj.weight" in weights :
125
+ expert_gate_proj = weights [f"{ self .weight_prefix } .{ i_experts } .gate_proj.weight" ]
126
+ #self.experts_gate_proj[i_experts] = expert_gate_proj
127
+
128
+ if expert_gate_proj is not None and expert_up_proj is not None :
129
+ expert_gate_up_proj = torch .cat ([expert_gate_proj , expert_up_proj ], dim = 0 )
130
+ self .experts_gate_projs [i_experts_ep ] = expert_gate_up_proj #self._cuda(expert_gate_up_proj)
131
+ expert_gate_up_proj_last = expert_gate_up_proj
132
+
133
+ if f"{ self .weight_prefix } .{ i_experts } .down_proj.weight" in weights :
134
+ expert_down_proj = weights [f"{ self .weight_prefix } .{ i_experts } .down_proj.weight" ]
135
+ self .experts_up_projs [i_experts_ep ] = expert_down_proj #self._cuda(expert_down_proj)
136
+ expert_down_proj_last = expert_down_proj
137
+
138
+ with self .lock :
139
+ if expert_gate_up_proj_last is not None :
140
+ #package, if there is broken experts
141
+
142
+ if self .expert_gate_up_proj_etp is None :
143
+ self .expert_gate_up_proj_etp = torch .zeros ( (n_expert_ep ,) + expert_gate_up_proj_last .shape ,
144
+ dtype = expert_gate_up_proj_last .dtype ).cuda (self .tp_rank_ )
145
+
146
+ for i_experts_ep in range (n_expert_ep ):
147
+ if self .experts_gate_projs [i_experts_ep ] is not None :
148
+ self .expert_gate_up_proj_etp [i_experts_ep ,:] = self .experts_gate_projs [i_experts_ep ]
149
+
150
+
151
+ if expert_down_proj_last is not None :
152
+ #package, if there is broken experts
153
+ if self .expert_down_proj_etp is None :
154
+ self .expert_down_proj_etp = torch .zeros ( (n_expert_ep ,) + expert_down_proj_last .shape ,
155
+ dtype = expert_down_proj_last .dtype ).cuda (self .tp_rank_ )
156
+
157
+ for i_experts_ep in range (n_expert_ep ):
158
+ if self .experts_up_projs [i_experts_ep ] is not None :
159
+ self .expert_down_proj_etp [i_experts_ep ,:] = self .experts_up_projs [i_experts_ep ]
160
+
161
+
98
162
def load_hf_weights (self , weights ):
99
- for i_experts in range (self .n_routed_experts ):
100
- w1_weight = f"{ self .weight_prefix } .{ i_experts } .{ self .w1_weight_name } .weight"
101
- w2_weight = f"{ self .weight_prefix } .{ i_experts } .{ self .w2_weight_name } .weight"
102
- w3_weight = f"{ self .weight_prefix } .{ i_experts } .{ self .w3_weight_name } .weight"
103
-
104
- if w1_weight in weights :
105
- self .experts_gate_projs [i_experts ] = weights [w1_weight ][
106
- self .split_inter_size * self .tp_rank_ : self .split_inter_size * (self .tp_rank_ + 1 ), :
107
- ]
108
- if w3_weight in weights :
109
- self .experts_up_projs [i_experts ] = weights [w3_weight ][
110
- self .split_inter_size * self .tp_rank_ : self .split_inter_size * (self .tp_rank_ + 1 ), :
111
- ]
112
-
113
- if w2_weight in weights :
114
- self .w2_list [i_experts ] = weights [w2_weight ][
115
- :, self .split_inter_size * self .tp_rank_ : self .split_inter_size * (self .tp_rank_ + 1 )
116
- ]
117
-
118
- self ._fuse ()
163
+ if os .environ .get ("ETP_MODE_ENABLED" ) == "true" :
164
+ self ._load_hf_weights_etp (weights )
165
+ else :
166
+ for i_experts in range (self .n_routed_experts ):
167
+ w1_weight = f"{ self .weight_prefix } .{ i_experts } .{ self .w1_weight_name } .weight"
168
+ w2_weight = f"{ self .weight_prefix } .{ i_experts } .{ self .w2_weight_name } .weight"
169
+ w3_weight = f"{ self .weight_prefix } .{ i_experts } .{ self .w3_weight_name } .weight"
170
+
171
+ if w1_weight in weights :
172
+ self .experts_gate_projs [i_experts ] = weights [w1_weight ][
173
+ self .split_inter_size * self .tp_rank_ : self .split_inter_size * (self .tp_rank_ + 1 ), :
174
+ ]
175
+ if w3_weight in weights :
176
+ self .experts_up_projs [i_experts ] = weights [w3_weight ][
177
+ self .split_inter_size * self .tp_rank_ : self .split_inter_size * (self .tp_rank_ + 1 ), :
178
+ ]
179
+
180
+ if w2_weight in weights :
181
+ self .w2_list [i_experts ] = weights [w2_weight ][
182
+ :, self .split_inter_size * self .tp_rank_ : self .split_inter_size * (self .tp_rank_ + 1 )
183
+ ]
184
+
185
+ self ._fuse ()
119
186
120
187
def _cuda (self , cpu_tensor ):
121
188
if self .tp_rank_ is None :
@@ -124,4 +191,7 @@ def _cuda(self, cpu_tensor):
124
191
return cpu_tensor .contiguous ().to (self .data_type_ ).cuda (self .tp_rank_ )
125
192
126
193
def verify_load (self ):
127
- return self .w1 is not None and self .w2 is not None
194
+ if os .environ .get ("ETP_MODE_ENABLED" ) == "true" :
195
+ return True
196
+ else :
197
+ return self .w1 is not None and self .w2 is not None
0 commit comments