Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 119 additions & 37 deletions cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import numpy

from cuda.core.experimental._memory import Buffer
from cuda.core.experimental._utils.cuda_utils import driver
from cuda.bindings cimport cydriver


ctypedef cpp_complex.complex[float] cpp_single_complex
Expand Down Expand Up @@ -128,67 +129,123 @@ cdef inline int prepare_ctypes_arg(
vector.vector[void*]& data_addresses,
arg,
const size_t idx) except -1:
if isinstance(arg, ctypes_bool):
cdef object arg_type = type(arg)
if arg_type is ctypes_bool:
return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int8):
elif arg_type is ctypes_int8:
return prepare_arg[int8_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int16):
elif arg_type is ctypes_int16:
return prepare_arg[int16_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int32):
elif arg_type is ctypes_int32:
return prepare_arg[int32_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int64):
elif arg_type is ctypes_int64:
return prepare_arg[int64_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint8):
elif arg_type is ctypes_uint8:
return prepare_arg[uint8_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint16):
elif arg_type is ctypes_uint16:
return prepare_arg[uint16_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint32):
elif arg_type is ctypes_uint32:
return prepare_arg[uint32_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint64):
elif arg_type is ctypes_uint64:
return prepare_arg[uint64_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_float):
elif arg_type is ctypes_float:
return prepare_arg[float](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_double):
elif arg_type is ctypes_double:
return prepare_arg[double](data, data_addresses, arg.value, idx)
else:
return 1
# If no exact types are found, fallback to slower `isinstance` check
if isinstance(arg, ctypes_bool):
return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int8):
return prepare_arg[int8_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int16):
return prepare_arg[int16_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int32):
return prepare_arg[int32_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int64):
return prepare_arg[int64_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint8):
return prepare_arg[uint8_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint16):
return prepare_arg[uint16_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint32):
return prepare_arg[uint32_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint64):
return prepare_arg[uint64_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_float):
return prepare_arg[float](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_double):
return prepare_arg[double](data, data_addresses, arg.value, idx)
else:
return 1


cdef inline int prepare_numpy_arg(
vector.vector[void*]& data,
vector.vector[void*]& data_addresses,
arg,
const size_t idx) except -1:
if isinstance(arg, numpy_bool):
cdef object arg_type = type(arg)
if arg_type is numpy_bool:
return prepare_arg[cpp_bool](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int8):
elif arg_type is numpy_int8:
return prepare_arg[int8_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int16):
elif arg_type is numpy_int16:
return prepare_arg[int16_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int32):
elif arg_type is numpy_int32:
return prepare_arg[int32_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int64):
elif arg_type is numpy_int64:
return prepare_arg[int64_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint8):
elif arg_type is numpy_uint8:
return prepare_arg[uint8_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint16):
elif arg_type is numpy_uint16:
return prepare_arg[uint16_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint32):
elif arg_type is numpy_uint32:
return prepare_arg[uint32_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint64):
elif arg_type is numpy_uint64:
return prepare_arg[uint64_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_float16):
elif arg_type is numpy_float16:
return prepare_arg[__half_raw](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_float32):
elif arg_type is numpy_float32:
return prepare_arg[float](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_float64):
elif arg_type is numpy_float64:
return prepare_arg[double](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_complex64):
elif arg_type is numpy_complex64:
return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_complex128):
elif arg_type is numpy_complex128:
return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx)
else:
return 1
# If no exact types are found, fallback to slower `isinstance` check
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this block might add penalty to ctypes and graph conditional handle dispatch, since they are tried after NumPy. I dunno how much this costs, though.

if isinstance(arg, numpy_bool):
return prepare_arg[cpp_bool](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int8):
return prepare_arg[int8_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int16):
return prepare_arg[int16_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int32):
return prepare_arg[int32_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int64):
return prepare_arg[int64_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint8):
return prepare_arg[uint8_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint16):
return prepare_arg[uint16_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint32):
return prepare_arg[uint32_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint64):
return prepare_arg[uint64_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_float16):
return prepare_arg[__half_raw](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_float32):
return prepare_arg[float](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_float64):
return prepare_arg[double](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_complex64):
return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_complex128):
return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx)
else:
return 1


cdef class ParamHolder:
Expand All @@ -207,44 +264,69 @@ cdef class ParamHolder:
cdef size_t n_args = len(kernel_args)
cdef size_t i
cdef int not_prepared
cdef object arg_type
self.data = vector.vector[voidptr](n_args, nullptr)
self.data_addresses = vector.vector[voidptr](n_args)
for i, arg in enumerate(kernel_args):
if isinstance(arg, Buffer):
arg_type = type(arg)
if arg_type is Buffer:
# we need the address of where the actual buffer address is stored
if isinstance(arg.handle, int):
if type(arg.handle) is int:
# see note below on handling int arguments
prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i)
continue
else:
# it's a CUdeviceptr:
self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr())
continue
elif isinstance(arg, int):
elif arg_type is bool:
prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i)
continue
elif arg_type is int:
# Here's the dilemma: We want to have a fast path to pass in Python
# integers as pointer addresses, but one could also (mistakenly) pass
# it with the intention of passing a scalar integer. It's a mistake
# bacause a Python int is ambiguous (arbitrary width). Our judgement
# call here is to treat it as a pointer address, without any warning!
prepare_arg[intptr_t](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, float):
elif arg_type is float:
prepare_arg[double](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, complex):
elif arg_type is complex:
prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, bool):
prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i)
continue

not_prepared = prepare_numpy_arg(self.data, self.data_addresses, arg, i)
if not_prepared:
not_prepared = prepare_ctypes_arg(self.data, self.data_addresses, arg, i)
if not_prepared:
# TODO: revisit this treatment if we decide to cythonize cuda.core
if isinstance(arg, driver.CUgraphConditionalHandle):
prepare_arg[intptr_t](self.data, self.data_addresses, <intptr_t>int(arg), i)
if arg_type is driver.CUgraphConditionalHandle:
prepare_arg[cydriver.CUgraphConditionalHandle](self.data, self.data_addresses, <intptr_t>int(arg), i)
continue
# If no exact types are found, fallback to slower `isinstance` check
elif isinstance(arg, Buffer):
if isinstance(arg.handle, int):
prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i)
continue
else:
self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr())
continue
elif isinstance(arg, bool):
prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, int):
prepare_arg[intptr_t](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, float):
prepare_arg[double](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, complex):
prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, driver.CUgraphConditionalHandle):
prepare_arg[cydriver.CUgraphConditionalHandle](self.data, self.data_addresses, arg, i)
continue
# TODO: support ctypes/numpy struct
raise TypeError("the argument is of unsupported type: " + str(type(arg)))
Expand Down
39 changes: 39 additions & 0 deletions cuda_core/docs/source/release/0.5.x-notes.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
.. SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
.. SPDX-License-Identifier: Apache-2.0
.. currentmodule:: cuda.core.experimental

``cuda.core`` 0.5.x Release Notes
=================================


Highlights
----------

None.


Breaking Changes
----------------

- Python ``bool`` objects are now converted to C++ ``bool`` type when passed as kernel
arguments. Previously, they were converted to ``int``. This brings them inline
with ``ctypes.c_bool`` and ``numpy.bool_``.
Comment on lines +19 to +21
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: I think this no longer applies? Should we just add another entry to mention the perf improvement



New features
------------

None.


New examples
------------

None.


Fixes and enhancements
----------------------

None.
39 changes: 27 additions & 12 deletions cuda_core/tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

import ctypes

import numpy as np
import pytest

Expand Down Expand Up @@ -35,19 +37,28 @@ def _common_kernels():
return mod


def _common_kernels_conditional():
def _common_kernels_conditional(cond_type):
if cond_type in (bool, np.bool_, ctypes.c_bool):
cond_type_str = "bool"
elif cond_type is int:
cond_type_str = "unsigned int"
else:
raise ValueError("Unsupported cond_type")

code = """
extern "C" __device__ __cudart_builtin__ void CUDARTAPI cudaGraphSetConditional(cudaGraphConditionalHandle handle,
unsigned int value);
__global__ void empty_kernel() {}
__global__ void add_one(int *a) { *a += 1; }
__global__ void set_handle(cudaGraphConditionalHandle handle, int value) { cudaGraphSetConditional(handle, value); }
__global__ void set_handle(cudaGraphConditionalHandle handle, $cond_type_str value) {
cudaGraphSetConditional(handle, value);
}
__global__ void loop_kernel(cudaGraphConditionalHandle handle)
{
static int count = 10;
cudaGraphSetConditional(handle, --count ? 1 : 0);
}
"""
""".replace("$cond_type_str", cond_type_str)
arch = "".join(f"{i}" for i in Device().compute_capability)
program_options = ProgramOptions(std="c++17", arch=f"sm_{arch}")
prog = Program(code, code_type="c++", options=program_options)
Expand Down Expand Up @@ -216,10 +227,12 @@ def test_graph_capture_errors(init_cuda):
gb.end_building().complete()


@pytest.mark.parametrize("condition_value", [True, False])
@pytest.mark.parametrize(
"condition_value", [True, False, ctypes.c_bool(True), ctypes.c_bool(False), np.bool_(True), np.bool_(False), 1, 0]
)
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
def test_graph_conditional_if(init_cuda, condition_value):
mod = _common_kernels_conditional()
mod = _common_kernels_conditional(type(condition_value))
add_one = mod.get_kernel("add_one")
set_handle = mod.get_kernel("set_handle")

Expand Down Expand Up @@ -278,10 +291,12 @@ def test_graph_conditional_if(init_cuda, condition_value):
b.close()


@pytest.mark.parametrize("condition_value", [True, False])
@pytest.mark.parametrize(
"condition_value", [True, False, ctypes.c_bool(True), ctypes.c_bool(False), np.bool_(True), np.bool_(False), 1, 0]
)
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
def test_graph_conditional_if_else(init_cuda, condition_value):
mod = _common_kernels_conditional()
mod = _common_kernels_conditional(type(condition_value))
add_one = mod.get_kernel("add_one")
set_handle = mod.get_kernel("set_handle")

Expand Down Expand Up @@ -353,7 +368,7 @@ def test_graph_conditional_if_else(init_cuda, condition_value):
@pytest.mark.parametrize("condition_value", [0, 1, 2, 3])
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
def test_graph_conditional_switch(init_cuda, condition_value):
mod = _common_kernels_conditional()
mod = _common_kernels_conditional(type(condition_value))
add_one = mod.get_kernel("add_one")
set_handle = mod.get_kernel("set_handle")

Expand Down Expand Up @@ -441,10 +456,10 @@ def test_graph_conditional_switch(init_cuda, condition_value):
b.close()


@pytest.mark.parametrize("condition_value", [True, False])
@pytest.mark.parametrize("condition_value", [True, False, 1, 0])
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
def test_graph_conditional_while(init_cuda, condition_value):
mod = _common_kernels_conditional()
mod = _common_kernels_conditional(type(condition_value))
add_one = mod.get_kernel("add_one")
loop_kernel = mod.get_kernel("loop_kernel")
empty_kernel = mod.get_kernel("empty_kernel")
Expand Down Expand Up @@ -545,7 +560,7 @@ def test_graph_child_graph(init_cuda):

@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
def test_graph_update(init_cuda):
mod = _common_kernels_conditional()
mod = _common_kernels_conditional(int)
add_one = mod.get_kernel("add_one")

# Allocate memory
Expand Down Expand Up @@ -668,7 +683,7 @@ def test_graph_stream_lifetime(init_cuda):


def test_graph_dot_print_options(init_cuda, tmp_path):
mod = _common_kernels_conditional()
mod = _common_kernels_conditional(bool)
set_handle = mod.get_kernel("set_handle")
empty_kernel = mod.get_kernel("empty_kernel")

Expand Down
Loading