Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,6 +1366,7 @@ def _select_generated_logits(
req_num_generation_steps: torch.Tensor,
num_context_logits_prefix_sum: list[int],
generation_requests_total_steps: int,
num_logits_to_keep: int,
) -> torch.Tensor:
# raw_logits should contain only the generated logits.
# If return context logits is requested, select only the generated logits.
Expand Down Expand Up @@ -1394,9 +1395,10 @@ def _select_generated_logits(
req_num_steps_fictitious_cuda = req_num_generation_steps_cuda[
: (len(scheduled_requests.context_requests) + 1)
].clone()
req_num_steps_fictitious_cuda[-1] = generation_requests_total_steps
next_context_req_offsets_cuda[-1] = (
next_context_req_offsets_cuda[-2] + req_num_steps_fictitious_cuda[-1]
req_num_steps_fictitious_cuda[-1].fill_(generation_requests_total_steps)
next_context_req_offsets_cuda[-1].copy_(
next_context_req_offsets_cuda[-2] + req_num_steps_fictitious_cuda[-1],
non_blocking=True,
)
else:
req_num_steps_fictitious_cuda = req_num_generation_steps_cuda[
Expand All @@ -1412,6 +1414,7 @@ def _select_generated_logits(
indices_to_keep_cuda = torch_multi_arange(
starts=(next_context_req_offsets_cuda - req_num_steps_fictitious_cuda),
ends=next_context_req_offsets_cuda,
output_length=num_logits_to_keep,
)

raw_logits_cuda = raw_logits_cuda[indices_to_keep_cuda]
Expand Down Expand Up @@ -1455,6 +1458,7 @@ def _process_requests(
if scheduled_requests.generation_requests
else 0
),
num_logits_to_keep=sum_steps,
)

# Handle embedding bias
Expand Down
50 changes: 38 additions & 12 deletions tensorrt_llm/_torch/pyexecutor/sampling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,25 +343,43 @@ def sample_grouped_strategies(
)


class _AcceptSyncCompute:
pass


ACCEPT_SYNC_COMPUTE = _AcceptSyncCompute()


# Inspired by https://github.com/pytorch/pytorch/issues/80577; note also the
# suggestion to consider torch.nested.
def torch_multi_arange(
ends: torch.Tensor,
*,
output_length: int | _AcceptSyncCompute,
starts: Optional[torch.Tensor] = None,
steps: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Efficiently compute torch.cat([torch.arange(b, e, d) for b, e, d in zip(starts, ends, steps)]).

Starts, ends, steps need to share dtype and shape. Invalid ranges like range(1, 2, -1) are
silently discarded. 'steps' defaults to 1 and 'starts' defaults to 0.

Provide 'output_length' to avoid synchronization when using device tensors or pass
`ACCEPT_SYNC_COMPUTE` to explicitly accept the possibility of a device sync (for device tensors)
or when tensors are known to reside on the host.
"""
if steps is not None:
assert ends.dtype == steps.dtype
assert ends.shape == steps.shape
assert ends.device == steps.device
if starts is not None:
assert ends.dtype == starts.dtype
assert ends.shape == starts.shape
assert ends.device == starts.device
output_length_arg = None if isinstance(output_length, _AcceptSyncCompute) else output_length

if ends.numel() == 0:
return ends.clone()

# This algorithm combines torch.repeat_interleaved() and torch.cumsum() to
# construct the result.
Expand All @@ -378,29 +396,37 @@ def torch_multi_arange(
repeats = repeats.clone()
repeats -= starts
if steps is not None:
repeats = (repeats + steps - 1).div(steps, rounding_mode="floor")
repeats = repeats.clip(0) # ignore invalid ranges
repeats *= steps.sign()
steps_abs = steps.abs()
repeats = (repeats + steps_abs - 1).div(steps_abs, rounding_mode="floor")
repeats = repeats.clip(min=0) # ignore invalid ranges
range_ends = repeats - 1 # last element in each range
if steps is not None:
range_ends *= steps
if starts is not None:
range_ends += starts
prev_range_ends = range_ends.roll(1) # last element in preceding range (or 0)
prev_range_ends[0] = 0
ones = (
torch.tensor(1, dtype=ends.dtype, pin_memory=True)
.to(device=ends.device, non_blocking=True)
.broadcast_to(ends.shape)
)
prev_range_ends[0].fill_(0)
ones = torch.ones((), dtype=ends.dtype, device=ends.device)
zeros = torch.zeros((), dtype=ends.dtype, device=ends.device)
if steps is None:
steps = ones
steps = ones.broadcast_to(ends.shape)
jumps = -prev_range_ends # delta from one range to the next
if starts is not None:
jumps += starts
# NB: Apply correction for empty ranges
jumps_corrections = torch.where(repeats == 0, jumps, zeros).cumsum(0, dtype=ends.dtype)
jumps += jumps_corrections
seq = torch.cat((jumps.unsqueeze(-1), steps.unsqueeze(-1)), dim=1).view(-1)
#
# 2. Construct output via torch.repeat_interleave() and torch.cumsum()
seq_repeats = torch.cat((ones.unsqueeze(-1), (repeats - 1).unsqueeze(-1)), dim=1).view(-1)
seq = seq.repeat_interleave(seq_repeats)
seq = seq.cumsum(0)
# NB: For a resulting empty range, repeats - 1 == -1. In this case, we
# should set repeats for delta and increment both to 0 instead.
jump_repeats = torch.where(repeats == 0, zeros, ones)
step_repeats = torch.where(repeats == 0, zeros, repeats - 1)
seq_repeats = torch.cat((jump_repeats.unsqueeze(-1), step_repeats.unsqueeze(-1)), dim=1).view(
-1
)
seq = seq.repeat_interleave(seq_repeats, output_size=output_length_arg)
seq = seq.cumsum(0, dtype=ends.dtype)
return seq
2 changes: 2 additions & 0 deletions tests/integration/test_lists/test-db/l0_a10.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ l0_a10:
tests:
# ------------- PyTorch tests ---------------
- unittest/_torch/sampler/test_torch_sampler.py
- unittest/_torch/sampler/test_torch_multi_arange.py
- unittest/utils/test_util.py
- unittest/_torch/modeling/test_modeling_mistral.py
- unittest/_torch/modeling/test_modeling_pixtral.py
# NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no
Expand Down
143 changes: 143 additions & 0 deletions tests/unittest/_torch/sampler/test_torch_multi_arange.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import nullcontext
from itertools import product
from typing import Iterable, Optional

import numpy as np
import pytest
import torch
from utils.util import assert_no_cuda_sync, force_ampere

from tensorrt_llm._torch.pyexecutor.sampling_utils import (ACCEPT_SYNC_COMPUTE,
torch_multi_arange)

BASE_CASES = [
(None, [], None, []),
([], [], None, []),
(None, [], [], []),
([], [], [], []),
(None, [1], None, [0]),
(None, [-1], None, []),
(None, [3], None, [0, 1, 2]),
(None, [-3], None, []),
([-5], [-3], None, [-5, -4]),
([-5], [-2], [2], [-5, -3]),
([-5], [-1], [2], [-5, -3]),
([-5], [-3], [3], [-5]),
([-3], [-5], None, []),
([-3], [-5], [-1], [-3, -4]),
([-3], [-5], [-3], [-3]),
([-3], [-5], [1], []),
([-5], [-3], [-2], []),
([-3], [2], None, [-3, -2, -1, 0, 1]),
([-3], [2], [2], [-3, -1, 1]),
([-3], [3], [2], [-3, -1, 1]),
([2], [5], None, [2, 3, 4]),
([2], [5], [2], [2, 4]),
([2], [6], [2], [2, 4]),
]


def _build_multi_arange_case() -> tuple[Iterable, Iterable, Iterable, Iterable]:
gen = np.random.default_rng(seed=42)
cases = [
BASE_CASES[i] for i in gen.choice(len(BASE_CASES), 128)
if len(BASE_CASES[i][3]) > 0
]
starts = [
val for case in cases
for val in (case[0] if case[0] is not None else [0] * len(case[1]))
]
ends = [val for case in cases for val in case[1]]
steps = [
val for case in cases
for val in (case[2] if case[2] is not None else [1] * len(case[1]))
]
expected = [val for case in cases for val in case[3]]
return starts, ends, steps, expected


@force_ampere
@pytest.mark.parametrize(
"device, allow_sync, dtype, starts, ends, steps, expected",
[
pytest.param(device, allow_sync, dtype, starts, ends, steps, expected)
for (dtype,
(starts, ends, steps, expected), device, allow_sync) in product(
[
torch.int32,
torch.int64,
],
BASE_CASES + [_build_multi_arange_case()],
[
"cpu",
"cuda",
],
[False, True],
) if device == "cuda" or allow_sync
],
)
def test_torch_multi_arange(
device: str,
allow_sync: bool,
dtype: torch.dtype,
starts: Optional[Iterable],
ends: Iterable,
steps: Optional[Iterable],
expected: Iterable,
):
torch_device = torch.device(device)

def _make_tensor(data: Iterable) -> torch.Tensor:
return torch.tensor(data, device=torch_device, dtype=dtype)

def _maybe_make_tensor(data: Optional[Iterable]) -> Optional[torch.Tensor]:
if data is None:
return None
return _make_tensor(data)

starts_tensor = _maybe_make_tensor(starts)
ends_tensor = _make_tensor(ends)
steps_tensor = _maybe_make_tensor(steps)
expected_tensor = _make_tensor(expected)

extra_args = {}
extra_args["output_length"] = ACCEPT_SYNC_COMPUTE
if device != "cpu":
# Pre-allocates a large chunk of memory, because PyTorch caching memory allocator
# can sync otherwise.
buf = torch.ones((2**30, ), device=device)
del buf
if not allow_sync:
extra_args["output_length"] = expected_tensor.numel()
# Warmup to avoid syncs due to lazy loading of kernels
_ = torch_multi_arange(
ends_tensor,
starts=starts_tensor,
steps=steps_tensor,
**extra_args,
)

with torch.cuda.Stream():
with assert_no_cuda_sync() if not allow_sync else nullcontext():
result = torch_multi_arange(
ends_tensor,
starts=starts_tensor,
steps=steps_tensor,
**extra_args,
)

torch.testing.assert_close(result, expected_tensor)
Loading