Skip to content

support qwen3moe overlap mode #974

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 273 additions & 0 deletions lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd
from functools import partial
from lightllm.utils.log_utils import init_logger
from lightllm.utils.dist_utils import get_global_world_size

logger = init_logger(__name__)

Expand All @@ -27,6 +28,7 @@ def __init__(self, layer_num, network_config, mode=[]):
)
self.num_experts_per_tok = network_config["num_experts_per_tok"]
self.norm_topk_prob = network_config["norm_topk_prob"]
self.n_shared_experts = network_config.get("n_shared_experts", None)
super().__init__(layer_num, network_config, mode)
self.head_dim_ = network_config["head_dim"]
self.tp_k_head_num_ = max(self.tp_k_head_num_, 1)
Expand Down Expand Up @@ -120,3 +122,274 @@ def _moe_ffn_edp(

ep_output = ep_output.view(token_num, hidden_dim)
return ep_output

def overlap_tpsp_token_forward(
self,
input_embdings: torch.Tensor,
input_embdings1: torch.Tensor,
infer_state: LlamaInferStateInfo,
infer_state1: LlamaInferStateInfo,
layer_weight: Qwen3MOETransformerLayerWeight,
):
if not self.is_moe:
return super().overlap_tpsp_token_forward(
input_embdings, input_embdings1, infer_state, infer_state1, layer_weight
)
# 0 attention
_0_input1 = self._att_norm(input_embdings, infer_state, layer_weight)
_0_cache_kv = self._pre_cache_kv(infer_state, layer_weight)
_0_q, _0_cache_kv = self._tpsp_get_qkv(_0_input1, _0_cache_kv, infer_state, layer_weight)
_0_input1 = None
self._post_cache_kv(_0_cache_kv, infer_state, layer_weight)
_0_o = self._token_attention_kernel(_0_q, infer_state, layer_weight)
_0_q = None
_0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight)
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
_0_o = None
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
_0_router_logits = layer_weight.moe_gate.mm(_0_input1)
# 1 hook
if getattr(infer_state1, "hook", None) is not None:
infer_state1.hook()
infer_state1.hook = None

# 0 shared expert
if self.n_shared_experts is not None:
_0_shared_output = LlamaTransformerLayerInfer._ffn(self, _0_input1, infer_state, layer_weight)

# 0 dispatch
(
_0_recv_x,
_0_masked_m,
_0_topk_idx,
_0_topk_weight,
_0_handle,
_0_hook,
) = layer_weight.experts.low_latency_dispatch(_0_input1, _0_router_logits)
infer_state.hook = _0_hook

# 1 attention
_1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight)
_1_cache_kv = self._pre_cache_kv(infer_state1, layer_weight)
_1_q, _1_cache_kv = self._tpsp_get_qkv(_1_input1, _1_cache_kv, infer_state1, layer_weight)
_1_input1 = None
self._post_cache_kv(_1_cache_kv, infer_state1, layer_weight)
_1_o = self._token_attention_kernel(_1_q, infer_state1, layer_weight)
_1_q = None
_1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight)
input_embdings1.add_(_1_o.view(-1, self.embed_dim_))
_1_o = None
_1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight)
# to do gate and disptatch

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Remove this comment, as it seems to be a leftover from development and contains a typo ('disptatch' should be 'dispatch').


_1_router_logits = layer_weight.moe_gate.mm(_1_input1)
# 0 hook
if getattr(infer_state, "hook", None) is not None:
infer_state.hook()
infer_state.hook = None

# 1 shared expert
if self.n_shared_experts is not None:
_1_shared_output = LlamaTransformerLayerInfer._ffn(self, _1_input1, infer_state1, layer_weight)

# 1 dispatch
(
_1_recv_x,
_1_masked_m,
_1_topk_idx,
_1_topk_weight,
_1_handle,
_1_hook,
) = layer_weight.experts.low_latency_dispatch(_1_input1, _1_router_logits)
infer_state1.hook = _1_hook

# moe calu
expected_m = triton.cdiv(
input_embdings.shape[0] * get_global_world_size() * self.num_experts_per_tok, self.n_routed_experts
)
_0_moe_out = layer_weight.experts.masked_group_gemm(_0_recv_x, _0_masked_m, input_embdings.dtype, expected_m)

# 1 hook
if getattr(infer_state1, "hook", None) is not None:
infer_state1.hook()
infer_state1.hook = None

# 0 combine
_0_ffn_out, _0_hook = layer_weight.experts.low_latency_combine(
_0_moe_out, _0_topk_idx, _0_topk_weight, _0_handle
)

infer_state.hook = _0_hook

# to do moe caclue
_1_moe_out = layer_weight.experts.masked_group_gemm(_1_recv_x, _1_masked_m, input_embdings1.dtype, expected_m)

# 0 hook
if getattr(infer_state, "hook", None) is not None:
infer_state.hook()
# _0_ffn_out *= self.routed_scaling_factor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line is commented out. Either uncomment it if routed_scaling_factor should be applied, or remove the line if it's obsolete.

if self.n_shared_experts is not None:
_0_ffn_out.add_(_0_shared_output)
input_embdings.add_(_0_ffn_out.view(-1, self.embed_dim_))
infer_state.hook = None

# 1 combine
_1_ffn_out, _1_hook = layer_weight.experts.low_latency_combine(
_1_moe_out, _1_topk_idx, _1_topk_weight, _1_handle
)

def _1_hook_post():
_1_hook()
nonlocal _1_ffn_out
# _1_ffn_out *= self.routed_scaling_factor
if self.n_shared_experts is not None:
_1_ffn_out.add_(_1_shared_output)
input_embdings1.add_(_1_ffn_out.view(-1, self.embed_dim_))
return

infer_state1.hook = _1_hook_post

return input_embdings, input_embdings1

def overlap_tpsp_context_forward(
self,
input_embdings: torch.Tensor,
input_embdings1: torch.Tensor,
infer_state: LlamaInferStateInfo,
infer_state1: LlamaInferStateInfo,
layer_weight: Qwen3MOETransformerLayerWeight,
):
if not self.is_moe:
return super().overlap_tpsp_context_forward(
input_embdings, input_embdings1, infer_state, infer_state1, layer_weight
)
# 0 attention
_0_input1 = self._att_norm(input_embdings, infer_state, layer_weight)
_0_cache_kv = self._pre_cache_kv(infer_state, layer_weight)
_0_q, _0_cache_kv = self._tpsp_get_qkv(_0_input1, _0_cache_kv, infer_state, layer_weight)
_0_input1 = None
self._post_cache_kv(_0_cache_kv, infer_state, layer_weight)
_0_o = self._context_attention_kernel(_0_q, _0_cache_kv, infer_state, layer_weight)
_0_q = None
_0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight)
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
_0_o = None
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
_0_router_logits = layer_weight.moe_gate.mm(_0_input1)

# wait last 1 combine
if getattr(infer_state1, "hook", None) is not None:
infer_state1.hook()
infer_state1.hook = None

_0_topk_weight, _0_topk_idx, _0_qinput_tensor = layer_weight.experts.select_experts_and_quant_input(
_0_input1, _0_router_logits
)
from deep_ep import Buffer

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Move this local import of deep_ep to the top of the file. If deep_ep is optional, wrap the import in a try...except ImportError block.


_0_overlap_event = Buffer.capture()

# 1 attention
_1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight)
_1_cache_kv = self._pre_cache_kv(infer_state1, layer_weight)
_1_q, _1_cache_kv = self._tpsp_get_qkv(_1_input1, _1_cache_kv, infer_state1, layer_weight)
_1_input1 = None
self._post_cache_kv(_1_cache_kv, infer_state1, layer_weight)
_1_o = self._context_attention_kernel(_1_q, _1_cache_kv, infer_state1, layer_weight)
_1_q = None
_1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight)
input_embdings1.add_(_1_o.view(-1, self.embed_dim_))
_1_o = None
_1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight)
# to do gate and disptatch

_1_router_logits = layer_weight.moe_gate.mm(_1_input1)

# 0 dispatch execute
(
_0_recv_x,
_0_recv_topk_idx,
_0_recv_topk_weight,
_0_num_recv_tokens_per_expert_list,
_0_handle,
_0_hook,
) = layer_weight.experts.dispatch(_0_qinput_tensor, _0_topk_idx, _0_topk_weight, overlap_event=_0_overlap_event)
infer_state.hook = _0_hook

# wait 0 dispatch
if getattr(infer_state, "hook", None) is not None:
infer_state.hook()
infer_state.hook = None

_1_topk_weight, _1_topk_idx, _1_qinput_tensor = layer_weight.experts.select_experts_and_quant_input(
_1_input1, _1_router_logits
)

_1_overlap_event = Buffer.capture()

# 0 shared expert
if self.n_shared_experts is not None:
_0_shared_output = LlamaTransformerLayerInfer._ffn(self, _0_input1, infer_state, layer_weight)

# 1 shared expert
if self.n_shared_experts is not None:
_1_shared_output = LlamaTransformerLayerInfer._ffn(self, _1_input1, infer_state1, layer_weight)

# 0 moe calu
_0_moe_out = layer_weight.experts.prefilled_group_gemm(
_0_num_recv_tokens_per_expert_list, _0_recv_x, _0_recv_topk_idx, _0_recv_topk_weight
)

# 1 dispatch execute
(
_1_recv_x,
_1_recv_topk_idx,
_1_recv_topk_weight,
_1_num_recv_tokens_per_expert_list,
_1_handle,
_1_hook,
) = layer_weight.experts.dispatch(_1_qinput_tensor, _1_topk_idx, _1_topk_weight, overlap_event=_1_overlap_event)
infer_state1.hook = _1_hook

# wait 1 dispatch
if getattr(infer_state1, "hook", None) is not None:
infer_state1.hook()
infer_state1.hook = None

_0_combine_event = Buffer.capture()
# 0 combine execute
_0_ffn_out, _0_hook = layer_weight.experts.combine(_0_moe_out, _0_handle, _0_combine_event)
infer_state.hook = _0_hook

# 1 moe calc
_1_moe_out = layer_weight.experts.prefilled_group_gemm(
_1_num_recv_tokens_per_expert_list, _1_recv_x, _1_recv_topk_idx, _1_recv_topk_weight
)

# wait 0 combine
if getattr(infer_state, "hook", None) is not None:
infer_state.hook()
infer_state.hook = None

_1_combine_event = Buffer.capture()

# _0_ffn_out *= self.routed_scaling_factor
if self.n_shared_experts is not None:
_0_ffn_out.add_(_0_shared_output)
input_embdings.add_(_0_ffn_out.view(-1, self.embed_dim_))

# 1 combine execute
_1_ffn_out, _1_hook = layer_weight.experts.combine(_1_moe_out, _1_handle, _1_combine_event)

def _1_hook_post():
_1_hook()
nonlocal _1_ffn_out
# _1_ffn_out *= self.routed_scaling_factor
if self.n_shared_experts is not None:
_1_ffn_out.add_(_1_shared_output)
input_embdings1.add_(_1_ffn_out.view(-1, self.embed_dim_))
return

infer_state1.hook = _1_hook_post

return input_embdings, input_embdings1
Comment on lines +126 to +395

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The functions overlap_tpsp_token_forward and overlap_tpsp_context_forward are lengthy and contain duplicated logic. Refactor to improve readability and maintainability.

Consider these steps:

  • Extract common logic into helper methods.
  • Create a generic overlap function that accepts attention kernels as parameters.
  • Encapsulate hook management in a helper function like _execute_hook_if_exists(infer_state).

3 changes: 2 additions & 1 deletion test/benchmark/static_inference/model_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,8 @@ def tppart_model_infer(args, model_kvargs, batch_size, input_len, output_len, an
enable_decode_overlap = args.enable_decode_microbatch_overlap
group_size = 1
if enable_decode_overlap or args.enable_prefill_microbatch_overlap:
assert batch_size % 2 == 0, "batch size must be even number"
for bs in batch_size:
assert bs % 2 == 0, "batch size must be even number"
Comment on lines +372 to +373

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion message is not clear. Clarify the assertion message to indicate which batch size is not an even number.

assert bs % 2 == 0, f"Batch size {bs} must be an even number"

group_size = 2
init_distributed_env(model_kvargs)
dist_group_manager.create_groups(group_size=group_size)
Expand Down