Skip to content
Open
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 117 additions & 36 deletions cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -128,67 +128,123 @@ cdef inline int prepare_ctypes_arg(
vector.vector[void*]& data_addresses,
arg,
const size_t idx) except -1:
if isinstance(arg, ctypes_bool):
cdef object arg_type = type(arg)
if arg_type is ctypes_bool:
return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int8):
elif arg_type is ctypes_int8:
return prepare_arg[int8_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int16):
elif arg_type is ctypes_int16:
return prepare_arg[int16_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int32):
elif arg_type is ctypes_int32:
return prepare_arg[int32_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int64):
elif arg_type is ctypes_int64:
return prepare_arg[int64_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint8):
elif arg_type is ctypes_uint8:
return prepare_arg[uint8_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint16):
elif arg_type is ctypes_uint16:
return prepare_arg[uint16_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint32):
elif arg_type is ctypes_uint32:
return prepare_arg[uint32_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint64):
elif arg_type is ctypes_uint64:
return prepare_arg[uint64_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_float):
elif arg_type is ctypes_float:
return prepare_arg[float](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_double):
elif arg_type is ctypes_double:
return prepare_arg[double](data, data_addresses, arg.value, idx)
else:
return 1
# If no exact types are found, fallback to slower `isinstance` check
if isinstance(arg, ctypes_bool):
return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int8):
return prepare_arg[int8_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int16):
return prepare_arg[int16_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int32):
return prepare_arg[int32_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int64):
return prepare_arg[int64_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint8):
return prepare_arg[uint8_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint16):
return prepare_arg[uint16_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint32):
return prepare_arg[uint32_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint64):
return prepare_arg[uint64_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_float):
return prepare_arg[float](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_double):
return prepare_arg[double](data, data_addresses, arg.value, idx)
else:
return 1


cdef inline int prepare_numpy_arg(
vector.vector[void*]& data,
vector.vector[void*]& data_addresses,
arg,
const size_t idx) except -1:
if isinstance(arg, numpy_bool):
cdef object arg_type = type(arg)
if arg_type is numpy_bool:
return prepare_arg[cpp_bool](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int8):
elif arg_type is numpy_int8:
return prepare_arg[int8_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int16):
elif arg_type is numpy_int16:
return prepare_arg[int16_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int32):
elif arg_type is numpy_int32:
return prepare_arg[int32_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int64):
elif arg_type is numpy_int64:
return prepare_arg[int64_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint8):
elif arg_type is numpy_uint8:
return prepare_arg[uint8_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint16):
elif arg_type is numpy_uint16:
return prepare_arg[uint16_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint32):
elif arg_type is numpy_uint32:
return prepare_arg[uint32_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint64):
elif arg_type is numpy_uint64:
return prepare_arg[uint64_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_float16):
elif arg_type is numpy_float16:
return prepare_arg[__half_raw](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_float32):
elif arg_type is numpy_float32:
return prepare_arg[float](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_float64):
elif arg_type is numpy_float64:
return prepare_arg[double](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_complex64):
elif arg_type is numpy_complex64:
return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_complex128):
elif arg_type is numpy_complex128:
return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx)
else:
return 1
# If no exact types are found, fallback to slower `isinstance` check
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this block might add penalty to ctypes and graph conditional handle dispatch, since they are tried after NumPy. I dunno how much this costs, though.

if isinstance(arg, numpy_bool):
return prepare_arg[cpp_bool](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int8):
return prepare_arg[int8_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int16):
return prepare_arg[int16_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int32):
return prepare_arg[int32_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int64):
return prepare_arg[int64_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint8):
return prepare_arg[uint8_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint16):
return prepare_arg[uint16_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint32):
return prepare_arg[uint32_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint64):
return prepare_arg[uint64_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_float16):
return prepare_arg[__half_raw](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_float32):
return prepare_arg[float](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_float64):
return prepare_arg[double](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_complex64):
return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_complex128):
return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx)
else:
return 1


cdef class ParamHolder:
Expand All @@ -207,43 +263,68 @@ cdef class ParamHolder:
cdef size_t n_args = len(kernel_args)
cdef size_t i
cdef int not_prepared
cdef object arg_type
self.data = vector.vector[voidptr](n_args, nullptr)
self.data_addresses = vector.vector[voidptr](n_args)
for i, arg in enumerate(kernel_args):
if isinstance(arg, Buffer):
arg_type = type(arg)
if arg_type is Buffer:
# we need the address of where the actual buffer address is stored
if isinstance(arg.handle, int):
if type(arg.handle) is int:
# see note below on handling int arguments
prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i)
continue
else:
# it's a CUdeviceptr:
self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr())
continue
elif isinstance(arg, int):
elif arg_type is bool:
prepare_arg[int32_t](self.data, self.data_addresses, int(arg), i)
continue
elif arg_type is int:
# Here's the dilemma: We want to have a fast path to pass in Python
# integers as pointer addresses, but one could also (mistakenly) pass
# it with the intention of passing a scalar integer. It's a mistake
# bacause a Python int is ambiguous (arbitrary width). Our judgement
# call here is to treat it as a pointer address, without any warning!
prepare_arg[intptr_t](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, float):
elif arg_type is float:
prepare_arg[double](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, complex):
elif arg_type is complex:
prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, bool):
prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i)
continue

not_prepared = prepare_numpy_arg(self.data, self.data_addresses, arg, i)
if not_prepared:
not_prepared = prepare_ctypes_arg(self.data, self.data_addresses, arg, i)
if not_prepared:
# TODO: revisit this treatment if we decide to cythonize cuda.core
if isinstance(arg, driver.CUgraphConditionalHandle):
if arg_type is driver.CUgraphConditionalHandle:
prepare_arg[intptr_t](self.data, self.data_addresses, <intptr_t>int(arg), i)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like in the header it's a typedef of unsigned long long, let's be explicit:

Suggested change
prepare_arg[intptr_t](self.data, self.data_addresses, <intptr_t>int(arg), i)
prepare_arg[unsigned long long](self.data, self.data_addresses, arg, i)

alternatively, I think this should work too

Suggested change
prepare_arg[intptr_t](self.data, self.data_addresses, <intptr_t>int(arg), i)
prepare_arg[cydriver.CUgraphConditionalHandle](self.data, self.data_addresses, arg, i)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first option doesn't work because the Cython parser can't handle unsigned long long in that place. It would have to be some sort of typedef to unsigned long long. But I like the second option a lot better anyway -- seems more explicit about what it needs to match.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first option doesn't work because the Cython parser can't handle unsigned long long in that place.

Ah yes. Need to add it to the list under the supported_type fused type. But I agree the 2nd option is better.

continue
# If no exact types are found, fallback to slower `isinstance` check
elif isinstance(arg, Buffer):
if isinstance(arg.handle, int):
prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i)
continue
else:
self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr())
continue
elif isinstance(arg, bool):
prepare_arg[int32_t](self.data, self.data_addresses, int(arg), i)
continue
elif isinstance(arg, int):
prepare_arg[intptr_t](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, float):
prepare_arg[double](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, complex):
prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, driver.CUgraphConditionalHandle):
prepare_arg[intptr_t](self.data, self.data_addresses, <intptr_t>int(arg), i)
continue
# TODO: support ctypes/numpy struct
Expand Down
Loading