Skip to content

Commit 1c2abf7

Browse files
committed
[fixed] fix pre-commit
1 parent 0c67427 commit 1c2abf7

File tree

3 files changed

+50
-52
lines changed

3 files changed

+50
-52
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
self.experts_up_projs = [None] * self.n_routed_experts
3131
self.experts_gate_projs = [None] * self.n_routed_experts
3232
self.expert_gate_up_proj_etp = None
33-
self.expert_down_proj_etp = None
33+
self.expert_down_proj_etp = None
3434
self.w2_list = [None] * self.n_routed_experts
3535
self.quant_method = None
3636
self.lock = threading.Lock()
@@ -39,7 +39,7 @@ def set_quant_method(self, quant_method):
3939
if isinstance(quant_method, vLLMFP8w8a8QuantizationMethod):
4040
self.quant_method = quant_method
4141
if self.quant_method is not None:
42-
self.quant_method.is_moe = True
42+
self.quant_method.is_moe = True
4343

4444
def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group):
4545

@@ -99,65 +99,64 @@ def _fuse(self):
9999
delattr(self, "experts_up_projs")
100100
delattr(self, "experts_gate_projs")
101101

102-
103102
def _load_hf_weights_etp(self, weights):
104103
world_size_ = get_world_size()
105104
assert self.n_routed_experts % world_size_ == 0
106105
n_expert_ep = self.n_routed_experts // world_size_
107106

108-
#tp to ep here
107+
# tp to ep here
109108
expert_gate_up_proj_last = None
110109
expert_down_proj_last = None
111-
110+
112111
for i_experts_ep in range(n_expert_ep):
113112
expert_up_proj = None
114113
expert_gate_proj = None
115114
expert_gate_up_proj = None
116115
expert_down_proj = None
117-
i_experts = i_experts_ep + n_expert_ep*self.tp_rank_
116+
i_experts = i_experts_ep + n_expert_ep * self.tp_rank_
118117

119118
if f"{self.weight_prefix}.{i_experts}.up_proj.weight" in weights:
120119
expert_up_proj = weights[f"{self.weight_prefix}.{i_experts}.up_proj.weight"]
121-
122-
#self.experts_up_proj[i_experts] = expert_up_proj
120+
121+
# self.experts_up_proj[i_experts] = expert_up_proj
123122

124123
if f"{self.weight_prefix}.{i_experts}.gate_proj.weight" in weights:
125124
expert_gate_proj = weights[f"{self.weight_prefix}.{i_experts}.gate_proj.weight"]
126-
#self.experts_gate_proj[i_experts] = expert_gate_proj
125+
# self.experts_gate_proj[i_experts] = expert_gate_proj
127126

128127
if expert_gate_proj is not None and expert_up_proj is not None:
129128
expert_gate_up_proj = torch.cat([expert_gate_proj, expert_up_proj], dim=0)
130-
self.experts_gate_projs[i_experts_ep] = expert_gate_up_proj #self._cuda(expert_gate_up_proj)
129+
self.experts_gate_projs[i_experts_ep] = expert_gate_up_proj # self._cuda(expert_gate_up_proj)
131130
expert_gate_up_proj_last = expert_gate_up_proj
132-
131+
133132
if f"{self.weight_prefix}.{i_experts}.down_proj.weight" in weights:
134133
expert_down_proj = weights[f"{self.weight_prefix}.{i_experts}.down_proj.weight"]
135-
self.experts_up_projs[i_experts_ep] = expert_down_proj #self._cuda(expert_down_proj)
134+
self.experts_up_projs[i_experts_ep] = expert_down_proj # self._cuda(expert_down_proj)
136135
expert_down_proj_last = expert_down_proj
137136

138137
with self.lock:
139138
if expert_gate_up_proj_last is not None:
140-
#package, if there is broken experts
139+
# package, if there is broken experts
140+
141+
if self.expert_gate_up_proj_etp is None:
142+
self.expert_gate_up_proj_etp = torch.zeros(
143+
(n_expert_ep,) + expert_gate_up_proj_last.shape, dtype=expert_gate_up_proj_last.dtype
144+
).cuda(self.tp_rank_)
141145

142-
if self.expert_gate_up_proj_etp is None:
143-
self.expert_gate_up_proj_etp = torch.zeros( (n_expert_ep,) + expert_gate_up_proj_last.shape,
144-
dtype=expert_gate_up_proj_last.dtype).cuda(self.tp_rank_)
145-
146146
for i_experts_ep in range(n_expert_ep):
147147
if self.experts_gate_projs[i_experts_ep] is not None:
148-
self.expert_gate_up_proj_etp[i_experts_ep,:] = self.experts_gate_projs[i_experts_ep]
149-
148+
self.expert_gate_up_proj_etp[i_experts_ep, :] = self.experts_gate_projs[i_experts_ep]
150149

151150
if expert_down_proj_last is not None:
152-
#package, if there is broken experts
153-
if self.expert_down_proj_etp is None:
154-
self.expert_down_proj_etp = torch.zeros( (n_expert_ep,) + expert_down_proj_last.shape,
155-
dtype=expert_down_proj_last.dtype).cuda(self.tp_rank_)
156-
151+
# package, if there is broken experts
152+
if self.expert_down_proj_etp is None:
153+
self.expert_down_proj_etp = torch.zeros(
154+
(n_expert_ep,) + expert_down_proj_last.shape, dtype=expert_down_proj_last.dtype
155+
).cuda(self.tp_rank_)
156+
157157
for i_experts_ep in range(n_expert_ep):
158158
if self.experts_up_projs[i_experts_ep] is not None:
159-
self.expert_down_proj_etp[i_experts_ep,:] = self.experts_up_projs[i_experts_ep]
160-
159+
self.expert_down_proj_etp[i_experts_ep, :] = self.experts_up_projs[i_experts_ep]
161160

162161
def load_hf_weights(self, weights):
163162
if os.environ.get("ETP_MODE_ENABLED") == "true":

lightllm/common/deepseek2_mem_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@ def get_cell_size(self):
1111

1212
def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
1313
self.kv_buffer = torch.empty((layer_num, size, head_num, head_dim), dtype=dtype, device="cuda")
14-
#todo, etp or edp use the same work buffer here
15-
#also it can be used for any kernels for work buffer witout save info only
14+
# todo, etp or edp use the same work buffer here
15+
# also it can be used for any kernels for work buffer witout save info only
1616
if os.environ.get("ETP_MODE_ENABLED") == "true":
17-
self.work_buffer = torch.empty(1024*1024*1024,dtype=torch.bfloat16, device="cuda")
17+
self.work_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.bfloat16, device="cuda")
1818
self.work_buffer.share_memory_()
19-
19+
2020
def alloc_kv_move_buffer(self, max_req_total_len):
2121
self.kv_move_buffer = torch.empty(
2222
(1, max_req_total_len + 8, self.head_num, self.head_dim), dtype=self.dtype, device="cuda"

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -403,42 +403,42 @@ def _moe_ffn_etp(
403403
self, input, infer_state: LlamaInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
404404
) -> torch.Tensor:
405405
world_size_ = self.world_size_
406-
num_local_experts = self.n_shared_experts // world_size_
407-
local_expert_offset = self.tp_rank_ * num_local_experts
406+
# num_local_experts = self.n_shared_experts // world_size_
407+
# local_expert_offset = self.tp_rank_ * num_local_experts
408408
num_experts_per_token = self.num_experts_per_tok
409409
num_experts = self.n_routed_experts
410-
num_expert_groups = self.n_group
411-
num_groups_per_token = self.topk_group
410+
# num_expert_groups = self.n_group
411+
# num_groups_per_token = self.topk_group
412412
gating_scaling_factor = self.routed_scaling_factor
413-
gating_normalize_prob = self.norm_topk_prob
413+
# gating_normalize_prob = self.norm_topk_prob
414414
rank_self = self.tp_rank_
415415

416416
hidden_states = input.view(-1, self.embed_dim_)
417417
num_tokens, hidden_dim = hidden_states.shape
418418

419-
final_hidden_states = torch.empty(num_tokens,hidden_dim,device=hidden_states.device,
420-
dtype = hidden_states.dtype )
419+
final_hidden_states = torch.empty(
420+
num_tokens, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype
421+
)
421422

422-
#router_logits_len = hidden_states.shape[0]*layer_weight.moe_gate.shape[1]
423+
# router_logits_len = hidden_states.shape[0]*layer_weight.moe_gate.shape[1]
423424
router_logits = layer_weight.moe_gate.mm(hidden_states)
424425

425-
#now some parameter is not supported yet
426-
#assert gating_normalize_prob is False
427-
#assert num_expert_groups<=1
428-
426+
# now some parameter is not supported yet
427+
# assert gating_normalize_prob is False
428+
# assert num_expert_groups<=1
429429

430-
431430
import lightllm_moe_etp_kernel
431+
432432
lightllm_moe_etp_kernel.moe_fused_all(
433433
router_logits.contiguous(),
434434
hidden_states.contiguous(),
435-
layer_weight.gate_up_proj.weight.contiguous(), #transpose
436-
layer_weight.down_proj.weight.contiguous(), #transpose
435+
layer_weight.gate_up_proj.weight.contiguous(), # transpose
436+
layer_weight.down_proj.weight.contiguous(), # transpose
437437
layer_weight.experts.expert_gate_up_proj_etp.contiguous(),
438-
layer_weight.experts.expert_down_proj_etp.contiguous(),
439-
infer_state.mem_manager.work_buffer.contiguous(),
438+
layer_weight.experts.expert_down_proj_etp.contiguous(),
439+
infer_state.mem_manager.work_buffer.contiguous(),
440440
infer_state.mem_manager.work_buffer.nelement(),
441-
final_hidden_states.contiguous(),
441+
final_hidden_states.contiguous(),
442442
rank_self,
443443
gating_scaling_factor,
444444
num_experts,
@@ -447,12 +447,11 @@ def _moe_ffn_etp(
447447
world_size_,
448448
True,
449449
hidden_dim,
450-
layer_weight.gate_up_proj.weight.size(1)//2,
451-
layer_weight.experts.expert_gate_up_proj_etp.size(1)//2,
452-
self.n_shared_experts is not None
450+
layer_weight.gate_up_proj.weight.size(1) // 2,
451+
layer_weight.experts.expert_gate_up_proj_etp.size(1) // 2,
452+
self.n_shared_experts is not None,
453453
)
454454

455455
router_logits = None
456456

457457
return final_hidden_states.view(num_tokens, hidden_dim)
458-

0 commit comments

Comments
 (0)