Skip to content

Conversation

@dragondream-chen
Copy link

@dragondream-chen dragondream-chen commented Sep 2, 2025

Purpose

In the feature of EPLB(Experts Load Balance), the PR optimizes the update method for expert load during each forward. The current approach is using the scatter_add_ method based on topk_ids results. When using DeepEP Low-Latency or PPLX on the CUDA platform, expert loads can be obtained directly from expert_tokens_meta.expert_num_tokens, which reduces redundant calculations on the expert load.

Test Plan

  1. Test expert load update
    Since the use of the kernel, such as DeepEP Low-Latency or PPLX, leads to some changes in the inference process, the precision of intermediate values cannot be fully aligned. We directly illustrate whether the expert load update function is functioning properly by comparing the load imbalance of one layer model.
    We add the following code in vllm/distributed/eplb/eplb_state.py for data collection.
if global_expert_load is None:
     physical_expert_load_window = self.expert_load_window.clone()
     global_physical_load_window = physical_expert_load_window.sum(dim=0)
     all_reduce(global_physical_load_window, group=ep_group)
if is_main_rank:
    global_num_experts = 96
    ep_size = ep_group.size()
    all_rank_node = []
    for ep_r in range(ep_size):
        base_experts = global_num_experts // ep_size
        remainder = global_num_experts % ep_size
        if er < remainder:
            local_num_experts = base_experts + 1
        else:
            local_num_experts = base_experts
        # Create a tensor of size num_experts filled with -1
        expert_map = torch.zeros(global_num_experts, device=global_physical_load_window.device, dtype=global_physical_load_window.dtype)
        # Create a expert map for the local experts
        start_idx = ep_r * base_experts + min(ep_r, remainder)
        expert_map[start_idx:start_idx + local_num_experts] = 1
        
        # [layers, phy_num] * [phy_num,]
        local_load = (global_physical_load_window * expert_map).sum(1).unsqueeze(1)
        all_rank_node.append(local_load)
    all_ranks = torch.cat(all_rank_node, dim=1).float() # [layers, ep_size]: [26, 8]
    max_ranks, _ = torch.max(all_ranks, dim=1) # [layers]
    mean_ranks = torch.mean(all_ranks, dim=1) # [layers]
    unbanlance = torch.div(max_ranks, mean_ranks) # [layers]
    logger.debug(f" | unbanlance : {unbanlance.cpu().tolist()}")
    for i in range(len(unbanlance)):
        logger.debug(f" | unbanlance layer {i} : {unbanlance[i]} ")
  1. Test performance
    Test the average time for a single update of the expert.

Test Result

image The blue curve represents the state before modification, while the red curve represents the state after modification. From the graph, it can be observed that the degree of imbalance is essentially similar.

The average time of update expert_load_view for one layer in every batch:
Before modification: 0.058ms
After modification: 0.015ms

Co-author

Co-authored-by: Skywalker-EP [email protected]

@mergify
Copy link

mergify bot commented Sep 2, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @dragondream-chen.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 2, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization for the Expert Parallelism Load Balancer (EPLB) by avoiding scatter_add_ for expert load calculation when a more direct method is available. The core idea is sound and should improve performance. However, the implementation introduces code duplication and some brittle logic that could be improved for better long-term maintainability. My review focuses on refactoring these areas to make the code more robust and easier to maintain.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization for EPLB by allowing expert load updates to bypass the scatter_add_ operation when using specific modular kernels. The changes are logical and well-contained. However, I've identified a critical bug in the implementation that would prevent this optimization from ever being activated. Additionally, there are a couple of high-severity maintainability issues related to code duplication and a local import that should be addressed.

@github-actions
Copy link

github-actions bot commented Sep 2, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@dragondream-chen
Copy link
Author

Hi @simon-mo and @khluu,

I just submitted my first PR to VLLM. I don’t have permission to unblock additional CI tests on Buildkite (only fastcheck runs by default). Could you help add me to the vLLM Buildkite org so I can trigger full CI?

Thanks for your help!

@mergify
Copy link

mergify bot commented Sep 4, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @dragondream-chen.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@abmfy
Copy link
Member

abmfy commented Sep 16, 2025

Since #22167, we have been collecting the global physical expert loads, so the overhead of masking out non-local experts has already been removed.

In addition, please see #24573, where the unnecessary local mask was removed. After that change, the only remaining operation for expert load collection is a scatter_add_, which is essentially the same as what this PR proposes.

@robertgshaw2-redhat robertgshaw2-redhat changed the title [Perf] EPLB optimize export_load_view update [EPLB]: Optimize export_load_view update Sep 16, 2025
@dragondream-chen
Copy link
Author

Since #22167, we have been collecting the global physical expert loads, so the overhead of masking out non-local experts has already been removed.

In addition, please see #24573, where the unnecessary local mask was removed. After that change, the only remaining operation for expert load collection is a scatter_add_, which is essentially the same as what this PR proposes.

PR 24573 does not conflict with our optimization for the expert_load_view update. This PR aligns with SGLang's approach and addresses the TODO of expert_load_view.
We directly updates the expert_load_view using expert_num_tokens obtained from FusedMoEPrepareAndFinalize. This method can reduce some additional processing on topk_ids, and the dimension handled by scatter_add_ is reduced from batch_size x expert_num to physical_expert_num.
Further discussion on this would be greatly appreciated!

@mergify
Copy link

mergify bot commented Sep 18, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @dragondream-chen.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 18, 2025
@abmfy
Copy link
Member

abmfy commented Oct 6, 2025

Sorry for the delay.

I see your point. You're right that previously we needed to perform scatter_add_ for batch_size * topk times, whereas now it's just a simple tensor addition over num_local_experts, which should indeed be faster.

To proceed, I think we should do the following:

  1. Make it more general. This PR currently hardcodes DeepEPHTPrepareAndFinalize to use this simplified path, but in fact DeepLLPrepareAndFinalize, BatchedPrepareAndFinalize, and PplxPrepareAndFinalize all provide expert token count metadata. We can add a method in the ABC mk.FusedMoEPrepareAndFinalize that indicates whether a prepare-and-finalize kernel returns expert token count metadata. By default, it can return False, and we can override it to return True in the aforementioned classes.

  2. Simplify token count addition. The token count in expert_tokens_metadata maps to local experts, so we can perform a direct tensor addition instead of scatter_add_. Specifically, use integer indexing like:

    expert_load_view[physical_expert_start:physical_expert_end] += expert_tokens_metadata.expert_num_tokens
    • The multiple abstraction layers (FusedMoEQuantMethodFusedMoEModularKernel) make it tricky to pass information directly. If you can find a better way to pass start/end indexes directly, that would be preferable to inferring them from expert_map every time.
    • Avoid using torch.arange for indexing here, as that effectively becomes a scatter_add_ again and would be slower than simple indexing.
  3. Benchmarking. Provide benchmark results. Make sure to lower eplb_config.step_interval and test on large datasets like ShareGPT (not just a few prompts) to properly trigger EPLB, e.g., run for 100 steps.

  4. Accuracy test. Test GSM8k before and after this PR to verify correctness.

Thanks again for your contribution!

@dragondream-chen
Copy link
Author

Thanks for your thoughts and suggestions! Next, we’ll refactor this section—we’ll add a ABC mk.FusedMoEPrepareAndFinalize for different kernels. We’ll compare the update approach for expert_load_view with the scatrer_add, and ensure accuracy and performance through benchmarking and accuracy test.

@mergify
Copy link

mergify bot commented Oct 12, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @dragondream-chen.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the kv-connector label Oct 17, 2025
@mergify mergify bot removed the tpu Related to Google TPUs label Oct 17, 2025
Signed-off-by: daishixun <[email protected]>
Signed-off-by: daishixun <[email protected]>
@hmellor
Copy link
Member

hmellor commented Oct 17, 2025

In future please take more care with git and force push. Almost all maintainers are now subcribed to this PR because of it.

Also, please:

  • Fix pre-commit
    uv pip install pre-commit
    pre-commit install
    pre-commit run -a
  • Merge from main to fix the doc issue

Mercykid-bash and others added 2 commits October 24, 2025 15:29
Signed-off-by: Mercykid-bash <[email protected]>
@Mercykid-bash
Copy link

For the token count addition simplification, we've replaced scatter_add_ with direct tensor addition. Here, physical_expert_start and physical_expert_end are now derived using the rank ID and local expert count, avoiding inference from expert_map or torch.arange usage.

Initial tests have been completed to verify correctness. We'll follow up with the Benchmarking and Accuracy test experiments shortly.

@dsxsteven
Copy link

dsxsteven commented Oct 31, 2025

Sorry for the delay.

I see your point. You're right that previously we needed to perform scatter_add_ for batch_size * topk times, whereas now it's just a simple tensor addition over num_local_experts, which should indeed be faster.

To proceed, I think we should do the following:

  1. Make it more general. This PR currently hardcodes DeepEPHTPrepareAndFinalize to use this simplified path, but in fact DeepLLPrepareAndFinalize, BatchedPrepareAndFinalize, and PplxPrepareAndFinalize all provide expert token count metadata. We can add a method in the ABC mk.FusedMoEPrepareAndFinalize that indicates whether a prepare-and-finalize kernel returns expert token count metadata. By default, it can return False, and we can override it to return True in the aforementioned classes.

  2. Simplify token count addition. The token count in expert_tokens_metadata maps to local experts, so we can perform a direct tensor addition instead of scatter_add_. Specifically, use integer indexing like:

    expert_load_view[physical_expert_start:physical_expert_end] += expert_tokens_metadata.expert_num_tokens
    • The multiple abstraction layers (FusedMoEQuantMethodFusedMoEModularKernel) make it tricky to pass information directly. If you can find a better way to pass start/end indexes directly, that would be preferable to inferring them from expert_map every time.
    • Avoid using torch.arange for indexing here, as that effectively becomes a scatter_add_ again and would be slower than simple indexing.
  3. Benchmarking. Provide benchmark results. Make sure to lower eplb_config.step_interval and test on large datasets like ShareGPT (not just a few prompts) to properly trigger EPLB, e.g., run for 100 steps.

  4. Accuracy test. Test GSM8k before and after this PR to verify correctness.

Thanks again for your contribution!

Hi, I used the Deepseek-V2-Lite model and gsm8k dataset for accuracy test. However, after enabling PPLX, there were accuracy issues. Before and after enabling PPLX, the accuracy was only around 2%. Below are the results from my AISbench test.

Accuracy before PR
image

Accuracy after PR
image

@dsxsteven
Copy link

dsxsteven commented Nov 4, 2025

I used the GSM8k dataset for performance testing and lowered "window_size":1000,"step_interval":300 to trigger EPLB.
As you can see, after enabling our pull request, both TTFT and TPOT decreased, resulting in a performance gain.

performance before PR
image

performance after PR
image

@Skywalker-EP
Copy link

@tlrmchlsmth Long tim no reply for this PR. Could you please take a look when you have time.

@abmfy
Copy link
Member

abmfy commented Nov 11, 2025

Hi, @Skywalker-EP — could you please share the current status? Does the accuracy issue still persist?

@dsxsteven
Copy link

dsxsteven commented Nov 11, 2025

Hi, @Skywalker-EP — could you please share the current status? Does the accuracy issue still persist?

Hi, the accuracy issue does still persist.

I pulled the latest code(without our pr), enabled the deepep backend, and tested the accuracy of the qwen3 model. However, the accuracy is still problematic (the accuracy is 96% without deepep), and curl shows only garbled characters.

image image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build deepseek Related to DeepSeek models documentation Improvements or additions to documentation eplb frontend gpt-oss Related to GPT-OSS models kv-connector llama Related to Llama models multi-modality Related to multi-modality (#4194) new-model Requests to new models performance Performance-related issues qwen Related to Qwen models rocm Related to AMD ROCm speculative-decoding structured-output tool-calling v1

Projects

Status: No status
Status: No status
Status: To Triage

Development

Successfully merging this pull request may close these issues.

8 participants