Skip to content

Commit

Permalink
Merge branch 'lwawrzyniak/launch-array-interfaces-v2' into 'main'
Browse files Browse the repository at this point in the history
Pass external arrays to kernels, faster Torch interop (v2)

See merge request omniverse/warp!625
  • Loading branch information
nvlukasz committed Jul 17, 2024
2 parents 01d9bb2 + 0f22ff5 commit 7f2376d
Show file tree
Hide file tree
Showing 9 changed files with 792 additions and 53 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
- Update `wp.matmul()` CPU fallback to use dtype explicitly in `np.matmul()` call
- Fix ShapeInstancer `__new__()` method (missing instance return and `*args` parameter)
- Add support for PEP 563's `from __future__ import annotations`.
- Allow passing external arrays/tensors to Warp kernels directly via `__cuda_array_interface__` and `__array_interface__`
- Add faster Torch interop path using `return_ctype` argument to `wp.from_torch()`

## [1.2.2] - 2024-07-04

Expand Down
126 changes: 122 additions & 4 deletions docs/modules/interoperability.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,57 @@
Interoperability
================

Warp can interop with other Python-based frameworks such as NumPy through standard interface protocols.
Warp can interoperate with other Python-based frameworks such as NumPy through standard interface protocols.

Warp supports passing external arrays to kernels directly, as long as they implement the ``__array__``, ``__array_interface__``, or ``__cuda_array_interface__`` protocols. This works with many common frameworks like NumPy, CuPy, or PyTorch.

For example, we can use NumPy arrays directly when launching Warp kernels on the CPU:

.. code:: python
import numpy as np
import warp as wp
@wp.kernel
def saxpy(x: wp.array(dtype=float), y: wp.array(dtype=float), a: float):
i = wp.tid()
y[i] = a * x[i] + y[i]
x = np.arange(n, dtype=np.float32)
y = np.ones(n, dtype=np.float32)
wp.launch(saxpy, dim=n, inputs=[x, y, 1.0], device="cpu")
Likewise, we can use CuPy arrays on a CUDA device:

.. code:: python
import cupy as cp
with cp.cuda.Device(0):
x = cp.arange(n, dtype=cp.float32)
y = cp.ones(n, dtype=cp.float32)
wp.launch(saxpy, dim=n, inputs=[x, y, 1.0], device="cuda:0")
Note that with CUDA arrays, it's important to ensure that the device on which the arrays reside is the same as the device on which the kernel is launched.

PyTorch supports both CPU and GPU tensors and both kinds can be passed to Warp kernels on the appropriate device.

.. code:: python
import random
import torch
if random.choice([False, True]):
device = "cpu"
else:
device = "cuda:0"
x = torch.arange(n, dtype=torch.float32, device=device)
y = torch.ones(n, dtype=torch.float32, device=device)
wp.launch(saxpy, dim=n, inputs=[x, y, 1.0], device=device)
NumPy
-----
Expand Down Expand Up @@ -247,12 +297,80 @@ Note that if Warp code is wrapped in a torch.autograd.function that gets called
exclude that function from compiler optimizations. If your script uses ``torch.compile()``, we recommend using Pytorch version 2.3.0+,
which has improvements that address this scenario.

Performance Notes
^^^^^^^^^^^^^^^^^

The ``wp.from_torch()`` function creates a Warp array object that shares data with a PyTorch tensor. Although this function does not copy the data, there is always some CPU overhead during the conversion. If these conversions happen frequently, the overall program performance may suffer. As a general rule, it's good to avoid repeated conversions of the same tensor. Instead of:

.. code:: python
x_t = torch.arange(n, dtype=torch.float32, device=device)
y_t = torch.ones(n, dtype=torch.float32, device=device)
for i in range(10):
x_w = wp.from_torch(x_t)
y_w = wp.from_torch(y_t)
wp.launch(saxpy, dim=n, inputs=[x_w, y_w, 1.0], device=device)
Try converting the arrays only once and reuse them:

.. code:: python
x_t = torch.arange(n, dtype=torch.float32, device=device)
y_t = torch.ones(n, dtype=torch.float32, device=device)
x_w = wp.from_torch(x_t)
y_w = wp.from_torch(y_t)
for i in range(10):
wp.launch(saxpy, dim=n, inputs=[x_w, y_w, 1.0], device=device)
If reusing arrays is not possible (e.g., a new PyTorch tensor is constructed on every iteration), passing ``return_ctype=True`` to ``wp.from_torch()`` should yield faster performance. Setting this argument to True avoids constructing a ``wp.array`` object and instead returns a low-level array descriptor. This descriptor is a simple C structure that can be passed to Warp kernels instead of a ``wp.array``, but cannot be used in other places that require a ``wp.array``.

.. code:: python
for n in range(1, 10):
# get Torch tensors for this iteration
x_t = torch.arange(n, dtype=torch.float32, device=device)
y_t = torch.ones(n, dtype=torch.float32, device=device)
# get Warp array descriptors
x_ctype = wp.from_torch(x_t, return_ctype=True)
y_ctype = wp.from_torch(y_t, return_ctype=True)
wp.launch(saxpy, dim=n, inputs=[x_ctype, y_ctype, 1.0], device=device)
An alternative approach is to pass the PyTorch tensors to Warp kernels directly. This avoids constructing temporary Warp arrays by leveraging standard array interfaces (like ``__cuda_array_interface__``) supported by both PyTorch and Warp. The main advantage of this approach is convenience, since there is no need to call any conversion functions. The main limitation is that it does not handle gradients, because gradient information is not included in the standard array interfaces. This technique is therefore most suitable for algorithms that do not involve differentiation.

.. code:: python
x = torch.arange(n, dtype=torch.float32, device=device)
y = torch.ones(n, dtype=torch.float32, device=device)
for i in range(10):
wp.launch(saxpy, dim=n, inputs=[x, y, 1.0], device=device)
.. code:: shell
python -m warp.examples.benchmarks.benchmark_interop_torch
Sample output:

.. code::
5095 ms from_torch(...)
2113 ms from_torch(..., return_ctype=True)
2950 ms direct from torch
The default ``wp.from_torch()`` conversion is the slowest. Passing ``return_ctype=True`` is the fastest, because it skips creating temporary Warp array objects. Passing PyTorch tensors to Warp kernels directly falls somewhere in between. It skips creating temporary Warp arrays, but accessing the ``__cuda_array_interface__`` attributes of PyTorch tensors adds overhead because they are initialized on-demand.


CuPy/Numba
----------

Warp GPU arrays support the ``__cuda_array_interface__`` protocol for sharing data with other Python GPU frameworks.
Currently this is one-directional, so that Warp arrays can be used as input to any framework that also supports the
``__cuda_array_interface__`` protocol, but not the other way around.
Warp GPU arrays support the ``__cuda_array_interface__`` protocol for sharing data with other Python GPU frameworks. This allows frameworks like CuPy to use Warp GPU arrays directly.

Likewise, Warp arrays can be created from any object that exposes the ``__cuda_array_interface__``. Such objects can also be passed to Warp kernels directly without creating a Warp array object.

.. _jax-interop:

Expand Down
49 changes: 41 additions & 8 deletions warp/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3316,17 +3316,20 @@ def load_dll(self, dll_path):
return dll

def get_device(self, ident: Devicelike = None) -> Device:
if isinstance(ident, Device):
# special cases
if type(ident) is Device:
return ident
elif ident is None:
return self.default_device
elif isinstance(ident, str):
if ident == "cuda":
return self.get_current_cuda_device()
else:
return self.device_map[ident]
else:
raise RuntimeError(f"Unable to resolve device from argument of type {type(ident)}")

# string lookup
device = self.device_map.get(ident)
if device is not None:
return device
elif ident == "cuda":
return self.get_current_cuda_device()

raise ValueError(f"Invalid device identifier: {ident}")

def set_default_device(self, ident: Devicelike):
self.default_device = self.get_device(ident)
Expand Down Expand Up @@ -4343,6 +4346,10 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
# allow for NULL arrays
return arg_type.__ctype__()

elif isinstance(value, warp.types.array_t):
# accept array descriptors verbatum
return value

else:
# check for array type
# - in forward passes, array types have to match
Expand All @@ -4353,6 +4360,32 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
array_matches = type(value) is type(arg_type)

if not array_matches:
# if a regular Warp array is required, try converting from __cuda_array_interface__ or __array_interface__
if isinstance(arg_type, warp.array):
if device.is_cuda:
# check for __cuda_array_interface__
try:
interface = value.__cuda_array_interface__
except AttributeError:
pass
else:
return warp.types.array_ctype_from_interface(interface, dtype=arg_type.dtype, owner=value)
else:
# check for __array_interface__
try:
interface = value.__array_interface__
except AttributeError:
pass
else:
return warp.types.array_ctype_from_interface(interface, dtype=arg_type.dtype, owner=value)
# check for __array__() method, e.g. Torch CPU tensors
try:
interface = value.__array__().__array_interface__
except AttributeError:
pass
else:
return warp.types.array_ctype_from_interface(interface, dtype=arg_type.dtype, owner=value)

adj = "adjoint " if adjoint else ""
raise RuntimeError(
f"Error launching kernel '{kernel.key}', {adj}argument '{arg_name}' expects an array of type {type(arg_type)}, but passed value has type {type(value)}."
Expand Down
158 changes: 158 additions & 0 deletions warp/examples/benchmarks/benchmark_interop_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import time

import torch

import warp as wp


def create_simple_kernel(dtype):
def simple_kernel(
a: wp.array(dtype=dtype),
b: wp.array(dtype=dtype),
c: wp.array(dtype=dtype),
d: wp.array(dtype=dtype),
e: wp.array(dtype=dtype),
):
pass

return wp.Kernel(simple_kernel)


def test_from_torch(kernel, num_iters, array_size, device, warp_dtype=None):
warp_device = wp.get_device(device)
torch_device = wp.device_to_torch(warp_device)

if hasattr(warp_dtype, "_shape_"):
torch_shape = (array_size, *warp_dtype._shape_)
torch_dtype = wp.dtype_to_torch(warp_dtype._wp_scalar_type_)
else:
torch_shape = (array_size,)
torch_dtype = torch.float32 if warp_dtype is None else wp.dtype_to_torch(warp_dtype)

_a = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
_b = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
_c = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
_d = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
_e = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)

wp.synchronize()

# profiler = Profiler(interval=0.000001)
# profiler.start()

t1 = time.time_ns()

for _ in range(num_iters):
a = wp.from_torch(_a, dtype=warp_dtype)
b = wp.from_torch(_b, dtype=warp_dtype)
c = wp.from_torch(_c, dtype=warp_dtype)
d = wp.from_torch(_d, dtype=warp_dtype)
e = wp.from_torch(_e, dtype=warp_dtype)
wp.launch(kernel, dim=array_size, inputs=[a, b, c, d, e])

t2 = time.time_ns()
print(f"{(t2 - t1) / 1_000_000 :8.0f} ms from_torch(...)")

# profiler.stop()
# profiler.print()


def test_array_ctype_from_torch(kernel, num_iters, array_size, device, warp_dtype=None):
warp_device = wp.get_device(device)
torch_device = wp.device_to_torch(warp_device)

if hasattr(warp_dtype, "_shape_"):
torch_shape = (array_size, *warp_dtype._shape_)
torch_dtype = wp.dtype_to_torch(warp_dtype._wp_scalar_type_)
else:
torch_shape = (array_size,)
torch_dtype = torch.float32 if warp_dtype is None else wp.dtype_to_torch(warp_dtype)

_a = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
_b = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
_c = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
_d = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
_e = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)

wp.synchronize()

# profiler = Profiler(interval=0.000001)
# profiler.start()

t1 = time.time_ns()

for _ in range(num_iters):
a = wp.from_torch(_a, dtype=warp_dtype, return_ctype=True)
b = wp.from_torch(_b, dtype=warp_dtype, return_ctype=True)
c = wp.from_torch(_c, dtype=warp_dtype, return_ctype=True)
d = wp.from_torch(_d, dtype=warp_dtype, return_ctype=True)
e = wp.from_torch(_e, dtype=warp_dtype, return_ctype=True)
wp.launch(kernel, dim=array_size, inputs=[a, b, c, d, e])

t2 = time.time_ns()
print(f"{(t2 - t1) / 1_000_000 :8.0f} ms from_torch(..., return_ctype=True)")

# profiler.stop()
# profiler.print()


def test_direct_from_torch(kernel, num_iters, array_size, device, warp_dtype=None):
warp_device = wp.get_device(device)
torch_device = wp.device_to_torch(warp_device)

if hasattr(warp_dtype, "_shape_"):
torch_shape = (array_size, *warp_dtype._shape_)
torch_dtype = wp.dtype_to_torch(warp_dtype._wp_scalar_type_)
else:
torch_shape = (array_size,)
torch_dtype = torch.float32 if warp_dtype is None else wp.dtype_to_torch(warp_dtype)

_a = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
_b = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
_c = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
_d = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
_e = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)

wp.synchronize()

# profiler = Profiler(interval=0.000001)
# profiler.start()

t1 = time.time_ns()

for _ in range(num_iters):
wp.launch(kernel, dim=array_size, inputs=[_a, _b, _c, _d, _e])

t2 = time.time_ns()
print(f"{(t2 - t1) / 1_000_000 :8.0f} ms direct from torch")

# profiler.stop()
# profiler.print()


wp.init()

params = [
# (warp_dtype arg, kernel)
(None, create_simple_kernel(wp.float32)),
(wp.float32, create_simple_kernel(wp.float32)),
(wp.vec3f, create_simple_kernel(wp.vec3f)),
(wp.mat22f, create_simple_kernel(wp.mat22f)),
]

wp.load_module()

num_iters = 100000

for warp_dtype, kernel in params:
print(f"\ndtype={wp.context.type_str(warp_dtype)}")
test_from_torch(kernel, num_iters, 10, "cuda:0", warp_dtype=warp_dtype)
test_array_ctype_from_torch(kernel, num_iters, 10, "cuda:0", warp_dtype=warp_dtype)
test_direct_from_torch(kernel, num_iters, 10, "cuda:0", warp_dtype=warp_dtype)
Loading

0 comments on commit 7f2376d

Please sign in to comment.