diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 681e06f08..1907cd0f0 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -441,8 +441,9 @@ def is_on_gpu(tensors: Iterable[torch.Tensor]): return on_gpu -def get_tensor_stream(tensor: Tensor) -> torch.cuda.Stream: - return torch.cuda.current_stream(tensor.device) +def get_tensor_stream(tensor: Tensor) -> ct.c_void_p: + # We use the raw stream for performance reasons. + return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index)) def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]: