-
Notifications
You must be signed in to change notification settings - Fork 269
support qwen3moe overlap mode #974
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @sufubao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly enhances the Qwen3-MoE model's inference capabilities by introducing an 'overlap mode'. This mode optimizes performance by executing attention and MoE FFN computations concurrently across different inputs, effectively reducing overall latency. The changes also include necessary configuration updates for shared experts and a corresponding adjustment to the benchmarking suite to properly test this new, more efficient inference pathway.
Highlights
- New Overlap Inference Mode for Qwen3-MoE: Introduced
overlap_tpsp_token_forward
andoverlap_tpsp_context_forward
methods in the Qwen3-MoE transformer layer. These methods enable overlapping computation stages for token and context generation, respectively, aiming to improve inference performance by hiding latency in Mixture-of-Experts (MoE) operations. - Support for Shared Experts: Added
n_shared_experts
as a configurable parameter to the Qwen3-MoE transformer layer. This allows the new overlap mode to properly handle models that incorporate shared experts within their MoE architecture. - Benchmark Update for Overlap Mode: Modified the static inference benchmark script to validate that all specified batch sizes are even when overlap modes are enabled. This ensures compatibility with the dual-stream processing nature of the new overlap inference implementation.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
The pull request introduces an overlap mode for qwen3_moe
models. The new methods in transformer_layer_infer.py
need refactoring to reduce complexity and code duplication. Also, the local import of deep_ep
needs to be addressed, and leftover comments/code should be removed. Finally, the assertion message in model_infer.py
should be clarified.
_1_input1 = None | ||
self._post_cache_kv(_1_cache_kv, infer_state1, layer_weight) | ||
_1_o = self._context_attention_kernel(_1_q, _1_cache_kv, infer_state1, layer_weight) | ||
_1_q = None | ||
_1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight) | ||
input_embdings1.add_(_1_o.view(-1, self.embed_dim_)) | ||
_1_o = None | ||
_1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight) | ||
# to do gate and disptatch | ||
|
||
_1_router_logits = layer_weight.moe_gate.mm(_1_input1) | ||
|
||
# 0 dispatch execute | ||
( | ||
_0_recv_x, | ||
_0_recv_topk_idx, | ||
_0_recv_topk_weight, | ||
_0_num_recv_tokens_per_expert_list, | ||
_0_handle, | ||
_0_hook, | ||
) = layer_weight.experts.dispatch(_0_qinput_tensor, _0_topk_idx, _0_topk_weight, overlap_event=_0_overlap_event) | ||
infer_state.hook = _0_hook | ||
|
||
# wait 0 dispatch | ||
if getattr(infer_state, "hook", None) is not None: | ||
infer_state.hook() | ||
infer_state.hook = None | ||
|
||
_1_topk_weight, _1_topk_idx, _1_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( | ||
_1_input1, _1_router_logits | ||
) | ||
|
||
_1_overlap_event = Buffer.capture() | ||
|
||
# 0 shared expert | ||
if self.n_shared_experts is not None: | ||
_0_shared_output = LlamaTransformerLayerInfer._ffn(self, _0_input1, infer_state, layer_weight) | ||
|
||
# 1 shared expert | ||
if self.n_shared_experts is not None: | ||
_1_shared_output = LlamaTransformerLayerInfer._ffn(self, _1_input1, infer_state1, layer_weight) | ||
|
||
# 0 moe calu | ||
_0_moe_out = layer_weight.experts.prefilled_group_gemm( | ||
_0_num_recv_tokens_per_expert_list, _0_recv_x, _0_recv_topk_idx, _0_recv_topk_weight | ||
) | ||
|
||
# 1 dispatch execute | ||
( | ||
_1_recv_x, | ||
_1_recv_topk_idx, | ||
_1_recv_topk_weight, | ||
_1_num_recv_tokens_per_expert_list, | ||
_1_handle, | ||
_1_hook, | ||
) = layer_weight.experts.dispatch(_1_qinput_tensor, _1_topk_idx, _1_topk_weight, overlap_event=_1_overlap_event) | ||
infer_state1.hook = _1_hook | ||
|
||
# wait 1 dispatch | ||
if getattr(infer_state1, "hook", None) is not None: | ||
infer_state1.hook() | ||
infer_state1.hook = None | ||
|
||
_0_combine_event = Buffer.capture() | ||
# 0 combine execute | ||
_0_ffn_out, _0_hook = layer_weight.experts.combine(_0_moe_out, _0_handle, _0_combine_event) | ||
infer_state.hook = _0_hook | ||
|
||
# 1 moe calc | ||
_1_moe_out = layer_weight.experts.prefilled_group_gemm( | ||
_1_num_recv_tokens_per_expert_list, _1_recv_x, _1_recv_topk_idx, _1_recv_topk_weight | ||
) | ||
|
||
# wait 0 combine | ||
if getattr(infer_state, "hook", None) is not None: | ||
infer_state.hook() | ||
infer_state.hook = None | ||
|
||
_1_combine_event = Buffer.capture() | ||
|
||
# _0_ffn_out *= self.routed_scaling_factor | ||
if self.n_shared_experts is not None: | ||
_0_ffn_out.add_(_0_shared_output) | ||
input_embdings.add_(_0_ffn_out.view(-1, self.embed_dim_)) | ||
|
||
# 1 combine execute | ||
_1_ffn_out, _1_hook = layer_weight.experts.combine(_1_moe_out, _1_handle, _1_combine_event) | ||
|
||
def _1_hook_post(): | ||
_1_hook() | ||
nonlocal _1_ffn_out | ||
# _1_ffn_out *= self.routed_scaling_factor | ||
if self.n_shared_experts is not None: | ||
_1_ffn_out.add_(_1_shared_output) | ||
input_embdings1.add_(_1_ffn_out.view(-1, self.embed_dim_)) | ||
return | ||
|
||
infer_state1.hook = _1_hook_post | ||
|
||
return input_embdings, input_embdings1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The functions overlap_tpsp_token_forward
and overlap_tpsp_context_forward
are lengthy and contain duplicated logic. Refactor to improve readability and maintainability.
Consider these steps:
- Extract common logic into helper methods.
- Create a generic overlap function that accepts attention kernels as parameters.
- Encapsulate hook management in a helper function like
_execute_hook_if_exists(infer_state)
.
_0_topk_weight, _0_topk_idx, _0_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( | ||
_0_input1, _0_router_logits | ||
) | ||
from deep_ep import Buffer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_embdings1.add_(_1_o.view(-1, self.embed_dim_)) | ||
_1_o = None | ||
_1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight) | ||
# to do gate and disptatch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# 0 hook | ||
if getattr(infer_state, "hook", None) is not None: | ||
infer_state.hook() | ||
# _0_ffn_out *= self.routed_scaling_factor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for bs in batch_size: | ||
assert bs % 2 == 0, "batch size must be even number" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No description provided.