Skip to content
Merged
8 changes: 7 additions & 1 deletion cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,13 @@ cdef class ParamHolder:
for i, arg in enumerate(kernel_args):
if isinstance(arg, Buffer):
# we need the address of where the actual buffer address is stored
self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr())
if isinstance(arg.handle, int):
Copy link
Contributor Author

@shwina shwina Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we stomach the cost of an isinstance check here?

  • One alternative is to use a try..except, where entering the try block is cheap, but entering the except block is expensive.

  • Another alternative, which will eliminate the need to make any changes to the kernel arg handling logic here:

    • introduce a new type HostPtr which wraps an integer representing a pointer, and exposes a getPtr() method to get it.
    • Expand the return type of Buffer.handle to DevicePtrT | HostPtr
    • Change LegacyPinnedMemoryResource to return a buffer whose handle is a HostPtr.

Copy link
Member

@leofang leofang Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think isinstance in Cython is cheap and what you have here is good. I don't want to introduce more types than needed, partly because we want MR providers to focus on the MR properties (is_host_accessible etc), which is nicer for programmatic checks. I actually think that Buffer.handle should be of Any type so as to not get in the way of the MR providers. From both CUDA and cccl-rt perspectives they should be all void*. We don't want to encode the memory space information as part of the type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think that Buffer.handle should be of Any type so as to not get in the way of the MR providers.

If we did type it as Any, how would _kernel_arg_handler know how to grab the pointer from underneath the Buffer?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well Python does not care about type annotations, right? 🙂

Copy link
Contributor Author

@shwina shwina Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern wasn't so much about the type annotation, but more that the kernel handler won't know what to do with a Buffer whose .handle is any arbitrary type.

Prior to this PR it could only handle the case when .handle is a CUdeviceptr, or something that has a .getPtr() method.

if isinstance(arg, Buffer):
# we need the address of where the actual buffer address is stored
self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr())

This PR adds the ability to handle int.

Technically, .handle is also allowed to be None:

DevicePointerT = Union[driver.CUdeviceptr, int, None]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, I see, you meant the mini dispatcher here needs to enumerate all possible types.

Let me think about it. What you have is good and a generic treatment can follow later.

Most likely with #564 we could rewrite the dispatcher that looks like this

if isinstance(arg, Buffer):
    prepare_arg[intptr_t](self.data, self.data_addresses, get_cuda_native_handle(arg.handle), i)

On the MR provider side, we just need them to implement a protocol

class IsHandleT(Protocol):
    def __int__(self) -> int: ...

if they are not using generic cuda.bindings or Python types. (FWIW we already have IsStreamT.) So maybe eventually Buffer.handle can be typed as

DevicePointerT = Optional[Union[IsHandleT, int]] 

Copy link
Contributor Author

@shwina shwina Jun 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like a reasonable approach and I agree it would simplify the handling here. A couple of comments:

  • Perhaps we should rename DevicePointerT to just PointerT? In the case of pinned memory for instance, it doesn't actually represent a device pointer AFAIU.
  • If we use the protocol as written, then Union[IsHandleT, int] is equivalent to just IsHandleT (int type implements __int__). The protocol would also allow types like float or bool.
    • I feel like this discussion has been had before, but it might be worth considering a protocol with a __cuda_handle__() method or something, rather than __int__()

# see note below on handling int arguments
prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i)
continue
else:
# it's a CUdeviceptr:
self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr())
continue
elif isinstance(arg, int):
# Here's the dilemma: We want to have a fast path to pass in Python
Expand Down
163 changes: 163 additions & 0 deletions cuda_core/examples/memory_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import cupy as cp
import numpy as np
from cuda.core.experimental import (
Device, LaunchConfig, Program, ProgramOptions, launch,
DeviceMemoryResource, LegacyPinnedMemoryResource, Buffer
)
from cuda.core.experimental._memory import MemoryResource
from cuda.core.experimental._utils.cuda_utils import handle_return
from cuda.bindings import driver

# Kernel for memory operations
code = """
extern "C"
__global__ void memory_ops(float* device_data,
float* pinned_data,
size_t N) {
const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < N) {
// Access device memory
device_data[tid] = device_data[tid] + 1.0f;

// Access pinned memory (zero-copy from GPU)
pinned_data[tid] = pinned_data[tid] * 3.0f;
}
}
"""

dev = Device()
dev.set_current()
stream = dev.create_stream()

# Compile kernel
arch = "".join(f"{i}" for i in dev.compute_capability)
program_options = ProgramOptions(std="c++17", arch=f"sm_{arch}")
prog = Program(code, code_type="c++", options=program_options)
mod = prog.compile("cubin")
kernel = mod.get_kernel("memory_ops")

# Create different memory resources
device_mr = DeviceMemoryResource(dev.device_id)
pinned_mr = LegacyPinnedMemoryResource()

# Allocate different types of memory
size = 1024
dtype = cp.float32
element_size = dtype().itemsize
total_size = size * element_size

# 1. Device Memory (GPU-only)
device_buffer = device_mr.allocate(total_size, stream=stream)
device_array = cp.ndarray(
size, dtype=dtype,
memptr=cp.cuda.MemoryPointer(
cp.cuda.UnownedMemory(int(device_buffer.handle), device_buffer.size, device_buffer), 0
)
)

# 2. Pinned Memory (CPU memory, GPU accessible)
pinned_buffer = pinned_mr.allocate(total_size, stream=stream)
pinned_array = cp.ndarray(
size, dtype=dtype,
memptr=cp.cuda.MemoryPointer(
cp.cuda.UnownedMemory(int(pinned_buffer.handle), pinned_buffer.size, pinned_buffer), 0
)
)

# Initialize data
rng = cp.random.default_rng()
device_array[:] = rng.random(size, dtype=dtype)
pinned_array[:] = rng.random(size, dtype=dtype)

# Store original values for verification
device_original = device_array.copy()
pinned_original = pinned_array.copy()

# Sync before kernel launch
dev.sync()

# Launch kernel
block = 256
grid = (size + block - 1) // block
config = LaunchConfig(grid=grid, block=block)

launch(stream, config, kernel,
device_buffer, pinned_buffer, cp.uint64(size))
stream.sync()

# Verify kernel operations
assert cp.allclose(device_array, device_original + 1.0), "Device memory operation failed"
assert cp.allclose(pinned_array, pinned_original * 3.0), "Pinned memory operation failed"

# Demonstrate buffer copying operations
print("Memory buffer properties:")
print(f"Device buffer - Device accessible: {device_buffer.is_device_accessible}")
print(f"Pinned buffer - Device accessible: {pinned_buffer.is_device_accessible}")

# Assert memory properties
assert device_buffer.is_device_accessible, "Device buffer should be device accessible"
assert not device_buffer.is_host_accessible, "Device buffer should not be host accessible"
assert pinned_buffer.is_device_accessible, "Pinned buffer should be device accessible"
assert pinned_buffer.is_host_accessible, "Pinned buffer should be host accessible"

# Copy data between different memory types
print("\nCopying data between memory types...")

# Copy from device to pinned memory
device_buffer.copy_to(pinned_buffer, stream=stream)
stream.sync()

# Verify the copy operation
assert cp.allclose(pinned_array, device_array), "Device to pinned copy failed"

# Create a new device buffer and copy from pinned
new_device_buffer = device_mr.allocate(total_size, stream=stream)
new_device_array = cp.ndarray(
size, dtype=dtype,
memptr=cp.cuda.MemoryPointer(
cp.cuda.UnownedMemory(int(new_device_buffer.handle), new_device_buffer.size, new_device_buffer), 0
)
)

pinned_buffer.copy_to(new_device_buffer, stream=stream)
stream.sync()

# Verify the copy operation
assert cp.allclose(new_device_array, pinned_array), "Pinned to device copy failed"

# Demonstrate DLPack integration
print("\nDLPack device information:")
print(f"Device buffer DLPack device: {device_buffer.__dlpack_device__()}")
print(f"Pinned buffer DLPack device: {pinned_buffer.__dlpack_device__()}")

# Assert DLPack device types
from cuda.core.experimental._memory import DLDeviceType

device_dlpack = device_buffer.__dlpack_device__()
pinned_dlpack = pinned_buffer.__dlpack_device__()

assert device_dlpack[0] == DLDeviceType.kDLCUDA, "Device buffer should have CUDA device type"
assert pinned_dlpack[0] == DLDeviceType.kDLCUDAHost, "Pinned buffer should have CUDA host device type"

# Test buffer size properties
assert device_buffer.size == total_size, f"Device buffer size mismatch: expected {total_size}, got {device_buffer.size}"
assert pinned_buffer.size == total_size, f"Pinned buffer size mismatch: expected {total_size}, got {pinned_buffer.size}"
assert new_device_buffer.size == total_size, f"New device buffer size mismatch: expected {total_size}, got {new_device_buffer.size}"

# Test memory resource properties
assert device_buffer.memory_resource == device_mr, "Device buffer should use device memory resource"
assert pinned_buffer.memory_resource == pinned_mr, "Pinned buffer should use pinned memory resource"
assert new_device_buffer.memory_resource == device_mr, "New device buffer should use device memory resource"

# Clean up
device_buffer.close(stream)
pinned_buffer.close(stream)
new_device_buffer.close(stream)
stream.close()

# Verify buffers are properly closed
assert device_buffer.handle == 0, "Device buffer should be closed"
assert pinned_buffer.handle == 0, "Pinned buffer should be closed"
assert new_device_buffer.handle == 0, "New device buffer should be closed"

print("Memory management example completed!")