diff --git a/litgpt/model.py b/litgpt/model.py index db6aebe790..ab88128806 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -566,8 +566,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) x = x.view(-1, C) # (B*T, C) router = self.gate(x) # (B*T, n_expert) + router = F.softmax(router, dim=1, dtype=torch.float) probs, indices = torch.topk(router, self.config.n_expert_per_token) # (B*T, n_expert_per_token) - probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype) + probs /= probs.sum(dim=1, keepdim=True) + probs = probs.to(dtype=x.dtype) masks = indices.unsqueeze(-1) == torch.arange(self.config.n_expert, device=x.device) masks = masks.permute(2, 0, 1) # (n_expert, B*T, n_expert_per_token) y = torch.zeros_like(x) # (B*T, C)