Skip to content

Commit 2053351

Browse files
authored
deepseek overflow fix (#349)
1 parent a264693 commit 2053351

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

vllm/model_executor/models/deepseek_v2.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
147147
shared_output = self.shared_experts(hidden_states)
148148
# router_logits: (num_tokens, n_experts)
149149
router_logits, _ = self.gate(hidden_states)
150-
final_hidden_states = self.experts(
151-
hidden_states=hidden_states,
152-
router_logits=router_logits) * self.routed_scaling_factor
150+
final_hidden_states = self.experts(hidden_states=hidden_states,
151+
router_logits=router_logits)
153152
if shared_output is not None:
154-
final_hidden_states = final_hidden_states + shared_output
153+
final_hidden_states = final_hidden_states + shared_output \
154+
* (1. / self.routed_scaling_factor)
155155
if self.tp_size > 1:
156156
final_hidden_states = tensor_model_parallel_all_reduce(
157157
final_hidden_states)
@@ -375,6 +375,7 @@ def __init__(
375375
eps=config.rms_norm_eps)
376376
self.post_attention_layernorm = RMSNorm(config.hidden_size,
377377
eps=config.rms_norm_eps)
378+
self.routed_scaling_factor = config.routed_scaling_factor
378379

379380
def forward(
380381
self,
@@ -399,9 +400,14 @@ def forward(
399400
)
400401

401402
# Fully Connected
403+
if isinstance(self.mlp, DeepseekV2MoE):
404+
hidden_states *= 1. / self.mlp.routed_scaling_factor
402405
hidden_states, residual = self.post_attention_layernorm(
403406
hidden_states, residual)
404407
hidden_states = self.mlp(hidden_states)
408+
if isinstance(self.mlp, DeepseekV2MLP):
409+
hidden_states *= 1. / self.routed_scaling_factor
410+
residual *= 1. / self.routed_scaling_factor
405411
return hidden_states, residual
406412

407413

0 commit comments

Comments
 (0)