-
Notifications
You must be signed in to change notification settings - Fork 171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Balance Loss to MoE Example for Enhanced Expert Load Distribution (Issue #1300) #1311
Add Balance Loss to MoE Example for Enhanced Expert Load Distribution (Issue #1300) #1311
Conversation
@skydoorkai @adamantboy @hxdtest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #1311 +/- ##
=======================================
Coverage 80.51% 80.52%
=======================================
Files 222 222
Lines 20698 20707 +9
=======================================
+ Hits 16666 16674 +8
- Misses 4032 4033 +1 ☔ View full report in Codecov by Sentry. |
@@ -162,7 +182,8 @@ def forward(self, hidden_states): | |||
if self.shared_experts is not None and not self.use_expert_parallelism: | |||
hidden_states = hidden_states + self.shared_experts(identify) | |||
|
|||
return hidden_states | |||
# Return the auxiliary loss along with the hidden states | |||
return hidden_states, aux_loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This example is based on transformers llama model.
llama mlp definition only returns hidden_states, return tuple won't work.
|
||
return router_probs, router_logits, topk_experts_index, aux_loss | ||
|
||
def _compute_auxiliary_loss(self, router_probs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to implement a real aux loss that works in this example, not a placeholder.
Add Balance Loss to MoE Example for Enhanced Expert Load Distribution (Issue #1300)
What changes were proposed in this pull request?
This pull request proposes the integration of a balance loss mechanism into the Mixture-of-Experts (MoE) example in the
atorch
codebase. Specifically, an auxiliary loss has been added to theTopNRouter
class to facilitate balanced load distribution across experts, improving model performance and efficiency. Key modifications include:Router Updates:
TopNRouter
class to compute and return an auxiliary loss based on router probabilities, helping distribute tokens more evenly._compute_auxiliary_loss()
method to calculate this auxiliary loss, which is currently set up to use the mean of router probabilities as a placeholder. This can be customized based on specific balancing requirements.MoE Layer Enhancements:
_SparseMLP
class to incorporate auxiliary loss from the router and propagate it back, enhancing the MoE layer’s load-balancing capabilities.Training Loop Modifications:
Auxiliary Loss Weight Configurability:
--aux_loss_weight
, to allow users to adjust the weight of the auxiliary loss as needed. This flexibility enables fine-tuning of the loss function based on model requirements.Why are the changes needed?
These changes address issue #1300 by introducing an auxiliary balance loss mechanism to the MoE example, which aims to improve the distribution of workload across experts. In multi-expert architectures, load imbalances can lead to inefficiencies and underutilized resources, ultimately impacting model performance. The proposed auxiliary loss provides a straightforward way to mitigate these imbalances, enhancing both training efficiency and overall model effectiveness.
Does this PR introduce any user-facing change?
Yes, this PR introduces a new command-line argument,
--aux_loss_weight
, which allows users to adjust the weight of the auxiliary loss as needed. By default, it is set to 0.01 but can be configured according to specific model or training needs.How was this patch tested?
The patch was tested through the following steps:
TopNRouter
class and is correctly incorporated into the MoE layer.--aux_loss_weight
argument correctly adjusts the auxiliary loss weight in different runs.These changes are expected to contribute to improved expert load balancing, benefiting users who require scalable and efficient MoE models.