-
Notifications
You must be signed in to change notification settings - Fork 269
support qwen3moe overlap mode #974
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd | ||
from functools import partial | ||
from lightllm.utils.log_utils import init_logger | ||
from lightllm.utils.dist_utils import get_global_world_size | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
@@ -27,6 +28,7 @@ def __init__(self, layer_num, network_config, mode=[]): | |
) | ||
self.num_experts_per_tok = network_config["num_experts_per_tok"] | ||
self.norm_topk_prob = network_config["norm_topk_prob"] | ||
self.n_shared_experts = network_config.get("n_shared_experts", None) | ||
super().__init__(layer_num, network_config, mode) | ||
self.head_dim_ = network_config["head_dim"] | ||
self.tp_k_head_num_ = max(self.tp_k_head_num_, 1) | ||
|
@@ -120,3 +122,274 @@ def _moe_ffn_edp( | |
|
||
ep_output = ep_output.view(token_num, hidden_dim) | ||
return ep_output | ||
|
||
def overlap_tpsp_token_forward( | ||
self, | ||
input_embdings: torch.Tensor, | ||
input_embdings1: torch.Tensor, | ||
infer_state: LlamaInferStateInfo, | ||
infer_state1: LlamaInferStateInfo, | ||
layer_weight: Qwen3MOETransformerLayerWeight, | ||
): | ||
if not self.is_moe: | ||
return super().overlap_tpsp_token_forward( | ||
input_embdings, input_embdings1, infer_state, infer_state1, layer_weight | ||
) | ||
# 0 attention | ||
_0_input1 = self._att_norm(input_embdings, infer_state, layer_weight) | ||
_0_cache_kv = self._pre_cache_kv(infer_state, layer_weight) | ||
_0_q, _0_cache_kv = self._tpsp_get_qkv(_0_input1, _0_cache_kv, infer_state, layer_weight) | ||
_0_input1 = None | ||
self._post_cache_kv(_0_cache_kv, infer_state, layer_weight) | ||
_0_o = self._token_attention_kernel(_0_q, infer_state, layer_weight) | ||
_0_q = None | ||
_0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight) | ||
input_embdings.add_(_0_o.view(-1, self.embed_dim_)) | ||
_0_o = None | ||
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||
_0_router_logits = layer_weight.moe_gate.mm(_0_input1) | ||
# 1 hook | ||
if getattr(infer_state1, "hook", None) is not None: | ||
infer_state1.hook() | ||
infer_state1.hook = None | ||
|
||
# 0 shared expert | ||
if self.n_shared_experts is not None: | ||
_0_shared_output = LlamaTransformerLayerInfer._ffn(self, _0_input1, infer_state, layer_weight) | ||
|
||
# 0 dispatch | ||
( | ||
_0_recv_x, | ||
_0_masked_m, | ||
_0_topk_idx, | ||
_0_topk_weight, | ||
_0_handle, | ||
_0_hook, | ||
) = layer_weight.experts.low_latency_dispatch(_0_input1, _0_router_logits) | ||
infer_state.hook = _0_hook | ||
|
||
# 1 attention | ||
_1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight) | ||
_1_cache_kv = self._pre_cache_kv(infer_state1, layer_weight) | ||
_1_q, _1_cache_kv = self._tpsp_get_qkv(_1_input1, _1_cache_kv, infer_state1, layer_weight) | ||
_1_input1 = None | ||
self._post_cache_kv(_1_cache_kv, infer_state1, layer_weight) | ||
_1_o = self._token_attention_kernel(_1_q, infer_state1, layer_weight) | ||
_1_q = None | ||
_1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight) | ||
input_embdings1.add_(_1_o.view(-1, self.embed_dim_)) | ||
_1_o = None | ||
_1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight) | ||
# to do gate and disptatch | ||
|
||
_1_router_logits = layer_weight.moe_gate.mm(_1_input1) | ||
# 0 hook | ||
if getattr(infer_state, "hook", None) is not None: | ||
infer_state.hook() | ||
infer_state.hook = None | ||
|
||
# 1 shared expert | ||
if self.n_shared_experts is not None: | ||
_1_shared_output = LlamaTransformerLayerInfer._ffn(self, _1_input1, infer_state1, layer_weight) | ||
|
||
# 1 dispatch | ||
( | ||
_1_recv_x, | ||
_1_masked_m, | ||
_1_topk_idx, | ||
_1_topk_weight, | ||
_1_handle, | ||
_1_hook, | ||
) = layer_weight.experts.low_latency_dispatch(_1_input1, _1_router_logits) | ||
infer_state1.hook = _1_hook | ||
|
||
# moe calu | ||
expected_m = triton.cdiv( | ||
input_embdings.shape[0] * get_global_world_size() * self.num_experts_per_tok, self.n_routed_experts | ||
) | ||
_0_moe_out = layer_weight.experts.masked_group_gemm(_0_recv_x, _0_masked_m, input_embdings.dtype, expected_m) | ||
|
||
# 1 hook | ||
if getattr(infer_state1, "hook", None) is not None: | ||
infer_state1.hook() | ||
infer_state1.hook = None | ||
|
||
# 0 combine | ||
_0_ffn_out, _0_hook = layer_weight.experts.low_latency_combine( | ||
_0_moe_out, _0_topk_idx, _0_topk_weight, _0_handle | ||
) | ||
|
||
infer_state.hook = _0_hook | ||
|
||
# to do moe caclue | ||
_1_moe_out = layer_weight.experts.masked_group_gemm(_1_recv_x, _1_masked_m, input_embdings1.dtype, expected_m) | ||
|
||
# 0 hook | ||
if getattr(infer_state, "hook", None) is not None: | ||
infer_state.hook() | ||
# _0_ffn_out *= self.routed_scaling_factor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
if self.n_shared_experts is not None: | ||
_0_ffn_out.add_(_0_shared_output) | ||
input_embdings.add_(_0_ffn_out.view(-1, self.embed_dim_)) | ||
infer_state.hook = None | ||
|
||
# 1 combine | ||
_1_ffn_out, _1_hook = layer_weight.experts.low_latency_combine( | ||
_1_moe_out, _1_topk_idx, _1_topk_weight, _1_handle | ||
) | ||
|
||
def _1_hook_post(): | ||
_1_hook() | ||
nonlocal _1_ffn_out | ||
# _1_ffn_out *= self.routed_scaling_factor | ||
if self.n_shared_experts is not None: | ||
_1_ffn_out.add_(_1_shared_output) | ||
input_embdings1.add_(_1_ffn_out.view(-1, self.embed_dim_)) | ||
return | ||
|
||
infer_state1.hook = _1_hook_post | ||
|
||
return input_embdings, input_embdings1 | ||
|
||
def overlap_tpsp_context_forward( | ||
self, | ||
input_embdings: torch.Tensor, | ||
input_embdings1: torch.Tensor, | ||
infer_state: LlamaInferStateInfo, | ||
infer_state1: LlamaInferStateInfo, | ||
layer_weight: Qwen3MOETransformerLayerWeight, | ||
): | ||
if not self.is_moe: | ||
return super().overlap_tpsp_context_forward( | ||
input_embdings, input_embdings1, infer_state, infer_state1, layer_weight | ||
) | ||
# 0 attention | ||
_0_input1 = self._att_norm(input_embdings, infer_state, layer_weight) | ||
_0_cache_kv = self._pre_cache_kv(infer_state, layer_weight) | ||
_0_q, _0_cache_kv = self._tpsp_get_qkv(_0_input1, _0_cache_kv, infer_state, layer_weight) | ||
_0_input1 = None | ||
self._post_cache_kv(_0_cache_kv, infer_state, layer_weight) | ||
_0_o = self._context_attention_kernel(_0_q, _0_cache_kv, infer_state, layer_weight) | ||
_0_q = None | ||
_0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight) | ||
input_embdings.add_(_0_o.view(-1, self.embed_dim_)) | ||
_0_o = None | ||
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||
_0_router_logits = layer_weight.moe_gate.mm(_0_input1) | ||
|
||
# wait last 1 combine | ||
if getattr(infer_state1, "hook", None) is not None: | ||
infer_state1.hook() | ||
infer_state1.hook = None | ||
|
||
_0_topk_weight, _0_topk_idx, _0_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( | ||
_0_input1, _0_router_logits | ||
) | ||
from deep_ep import Buffer | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
_0_overlap_event = Buffer.capture() | ||
|
||
# 1 attention | ||
_1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight) | ||
_1_cache_kv = self._pre_cache_kv(infer_state1, layer_weight) | ||
_1_q, _1_cache_kv = self._tpsp_get_qkv(_1_input1, _1_cache_kv, infer_state1, layer_weight) | ||
_1_input1 = None | ||
self._post_cache_kv(_1_cache_kv, infer_state1, layer_weight) | ||
_1_o = self._context_attention_kernel(_1_q, _1_cache_kv, infer_state1, layer_weight) | ||
_1_q = None | ||
_1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight) | ||
input_embdings1.add_(_1_o.view(-1, self.embed_dim_)) | ||
_1_o = None | ||
_1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight) | ||
# to do gate and disptatch | ||
|
||
_1_router_logits = layer_weight.moe_gate.mm(_1_input1) | ||
|
||
# 0 dispatch execute | ||
( | ||
_0_recv_x, | ||
_0_recv_topk_idx, | ||
_0_recv_topk_weight, | ||
_0_num_recv_tokens_per_expert_list, | ||
_0_handle, | ||
_0_hook, | ||
) = layer_weight.experts.dispatch(_0_qinput_tensor, _0_topk_idx, _0_topk_weight, overlap_event=_0_overlap_event) | ||
infer_state.hook = _0_hook | ||
|
||
# wait 0 dispatch | ||
if getattr(infer_state, "hook", None) is not None: | ||
infer_state.hook() | ||
infer_state.hook = None | ||
|
||
_1_topk_weight, _1_topk_idx, _1_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( | ||
_1_input1, _1_router_logits | ||
) | ||
|
||
_1_overlap_event = Buffer.capture() | ||
|
||
# 0 shared expert | ||
if self.n_shared_experts is not None: | ||
_0_shared_output = LlamaTransformerLayerInfer._ffn(self, _0_input1, infer_state, layer_weight) | ||
|
||
# 1 shared expert | ||
if self.n_shared_experts is not None: | ||
_1_shared_output = LlamaTransformerLayerInfer._ffn(self, _1_input1, infer_state1, layer_weight) | ||
|
||
# 0 moe calu | ||
_0_moe_out = layer_weight.experts.prefilled_group_gemm( | ||
_0_num_recv_tokens_per_expert_list, _0_recv_x, _0_recv_topk_idx, _0_recv_topk_weight | ||
) | ||
|
||
# 1 dispatch execute | ||
( | ||
_1_recv_x, | ||
_1_recv_topk_idx, | ||
_1_recv_topk_weight, | ||
_1_num_recv_tokens_per_expert_list, | ||
_1_handle, | ||
_1_hook, | ||
) = layer_weight.experts.dispatch(_1_qinput_tensor, _1_topk_idx, _1_topk_weight, overlap_event=_1_overlap_event) | ||
infer_state1.hook = _1_hook | ||
|
||
# wait 1 dispatch | ||
if getattr(infer_state1, "hook", None) is not None: | ||
infer_state1.hook() | ||
infer_state1.hook = None | ||
|
||
_0_combine_event = Buffer.capture() | ||
# 0 combine execute | ||
_0_ffn_out, _0_hook = layer_weight.experts.combine(_0_moe_out, _0_handle, _0_combine_event) | ||
infer_state.hook = _0_hook | ||
|
||
# 1 moe calc | ||
_1_moe_out = layer_weight.experts.prefilled_group_gemm( | ||
_1_num_recv_tokens_per_expert_list, _1_recv_x, _1_recv_topk_idx, _1_recv_topk_weight | ||
) | ||
|
||
# wait 0 combine | ||
if getattr(infer_state, "hook", None) is not None: | ||
infer_state.hook() | ||
infer_state.hook = None | ||
|
||
_1_combine_event = Buffer.capture() | ||
|
||
# _0_ffn_out *= self.routed_scaling_factor | ||
if self.n_shared_experts is not None: | ||
_0_ffn_out.add_(_0_shared_output) | ||
input_embdings.add_(_0_ffn_out.view(-1, self.embed_dim_)) | ||
|
||
# 1 combine execute | ||
_1_ffn_out, _1_hook = layer_weight.experts.combine(_1_moe_out, _1_handle, _1_combine_event) | ||
|
||
def _1_hook_post(): | ||
_1_hook() | ||
nonlocal _1_ffn_out | ||
# _1_ffn_out *= self.routed_scaling_factor | ||
if self.n_shared_experts is not None: | ||
_1_ffn_out.add_(_1_shared_output) | ||
input_embdings1.add_(_1_ffn_out.view(-1, self.embed_dim_)) | ||
return | ||
|
||
infer_state1.hook = _1_hook_post | ||
|
||
return input_embdings, input_embdings1 | ||
Comment on lines
+126
to
+395
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The functions Consider these steps:
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -369,7 +369,8 @@ def tppart_model_infer(args, model_kvargs, batch_size, input_len, output_len, an | |
enable_decode_overlap = args.enable_decode_microbatch_overlap | ||
group_size = 1 | ||
if enable_decode_overlap or args.enable_prefill_microbatch_overlap: | ||
assert batch_size % 2 == 0, "batch size must be even number" | ||
for bs in batch_size: | ||
assert bs % 2 == 0, "batch size must be even number" | ||
Comment on lines
+372
to
+373
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
group_size = 2 | ||
init_distributed_env(model_kvargs) | ||
dist_group_manager.create_groups(group_size=group_size) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this comment, as it seems to be a leftover from development and contains a typo ('disptatch' should be 'dispatch').