Skip to content

Commit

Permalink
input_pos_maxp1 as torch tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrei-Aksionov committed Dec 31, 2024
1 parent 3702b03 commit 1b1b592
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
6 changes: 3 additions & 3 deletions litgpt/generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def next_token(
model: GPT,
input_pos: torch.Tensor,
x: torch.Tensor,
input_pos_maxp1: Optional[int] = None,
input_pos_maxp1: Optional[torch.Tensor] = None,
**sample_kwargs: Dict[str, Any],
) -> torch.Tensor:
logits = model(x, input_pos, input_pos_maxp1=input_pos_maxp1)
Expand Down Expand Up @@ -174,7 +174,7 @@ def generate_fn(
token = prompt
prefill_token = True
input_pos = torch.arange(0, prompt_size, device=device, dtype=torch.int64)
input_pos_maxp1 = prompt_size
input_pos_maxp1 = torch.tensor(prompt_size, device=device)
for current_idx in range(max_returned_tokens - prompt_size):

# Generate the token
Expand Down Expand Up @@ -222,7 +222,7 @@ def generate_fn(
input_pos = torch.tensor([prompt_size], device=device, dtype=torch.int64)
else:
input_pos.add_(1)
input_pos_maxp1 += 1
input_pos_maxp1.add_(1)

# Yield any remaining tokens
if yielded_idx < len(tokens):
Expand Down
6 changes: 3 additions & 3 deletions litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def forward(
self,
idx: torch.Tensor,
input_pos: Optional[torch.Tensor] = None,
input_pos_maxp1: Optional[int] = None,
input_pos_maxp1: Optional[torch.Tensor] = None,
lm_head_chunk_size: int = 0,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Expand Down Expand Up @@ -283,7 +283,7 @@ def forward(
sin: torch.Tensor,
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
input_pos_maxp1: Optional[int] = None,
input_pos_maxp1: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Non-parallel residual Parallel residual
Expand Down Expand Up @@ -351,7 +351,7 @@ def forward(
sin: torch.Tensor,
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
input_pos_maxp1: Optional[int] = None,
input_pos_maxp1: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Notation:
# - B | batch size
Expand Down
4 changes: 2 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ def test_against_original_salamandra(model_name, device, dtype):
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)


@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("SmolLM2-135M", "SmolLM2-360M", "SmolLM2-1.7B"))
Expand Down Expand Up @@ -1380,7 +1380,7 @@ def test_forward_with_without_input_pos_maxp1():
model.set_kv_cache(batch_size)
idx = torch.randint(0, config.padded_vocab_size, (1, 10))
input_pos = torch.arange(1, 11)
input_pos_maxp1 = 11
input_pos_maxp1 = torch.tensor(11)
logits_with_maxp1 = model(idx, input_pos, input_pos_maxp1=input_pos_maxp1)
logits_no_maxp1 = model(idx, input_pos)
torch.testing.assert_close(logits_with_maxp1, logits_no_maxp1)

0 comments on commit 1b1b592

Please sign in to comment.