Skip to content

Commit d8bfe04

Browse files
committed
feat: fully async _select_generated_logits with tests
Signed-off-by: ixlmar <[email protected]>
1 parent 3a5845e commit d8bfe04

File tree

7 files changed

+471
-17
lines changed

7 files changed

+471
-17
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,6 +1366,7 @@ def _select_generated_logits(
13661366
req_num_generation_steps: torch.Tensor,
13671367
num_context_logits_prefix_sum: list[int],
13681368
generation_requests_total_steps: int,
1369+
num_logits_to_keep: int,
13691370
) -> torch.Tensor:
13701371
# raw_logits should contain only the generated logits.
13711372
# If return context logits is requested, select only the generated logits.
@@ -1394,9 +1395,10 @@ def _select_generated_logits(
13941395
req_num_steps_fictitious_cuda = req_num_generation_steps_cuda[
13951396
: (len(scheduled_requests.context_requests) + 1)
13961397
].clone()
1397-
req_num_steps_fictitious_cuda[-1] = generation_requests_total_steps
1398-
next_context_req_offsets_cuda[-1] = (
1399-
next_context_req_offsets_cuda[-2] + req_num_steps_fictitious_cuda[-1]
1398+
req_num_steps_fictitious_cuda[-1].fill_(generation_requests_total_steps)
1399+
next_context_req_offsets_cuda[-1].copy_(
1400+
next_context_req_offsets_cuda[-2] + req_num_steps_fictitious_cuda[-1],
1401+
non_blocking=True,
14001402
)
14011403
else:
14021404
req_num_steps_fictitious_cuda = req_num_generation_steps_cuda[
@@ -1412,6 +1414,7 @@ def _select_generated_logits(
14121414
indices_to_keep_cuda = torch_multi_arange(
14131415
starts=(next_context_req_offsets_cuda - req_num_steps_fictitious_cuda),
14141416
ends=next_context_req_offsets_cuda,
1417+
output_length=num_logits_to_keep,
14151418
)
14161419

14171420
raw_logits_cuda = raw_logits_cuda[indices_to_keep_cuda]
@@ -1455,6 +1458,7 @@ def _process_requests(
14551458
if scheduled_requests.generation_requests
14561459
else 0
14571460
),
1461+
num_logits_to_keep=sum_steps,
14581462
)
14591463

14601464
# Handle embedding bias

tensorrt_llm/_torch/pyexecutor/sampling_utils.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -350,18 +350,28 @@ def torch_multi_arange(
350350
*,
351351
starts: Optional[torch.Tensor] = None,
352352
steps: Optional[torch.Tensor] = None,
353+
output_length: Optional[int] = None,
353354
) -> torch.Tensor:
354355
"""Efficiently compute torch.cat([torch.arange(b, e, d) for b, e, d in zip(starts, ends, steps)]).
355356
356357
Starts, ends, steps need to share dtype and shape. Invalid ranges like range(1, 2, -1) are
357358
silently discarded. 'steps' defaults to 1 and 'starts' defaults to 0.
359+
360+
Provide 'output_length' to avoid synchronization when using device tensors.
358361
"""
359362
if steps is not None:
360363
assert ends.dtype == steps.dtype
361364
assert ends.shape == steps.shape
365+
assert ends.device == steps.device
362366
if starts is not None:
363367
assert ends.dtype == starts.dtype
364368
assert ends.shape == starts.shape
369+
assert ends.device == starts.device
370+
if ends.device != torch.device("cpu") and output_length is None:
371+
raise ValueError("Device tensors require 'output_length'")
372+
373+
if ends.numel() == 0:
374+
return ends.clone()
365375

366376
# This algorithm combines torch.repeat_interleaved() and torch.cumsum() to
367377
# construct the result.
@@ -378,29 +388,37 @@ def torch_multi_arange(
378388
repeats = repeats.clone()
379389
repeats -= starts
380390
if steps is not None:
381-
repeats = (repeats + steps - 1).div(steps, rounding_mode="floor")
382-
repeats = repeats.clip(0) # ignore invalid ranges
391+
repeats *= steps.sign()
392+
steps_abs = steps.abs()
393+
repeats = (repeats + steps_abs - 1).div(steps_abs, rounding_mode="floor")
394+
repeats = repeats.clip(min=0) # ignore invalid ranges
383395
range_ends = repeats - 1 # last element in each range
384396
if steps is not None:
385397
range_ends *= steps
386398
if starts is not None:
387399
range_ends += starts
388400
prev_range_ends = range_ends.roll(1) # last element in preceding range (or 0)
389-
prev_range_ends[0] = 0
390-
ones = (
391-
torch.tensor(1, dtype=ends.dtype, pin_memory=True)
392-
.to(device=ends.device, non_blocking=True)
393-
.broadcast_to(ends.shape)
394-
)
401+
prev_range_ends[0].fill_(0)
402+
ones = torch.ones((), dtype=ends.dtype, device=ends.device)
403+
zeros = torch.zeros((), dtype=ends.dtype, device=ends.device)
395404
if steps is None:
396-
steps = ones
405+
steps = ones.broadcast_to(ends.shape)
397406
jumps = -prev_range_ends # delta from one range to the next
398407
if starts is not None:
399408
jumps += starts
409+
# NB: Apply correction for empty ranges
410+
jumps_corrections = torch.where(repeats == 0, jumps, zeros).cumsum(0, dtype=ends.dtype)
411+
jumps += jumps_corrections
400412
seq = torch.cat((jumps.unsqueeze(-1), steps.unsqueeze(-1)), dim=1).view(-1)
401413
#
402414
# 2. Construct output via torch.repeat_interleave() and torch.cumsum()
403-
seq_repeats = torch.cat((ones.unsqueeze(-1), (repeats - 1).unsqueeze(-1)), dim=1).view(-1)
404-
seq = seq.repeat_interleave(seq_repeats)
405-
seq = seq.cumsum(0)
415+
# NB: For a resulting empty range, repeats - 1 == -1. In this case, we
416+
# should set repeats for delta and increment both to 0 instead.
417+
jump_repeats = torch.where(repeats == 0, zeros, ones)
418+
step_repeats = torch.where(repeats == 0, zeros, repeats - 1)
419+
seq_repeats = torch.cat((jump_repeats.unsqueeze(-1), step_repeats.unsqueeze(-1)), dim=1).view(
420+
-1
421+
)
422+
seq = seq.repeat_interleave(seq_repeats, output_size=output_length)
423+
seq = seq.cumsum(0, dtype=ends.dtype)
406424
return seq

tests/integration/test_lists/test-db/l0_a10.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ l0_a10:
1515
tests:
1616
# ------------- PyTorch tests ---------------
1717
- unittest/_torch/sampler/test_torch_sampler.py
18+
- unittest/_torch/sampler/test_torch_multi_arange.py
19+
- unittest/utils/test_util.py
1820
- unittest/_torch/modeling/test_modeling_mistral.py
1921
- unittest/_torch/modeling/test_modeling_pixtral.py
2022
# NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from itertools import product
16+
from typing import Iterable, Optional
17+
18+
import numpy as np
19+
import pytest
20+
import torch
21+
from utils.util import assert_no_cuda_sync, force_ampere
22+
23+
from tensorrt_llm._torch.pyexecutor.sampling_utils import torch_multi_arange
24+
25+
BASE_CASES = [
26+
(None, [], None, []),
27+
([], [], None, []),
28+
(None, [], [], []),
29+
([], [], [], []),
30+
(None, [1], None, [0]),
31+
(None, [-1], None, []),
32+
(None, [3], None, [0, 1, 2]),
33+
(None, [-3], None, []),
34+
([-5], [-3], None, [-5, -4]),
35+
([-5], [-2], [2], [-5, -3]),
36+
([-5], [-1], [2], [-5, -3]),
37+
([-5], [-3], [3], [-5]),
38+
([-3], [-5], None, []),
39+
([-3], [-5], [-1], [-3, -4]),
40+
([-3], [-5], [-3], [-3]),
41+
([-3], [-5], [1], []),
42+
([-5], [-3], [-2], []),
43+
([-3], [2], None, [-3, -2, -1, 0, 1]),
44+
([-3], [2], [2], [-3, -1, 1]),
45+
([-3], [3], [2], [-3, -1, 1]),
46+
([2], [5], None, [2, 3, 4]),
47+
([2], [5], [2], [2, 4]),
48+
([2], [6], [2], [2, 4]),
49+
]
50+
51+
52+
def _build_multi_arange_case() -> tuple[Iterable, Iterable, Iterable, Iterable]:
53+
gen = np.random.default_rng(seed=42)
54+
cases = [
55+
BASE_CASES[i] for i in gen.choice(len(BASE_CASES), 128)
56+
if len(BASE_CASES[i][3]) > 0
57+
]
58+
starts = [
59+
val for case in cases
60+
for val in (case[0] if case[0] is not None else [0] * len(case[1]))
61+
]
62+
ends = [val for case in cases for val in case[1]]
63+
steps = [
64+
val for case in cases
65+
for val in (case[2] if case[2] is not None else [1] * len(case[1]))
66+
]
67+
expected = [val for case in cases for val in case[3]]
68+
return starts, ends, steps, expected
69+
70+
71+
@force_ampere
72+
@pytest.mark.parametrize(
73+
"device, dtype, starts, ends, steps, expected",
74+
[
75+
pytest.param(device, dtype, starts, ends, steps, expected)
76+
for (dtype, (starts, ends, steps, expected), device) in product(
77+
[
78+
torch.int32,
79+
torch.int64,
80+
],
81+
BASE_CASES + [_build_multi_arange_case()],
82+
[
83+
"cpu",
84+
"cuda",
85+
],
86+
)
87+
],
88+
)
89+
def test_torch_multi_arange(
90+
device: str,
91+
dtype: torch.dtype,
92+
starts: Optional[Iterable],
93+
ends: Iterable,
94+
steps: Optional[Iterable],
95+
expected: Iterable,
96+
):
97+
torch_device = torch.device(device)
98+
99+
def _make_tensor(data: Iterable) -> torch.Tensor:
100+
return torch.tensor(data, device=torch_device, dtype=dtype)
101+
102+
def _maybe_make_tensor(data: Optional[Iterable]) -> Optional[torch.Tensor]:
103+
if data is None:
104+
return None
105+
return _make_tensor(data)
106+
107+
starts_tensor = _maybe_make_tensor(starts)
108+
ends_tensor = _make_tensor(ends)
109+
steps_tensor = _maybe_make_tensor(steps)
110+
expected_tensor = _make_tensor(expected)
111+
112+
extra_args = {}
113+
if device != "cpu":
114+
# Pre-allocates a large chunk of memory, because PyTorch caching memory allocator
115+
# can sync otherwise.
116+
buf = torch.ones((2**30, ), device=device)
117+
del buf
118+
extra_args["output_length"] = expected_tensor.numel()
119+
# Warmup to avoid syncs due to lazy loading of kernels
120+
_ = torch_multi_arange(ends_tensor,
121+
starts=starts_tensor,
122+
steps=steps_tensor,
123+
**extra_args)
124+
125+
with torch.cuda.Stream():
126+
with assert_no_cuda_sync():
127+
result = torch_multi_arange(ends_tensor,
128+
starts=starts_tensor,
129+
steps=steps_tensor,
130+
**extra_args)
131+
132+
torch.testing.assert_close(result, expected_tensor)

0 commit comments

Comments
 (0)