From 01bf54eaa41b1b5c3d433321fd7a697411a71e3f Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Wed, 23 Oct 2024 22:09:45 -0400 Subject: [PATCH] perf: reduce overhead from getting cudaStream ptr --- bitsandbytes/functional.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 681e06f08..1907cd0f0 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -441,8 +441,9 @@ def is_on_gpu(tensors: Iterable[torch.Tensor]): return on_gpu -def get_tensor_stream(tensor: Tensor) -> torch.cuda.Stream: - return torch.cuda.current_stream(tensor.device) +def get_tensor_stream(tensor: Tensor) -> ct.c_void_p: + # We use the raw stream for performance reasons. + return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index)) def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]: