@@ -147,11 +147,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
147
147
shared_output = self .shared_experts (hidden_states )
148
148
# router_logits: (num_tokens, n_experts)
149
149
router_logits , _ = self .gate (hidden_states )
150
- final_hidden_states = self .experts (
151
- hidden_states = hidden_states ,
152
- router_logits = router_logits ) * self .routed_scaling_factor
150
+ final_hidden_states = self .experts (hidden_states = hidden_states ,
151
+ router_logits = router_logits )
153
152
if shared_output is not None :
154
- final_hidden_states = final_hidden_states + shared_output
153
+ final_hidden_states = final_hidden_states + shared_output \
154
+ * (1. / self .routed_scaling_factor )
155
155
if self .tp_size > 1 :
156
156
final_hidden_states = tensor_model_parallel_all_reduce (
157
157
final_hidden_states )
@@ -375,6 +375,7 @@ def __init__(
375
375
eps = config .rms_norm_eps )
376
376
self .post_attention_layernorm = RMSNorm (config .hidden_size ,
377
377
eps = config .rms_norm_eps )
378
+ self .routed_scaling_factor = config .routed_scaling_factor
378
379
379
380
def forward (
380
381
self ,
@@ -399,9 +400,14 @@ def forward(
399
400
)
400
401
401
402
# Fully Connected
403
+ if isinstance (self .mlp , DeepseekV2MoE ):
404
+ hidden_states *= 1. / self .mlp .routed_scaling_factor
402
405
hidden_states , residual = self .post_attention_layernorm (
403
406
hidden_states , residual )
404
407
hidden_states = self .mlp (hidden_states )
408
+ if isinstance (self .mlp , DeepseekV2MLP ):
409
+ hidden_states *= 1. / self .routed_scaling_factor
410
+ residual *= 1. / self .routed_scaling_factor
405
411
return hidden_states , residual
406
412
407
413
0 commit comments