diff --git a/Transformers/llama/llama3_implementation.py b/Transformers/llama/llama3_implementation.py index b1a4bef..6dfd30b 100644 --- a/Transformers/llama/llama3_implementation.py +++ b/Transformers/llama/llama3_implementation.py @@ -251,16 +251,16 @@ def forward(self, tokens: torch.Tensor, start_pos: int): mask = None if seq_len > 1: - mask = torch.full((seq_len, seq_len), float("-inf"), device=tokens.device) + mask_for_kvcache = torch.full((seq_len, seq_len), float("-inf"), device=tokens.device) - mask = torch.triu(mask, diagonal=1) + mask_for_kvcache = torch.triu(mask_for_kvcache, diagonal=1) # When performing key-value caching, we compute the attention scores # only for the new sequence. Thus, the matrix of scores is of size # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for # j > cache_len + i, since row i corresponds to token cache_len + i. mask = torch.hstack( - [torch.zeros((seq_len, start_pos), device=tokens.device), mask] + [torch.zeros((seq_len, start_pos), device=tokens.device), mask_for_kvcache] ).type_as(h) for layer in self.layers: