From 79772cc9e7950e2519f3306f6a5419c395180d2d Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Mon, 9 Sep 2024 16:31:12 -0400 Subject: [PATCH 1/4] Warp-JAX multi-GPU interoperability and added custom launch dimension --- docs/modules/interoperability.rst | 113 +++++++++++++++++++++++++++++- warp/jax_experimental.py | 42 +++++++---- warp/tests/test_jax.py | 58 +++++++++++++++ 3 files changed, 197 insertions(+), 16 deletions(-) diff --git a/docs/modules/interoperability.rst b/docs/modules/interoperability.rst index eb2217c18..26d989c59 100644 --- a/docs/modules/interoperability.rst +++ b/docs/modules/interoperability.rst @@ -418,7 +418,6 @@ Since this is an experimental feature, there are some limitations: - Kernel launch dimensions are inferred from the shape of the first argument. - Input arguments are followed by output arguments in the Warp kernel definition. - There must be at least one input argument and at least one output argument. - - Output shapes must match the launch dimensions (i.e., output shapes must match the shape of the first argument). - All arrays must be contiguous. - Only the CUDA backend is supported. @@ -462,6 +461,118 @@ Here is an example of an operation with three inputs and two outputs:: print(x) print(y) +Using shardmap for distributed computation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Warp can be used in conjunction with JAX's `shard_map `_ to perform distributed multi-GPU computations. + +To achieve this, the JAX distributed environment must be initialized (see `Distributed Arrays and Automatic Parallelization `_ for more details): + +.. code-block:: python + + import jax + jax.distributed.initialize() + +This initialization must be called at the beginning of your program, before any other JAX operations. + +Here's an example of how to use `shard_map` with a Warp kernel: + +.. code-block:: python + + import warp as wp + import jax + import jax.numpy as jnp + from jax.sharding import PartitionSpec as P + from jax.experimental.shard_map import shard_map + from warp.jax_experimental import jax_kernel + + # Initialize JAX distributed environment + jax.distributed.initialize() + + @wp.kernel + def multiply_by_two_kernel( + a_in: wp.array(dtype=wp.float32), + a_out: wp.array(dtype=wp.float32), + ): + index = wp.tid() + a_out[index] = a_in[index] * 2.0 + + jax_warp_multiply = jax_kernel(multiply_by_two_kernel) + + def warp_distributed_operator(a_in): + def _sharded_operator(a_in): + return jax_warp_multiply(a_in)[0] + + return shard_map( + _sharded_operator, + mesh=jax.sharding.Mesh(np.array(jax.devices()), "x"), + in_specs=(P("x"),), + out_specs=P("x"), + check_rep=False, + )(a_in) + +In this example, `shard_map` is used to distribute the computation across available devices. The input array `a_in` is sharded along the 'x' axis, and each device processes its local shard. The Warp kernel `multiply_by_two_kernel` is applied to each shard, and the results are combined to form the final output. + +This approach allows for efficient parallel processing of large arrays, as each device works on a portion of the data simultaneously. + +To run this program on multiple GPUs, you can use `mpirun` with the following command: + +.. code-block:: bash + + mpirun -np python .py + + +Specifying launch dimensions for matrix operations +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In some cases, particularly for matrix operations, it's necessary to specify the launch dimensions for Warp kernels. This is because the default behavior of inferring dimensions from the first argument may not always be suitable for matrix operations. Here's an example of a distributed matrix multiplication using Warp and JAX: + +.. code-block:: python + + @wp.kernel + def matmul_kernel( + a: wp.array2d(dtype=wp.float32), + b: wp.array2d(dtype=wp.float32), + c: wp.array2d(dtype=wp.float32), + ): + i, j = wp.tid() + M, K = a.shape + N = b.shape[1] + if i < M and j < N: + s = wp.float32(0.0) + for k in range(K): + s += a[i, k] * b[k, j] + c[i, j] = s + + # Specify launch dimensions based on the number of GPUs + def create_jax_warp_matmul(M, N): + num_gpus = jax.device_count() + block_size_m = M // num_gpus + block_size_n = N + return jax_kernel(matmul_kernel, launch_dims=(block_size_m, block_size_n)) + + def warp_distributed_matmul(a, b): + M, K = a.shape + _, N = b.shape + jax_warp_matmul = create_jax_warp_matmul(M, N) + + def _sharded_operator(a_shard, b): + return jax_warp_matmul(a_shard, b) + + return shard_map( + _sharded_operator, + mesh=jax.sharding.Mesh(np.array(jax.devices()), "x"), + in_specs=(P("x", None), P(None, None)), + out_specs=P("x", None), + check_rep=False, + )(a, b) + +In this example, we create a function `create_jax_warp_matmul` that calculates the launch dimensions based on the number of available GPUs. We use `jax.device_count()` to get the number of GPUs and divide the `M` dimension (rows) of the matrix by this number. This ensures that each GPU processes an equal portion of the input matrix A. The `N` dimension (columns) remains unchanged as we're not sharding in that direction. + +Note that the launch dimensions are set to match the shape of the matrix portion on each GPU. The `block_size_m` is calculated by dividing the total number of rows by the number of GPUs, while `block_size_n` is set to the full width of the output matrix. + +Note that this is a naive implementation of matrix multiplication for the sake of this illustration, and there are many optimizations that can be made to improve performance. + .. _DLPack: DLPack diff --git a/warp/jax_experimental.py b/warp/jax_experimental.py index c3e3d0727..432a44aee 100644 --- a/warp/jax_experimental.py +++ b/warp/jax_experimental.py @@ -21,17 +21,21 @@ _registered_kernel_to_id = {} -def jax_kernel(wp_kernel): +def jax_kernel(wp_kernel, launch_dims=None): """Create a Jax primitive from a Warp kernel. NOTE: This is an experimental feature under development. + Args: + wp_kernel: The Warp kernel to be wrapped. + launch_dims: Optional. Specify the kernel launch dimensions. If None, + dimensions are inferred from the shape of the first argument. + Current limitations: - All kernel arguments must be arrays. - - Kernel launch dimensions are inferred from the shape of the first argument. + - If launch_dims is not provided, kernel launch dimensions are inferred from the shape of the first argument. - Input arguments are followed by output arguments in the Warp kernel definition. - There must be at least one input argument and at least one output argument. - - Output shapes must match the launch dimensions (i.e., output shapes must match the shape of the first argument). - All arrays must be contiguous. - Only the CUDA backend is supported. """ @@ -47,7 +51,7 @@ def jax_kernel(wp_kernel): id = _registered_kernel_to_id[wp_kernel] def bind(*args): - return _jax_warp_p.bind(*args, kernel=id) + return _jax_warp_p.bind(*args, kernel=id, launch_dims=launch_dims) return bind @@ -106,7 +110,7 @@ def _get_jax_device(): device = jax.config.jax_default_device # if default device is not set, use first device if device is None: - device = jax.devices()[0] + device = jax.local_devices()[0] return device @@ -223,12 +227,17 @@ def base_type_is_compatible(warp_type, jax_ir_type): raise TypeError(f"Invalid or unsupported data type: {jax_ir_type}") # Abstract evaluation. - def jax_warp_abstract(*args, kernel=None): + def jax_warp_abstract(*args, kernel=None, launch_dims=None): wp_kernel = _registered_kernels[kernel] # All the extra arguments to the warp kernel are outputs. warp_outputs = [o.type for o in wp_kernel.adj.args[len(args) :]] - # TODO. Let's just use the first input dimension to infer the output's dimensions. - dims = strip_vecmat_dimensions(wp_kernel.adj.args[0], list(args[0].shape)) + + if launch_dims is None: + # Use the first input dimension to infer the output's dimensions if launch_dims is not provided + dims = strip_vecmat_dimensions(wp_kernel.adj.args[0], list(args[0].shape)) + else: + dims = launch_dims + jax_outputs = [] for o in warp_outputs: shape = list(dims) + list(get_vecmat_shape(o)) @@ -260,7 +269,7 @@ def jax_warp_abstract(*args, kernel=None): def default_layout(shape): return range(len(shape) - 1, -1, -1) - def warp_call_lowering(ctx, *args, kernel=None): + def warp_call_lowering(ctx, *args, kernel=None, launch_dims=None): if not kernel: raise Exception("Unknown kernel id " + str(kernel)) wp_kernel = _registered_kernels[kernel] @@ -272,12 +281,15 @@ def warp_call_lowering(ctx, *args, kernel=None): if not module.load(device): raise Exception("Could not load kernel on device") - # Infer dimensions from the first input. - warp_arg0 = wp_kernel.adj.args[0] - actual_shape0 = ir.RankedTensorType(args[0].type).shape - dims = strip_vecmat_dimensions(warp_arg0, actual_shape0) - warp_dims = collapse_into_leading_dimension(warp_arg0, dims) - + if launch_dims is None: + # Infer dimensions from the first input. + warp_arg0 = wp_kernel.adj.args[0] + actual_shape0 = ir.RankedTensorType(args[0].type).shape + dims = strip_vecmat_dimensions(warp_arg0, actual_shape0) + warp_dims = collapse_into_leading_dimension(warp_arg0, dims) + else: + dims = launch_dims + warp_dims = launch_dims # Figure out the types and shapes of the input arrays. arg_strings = [] operand_layouts = [] diff --git a/warp/tests/test_jax.py b/warp/tests/test_jax.py index 7cb24db3b..d00d03aed 100644 --- a/warp/tests/test_jax.py +++ b/warp/tests/test_jax.py @@ -246,6 +246,60 @@ def f(): assert_np_equal(result_y, expected_y) +@unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old") +def test_jax_kernel_launch_dims(test, device): + import jax.numpy as jp + + from warp.jax_experimental import jax_kernel + + n = 64 + m = 32 + + # Test with 1D launch dims + @wp.kernel + def add_one_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)): + tid = wp.tid() + y[tid] = x[tid] + 1.0 + + jax_add_one = jax_kernel( + add_one_kernel, launch_dims=(n - 2,) + ) # Intentionally not the same as the first dimension of the input + + @jax.jit + def f_1d(): + x = jp.arange(n, dtype=jp.float32) + return jax_add_one(x) + + # Test with 2D launch dims + @wp.kernel + def add_one_2d_kernel(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)): + i, j = wp.tid() + y[i, j] = x[i, j] + 1.0 + + jax_add_one_2d = jax_kernel( + add_one_2d_kernel, launch_dims=(n - 2, m - 2) + ) # Intentionally not the same as the first dimension of the input + + @jax.jit + def f_2d(): + x = jp.zeros((n, m), dtype=jp.float32) + 3.0 + return jax_add_one_2d(x) + + # run on the given device + with jax.default_device(wp.device_to_jax(device)): + y_1d = f_1d() + y_2d = f_2d() + + result_1d = np.asarray(y_1d).reshape((n - 2,)) + expected_1d = np.arange(n - 2, dtype=np.float32) + 1.0 + + result_2d = np.asarray(y_2d).reshape((n - 2, m - 2)) + expected_2d = np.full((n - 2, m - 2), 4.0, dtype=np.float32) + + assert_np_equal(result_1d, expected_1d) + assert_np_equal(result_2d, expected_2d) + + class TestJax(unittest.TestCase): pass @@ -296,6 +350,10 @@ class TestJax(unittest.TestCase): TestJax, "test_jax_kernel_multiarg", test_jax_kernel_multiarg, devices=jax_compatible_cuda_devices ) + add_function_test( + TestJax, "test_jax_kernel_launch_dims", test_jax_kernel_launch_dims, devices=jax_compatible_cuda_devices + ) + except Exception as e: print(f"Skipping Jax tests due to exception: {e}") From 4b1b11daad815663c28edc709669736323ed166d Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Tue, 10 Sep 2024 10:16:18 -0400 Subject: [PATCH 2/4] Update docs/modules/interoperability.rst MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Frédéric Bastien --- docs/modules/interoperability.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/modules/interoperability.rst b/docs/modules/interoperability.rst index 26d989c59..71bfd8f60 100644 --- a/docs/modules/interoperability.rst +++ b/docs/modules/interoperability.rst @@ -567,7 +567,7 @@ In some cases, particularly for matrix operations, it's necessary to specify the check_rep=False, )(a, b) -In this example, we create a function `create_jax_warp_matmul` that calculates the launch dimensions based on the number of available GPUs. We use `jax.device_count()` to get the number of GPUs and divide the `M` dimension (rows) of the matrix by this number. This ensures that each GPU processes an equal portion of the input matrix A. The `N` dimension (columns) remains unchanged as we're not sharding in that direction. +In this example, we create a function `create_jax_warp_matmul` that calculates the launch dimensions based on the number of available GPUs. We use `jax.device_count()` to get the global number of GPUs and divide the `M` dimension (rows) of the matrix by this number. This ensures that each GPU processes an equal portion of the input matrix A. The `N` dimension (columns) remains unchanged as we're not sharding in that direction. Note that the launch dimensions are set to match the shape of the matrix portion on each GPU. The `block_size_m` is calculated by dividing the total number of rows by the number of GPUs, while `block_size_n` is set to the full width of the output matrix. From 637ca2ec4d59aad9348aec9e85b5705775e85429 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Tue, 10 Sep 2024 10:23:38 -0400 Subject: [PATCH 3/4] Added openmpi installation guide + instructions about the output shape when launch dims are set --- docs/modules/interoperability.rst | 2 +- warp/jax_experimental.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/modules/interoperability.rst b/docs/modules/interoperability.rst index 26d989c59..78b0d4b51 100644 --- a/docs/modules/interoperability.rst +++ b/docs/modules/interoperability.rst @@ -515,7 +515,7 @@ In this example, `shard_map` is used to distribute the computation across availa This approach allows for efficient parallel processing of large arrays, as each device works on a portion of the data simultaneously. -To run this program on multiple GPUs, you can use `mpirun` with the following command: +To run this program on multiple GPUs, you must have OpenMPI installed. You can consult the `OpenMPI installation guide `_ for instructions on how to install it. Once OpenMPI is installed, you can use `mpirun` with the following command: .. code-block:: bash diff --git a/warp/jax_experimental.py b/warp/jax_experimental.py index 432a44aee..8e78ab26a 100644 --- a/warp/jax_experimental.py +++ b/warp/jax_experimental.py @@ -30,6 +30,7 @@ def jax_kernel(wp_kernel, launch_dims=None): wp_kernel: The Warp kernel to be wrapped. launch_dims: Optional. Specify the kernel launch dimensions. If None, dimensions are inferred from the shape of the first argument. + This option when set will specify the output dimensions. Current limitations: - All kernel arguments must be arrays. From c3dca564ef7a6cee5ac60c52dac77481b877daa0 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Sun, 15 Sep 2024 20:08:01 -0400 Subject: [PATCH 4/4] Made the examples in doc runnable --- docs/modules/interoperability.rst | 139 +++++++++++++++++++++++++++--- 1 file changed, 127 insertions(+), 12 deletions(-) diff --git a/docs/modules/interoperability.rst b/docs/modules/interoperability.rst index c35e0d599..800d4e796 100644 --- a/docs/modules/interoperability.rst +++ b/docs/modules/interoperability.rst @@ -483,11 +483,20 @@ Here's an example of how to use `shard_map` with a Warp kernel: import jax import jax.numpy as jnp from jax.sharding import PartitionSpec as P + from jax.experimental.multihost_utils import process_allgather as allgather from jax.experimental.shard_map import shard_map from warp.jax_experimental import jax_kernel + import numpy as np # Initialize JAX distributed environment jax.distributed.initialize() + num_gpus = jax.device_count() + + def print_on_process_0(*args, **kwargs): + if jax.process_index() == 0: + print(*args, **kwargs) + + print_on_process_0(f"Running on {num_gpus} GPU(s)") @wp.kernel def multiply_by_two_kernel( @@ -499,18 +508,63 @@ Here's an example of how to use `shard_map` with a Warp kernel: jax_warp_multiply = jax_kernel(multiply_by_two_kernel) + def warp_multiply(x): + result = jax_warp_multiply(x) + return result + + # a_in here is the full sharded array with shape (M,) + # The output will also be a sharded array with shape (M,) def warp_distributed_operator(a_in): def _sharded_operator(a_in): - return jax_warp_multiply(a_in)[0] - + # Inside the sharded operator, a_in is a local shard on each device + # If we have N devices and input size M, each shard has shape (M/N,) + + # warp_multiply applies the Warp kernel to the local shard + result = warp_multiply(a_in)[0] + + # result has the same shape as the input shard (M/N,) + return result + + # shard_map distributes the computation across devices return shard_map( _sharded_operator, mesh=jax.sharding.Mesh(np.array(jax.devices()), "x"), - in_specs=(P("x"),), - out_specs=P("x"), + in_specs=(P("x"),), # Input is sharded along the 'x' axis + out_specs=P("x"), # Output is also sharded along the 'x' axis check_rep=False, )(a_in) + print_on_process_0("Test distributed multiplication using JAX + Warp") + + devices = jax.devices() + mesh = jax.sharding.Mesh(np.array(devices), "x") + sharding_spec = jax.sharding.NamedSharding(mesh, P("x")) + + input_size = num_gpus * 5 # 5 elements per device + single_device_arrays = jnp.arange(input_size, dtype=jnp.float32) + + # Define the shape of the input array based on the total input size + shape = (input_size,) + + # Create a list of arrays by distributing the single_device_arrays across the available devices + # Each device will receive a portion of the input data + arrays = [ + jax.device_put(single_device_arrays[index], d) # Place each element on the corresponding device + for d, index in sharding_spec.addressable_devices_indices_map(shape).items() + ] + + # Combine the individual device arrays into a single sharded array + sharded_array = jax.make_array_from_single_device_arrays(shape, sharding_spec, arrays) + + # sharded_array has shape (input_size,) but is distributed across devices + print_on_process_0(f"Input array: {allgather(sharded_array)}") + + # warp_result has the same shape and sharding as sharded_array + warp_result = warp_distributed_operator(sharded_array) + + # allgather collects results from all devices, resulting in a full array of shape (input_size,) + print_on_process_0("Warp Output:", allgather(warp_result)) + In this example, `shard_map` is used to distribute the computation across available devices. The input array `a_in` is sharded along the 'x' axis, and each device processes its local shard. The Warp kernel `multiply_by_two_kernel` is applied to each shard, and the results are combined to form the final output. This approach allows for efficient parallel processing of large arrays, as each device works on a portion of the data simultaneously. @@ -529,15 +583,35 @@ In some cases, particularly for matrix operations, it's necessary to specify the .. code-block:: python + import warp as wp + import jax + import jax.numpy as jnp + from jax.sharding import PartitionSpec as P + from jax.experimental.multihost_utils import process_allgather as allgather + from jax.experimental.shard_map import shard_map + from warp.jax_experimental import jax_kernel + import numpy as np + + jax.distributed.initialize() + num_gpus = jax.device_count() + + def print_on_process_0(*args, **kwargs): + if jax.process_index() == 0: + print(*args, **kwargs) + + print_on_process_0(f"Running on {num_gpus} GPU(s)") + @wp.kernel def matmul_kernel( a: wp.array2d(dtype=wp.float32), b: wp.array2d(dtype=wp.float32), c: wp.array2d(dtype=wp.float32), ): + # a: (M/num_gpus, K), b: (K, N), c: (M/num_gpus, N) i, j = wp.tid() - M, K = a.shape - N = b.shape[1] + M = a.shape[0] # M/num_gpus + K = a.shape[1] # K + N = b.shape[1] # N if i < M and j < N: s = wp.float32(0.0) for k in range(K): @@ -546,27 +620,68 @@ In some cases, particularly for matrix operations, it's necessary to specify the # Specify launch dimensions based on the number of GPUs def create_jax_warp_matmul(M, N): - num_gpus = jax.device_count() - block_size_m = M // num_gpus - block_size_n = N + # M: total rows, N: total columns + block_size_m = M // num_gpus # Rows per GPU + block_size_n = N # All columns return jax_kernel(matmul_kernel, launch_dims=(block_size_m, block_size_n)) def warp_distributed_matmul(a, b): + # a: (M, K) sharded across GPUs, b: (K, N) replicated M, K = a.shape _, N = b.shape jax_warp_matmul = create_jax_warp_matmul(M, N) def _sharded_operator(a_shard, b): - return jax_warp_matmul(a_shard, b) + # a_shard: (M/num_gpus, K), b: (K, N) + return jax_warp_matmul(a_shard, b)[0] # Result: (M/num_gpus, N) return shard_map( _sharded_operator, mesh=jax.sharding.Mesh(np.array(jax.devices()), "x"), - in_specs=(P("x", None), P(None, None)), - out_specs=P("x", None), + in_specs=(P("x", None), P(None, None)), # a sharded in first dim, b replicated + out_specs=P("x", None), # Output sharded in first dim check_rep=False, )(a, b) + print_on_process_0("Test distributed matrix multiplication using JAX + Warp") + + # Define matrix dimensions + M = 8 * num_gpus # Scale M with the number of devices + K, N = 4, 6 + + # Create input matrices + a = jnp.arange(M * K, dtype=jnp.float32).reshape(M, K) # Shape: (M, K) + b = jnp.arange(K * N, dtype=jnp.float32).reshape(K, N) # Shape: (K, N) + + devices = jax.devices() + mesh = jax.sharding.Mesh(np.array(devices), "x") + sharding_spec_a = jax.sharding.NamedSharding(mesh, P("x", None)) + sharding_spec_b = jax.sharding.NamedSharding(mesh, P(None, None)) + + # Shard matrix A and replicate matrix B + sharded_a = jax.device_put(a, sharding_spec_a) # Sharded shape: (M/num_gpus, K) per device + replicated_b = jax.device_put(b, sharding_spec_b) # Replicated shape: (K, N) on all devices + + print_on_process_0(f"Input matrix A:\n{allgather(sharded_a)}") # Shape: (M, K) + print_on_process_0(f"Input matrix B:\n{allgather(replicated_b)}") # Shape: (K, N) + + warp_result = warp_distributed_matmul(sharded_a, replicated_b) # Sharded result: (M/num_gpus, N) per device + print_on_process_0("Warp Output:") + # Use allgather to collect results from all devices + print_on_process_0(allgather(warp_result)) # Shape: (M, N) + + jax_result = jnp.matmul(a, b) # Shape: (M, N) + print_on_process_0("JAX Output:") + print_on_process_0(jax_result) + + expected_shape = (M, N) + print_on_process_0(f"Expected shape: {expected_shape}") + print_on_process_0(f"Warp output shape: {warp_result.shape}") # Should be (M/num_gpus, N) on each device + print_on_process_0(f"JAX output shape: {jax_result.shape}") # Should be (M, N) + + allclose = jnp.allclose(allgather(warp_result), jax_result, atol=1e-5) + print_on_process_0(f"Allclose: {allclose}") + In this example, we create a function `create_jax_warp_matmul` that calculates the launch dimensions based on the number of available GPUs. We use `jax.device_count()` to get the global number of GPUs and divide the `M` dimension (rows) of the matrix by this number. This ensures that each GPU processes an equal portion of the input matrix A. The `N` dimension (columns) remains unchanged as we're not sharding in that direction. Note that the launch dimensions are set to match the shape of the matrix portion on each GPU. The `block_size_m` is calculated by dividing the total number of rows by the number of GPUs, while `block_size_n` is set to the full width of the output matrix.