diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 4b588829118..276aa977003 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -1366,6 +1366,7 @@ def _select_generated_logits( req_num_generation_steps: torch.Tensor, num_context_logits_prefix_sum: list[int], generation_requests_total_steps: int, + num_logits_to_keep: int, ) -> torch.Tensor: # raw_logits should contain only the generated logits. # If return context logits is requested, select only the generated logits. @@ -1394,9 +1395,10 @@ def _select_generated_logits( req_num_steps_fictitious_cuda = req_num_generation_steps_cuda[ : (len(scheduled_requests.context_requests) + 1) ].clone() - req_num_steps_fictitious_cuda[-1] = generation_requests_total_steps - next_context_req_offsets_cuda[-1] = ( - next_context_req_offsets_cuda[-2] + req_num_steps_fictitious_cuda[-1] + req_num_steps_fictitious_cuda[-1].fill_(generation_requests_total_steps) + next_context_req_offsets_cuda[-1].copy_( + next_context_req_offsets_cuda[-2] + req_num_steps_fictitious_cuda[-1], + non_blocking=True, ) else: req_num_steps_fictitious_cuda = req_num_generation_steps_cuda[ @@ -1412,6 +1414,7 @@ def _select_generated_logits( indices_to_keep_cuda = torch_multi_arange( starts=(next_context_req_offsets_cuda - req_num_steps_fictitious_cuda), ends=next_context_req_offsets_cuda, + output_length=num_logits_to_keep, ) raw_logits_cuda = raw_logits_cuda[indices_to_keep_cuda] @@ -1455,6 +1458,7 @@ def _process_requests( if scheduled_requests.generation_requests else 0 ), + num_logits_to_keep=sum_steps, ) # Handle embedding bias diff --git a/tensorrt_llm/_torch/pyexecutor/sampling_utils.py b/tensorrt_llm/_torch/pyexecutor/sampling_utils.py index df893c90238..0ea3494aa1c 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampling_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/sampling_utils.py @@ -343,11 +343,19 @@ def sample_grouped_strategies( ) +class _AcceptSyncCompute: + pass + + +ACCEPT_SYNC_COMPUTE = _AcceptSyncCompute() + + # Inspired by https://github.com/pytorch/pytorch/issues/80577; note also the # suggestion to consider torch.nested. def torch_multi_arange( ends: torch.Tensor, *, + output_length: int | _AcceptSyncCompute, starts: Optional[torch.Tensor] = None, steps: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -355,13 +363,23 @@ def torch_multi_arange( Starts, ends, steps need to share dtype and shape. Invalid ranges like range(1, 2, -1) are silently discarded. 'steps' defaults to 1 and 'starts' defaults to 0. + + Provide 'output_length' to avoid synchronization when using device tensors or pass + `ACCEPT_SYNC_COMPUTE` to explicitly accept the possibility of a device sync (for device tensors) + or when tensors are known to reside on the host. """ if steps is not None: assert ends.dtype == steps.dtype assert ends.shape == steps.shape + assert ends.device == steps.device if starts is not None: assert ends.dtype == starts.dtype assert ends.shape == starts.shape + assert ends.device == starts.device + output_length_arg = None if isinstance(output_length, _AcceptSyncCompute) else output_length + + if ends.numel() == 0: + return ends.clone() # This algorithm combines torch.repeat_interleaved() and torch.cumsum() to # construct the result. @@ -378,29 +396,37 @@ def torch_multi_arange( repeats = repeats.clone() repeats -= starts if steps is not None: - repeats = (repeats + steps - 1).div(steps, rounding_mode="floor") - repeats = repeats.clip(0) # ignore invalid ranges + repeats *= steps.sign() + steps_abs = steps.abs() + repeats = (repeats + steps_abs - 1).div(steps_abs, rounding_mode="floor") + repeats = repeats.clip(min=0) # ignore invalid ranges range_ends = repeats - 1 # last element in each range if steps is not None: range_ends *= steps if starts is not None: range_ends += starts prev_range_ends = range_ends.roll(1) # last element in preceding range (or 0) - prev_range_ends[0] = 0 - ones = ( - torch.tensor(1, dtype=ends.dtype, pin_memory=True) - .to(device=ends.device, non_blocking=True) - .broadcast_to(ends.shape) - ) + prev_range_ends[0].fill_(0) + ones = torch.ones((), dtype=ends.dtype, device=ends.device) + zeros = torch.zeros((), dtype=ends.dtype, device=ends.device) if steps is None: - steps = ones + steps = ones.broadcast_to(ends.shape) jumps = -prev_range_ends # delta from one range to the next if starts is not None: jumps += starts + # NB: Apply correction for empty ranges + jumps_corrections = torch.where(repeats == 0, jumps, zeros).cumsum(0, dtype=ends.dtype) + jumps += jumps_corrections seq = torch.cat((jumps.unsqueeze(-1), steps.unsqueeze(-1)), dim=1).view(-1) # # 2. Construct output via torch.repeat_interleave() and torch.cumsum() - seq_repeats = torch.cat((ones.unsqueeze(-1), (repeats - 1).unsqueeze(-1)), dim=1).view(-1) - seq = seq.repeat_interleave(seq_repeats) - seq = seq.cumsum(0) + # NB: For a resulting empty range, repeats - 1 == -1. In this case, we + # should set repeats for delta and increment both to 0 instead. + jump_repeats = torch.where(repeats == 0, zeros, ones) + step_repeats = torch.where(repeats == 0, zeros, repeats - 1) + seq_repeats = torch.cat((jump_repeats.unsqueeze(-1), step_repeats.unsqueeze(-1)), dim=1).view( + -1 + ) + seq = seq.repeat_interleave(seq_repeats, output_size=output_length_arg) + seq = seq.cumsum(0, dtype=ends.dtype) return seq diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 1803124f18b..fc53c420f23 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -15,6 +15,8 @@ l0_a10: tests: # ------------- PyTorch tests --------------- - unittest/_torch/sampler/test_torch_sampler.py + - unittest/_torch/sampler/test_torch_multi_arange.py + - unittest/utils/test_util.py - unittest/_torch/modeling/test_modeling_mistral.py - unittest/_torch/modeling/test_modeling_pixtral.py # NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no diff --git a/tests/unittest/_torch/sampler/test_torch_multi_arange.py b/tests/unittest/_torch/sampler/test_torch_multi_arange.py new file mode 100644 index 00000000000..a05e059b6b5 --- /dev/null +++ b/tests/unittest/_torch/sampler/test_torch_multi_arange.py @@ -0,0 +1,143 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext +from itertools import product +from typing import Iterable, Optional + +import numpy as np +import pytest +import torch +from utils.util import assert_no_cuda_sync, force_ampere + +from tensorrt_llm._torch.pyexecutor.sampling_utils import (ACCEPT_SYNC_COMPUTE, + torch_multi_arange) + +BASE_CASES = [ + (None, [], None, []), + ([], [], None, []), + (None, [], [], []), + ([], [], [], []), + (None, [1], None, [0]), + (None, [-1], None, []), + (None, [3], None, [0, 1, 2]), + (None, [-3], None, []), + ([-5], [-3], None, [-5, -4]), + ([-5], [-2], [2], [-5, -3]), + ([-5], [-1], [2], [-5, -3]), + ([-5], [-3], [3], [-5]), + ([-3], [-5], None, []), + ([-3], [-5], [-1], [-3, -4]), + ([-3], [-5], [-3], [-3]), + ([-3], [-5], [1], []), + ([-5], [-3], [-2], []), + ([-3], [2], None, [-3, -2, -1, 0, 1]), + ([-3], [2], [2], [-3, -1, 1]), + ([-3], [3], [2], [-3, -1, 1]), + ([2], [5], None, [2, 3, 4]), + ([2], [5], [2], [2, 4]), + ([2], [6], [2], [2, 4]), +] + + +def _build_multi_arange_case() -> tuple[Iterable, Iterable, Iterable, Iterable]: + gen = np.random.default_rng(seed=42) + cases = [ + BASE_CASES[i] for i in gen.choice(len(BASE_CASES), 128) + if len(BASE_CASES[i][3]) > 0 + ] + starts = [ + val for case in cases + for val in (case[0] if case[0] is not None else [0] * len(case[1])) + ] + ends = [val for case in cases for val in case[1]] + steps = [ + val for case in cases + for val in (case[2] if case[2] is not None else [1] * len(case[1])) + ] + expected = [val for case in cases for val in case[3]] + return starts, ends, steps, expected + + +@force_ampere +@pytest.mark.parametrize( + "device, allow_sync, dtype, starts, ends, steps, expected", + [ + pytest.param(device, allow_sync, dtype, starts, ends, steps, expected) + for (dtype, + (starts, ends, steps, expected), device, allow_sync) in product( + [ + torch.int32, + torch.int64, + ], + BASE_CASES + [_build_multi_arange_case()], + [ + "cpu", + "cuda", + ], + [False, True], + ) if device == "cuda" or allow_sync + ], +) +def test_torch_multi_arange( + device: str, + allow_sync: bool, + dtype: torch.dtype, + starts: Optional[Iterable], + ends: Iterable, + steps: Optional[Iterable], + expected: Iterable, +): + torch_device = torch.device(device) + + def _make_tensor(data: Iterable) -> torch.Tensor: + return torch.tensor(data, device=torch_device, dtype=dtype) + + def _maybe_make_tensor(data: Optional[Iterable]) -> Optional[torch.Tensor]: + if data is None: + return None + return _make_tensor(data) + + starts_tensor = _maybe_make_tensor(starts) + ends_tensor = _make_tensor(ends) + steps_tensor = _maybe_make_tensor(steps) + expected_tensor = _make_tensor(expected) + + extra_args = {} + extra_args["output_length"] = ACCEPT_SYNC_COMPUTE + if device != "cpu": + # Pre-allocates a large chunk of memory, because PyTorch caching memory allocator + # can sync otherwise. + buf = torch.ones((2**30, ), device=device) + del buf + if not allow_sync: + extra_args["output_length"] = expected_tensor.numel() + # Warmup to avoid syncs due to lazy loading of kernels + _ = torch_multi_arange( + ends_tensor, + starts=starts_tensor, + steps=steps_tensor, + **extra_args, + ) + + with torch.cuda.Stream(): + with assert_no_cuda_sync() if not allow_sync else nullcontext(): + result = torch_multi_arange( + ends_tensor, + starts=starts_tensor, + steps=steps_tensor, + **extra_args, + ) + + torch.testing.assert_close(result, expected_tensor) diff --git a/tests/unittest/_torch/sampler/test_torch_sampler.py b/tests/unittest/_torch/sampler/test_torch_sampler.py index bc720685e9c..111f465455f 100644 --- a/tests/unittest/_torch/sampler/test_torch_sampler.py +++ b/tests/unittest/_torch/sampler/test_torch_sampler.py @@ -12,15 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager +from dataclasses import dataclass from itertools import product -from typing import Optional, cast +from typing import Callable, Generator, Optional, cast import pytest -from utils.util import force_ampere +import torch +from utils.util import assert_no_cuda_sync, force_ampere from tensorrt_llm._torch.pyexecutor.sampler import ( GREEDY, LlmRequest, + ScheduledRequests, TorchSampler, _request_strategy, ) @@ -315,3 +319,172 @@ def test_should_provide_draft_probs_consistency( is_greedy = strat is GREEDY assert torch_sampler.should_provide_draft_probs(request) == (not is_greedy) + + +@force_ampere +@pytest.mark.parametrize( + "draft_len, with_ctx, with_gen", + [ + pytest.param(draft_len, with_ctx, with_gen) + for (draft_len, with_ctx, with_gen) in product( + [0, 3], + [False, True], + [False, True], + ) + if with_ctx or with_gen + ], +) +def test_select_generated_logits(draft_len: int, with_ctx: bool, with_gen: bool): + # Currently only checks that this works and does not sync + + device = torch.device("cuda") + + @contextmanager + def _test_runner() -> Generator[Callable[[], None], None, None]: + class ContextRequestMock: + def __init__(self, return_context_logits: bool): + self._return_context_logits = return_context_logits + + @property + def py_return_context_logits(self) -> bool: + return self._return_context_logits + + class GenRequestMock: + pass + + class ScheduledRequestsMock: + @property + def context_requests(self) -> list[LlmRequest]: + return ( + [ + # NB: One request with py_return_context_logits is enough + # to trigger tested code. + cast(LlmRequest, ContextRequestMock(True)), + cast(LlmRequest, ContextRequestMock(False)), + cast(LlmRequest, ContextRequestMock(True)), + ] + if with_ctx + else [] + ) + + @property + def generation_requests(self) -> list[LlmRequest]: + # NB: Currently this list is not inspected, UUT only checks that this + # is not empty. + return ( + [ + cast(LlmRequest, GenRequestMock()), + cast(LlmRequest, GenRequestMock()), + ] + if with_gen + else [] + ) + + vocab_size = 12 + + num_context_logits_prefix_sum = [ + 0, + *( + [ + 100 + 1, # context req. 1 (assume context len. 100) + (100 + 1) + (0 + 1), # context req. 2 (not returning context) + (100 + 1) + (0 + 1) + (50 + 1), # context req. 3 (assume context len. 50) + ] + if with_ctx + else [] + ), + ] + draft_len_req1 = draft_len + draft_len_req2 = draft_len + 1 # test with different draft lens + req_num_generation_steps = [ + *( + [ + 1, # context req. 1 + 1, # context req. 2 + 1, # context req. 3 + ] + if with_ctx + else [] + ), + *( + [ + draft_len_req1 + 1, # gen. req. 1 + draft_len_req2 + 1, # gen. req. 2 + ] + if with_gen + else [] + ), + ] + req_num_generation_steps_tensor = torch.tensor(req_num_generation_steps, dtype=torch.int32) + num_logits_to_keep = cast(int, req_num_generation_steps_tensor.sum().item()) + generation_requests_total_steps = (draft_len_req1 + 1) + ( + draft_len_req2 + 1 + ) # cf. req_num_generation_steps + + num_total_steps = num_context_logits_prefix_sum[-1] + generation_requests_total_steps + all_logits = torch.empty((num_total_steps, vocab_size)) + + for i in range(all_logits.size(0)): + all_logits[i, :] = torch.arange(i, i + vocab_size) + + all_logits_cuda = all_logits.to(device=device) + + expected_logit_indices = [] + if with_ctx: + expected_logit_indices += [ + 100, # gen logits from context req. 1 + 101, # gen logits from context req. 2 + 152, # gen logits from context req. 3 + ] + if with_gen: + gen_logit_offset = num_context_logits_prefix_sum[-1] + expected_logit_indices += [ + *range( + gen_logit_offset, gen_logit_offset + draft_len_req1 + 1 + ), # gen logits from gen. req. 1 + *range( + gen_logit_offset + draft_len_req1 + 1, + gen_logit_offset + generation_requests_total_steps, + ), # gen logits from gen. req. 2 + ] + + @dataclass + class UutResult: + selected_logits: torch.Tensor + + @dataclass + class UutResultWrapper: + result: Optional[UutResult] = None + + res = UutResultWrapper() + + def _uut(res=res): + selected_logits = TorchSampler._select_generated_logits( + cast(ScheduledRequests, ScheduledRequestsMock()), + all_logits_cuda, + req_num_generation_steps=req_num_generation_steps_tensor, + num_context_logits_prefix_sum=num_context_logits_prefix_sum, + generation_requests_total_steps=generation_requests_total_steps, + num_logits_to_keep=num_logits_to_keep, + ) + res.result = UutResult(selected_logits=selected_logits) + + yield _uut + + # Check logits + assert res.result is not None + selected_logits = res.result.selected_logits + torch.testing.assert_close(selected_logits.to("cpu"), all_logits[expected_logit_indices]) + + with _test_runner() as uut: + # Pre-allocates a large chunk of memory, because PyTorch caching memory allocator + # can sync otherwise. + buf = torch.ones((2**30,), device=device) + del buf + # Warmup to avoid syncs due to lazy loading of kernels + uut() + + with torch.cuda.Stream(): + with _test_runner() as uut: + with assert_no_cuda_sync(): + uut() diff --git a/tests/unittest/utils/test_util.py b/tests/unittest/utils/test_util.py new file mode 100644 index 00000000000..0793d6ec656 --- /dev/null +++ b/tests/unittest/utils/test_util.py @@ -0,0 +1,66 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext + +import pytest +import torch + +from .util import (DeviceSleepCtl, assert_no_cuda_sync, device_sleep, + force_ampere) + + +@force_ampere +@pytest.mark.parametrize( + "cancel", + [False, True], +) +def test_device_sleep(cancel: bool): + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + sleep_ctl = DeviceSleepCtl() + sleep_time = 0.3 + + start_event.record() + device_sleep(sleep_time, ctl=sleep_ctl, spin_s=0.01) + end_event.record() + + if cancel: + sleep_ctl.cancel() + end_event.synchronize() + # NB: torch.cuda.Event.elapsed_time returns millis + elapsed_time = start_event.elapsed_time(end_event) / 1000 + if cancel: + assert elapsed_time < sleep_time + else: + assert elapsed_time >= sleep_time + + +@force_ampere +@pytest.mark.parametrize( + "uut_syncs", + [False, True], +) +def test_assert_no_cuda_sync(uut_syncs: bool): + + def _uut(): + if uut_syncs: + torch.cuda.synchronize() + + ctx = pytest.raises(AssertionError, match="sync code should return quickly" + ) if uut_syncs else nullcontext() + with ctx: + with assert_no_cuda_sync(sync_timeout_s=0.2): + _uut() diff --git a/tests/unittest/utils/util.py b/tests/unittest/utils/util.py index babed7f1012..76408d66614 100644 --- a/tests/unittest/utils/util.py +++ b/tests/unittest/utils/util.py @@ -1,6 +1,23 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math import os +import time import unittest from contextlib import contextmanager +from dataclasses import dataclass from difflib import SequenceMatcher from pathlib import Path from typing import Any, Generator @@ -19,6 +36,7 @@ from parameterized import parameterized import tensorrt_llm +from tensorrt_llm._torch.hostfunc import hostfunc from tensorrt_llm._utils import (mpi_disabled, torch_dtype_to_trt, trt_dtype_to_torch) from tensorrt_llm.llmapi.utils import get_total_gpu_memory @@ -451,3 +469,49 @@ def check_accuracy(a, b, atol, rtol, percent): skip_ray = pytest.mark.skipif( mpi_disabled(), reason="This test is skipped for Ray orchestrator.") + + +@dataclass +class DeviceSleepCtl: + _cancellation_requested: bool = False + + @property + def cancellation_requested(self): + return self._cancellation_requested + + def cancel(self): + self._cancellation_requested = True + + +@hostfunc +def device_sleep(duration_s: float, + *, + ctl: DeviceSleepCtl, + spin_s: float = 0.1): + spin_iters = math.ceil(duration_s / spin_s) + for _ in range(spin_iters): + if ctl.cancellation_requested: + break + time.sleep(spin_s) + + +@contextmanager +def assert_no_cuda_sync( + sync_timeout_s: float = 5) -> Generator[None, None, None]: + """Check that the function does not stream synchronize.""" + + sleep_finished_event = torch.cuda.Event() + scope_finished_event = torch.cuda.Event() + + torch.cuda.synchronize() + sleep_ctl = DeviceSleepCtl() + device_sleep(sync_timeout_s, ctl=sleep_ctl) + sleep_finished_event.record() + yield None + scope_finished_event.record() + + assert not sleep_finished_event.query( + ), """sync code should return quickly""" + + sleep_ctl.cancel() + scope_finished_event.synchronize()