Skip to content

Commit

Permalink
perf: reduce overhead from getting cudaStream ptr
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Oct 24, 2024
1 parent dfc4668 commit 01bf54e
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,9 @@ def is_on_gpu(tensors: Iterable[torch.Tensor]):
return on_gpu


def get_tensor_stream(tensor: Tensor) -> torch.cuda.Stream:
return torch.cuda.current_stream(tensor.device)
def get_tensor_stream(tensor: Tensor) -> ct.c_void_p:
# We use the raw stream for performance reasons.
return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index))


def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]:
Expand Down

0 comments on commit 01bf54e

Please sign in to comment.