Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warp-JAX multi-GPU interoperability and added custom launch dimension #310

Merged
merged 5 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 112 additions & 1 deletion docs/modules/interoperability.rst
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,6 @@ Since this is an experimental feature, there are some limitations:
- Kernel launch dimensions are inferred from the shape of the first argument.
- Input arguments are followed by output arguments in the Warp kernel definition.
- There must be at least one input argument and at least one output argument.
- Output shapes must match the launch dimensions (i.e., output shapes must match the shape of the first argument).
- All arrays must be contiguous.
- Only the CUDA backend is supported.

Expand Down Expand Up @@ -462,6 +461,118 @@ Here is an example of an operation with three inputs and two outputs::
print(x)
print(y)

Using shardmap for distributed computation
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Warp can be used in conjunction with JAX's `shard_map <https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html>`_ to perform distributed multi-GPU computations.

To achieve this, the JAX distributed environment must be initialized (see `Distributed Arrays and Automatic Parallelization <https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html>`_ for more details):

.. code-block:: python

import jax
jax.distributed.initialize()

This initialization must be called at the beginning of your program, before any other JAX operations.

Here's an example of how to use `shard_map` with a Warp kernel:

.. code-block:: python
mehdiataei marked this conversation as resolved.
Show resolved Hide resolved

import warp as wp
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax.experimental.shard_map import shard_map
from warp.jax_experimental import jax_kernel

# Initialize JAX distributed environment
jax.distributed.initialize()

@wp.kernel
def multiply_by_two_kernel(
a_in: wp.array(dtype=wp.float32),
a_out: wp.array(dtype=wp.float32),
):
index = wp.tid()
a_out[index] = a_in[index] * 2.0

jax_warp_multiply = jax_kernel(multiply_by_two_kernel)

def warp_distributed_operator(a_in):
def _sharded_operator(a_in):
return jax_warp_multiply(a_in)[0]

return shard_map(
_sharded_operator,
mesh=jax.sharding.Mesh(np.array(jax.devices()), "x"),
in_specs=(P("x"),),
out_specs=P("x"),
check_rep=False,
)(a_in)

In this example, `shard_map` is used to distribute the computation across available devices. The input array `a_in` is sharded along the 'x' axis, and each device processes its local shard. The Warp kernel `multiply_by_two_kernel` is applied to each shard, and the results are combined to form the final output.

This approach allows for efficient parallel processing of large arrays, as each device works on a portion of the data simultaneously.

To run this program on multiple GPUs, you can use `mpirun` with the following command:
mehdiataei marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: bash

mpirun -np <NUM_OF_GPUS> python <filename>.py


Specifying launch dimensions for matrix operations
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In some cases, particularly for matrix operations, it's necessary to specify the launch dimensions for Warp kernels. This is because the default behavior of inferring dimensions from the first argument may not always be suitable for matrix operations. Here's an example of a distributed matrix multiplication using Warp and JAX:

.. code-block:: python
mehdiataei marked this conversation as resolved.
Show resolved Hide resolved

@wp.kernel
def matmul_kernel(
a: wp.array2d(dtype=wp.float32),
b: wp.array2d(dtype=wp.float32),
c: wp.array2d(dtype=wp.float32),
):
i, j = wp.tid()
M, K = a.shape
N = b.shape[1]
if i < M and j < N:
s = wp.float32(0.0)
for k in range(K):
s += a[i, k] * b[k, j]
c[i, j] = s

# Specify launch dimensions based on the number of GPUs
def create_jax_warp_matmul(M, N):
num_gpus = jax.device_count()
block_size_m = M // num_gpus
block_size_n = N
return jax_kernel(matmul_kernel, launch_dims=(block_size_m, block_size_n))

def warp_distributed_matmul(a, b):
M, K = a.shape
_, N = b.shape
jax_warp_matmul = create_jax_warp_matmul(M, N)

def _sharded_operator(a_shard, b):
return jax_warp_matmul(a_shard, b)

return shard_map(
_sharded_operator,
mesh=jax.sharding.Mesh(np.array(jax.devices()), "x"),
in_specs=(P("x", None), P(None, None)),
out_specs=P("x", None),
check_rep=False,
)(a, b)

In this example, we create a function `create_jax_warp_matmul` that calculates the launch dimensions based on the number of available GPUs. We use `jax.device_count()` to get the number of GPUs and divide the `M` dimension (rows) of the matrix by this number. This ensures that each GPU processes an equal portion of the input matrix A. The `N` dimension (columns) remains unchanged as we're not sharding in that direction.
mehdiataei marked this conversation as resolved.
Show resolved Hide resolved

Note that the launch dimensions are set to match the shape of the matrix portion on each GPU. The `block_size_m` is calculated by dividing the total number of rows by the number of GPUs, while `block_size_n` is set to the full width of the output matrix.

Note that this is a naive implementation of matrix multiplication for the sake of this illustration, and there are many optimizations that can be made to improve performance.

.. _DLPack:

DLPack
Expand Down
42 changes: 27 additions & 15 deletions warp/jax_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,21 @@
_registered_kernel_to_id = {}


def jax_kernel(wp_kernel):
def jax_kernel(wp_kernel, launch_dims=None):
"""Create a Jax primitive from a Warp kernel.

NOTE: This is an experimental feature under development.

Args:
wp_kernel: The Warp kernel to be wrapped.
launch_dims: Optional. Specify the kernel launch dimensions. If None,
dimensions are inferred from the shape of the first argument.
mehdiataei marked this conversation as resolved.
Show resolved Hide resolved

Current limitations:
- All kernel arguments must be arrays.
- Kernel launch dimensions are inferred from the shape of the first argument.
- If launch_dims is not provided, kernel launch dimensions are inferred from the shape of the first argument.
- Input arguments are followed by output arguments in the Warp kernel definition.
- There must be at least one input argument and at least one output argument.
- Output shapes must match the launch dimensions (i.e., output shapes must match the shape of the first argument).
- All arrays must be contiguous.
- Only the CUDA backend is supported.
"""
Expand All @@ -47,7 +51,7 @@ def jax_kernel(wp_kernel):
id = _registered_kernel_to_id[wp_kernel]

def bind(*args):
return _jax_warp_p.bind(*args, kernel=id)
return _jax_warp_p.bind(*args, kernel=id, launch_dims=launch_dims)

return bind

Expand Down Expand Up @@ -106,7 +110,7 @@ def _get_jax_device():
device = jax.config.jax_default_device
# if default device is not set, use first device
if device is None:
device = jax.devices()[0]
device = jax.local_devices()[0]
return device


Expand Down Expand Up @@ -223,12 +227,17 @@ def base_type_is_compatible(warp_type, jax_ir_type):
raise TypeError(f"Invalid or unsupported data type: {jax_ir_type}")

# Abstract evaluation.
def jax_warp_abstract(*args, kernel=None):
def jax_warp_abstract(*args, kernel=None, launch_dims=None):
wp_kernel = _registered_kernels[kernel]
# All the extra arguments to the warp kernel are outputs.
warp_outputs = [o.type for o in wp_kernel.adj.args[len(args) :]]
# TODO. Let's just use the first input dimension to infer the output's dimensions.
dims = strip_vecmat_dimensions(wp_kernel.adj.args[0], list(args[0].shape))

if launch_dims is None:
# Use the first input dimension to infer the output's dimensions if launch_dims is not provided
dims = strip_vecmat_dimensions(wp_kernel.adj.args[0], list(args[0].shape))
else:
dims = launch_dims

jax_outputs = []
for o in warp_outputs:
shape = list(dims) + list(get_vecmat_shape(o))
Expand Down Expand Up @@ -260,7 +269,7 @@ def jax_warp_abstract(*args, kernel=None):
def default_layout(shape):
return range(len(shape) - 1, -1, -1)

def warp_call_lowering(ctx, *args, kernel=None):
def warp_call_lowering(ctx, *args, kernel=None, launch_dims=None):
if not kernel:
raise Exception("Unknown kernel id " + str(kernel))
wp_kernel = _registered_kernels[kernel]
Expand All @@ -272,12 +281,15 @@ def warp_call_lowering(ctx, *args, kernel=None):
if not module.load(device):
raise Exception("Could not load kernel on device")

# Infer dimensions from the first input.
warp_arg0 = wp_kernel.adj.args[0]
actual_shape0 = ir.RankedTensorType(args[0].type).shape
dims = strip_vecmat_dimensions(warp_arg0, actual_shape0)
warp_dims = collapse_into_leading_dimension(warp_arg0, dims)

if launch_dims is None:
# Infer dimensions from the first input.
warp_arg0 = wp_kernel.adj.args[0]
actual_shape0 = ir.RankedTensorType(args[0].type).shape
dims = strip_vecmat_dimensions(warp_arg0, actual_shape0)
warp_dims = collapse_into_leading_dimension(warp_arg0, dims)
else:
dims = launch_dims
warp_dims = launch_dims
# Figure out the types and shapes of the input arrays.
arg_strings = []
operand_layouts = []
Expand Down
58 changes: 58 additions & 0 deletions warp/tests/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,60 @@ def f():
assert_np_equal(result_y, expected_y)


@unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
def test_jax_kernel_launch_dims(test, device):
import jax.numpy as jp

from warp.jax_experimental import jax_kernel

n = 64
m = 32

# Test with 1D launch dims
@wp.kernel
def add_one_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
tid = wp.tid()
y[tid] = x[tid] + 1.0

jax_add_one = jax_kernel(
add_one_kernel, launch_dims=(n - 2,)
) # Intentionally not the same as the first dimension of the input

@jax.jit
def f_1d():
x = jp.arange(n, dtype=jp.float32)
return jax_add_one(x)

# Test with 2D launch dims
@wp.kernel
def add_one_2d_kernel(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
i, j = wp.tid()
y[i, j] = x[i, j] + 1.0

jax_add_one_2d = jax_kernel(
add_one_2d_kernel, launch_dims=(n - 2, m - 2)
) # Intentionally not the same as the first dimension of the input

@jax.jit
def f_2d():
x = jp.zeros((n, m), dtype=jp.float32) + 3.0
return jax_add_one_2d(x)

# run on the given device
with jax.default_device(wp.device_to_jax(device)):
y_1d = f_1d()
y_2d = f_2d()

result_1d = np.asarray(y_1d).reshape((n - 2,))
expected_1d = np.arange(n - 2, dtype=np.float32) + 1.0

result_2d = np.asarray(y_2d).reshape((n - 2, m - 2))
nvlukasz marked this conversation as resolved.
Show resolved Hide resolved
expected_2d = np.full((n - 2, m - 2), 4.0, dtype=np.float32)

assert_np_equal(result_1d, expected_1d)
assert_np_equal(result_2d, expected_2d)


class TestJax(unittest.TestCase):
pass

Expand Down Expand Up @@ -296,6 +350,10 @@ class TestJax(unittest.TestCase):
TestJax, "test_jax_kernel_multiarg", test_jax_kernel_multiarg, devices=jax_compatible_cuda_devices
)

add_function_test(
TestJax, "test_jax_kernel_launch_dims", test_jax_kernel_launch_dims, devices=jax_compatible_cuda_devices
)

except Exception as e:
print(f"Skipping Jax tests due to exception: {e}")

Expand Down
Loading