diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 2e01bc6e4..c1935ad89 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -13,6 +13,7 @@ from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd from functools import partial from lightllm.utils.log_utils import init_logger +from lightllm.utils.dist_utils import get_global_world_size logger = init_logger(__name__) @@ -27,6 +28,7 @@ def __init__(self, layer_num, network_config, mode=[]): ) self.num_experts_per_tok = network_config["num_experts_per_tok"] self.norm_topk_prob = network_config["norm_topk_prob"] + self.n_shared_experts = network_config.get("n_shared_experts", None) super().__init__(layer_num, network_config, mode) self.head_dim_ = network_config["head_dim"] self.tp_k_head_num_ = max(self.tp_k_head_num_, 1) @@ -120,3 +122,274 @@ def _moe_ffn_edp( ep_output = ep_output.view(token_num, hidden_dim) return ep_output + + def overlap_tpsp_token_forward( + self, + input_embdings: torch.Tensor, + input_embdings1: torch.Tensor, + infer_state: LlamaInferStateInfo, + infer_state1: LlamaInferStateInfo, + layer_weight: Qwen3MOETransformerLayerWeight, + ): + if not self.is_moe: + return super().overlap_tpsp_token_forward( + input_embdings, input_embdings1, infer_state, infer_state1, layer_weight + ) + # 0 attention + _0_input1 = self._att_norm(input_embdings, infer_state, layer_weight) + _0_cache_kv = self._pre_cache_kv(infer_state, layer_weight) + _0_q, _0_cache_kv = self._tpsp_get_qkv(_0_input1, _0_cache_kv, infer_state, layer_weight) + _0_input1 = None + self._post_cache_kv(_0_cache_kv, infer_state, layer_weight) + _0_o = self._token_attention_kernel(_0_q, infer_state, layer_weight) + _0_q = None + _0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight) + input_embdings.add_(_0_o.view(-1, self.embed_dim_)) + _0_o = None + _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + _0_router_logits = layer_weight.moe_gate.mm(_0_input1) + # 1 hook + if getattr(infer_state1, "hook", None) is not None: + infer_state1.hook() + infer_state1.hook = None + + # 0 shared expert + if self.n_shared_experts is not None: + _0_shared_output = LlamaTransformerLayerInfer._ffn(self, _0_input1, infer_state, layer_weight) + + # 0 dispatch + ( + _0_recv_x, + _0_masked_m, + _0_topk_idx, + _0_topk_weight, + _0_handle, + _0_hook, + ) = layer_weight.experts.low_latency_dispatch(_0_input1, _0_router_logits) + infer_state.hook = _0_hook + + # 1 attention + _1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight) + _1_cache_kv = self._pre_cache_kv(infer_state1, layer_weight) + _1_q, _1_cache_kv = self._tpsp_get_qkv(_1_input1, _1_cache_kv, infer_state1, layer_weight) + _1_input1 = None + self._post_cache_kv(_1_cache_kv, infer_state1, layer_weight) + _1_o = self._token_attention_kernel(_1_q, infer_state1, layer_weight) + _1_q = None + _1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight) + input_embdings1.add_(_1_o.view(-1, self.embed_dim_)) + _1_o = None + _1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight) + # to do gate and disptatch + + _1_router_logits = layer_weight.moe_gate.mm(_1_input1) + # 0 hook + if getattr(infer_state, "hook", None) is not None: + infer_state.hook() + infer_state.hook = None + + # 1 shared expert + if self.n_shared_experts is not None: + _1_shared_output = LlamaTransformerLayerInfer._ffn(self, _1_input1, infer_state1, layer_weight) + + # 1 dispatch + ( + _1_recv_x, + _1_masked_m, + _1_topk_idx, + _1_topk_weight, + _1_handle, + _1_hook, + ) = layer_weight.experts.low_latency_dispatch(_1_input1, _1_router_logits) + infer_state1.hook = _1_hook + + # moe calu + expected_m = triton.cdiv( + input_embdings.shape[0] * get_global_world_size() * self.num_experts_per_tok, self.n_routed_experts + ) + _0_moe_out = layer_weight.experts.masked_group_gemm(_0_recv_x, _0_masked_m, input_embdings.dtype, expected_m) + + # 1 hook + if getattr(infer_state1, "hook", None) is not None: + infer_state1.hook() + infer_state1.hook = None + + # 0 combine + _0_ffn_out, _0_hook = layer_weight.experts.low_latency_combine( + _0_moe_out, _0_topk_idx, _0_topk_weight, _0_handle + ) + + infer_state.hook = _0_hook + + # to do moe caclue + _1_moe_out = layer_weight.experts.masked_group_gemm(_1_recv_x, _1_masked_m, input_embdings1.dtype, expected_m) + + # 0 hook + if getattr(infer_state, "hook", None) is not None: + infer_state.hook() + # _0_ffn_out *= self.routed_scaling_factor + if self.n_shared_experts is not None: + _0_ffn_out.add_(_0_shared_output) + input_embdings.add_(_0_ffn_out.view(-1, self.embed_dim_)) + infer_state.hook = None + + # 1 combine + _1_ffn_out, _1_hook = layer_weight.experts.low_latency_combine( + _1_moe_out, _1_topk_idx, _1_topk_weight, _1_handle + ) + + def _1_hook_post(): + _1_hook() + nonlocal _1_ffn_out + # _1_ffn_out *= self.routed_scaling_factor + if self.n_shared_experts is not None: + _1_ffn_out.add_(_1_shared_output) + input_embdings1.add_(_1_ffn_out.view(-1, self.embed_dim_)) + return + + infer_state1.hook = _1_hook_post + + return input_embdings, input_embdings1 + + def overlap_tpsp_context_forward( + self, + input_embdings: torch.Tensor, + input_embdings1: torch.Tensor, + infer_state: LlamaInferStateInfo, + infer_state1: LlamaInferStateInfo, + layer_weight: Qwen3MOETransformerLayerWeight, + ): + if not self.is_moe: + return super().overlap_tpsp_context_forward( + input_embdings, input_embdings1, infer_state, infer_state1, layer_weight + ) + # 0 attention + _0_input1 = self._att_norm(input_embdings, infer_state, layer_weight) + _0_cache_kv = self._pre_cache_kv(infer_state, layer_weight) + _0_q, _0_cache_kv = self._tpsp_get_qkv(_0_input1, _0_cache_kv, infer_state, layer_weight) + _0_input1 = None + self._post_cache_kv(_0_cache_kv, infer_state, layer_weight) + _0_o = self._context_attention_kernel(_0_q, _0_cache_kv, infer_state, layer_weight) + _0_q = None + _0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight) + input_embdings.add_(_0_o.view(-1, self.embed_dim_)) + _0_o = None + _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + _0_router_logits = layer_weight.moe_gate.mm(_0_input1) + + # wait last 1 combine + if getattr(infer_state1, "hook", None) is not None: + infer_state1.hook() + infer_state1.hook = None + + _0_topk_weight, _0_topk_idx, _0_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( + _0_input1, _0_router_logits + ) + from deep_ep import Buffer + + _0_overlap_event = Buffer.capture() + + # 1 attention + _1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight) + _1_cache_kv = self._pre_cache_kv(infer_state1, layer_weight) + _1_q, _1_cache_kv = self._tpsp_get_qkv(_1_input1, _1_cache_kv, infer_state1, layer_weight) + _1_input1 = None + self._post_cache_kv(_1_cache_kv, infer_state1, layer_weight) + _1_o = self._context_attention_kernel(_1_q, _1_cache_kv, infer_state1, layer_weight) + _1_q = None + _1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight) + input_embdings1.add_(_1_o.view(-1, self.embed_dim_)) + _1_o = None + _1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight) + # to do gate and disptatch + + _1_router_logits = layer_weight.moe_gate.mm(_1_input1) + + # 0 dispatch execute + ( + _0_recv_x, + _0_recv_topk_idx, + _0_recv_topk_weight, + _0_num_recv_tokens_per_expert_list, + _0_handle, + _0_hook, + ) = layer_weight.experts.dispatch(_0_qinput_tensor, _0_topk_idx, _0_topk_weight, overlap_event=_0_overlap_event) + infer_state.hook = _0_hook + + # wait 0 dispatch + if getattr(infer_state, "hook", None) is not None: + infer_state.hook() + infer_state.hook = None + + _1_topk_weight, _1_topk_idx, _1_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( + _1_input1, _1_router_logits + ) + + _1_overlap_event = Buffer.capture() + + # 0 shared expert + if self.n_shared_experts is not None: + _0_shared_output = LlamaTransformerLayerInfer._ffn(self, _0_input1, infer_state, layer_weight) + + # 1 shared expert + if self.n_shared_experts is not None: + _1_shared_output = LlamaTransformerLayerInfer._ffn(self, _1_input1, infer_state1, layer_weight) + + # 0 moe calu + _0_moe_out = layer_weight.experts.prefilled_group_gemm( + _0_num_recv_tokens_per_expert_list, _0_recv_x, _0_recv_topk_idx, _0_recv_topk_weight + ) + + # 1 dispatch execute + ( + _1_recv_x, + _1_recv_topk_idx, + _1_recv_topk_weight, + _1_num_recv_tokens_per_expert_list, + _1_handle, + _1_hook, + ) = layer_weight.experts.dispatch(_1_qinput_tensor, _1_topk_idx, _1_topk_weight, overlap_event=_1_overlap_event) + infer_state1.hook = _1_hook + + # wait 1 dispatch + if getattr(infer_state1, "hook", None) is not None: + infer_state1.hook() + infer_state1.hook = None + + _0_combine_event = Buffer.capture() + # 0 combine execute + _0_ffn_out, _0_hook = layer_weight.experts.combine(_0_moe_out, _0_handle, _0_combine_event) + infer_state.hook = _0_hook + + # 1 moe calc + _1_moe_out = layer_weight.experts.prefilled_group_gemm( + _1_num_recv_tokens_per_expert_list, _1_recv_x, _1_recv_topk_idx, _1_recv_topk_weight + ) + + # wait 0 combine + if getattr(infer_state, "hook", None) is not None: + infer_state.hook() + infer_state.hook = None + + _1_combine_event = Buffer.capture() + + # _0_ffn_out *= self.routed_scaling_factor + if self.n_shared_experts is not None: + _0_ffn_out.add_(_0_shared_output) + input_embdings.add_(_0_ffn_out.view(-1, self.embed_dim_)) + + # 1 combine execute + _1_ffn_out, _1_hook = layer_weight.experts.combine(_1_moe_out, _1_handle, _1_combine_event) + + def _1_hook_post(): + _1_hook() + nonlocal _1_ffn_out + # _1_ffn_out *= self.routed_scaling_factor + if self.n_shared_experts is not None: + _1_ffn_out.add_(_1_shared_output) + input_embdings1.add_(_1_ffn_out.view(-1, self.embed_dim_)) + return + + infer_state1.hook = _1_hook_post + + return input_embdings, input_embdings1 diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index 3db0ac004..73a99ff28 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -369,7 +369,8 @@ def tppart_model_infer(args, model_kvargs, batch_size, input_len, output_len, an enable_decode_overlap = args.enable_decode_microbatch_overlap group_size = 1 if enable_decode_overlap or args.enable_prefill_microbatch_overlap: - assert batch_size % 2 == 0, "batch size must be even number" + for bs in batch_size: + assert bs % 2 == 0, "batch size must be even number" group_size = 2 init_distributed_env(model_kvargs) dist_group_manager.create_groups(group_size=group_size)