Skip to content

Commit

Permalink
refactor: Rename the internal variable for clearness
Browse files Browse the repository at this point in the history
  • Loading branch information
YeonwooSung committed Aug 4, 2024
1 parent 91bffac commit f56e88a
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions Transformers/llama/llama3_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,16 +251,16 @@ def forward(self, tokens: torch.Tensor, start_pos: int):

mask = None
if seq_len > 1:
mask = torch.full((seq_len, seq_len), float("-inf"), device=tokens.device)
mask_for_kvcache = torch.full((seq_len, seq_len), float("-inf"), device=tokens.device)

mask = torch.triu(mask, diagonal=1)
mask_for_kvcache = torch.triu(mask_for_kvcache, diagonal=1)

# When performing key-value caching, we compute the attention scores
# only for the new sequence. Thus, the matrix of scores is of size
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
mask = torch.hstack(
[torch.zeros((seq_len, start_pos), device=tokens.device), mask]
[torch.zeros((seq_len, start_pos), device=tokens.device), mask_for_kvcache]
).type_as(h)

for layer in self.layers:
Expand Down

0 comments on commit f56e88a

Please sign in to comment.