Skip to content
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
a0f25af
initial
brandon-b-miller Aug 15, 2025
5322eef
tests
brandon-b-miller Aug 16, 2025
251f4e9
refactor
brandon-b-miller Aug 16, 2025
505cd4d
small changes
brandon-b-miller Aug 18, 2025
b861723
__cuda_stream__
brandon-b-miller Aug 18, 2025
b53f9ca
Merge branch 'main' into cuda-core-streams
brandon-b-miller Aug 20, 2025
2181748
accomodate ctypes bindings
brandon-b-miller Aug 20, 2025
46863d3
clean
brandon-b-miller Aug 20, 2025
2082063
more pacifying ctypes bindings
brandon-b-miller Aug 20, 2025
ec5841c
fix
brandon-b-miller Aug 20, 2025
2e45f6d
Merge branch 'main' into cuda-core-streams
brandon-b-miller Aug 25, 2025
4fcf9d1
renaming
brandon-b-miller Aug 25, 2025
220c2e3
address reviews
brandon-b-miller Aug 25, 2025
f3b07c0
Update numba_cuda/numba/cuda/cudadrv/driver.py
brandon-b-miller Aug 25, 2025
387ba84
merge/resolve
brandon-b-miller Oct 7, 2025
20440ab
address some reviews
brandon-b-miller Oct 7, 2025
1a00d67
Merge branch 'main' into cuda-core-streams
brandon-b-miller Oct 13, 2025
f0ff9d5
fix ctypes tests
brandon-b-miller Oct 13, 2025
1b59b5c
addressing old comments
brandon-b-miller Oct 13, 2025
6f8ddb3
merge/resolve
brandon-b-miller Oct 14, 2025
9ab36e7
merge/resolve
brandon-b-miller Oct 15, 2025
d1ad577
small fix
brandon-b-miller Oct 15, 2025
b7b56eb
small fix
brandon-b-miller Oct 15, 2025
c3e10af
Merge branch 'main' into cuda-core-streams
brandon-b-miller Oct 24, 2025
9b301a8
Merge branch 'main' into cuda-core-streams
brandon-b-miller Oct 27, 2025
324a48a
USE_NV_BINDING
brandon-b-miller Oct 27, 2025
7df62ce
events
brandon-b-miller Oct 27, 2025
f859466
skip event tests on sim
brandon-b-miller Oct 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 59 additions & 35 deletions numba_cuda/numba/cuda/cudadrv/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@
ObjectCode,
)

from cuda.bindings.utils import get_cuda_native_handle
from cuda.core.experimental import (
Stream as ExperimentalStream,
)


# There is no definition of the default stream in the Nvidia bindings (nor
# is there at the C/C++ level), so we define it here so we don't need to
# use a magic number 0 in places where we want the default stream.
Expand Down Expand Up @@ -2064,6 +2070,11 @@ def __int__(self):
# The default stream's handle.value is 0, which gives `None`
return self.handle.value or drvapi.CU_STREAM_DEFAULT

def __cuda_stream__(self):
if not self.handle.value:
return (0, drvapi.CU_STREAM_DEFAULT)
return (0, self.handle.value)

def __repr__(self):
default_streams = {
drvapi.CU_STREAM_DEFAULT: "<Default CUDA stream on %s>",
Expand Down Expand Up @@ -3080,17 +3091,14 @@ def host_to_device(dst, src, size, stream=0):
it should not be changed until the operation which can be asynchronous
completes.
"""
varargs = []
fn = driver.cuMemcpyHtoD
args = (device_pointer(dst), host_pointer(src, readonly=True), size)

if stream:
assert isinstance(stream, Stream)
fn = driver.cuMemcpyHtoDAsync
handle = stream.handle.value
varargs.append(handle)
else:
fn = driver.cuMemcpyHtoD
args += (_stream_handle(stream),)

fn(device_pointer(dst), host_pointer(src, readonly=True), size, *varargs)
fn(*args)


def device_to_host(dst, src, size, stream=0):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned below (or above), stream semantics is changed which probably has a bigger impact to this method, because the copy is now asynchronous and to access src on host a stream synchronization is needed.

Expand All @@ -3099,61 +3107,52 @@ def device_to_host(dst, src, size, stream=0):
it should not be changed until the operation which can be asynchronous
completes.
"""
varargs = []
fn = driver.cuMemcpyDtoH
args = (host_pointer(dst), device_pointer(src), size)

if stream:
assert isinstance(stream, Stream)
fn = driver.cuMemcpyDtoHAsync
handle = stream.handle.value
varargs.append(handle)
else:
fn = driver.cuMemcpyDtoH
args += (_stream_handle(stream),)

fn(host_pointer(dst), device_pointer(src), size, *varargs)
fn(*args)


def device_to_device(dst, src, size, stream=0):
"""
NOTE: The underlying data pointer from the host data buffer is used and
NOTE: The underlying data pointer from the device buffer is used and
it should not be changed until the operation which can be asynchronous
completes.
"""
varargs = []
fn = driver.cuMemcpyDtoD
args = (device_pointer(dst), device_pointer(src), size)

if stream:
assert isinstance(stream, Stream)
fn = driver.cuMemcpyDtoDAsync
handle = stream.handle.value
varargs.append(handle)
else:
fn = driver.cuMemcpyDtoD
args += (_stream_handle(stream),)

fn(device_pointer(dst), device_pointer(src), size, *varargs)
fn(*args)


def device_memset(dst, val, size, stream=0):
"""Memset on the device.
If stream is not zero, asynchronous mode is used.
"""
Memset on the device.
If stream is 0, the call is synchronous.
If stream is a Stream object, asynchronous mode is used.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a bug (or change or behavior) here and elsewhere. stream can be a Stream object from either numba-cuda or cuda.core, but still holds 0 (the default stream) under the hood. However, the call now becomes asynchronous (with respect to the host) instead of synchronous. Just wanted to call it out in case it was not the intention.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a really good catch. As a follow up to this, is the output here as expected, where dev is a cuda.core.experimental.Device for whom set_current() has been called? Should it not be (0, 0)?

>>> dev.default_stream.__cuda_stream__()
(0, 1)

I ask hoping there's a reliable way of detecting this situation based on the passed object.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After a while searching around the codebase I concluded this was at least the original intention, though these are really only used for the deprecated device array API:

        If a CUDA ``stream`` is given, then the transfer will be made
        asynchronously as part as the given stream.  Otherwise, the transfer is
        synchronous: the function returns after the copy is finished.

So AFAICT this PR maintains the above behavior just with a new stream object. Ultimately though I'm not sure we should spend too much time thinking about it as these will be removed and users performing these types of memory transfers should use either cupy for a nice array API or cuda.bindings for full control of things like synchronization behavior.


dst: device memory
val: byte value to be written
size: number of byte to be written
stream: a CUDA stream
size: number of bytes to be written
stream: 0 (synchronous) or a CUDA stream
"""
ptr = device_pointer(dst)

varargs = []
fn = driver.cuMemsetD8
args = (device_pointer(dst), val, size)

if stream:
assert isinstance(stream, Stream)
fn = driver.cuMemsetD8Async
handle = stream.handle.value
varargs.append(handle)
else:
fn = driver.cuMemsetD8
args += (_stream_handle(stream),)

try:
fn(ptr, val, size, *varargs)
fn(*args)
except CudaAPIError as e:
invalid = binding.CUresult.CUDA_ERROR_INVALID_VALUE
if (
Expand Down Expand Up @@ -3226,3 +3225,28 @@ def inspect_obj_content(objpath: str):
code_types.add(match.group(1))

return code_types


def _stream_handle(stream):
"""
Obtain the appropriate handle for various types of
acceptable stream objects. Acceptable types are
int (0 for default stream), Stream, ExperimentalStream
Comment on lines +3233 to +3234
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the docstring outdated? int is currently not allowed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only for the special value 0 I believe.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we consider deprecating allowing passing 0 as a Stream? The "default stream" is ambiguous in Python since PTDS is normally a host compile-time concept. We have an environment variable for controlling it in cuda.bindings / cuda.core: CUDA_PYTHON_CUDA_PER_THREAD_DEFAULT_STREAM which I think should be generally used.

It would be great if we could introduce a deprecation warning in some form to passing 0 as a Stream in user facing APIs.

Copy link
Contributor Author

@brandon-b-miller brandon-b-miller Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the user perspective we're deprecating the apis fully in #546, so those should be gone entirely. But we should do a sweep and make sure we're being explicit with all our usages of streams internally.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outside of the DeviceNDArray class, I think streams are accepted when launching kernels and using the Event APIs as well where we should properly handle there as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

launching is tested as part of this PR, events added in 7df62ce though.

"""

if stream == 0:
return stream
allowed = (Stream, ExperimentalStream)
if not isinstance(stream, allowed):
raise TypeError(
"Expected a Stream object or 0, got %s" % type(stream).__name__
)
elif hasattr(stream, "__cuda_stream__"):
ver, ptr = stream.__cuda_stream__()
assert ver == 0
if isinstance(ptr, binding.CUstream):
return get_cuda_native_handle(ptr)
else:
return ptr
else:
raise TypeError("Invalid Stream")
2 changes: 1 addition & 1 deletion numba_cuda/numba/cuda/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def launch(self, args, griddim, blockdim, stream=0, sharedmem=0):
for t, v in zip(self.argument_types, args):
self._prepare_args(t, v, stream, retr, kernelargs)

stream_handle = stream and stream.handle.value or 0
stream_handle = driver._stream_handle(stream)

# Invoke kernel
driver.launch_kernel(
Expand Down
63 changes: 63 additions & 0 deletions numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
driver,
launch_kernel,
)

from numba import cuda
from numba.cuda.cudadrv import devices, driver as _driver
from numba.cuda.testing import unittest, CUDATestCase
from numba.cuda.testing import skip_on_cudasim
import contextlib

from cuda.core.experimental import Device

ptx1 = """
.version 1.4
Expand Down Expand Up @@ -152,6 +156,65 @@ def test_cuda_driver_stream_operations(self):
for i, v in enumerate(array):
self.assertEqual(i, v)

def test_cuda_core_stream_operations(self):
module = self.context.create_module_ptx(self.ptx)
function = module.get_function("_Z10helloworldPi")
array = (c_int * 100)()
dev = Device()
dev.set_current()
stream = dev.create_stream()

@contextlib.contextmanager
def auto_synchronize(stream):
try:
yield stream
finally:
stream.sync()

with auto_synchronize(stream):
memory = self.context.memalloc(sizeof(array))
host_to_device(memory, array, sizeof(array), stream=stream)

ptr = memory.device_ctypes_pointer

launch_kernel(
function.handle, # Kernel
1,
1,
1, # gx, gy, gz
100,
1,
1, # bx, by, bz
0, # dynamic shared mem
stream.handle, # stream
[ptr],
)

device_to_host(array, memory, sizeof(array), stream=stream)
for i, v in enumerate(array):
self.assertEqual(i, v)

def test_cuda_core_stream_launch_user_facing(self):
@cuda.jit
def kernel(a):
idx = cuda.grid(1)
if idx < len(a):
a[idx] = idx

dev = Device()
dev.set_current()
stream = dev.create_stream()

ary = cuda.to_device([0] * 100, stream=stream)
stream.sync()

kernel[1, 100, stream](ary)
stream.sync()

result = ary.copy_to_host(stream=stream)
for i, v in enumerate(result):
self.assertEqual(i, v)

def test_cuda_driver_default_stream(self):
# Test properties of the default stream
ds = self.context.get_default_stream()
Expand Down