Skip to content

Commit 01bf54e

Browse files
perf: reduce overhead from getting cudaStream ptr
1 parent dfc4668 commit 01bf54e

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

bitsandbytes/functional.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,9 @@ def is_on_gpu(tensors: Iterable[torch.Tensor]):
441441
return on_gpu
442442

443443

444-
def get_tensor_stream(tensor: Tensor) -> torch.cuda.Stream:
445-
return torch.cuda.current_stream(tensor.device)
444+
def get_tensor_stream(tensor: Tensor) -> ct.c_void_p:
445+
# We use the raw stream for performance reasons.
446+
return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index))
446447

447448

448449
def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]:

0 commit comments

Comments
 (0)