Skip to content

Commit 228e3e7

Browse files
committed
[feature]moe etp done, without group greed
1 parent 413f330 commit 228e3e7

File tree

3 files changed

+166
-25
lines changed

3 files changed

+166
-25
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py

Lines changed: 92 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from lightllm.utils.dist_utils import get_world_size, get_rank
44
import threading
55
from lightllm.common.quantization import vLLMFP8w8a8QuantizationMethod
6+
import os
67

78
try:
89
HAS_VLLM = True
@@ -28,6 +29,8 @@ def __init__(
2829
self.tp_rank_ = get_rank()
2930
self.experts_up_projs = [None] * self.n_routed_experts
3031
self.experts_gate_projs = [None] * self.n_routed_experts
32+
self.expert_gate_up_proj_etp = None
33+
self.expert_down_proj_etp = None
3134
self.w2_list = [None] * self.n_routed_experts
3235
self.quant_method = None
3336
self.lock = threading.Lock()
@@ -36,9 +39,10 @@ def set_quant_method(self, quant_method):
3639
if isinstance(quant_method, vLLMFP8w8a8QuantizationMethod):
3740
self.quant_method = quant_method
3841
if self.quant_method is not None:
39-
self.quant_method.is_moe = True
42+
self.quant_method.is_moe = True
4043

4144
def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group):
45+
4246
topk_weights, topk_ids = FusedMoE.select_experts(
4347
hidden_states=input_tensor,
4448
router_logits=router_logits,
@@ -95,27 +99,90 @@ def _fuse(self):
9599
delattr(self, "experts_up_projs")
96100
delattr(self, "experts_gate_projs")
97101

102+
103+
def _load_hf_weights_etp(self, weights):
104+
world_size_ = get_world_size()
105+
assert self.n_routed_experts % world_size_ == 0
106+
n_expert_ep = self.n_routed_experts // world_size_
107+
108+
#tp to ep here
109+
expert_gate_up_proj_last = None
110+
expert_down_proj_last = None
111+
112+
for i_experts_ep in range(n_expert_ep):
113+
expert_up_proj = None
114+
expert_gate_proj = None
115+
expert_gate_up_proj = None
116+
expert_down_proj = None
117+
i_experts = i_experts_ep + n_expert_ep*self.tp_rank_
118+
119+
if f"{self.weight_prefix}.{i_experts}.up_proj.weight" in weights:
120+
expert_up_proj = weights[f"{self.weight_prefix}.{i_experts}.up_proj.weight"]
121+
122+
#self.experts_up_proj[i_experts] = expert_up_proj
123+
124+
if f"{self.weight_prefix}.{i_experts}.gate_proj.weight" in weights:
125+
expert_gate_proj = weights[f"{self.weight_prefix}.{i_experts}.gate_proj.weight"]
126+
#self.experts_gate_proj[i_experts] = expert_gate_proj
127+
128+
if expert_gate_proj is not None and expert_up_proj is not None:
129+
expert_gate_up_proj = torch.cat([expert_gate_proj, expert_up_proj], dim=0)
130+
self.experts_gate_projs[i_experts_ep] = expert_gate_up_proj #self._cuda(expert_gate_up_proj)
131+
expert_gate_up_proj_last = expert_gate_up_proj
132+
133+
if f"{self.weight_prefix}.{i_experts}.down_proj.weight" in weights:
134+
expert_down_proj = weights[f"{self.weight_prefix}.{i_experts}.down_proj.weight"]
135+
self.experts_up_projs[i_experts_ep] = expert_down_proj #self._cuda(expert_down_proj)
136+
expert_down_proj_last = expert_down_proj
137+
138+
with self.lock:
139+
if expert_gate_up_proj_last is not None:
140+
#package, if there is broken experts
141+
142+
if self.expert_gate_up_proj_etp is None:
143+
self.expert_gate_up_proj_etp = torch.zeros( (n_expert_ep,) + expert_gate_up_proj_last.shape,
144+
dtype=expert_gate_up_proj_last.dtype).cuda(self.tp_rank_)
145+
146+
for i_experts_ep in range(n_expert_ep):
147+
if self.experts_gate_projs[i_experts_ep] is not None:
148+
self.expert_gate_up_proj_etp[i_experts_ep,:] = self.experts_gate_projs[i_experts_ep]
149+
150+
151+
if expert_down_proj_last is not None:
152+
#package, if there is broken experts
153+
if self.expert_down_proj_etp is None:
154+
self.expert_down_proj_etp = torch.zeros( (n_expert_ep,) + expert_down_proj_last.shape,
155+
dtype=expert_down_proj_last.dtype).cuda(self.tp_rank_)
156+
157+
for i_experts_ep in range(n_expert_ep):
158+
if self.experts_up_projs[i_experts_ep] is not None:
159+
self.expert_down_proj_etp[i_experts_ep,:] = self.experts_up_projs[i_experts_ep]
160+
161+
98162
def load_hf_weights(self, weights):
99-
for i_experts in range(self.n_routed_experts):
100-
w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight"
101-
w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight"
102-
w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight"
103-
104-
if w1_weight in weights:
105-
self.experts_gate_projs[i_experts] = weights[w1_weight][
106-
self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :
107-
]
108-
if w3_weight in weights:
109-
self.experts_up_projs[i_experts] = weights[w3_weight][
110-
self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :
111-
]
112-
113-
if w2_weight in weights:
114-
self.w2_list[i_experts] = weights[w2_weight][
115-
:, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1)
116-
]
117-
118-
self._fuse()
163+
if os.environ.get("ETP_MODE_ENABLED") == "true":
164+
self._load_hf_weights_etp(weights)
165+
else:
166+
for i_experts in range(self.n_routed_experts):
167+
w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight"
168+
w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight"
169+
w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight"
170+
171+
if w1_weight in weights:
172+
self.experts_gate_projs[i_experts] = weights[w1_weight][
173+
self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :
174+
]
175+
if w3_weight in weights:
176+
self.experts_up_projs[i_experts] = weights[w3_weight][
177+
self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :
178+
]
179+
180+
if w2_weight in weights:
181+
self.w2_list[i_experts] = weights[w2_weight][
182+
:, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1)
183+
]
184+
185+
self._fuse()
119186

120187
def _cuda(self, cpu_tensor):
121188
if self.tp_rank_ is None:
@@ -124,4 +191,7 @@ def _cuda(self, cpu_tensor):
124191
return cpu_tensor.contiguous().to(self.data_type_).cuda(self.tp_rank_)
125192

126193
def verify_load(self):
127-
return self.w1 is not None and self.w2 is not None
194+
if os.environ.get("ETP_MODE_ENABLED") == "true":
195+
return True
196+
else:
197+
return self.w1 is not None and self.w2 is not None

lightllm/common/deepseek2_mem_manager.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import os
23

34
from .mem_manager import MemoryManager
45
from typing import List
@@ -10,7 +11,12 @@ def get_cell_size(self):
1011

1112
def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
1213
self.kv_buffer = torch.empty((layer_num, size, head_num, head_dim), dtype=dtype, device="cuda")
13-
14+
#todo, etp or edp use the same work buffer here
15+
#also it can be used for any kernels for work buffer witout save info only
16+
if os.environ.get("ETP_MODE_ENABLED") == "true":
17+
self.work_buffer = torch.empty(1024*1024*1024,dtype=torch.bfloat16, device="cuda")
18+
self.work_buffer.share_memory_()
19+
1420
def alloc_kv_move_buffer(self, max_req_total_len):
1521
self.kv_move_buffer = torch.empty(
1622
(1, max_req_total_len + 8, self.head_num, self.head_dim), dtype=self.dtype, device="cuda"

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
1919
from functools import partial
2020
from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale
21-
21+
import os
2222

2323
class Deepseek2TransformerLayerInfer(LlamaTransformerLayerInfer):
2424
def __init__(
@@ -30,6 +30,9 @@ def __init__(
3030
self.qk_rope_head_dim = network_config["qk_rope_head_dim"]
3131
self.q_lora_rank = network_config["q_lora_rank"]
3232
self.kv_lora_rank = network_config["kv_lora_rank"]
33+
34+
self.n_routed_experts = network_config["n_routed_experts"]
35+
3336
self.is_moe = (
3437
network_config["n_routed_experts"] is not None
3538
and layer_num >= network_config["first_k_dense_replace"]
@@ -64,7 +67,10 @@ def _bind_attention(self):
6467
)
6568
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
6669
if self.is_moe:
67-
self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn, self)
70+
if os.environ.get("ETP_MODE_ENABLED") == "true":
71+
self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn_etp, self)
72+
else:
73+
self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn, self)
6874
else:
6975
self._ffn = partial(LlamaTransformerLayerInfer._ffn, self)
7076

@@ -196,6 +202,7 @@ def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager):
196202
def _moe_ffn(
197203
self, input, infer_state: LlamaInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
198204
) -> torch.Tensor:
205+
199206
hidden_states = input.view(-1, self.embed_dim_)
200207
num_tokens, hidden_dim = hidden_states.shape
201208

@@ -219,3 +226,61 @@ def _moe_ffn(
219226
hidden_states.add_(shared_output)
220227

221228
return hidden_states.view(num_tokens, hidden_dim)
229+
230+
def _moe_ffn_etp(
231+
self, input, infer_state: LlamaInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
232+
) -> torch.Tensor:
233+
world_size_ = self.world_size_
234+
num_local_experts = self.n_shared_experts // world_size_
235+
local_expert_offset = self.tp_rank_ * num_local_experts
236+
num_experts_per_token = self.num_experts_per_tok
237+
num_experts = self.n_routed_experts
238+
num_expert_groups = self.n_group
239+
num_groups_per_token = self.topk_group
240+
gating_scaling_factor = self.routed_scaling_factor
241+
gating_normalize_prob = self.norm_topk_prob
242+
rank_self = self.tp_rank_
243+
244+
hidden_states = input.view(-1, self.embed_dim_)
245+
num_tokens, hidden_dim = hidden_states.shape
246+
247+
final_hidden_states = torch.empty(num_tokens,hidden_dim,device=hidden_states.device,
248+
dtype = hidden_states.dtype )
249+
250+
#router_logits_len = hidden_states.shape[0]*layer_weight.moe_gate.shape[1]
251+
router_logits = layer_weight.moe_gate.mm(hidden_states)
252+
253+
#now some parameter is not supported yet
254+
#assert gating_normalize_prob is False
255+
#assert num_expert_groups<=1
256+
257+
258+
259+
import lightllm_moe_etp_kernel
260+
lightllm_moe_etp_kernel.moe_fused_all(
261+
router_logits.contiguous(),
262+
hidden_states.contiguous(),
263+
layer_weight.gate_up_proj.weight.contiguous(), #transpose
264+
layer_weight.down_proj.weight.contiguous(), #transpose
265+
layer_weight.experts.expert_gate_up_proj_etp.contiguous(),
266+
layer_weight.experts.expert_down_proj_etp.contiguous(),
267+
infer_state.mem_manager.work_buffer.contiguous(),
268+
infer_state.mem_manager.work_buffer.nelement(),
269+
final_hidden_states.contiguous(),
270+
rank_self,
271+
gating_scaling_factor,
272+
num_experts,
273+
num_experts_per_token,
274+
num_tokens,
275+
world_size_,
276+
True,
277+
hidden_dim,
278+
layer_weight.gate_up_proj.weight.size(1)//2,
279+
layer_weight.experts.expert_gate_up_proj_etp.size(1)//2,
280+
self.n_shared_experts is not None
281+
)
282+
283+
router_logits = None
284+
285+
return final_hidden_states.view(num_tokens, hidden_dim)
286+

0 commit comments

Comments
 (0)