Skip to content

Commit

Permalink
Made the examples in doc runnable
Browse files Browse the repository at this point in the history
  • Loading branch information
mehdiataei committed Sep 16, 2024
1 parent 26e7ae5 commit c3dca56
Showing 1 changed file with 127 additions and 12 deletions.
139 changes: 127 additions & 12 deletions docs/modules/interoperability.rst
Original file line number Diff line number Diff line change
Expand Up @@ -483,11 +483,20 @@ Here's an example of how to use `shard_map` with a Warp kernel:
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax.experimental.multihost_utils import process_allgather as allgather
from jax.experimental.shard_map import shard_map
from warp.jax_experimental import jax_kernel
import numpy as np
# Initialize JAX distributed environment
jax.distributed.initialize()
num_gpus = jax.device_count()
def print_on_process_0(*args, **kwargs):
if jax.process_index() == 0:
print(*args, **kwargs)
print_on_process_0(f"Running on {num_gpus} GPU(s)")
@wp.kernel
def multiply_by_two_kernel(
Expand All @@ -499,18 +508,63 @@ Here's an example of how to use `shard_map` with a Warp kernel:
jax_warp_multiply = jax_kernel(multiply_by_two_kernel)
def warp_multiply(x):
result = jax_warp_multiply(x)
return result
# a_in here is the full sharded array with shape (M,)
# The output will also be a sharded array with shape (M,)
def warp_distributed_operator(a_in):
def _sharded_operator(a_in):
return jax_warp_multiply(a_in)[0]
# Inside the sharded operator, a_in is a local shard on each device
# If we have N devices and input size M, each shard has shape (M/N,)
# warp_multiply applies the Warp kernel to the local shard
result = warp_multiply(a_in)[0]
# result has the same shape as the input shard (M/N,)
return result
# shard_map distributes the computation across devices
return shard_map(
_sharded_operator,
mesh=jax.sharding.Mesh(np.array(jax.devices()), "x"),
in_specs=(P("x"),),
out_specs=P("x"),
in_specs=(P("x"),), # Input is sharded along the 'x' axis
out_specs=P("x"), # Output is also sharded along the 'x' axis
check_rep=False,
)(a_in)
print_on_process_0("Test distributed multiplication using JAX + Warp")
devices = jax.devices()
mesh = jax.sharding.Mesh(np.array(devices), "x")
sharding_spec = jax.sharding.NamedSharding(mesh, P("x"))
input_size = num_gpus * 5 # 5 elements per device
single_device_arrays = jnp.arange(input_size, dtype=jnp.float32)
# Define the shape of the input array based on the total input size
shape = (input_size,)
# Create a list of arrays by distributing the single_device_arrays across the available devices
# Each device will receive a portion of the input data
arrays = [
jax.device_put(single_device_arrays[index], d) # Place each element on the corresponding device
for d, index in sharding_spec.addressable_devices_indices_map(shape).items()
]
# Combine the individual device arrays into a single sharded array
sharded_array = jax.make_array_from_single_device_arrays(shape, sharding_spec, arrays)
# sharded_array has shape (input_size,) but is distributed across devices
print_on_process_0(f"Input array: {allgather(sharded_array)}")
# warp_result has the same shape and sharding as sharded_array
warp_result = warp_distributed_operator(sharded_array)
# allgather collects results from all devices, resulting in a full array of shape (input_size,)
print_on_process_0("Warp Output:", allgather(warp_result))
In this example, `shard_map` is used to distribute the computation across available devices. The input array `a_in` is sharded along the 'x' axis, and each device processes its local shard. The Warp kernel `multiply_by_two_kernel` is applied to each shard, and the results are combined to form the final output.

This approach allows for efficient parallel processing of large arrays, as each device works on a portion of the data simultaneously.
Expand All @@ -529,15 +583,35 @@ In some cases, particularly for matrix operations, it's necessary to specify the

.. code-block:: python
import warp as wp
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax.experimental.multihost_utils import process_allgather as allgather
from jax.experimental.shard_map import shard_map
from warp.jax_experimental import jax_kernel
import numpy as np
jax.distributed.initialize()
num_gpus = jax.device_count()
def print_on_process_0(*args, **kwargs):
if jax.process_index() == 0:
print(*args, **kwargs)
print_on_process_0(f"Running on {num_gpus} GPU(s)")
@wp.kernel
def matmul_kernel(
a: wp.array2d(dtype=wp.float32),
b: wp.array2d(dtype=wp.float32),
c: wp.array2d(dtype=wp.float32),
):
# a: (M/num_gpus, K), b: (K, N), c: (M/num_gpus, N)
i, j = wp.tid()
M, K = a.shape
N = b.shape[1]
M = a.shape[0] # M/num_gpus
K = a.shape[1] # K
N = b.shape[1] # N
if i < M and j < N:
s = wp.float32(0.0)
for k in range(K):
Expand All @@ -546,27 +620,68 @@ In some cases, particularly for matrix operations, it's necessary to specify the
# Specify launch dimensions based on the number of GPUs
def create_jax_warp_matmul(M, N):
num_gpus = jax.device_count()
block_size_m = M // num_gpus
block_size_n = N
# M: total rows, N: total columns
block_size_m = M // num_gpus # Rows per GPU
block_size_n = N # All columns
return jax_kernel(matmul_kernel, launch_dims=(block_size_m, block_size_n))
def warp_distributed_matmul(a, b):
# a: (M, K) sharded across GPUs, b: (K, N) replicated
M, K = a.shape
_, N = b.shape
jax_warp_matmul = create_jax_warp_matmul(M, N)
def _sharded_operator(a_shard, b):
return jax_warp_matmul(a_shard, b)
# a_shard: (M/num_gpus, K), b: (K, N)
return jax_warp_matmul(a_shard, b)[0] # Result: (M/num_gpus, N)
return shard_map(
_sharded_operator,
mesh=jax.sharding.Mesh(np.array(jax.devices()), "x"),
in_specs=(P("x", None), P(None, None)),
out_specs=P("x", None),
in_specs=(P("x", None), P(None, None)), # a sharded in first dim, b replicated
out_specs=P("x", None), # Output sharded in first dim
check_rep=False,
)(a, b)
print_on_process_0("Test distributed matrix multiplication using JAX + Warp")
# Define matrix dimensions
M = 8 * num_gpus # Scale M with the number of devices
K, N = 4, 6
# Create input matrices
a = jnp.arange(M * K, dtype=jnp.float32).reshape(M, K) # Shape: (M, K)
b = jnp.arange(K * N, dtype=jnp.float32).reshape(K, N) # Shape: (K, N)
devices = jax.devices()
mesh = jax.sharding.Mesh(np.array(devices), "x")
sharding_spec_a = jax.sharding.NamedSharding(mesh, P("x", None))
sharding_spec_b = jax.sharding.NamedSharding(mesh, P(None, None))
# Shard matrix A and replicate matrix B
sharded_a = jax.device_put(a, sharding_spec_a) # Sharded shape: (M/num_gpus, K) per device
replicated_b = jax.device_put(b, sharding_spec_b) # Replicated shape: (K, N) on all devices
print_on_process_0(f"Input matrix A:\n{allgather(sharded_a)}") # Shape: (M, K)
print_on_process_0(f"Input matrix B:\n{allgather(replicated_b)}") # Shape: (K, N)
warp_result = warp_distributed_matmul(sharded_a, replicated_b) # Sharded result: (M/num_gpus, N) per device
print_on_process_0("Warp Output:")
# Use allgather to collect results from all devices
print_on_process_0(allgather(warp_result)) # Shape: (M, N)
jax_result = jnp.matmul(a, b) # Shape: (M, N)
print_on_process_0("JAX Output:")
print_on_process_0(jax_result)
expected_shape = (M, N)
print_on_process_0(f"Expected shape: {expected_shape}")
print_on_process_0(f"Warp output shape: {warp_result.shape}") # Should be (M/num_gpus, N) on each device
print_on_process_0(f"JAX output shape: {jax_result.shape}") # Should be (M, N)
allclose = jnp.allclose(allgather(warp_result), jax_result, atol=1e-5)
print_on_process_0(f"Allclose: {allclose}")
In this example, we create a function `create_jax_warp_matmul` that calculates the launch dimensions based on the number of available GPUs. We use `jax.device_count()` to get the global number of GPUs and divide the `M` dimension (rows) of the matrix by this number. This ensures that each GPU processes an equal portion of the input matrix A. The `N` dimension (columns) remains unchanged as we're not sharding in that direction.

Note that the launch dimensions are set to match the shape of the matrix portion on each GPU. The `block_size_m` is calculated by dividing the total number of rows by the number of GPUs, while `block_size_n` is set to the full width of the output matrix.
Expand Down

0 comments on commit c3dca56

Please sign in to comment.