@@ -30,7 +30,7 @@ def __init__(
30
30
self .experts_up_projs = [None ] * self .n_routed_experts
31
31
self .experts_gate_projs = [None ] * self .n_routed_experts
32
32
self .expert_gate_up_proj_etp = None
33
- self .expert_down_proj_etp = None
33
+ self .expert_down_proj_etp = None
34
34
self .w2_list = [None ] * self .n_routed_experts
35
35
self .quant_method = None
36
36
self .lock = threading .Lock ()
@@ -39,7 +39,7 @@ def set_quant_method(self, quant_method):
39
39
if isinstance (quant_method , vLLMFP8w8a8QuantizationMethod ):
40
40
self .quant_method = quant_method
41
41
if self .quant_method is not None :
42
- self .quant_method .is_moe = True
42
+ self .quant_method .is_moe = True
43
43
44
44
def experts (self , input_tensor , router_logits , top_k , renormalize , use_grouped_topk , topk_group , num_expert_group ):
45
45
@@ -99,65 +99,64 @@ def _fuse(self):
99
99
delattr (self , "experts_up_projs" )
100
100
delattr (self , "experts_gate_projs" )
101
101
102
-
103
102
def _load_hf_weights_etp (self , weights ):
104
103
world_size_ = get_world_size ()
105
104
assert self .n_routed_experts % world_size_ == 0
106
105
n_expert_ep = self .n_routed_experts // world_size_
107
106
108
- #tp to ep here
107
+ # tp to ep here
109
108
expert_gate_up_proj_last = None
110
109
expert_down_proj_last = None
111
-
110
+
112
111
for i_experts_ep in range (n_expert_ep ):
113
112
expert_up_proj = None
114
113
expert_gate_proj = None
115
114
expert_gate_up_proj = None
116
115
expert_down_proj = None
117
- i_experts = i_experts_ep + n_expert_ep * self .tp_rank_
116
+ i_experts = i_experts_ep + n_expert_ep * self .tp_rank_
118
117
119
118
if f"{ self .weight_prefix } .{ i_experts } .up_proj.weight" in weights :
120
119
expert_up_proj = weights [f"{ self .weight_prefix } .{ i_experts } .up_proj.weight" ]
121
-
122
- #self.experts_up_proj[i_experts] = expert_up_proj
120
+
121
+ # self.experts_up_proj[i_experts] = expert_up_proj
123
122
124
123
if f"{ self .weight_prefix } .{ i_experts } .gate_proj.weight" in weights :
125
124
expert_gate_proj = weights [f"{ self .weight_prefix } .{ i_experts } .gate_proj.weight" ]
126
- #self.experts_gate_proj[i_experts] = expert_gate_proj
125
+ # self.experts_gate_proj[i_experts] = expert_gate_proj
127
126
128
127
if expert_gate_proj is not None and expert_up_proj is not None :
129
128
expert_gate_up_proj = torch .cat ([expert_gate_proj , expert_up_proj ], dim = 0 )
130
- self .experts_gate_projs [i_experts_ep ] = expert_gate_up_proj # self._cuda(expert_gate_up_proj)
129
+ self .experts_gate_projs [i_experts_ep ] = expert_gate_up_proj # self._cuda(expert_gate_up_proj)
131
130
expert_gate_up_proj_last = expert_gate_up_proj
132
-
131
+
133
132
if f"{ self .weight_prefix } .{ i_experts } .down_proj.weight" in weights :
134
133
expert_down_proj = weights [f"{ self .weight_prefix } .{ i_experts } .down_proj.weight" ]
135
- self .experts_up_projs [i_experts_ep ] = expert_down_proj # self._cuda(expert_down_proj)
134
+ self .experts_up_projs [i_experts_ep ] = expert_down_proj # self._cuda(expert_down_proj)
136
135
expert_down_proj_last = expert_down_proj
137
136
138
137
with self .lock :
139
138
if expert_gate_up_proj_last is not None :
140
- #package, if there is broken experts
139
+ # package, if there is broken experts
140
+
141
+ if self .expert_gate_up_proj_etp is None :
142
+ self .expert_gate_up_proj_etp = torch .zeros (
143
+ (n_expert_ep ,) + expert_gate_up_proj_last .shape , dtype = expert_gate_up_proj_last .dtype
144
+ ).cuda (self .tp_rank_ )
141
145
142
- if self .expert_gate_up_proj_etp is None :
143
- self .expert_gate_up_proj_etp = torch .zeros ( (n_expert_ep ,) + expert_gate_up_proj_last .shape ,
144
- dtype = expert_gate_up_proj_last .dtype ).cuda (self .tp_rank_ )
145
-
146
146
for i_experts_ep in range (n_expert_ep ):
147
147
if self .experts_gate_projs [i_experts_ep ] is not None :
148
- self .expert_gate_up_proj_etp [i_experts_ep ,:] = self .experts_gate_projs [i_experts_ep ]
149
-
148
+ self .expert_gate_up_proj_etp [i_experts_ep , :] = self .experts_gate_projs [i_experts_ep ]
150
149
151
150
if expert_down_proj_last is not None :
152
- #package, if there is broken experts
153
- if self .expert_down_proj_etp is None :
154
- self .expert_down_proj_etp = torch .zeros ( (n_expert_ep ,) + expert_down_proj_last .shape ,
155
- dtype = expert_down_proj_last .dtype ).cuda (self .tp_rank_ )
156
-
151
+ # package, if there is broken experts
152
+ if self .expert_down_proj_etp is None :
153
+ self .expert_down_proj_etp = torch .zeros (
154
+ (n_expert_ep ,) + expert_down_proj_last .shape , dtype = expert_down_proj_last .dtype
155
+ ).cuda (self .tp_rank_ )
156
+
157
157
for i_experts_ep in range (n_expert_ep ):
158
158
if self .experts_up_projs [i_experts_ep ] is not None :
159
- self .expert_down_proj_etp [i_experts_ep ,:] = self .experts_up_projs [i_experts_ep ]
160
-
159
+ self .expert_down_proj_etp [i_experts_ep , :] = self .experts_up_projs [i_experts_ep ]
161
160
162
161
def load_hf_weights (self , weights ):
163
162
if os .environ .get ("ETP_MODE_ENABLED" ) == "true" :
0 commit comments