From 561a323766840633766ed81693075c672f15c653 Mon Sep 17 00:00:00 2001 From: Michael Goldfarb Date: Sun, 12 Jan 2025 22:32:15 +0000 Subject: [PATCH] Consolidate the distributed fused attention tests to shared input generation and execition logic. Signed-off-by: Michael Goldfarb --- tests/jax/distributed_test_base.py | 29 +- tests/jax/test_distributed_fused_attn.py | 444 +++++------------------ tests/jax/test_fused_attn.py | 242 ++++++++++-- 3 files changed, 323 insertions(+), 392 deletions(-) diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index c2d7039a53..d0ace8263f 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -18,14 +18,22 @@ def generate_configs(): configs = [] if is_devices_enough(2): - configs.append([2, (2,), "dp", MeshResource(dp_resource="dp")]) - configs.append([2, (2,), "tp", MeshResource(tp_resource="tp")]) + configs.append( + pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1") + ) + configs.append( + pytest.param(2, (2,), ("tp",), MeshResource(tp_resource="tp"), id="n2_dp1_tp2") + ) if is_devices_enough(4): - TP_size = 2 - DP_size = 2 configs.append( - [4, (DP_size, TP_size), ("dp", "tp"), MeshResource(dp_resource="dp", tp_resource="tp")] + pytest.param( + 4, + (2, 2), + ("dp", "tp"), + MeshResource(dp_resource="dp", tp_resource="tp"), + id=f"n4_dp2_tp2", + ) ) return configs @@ -33,7 +41,8 @@ def generate_configs(): def generate_context_parallel_configs(): configs = [] - + mr = MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp") + axes = ("dp", "cp", "tp") DP_sizes = (1, 2) CP_sizes = (1, 2, 4, 8) TP_sizes = (1, 2) @@ -41,13 +50,7 @@ def generate_context_parallel_configs(): ndev = cp * tp * dp if is_devices_enough(ndev): configs.append( - pytest.param( - ndev, - (dp, cp, tp), - ("dp", "cp", "tp"), - MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp"), - id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}", - ) + pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}") ) return configs diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 5a41911691..2e15dd4d5d 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -37,8 +37,7 @@ ) from transformer_engine.jax.sharding import MeshResource -# We will use the golden reference model from our non distributed attention test fixture. -from test_fused_attn import general_dot_product_attention, make_mask +from test_fused_attn import FusedAttnRunner, BiasShape, general_dot_product_attention, make_mask DTYPES = [jnp.float16, jnp.bfloat16] @@ -49,7 +48,7 @@ def generate_collectives_count_ref( self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype ): jax_dtype = jax.dtypes.canonicalize_dtype(dtype) - _, seqlen, _, heads, _ = shape + _, seqlen, heads, _ = shape is_dp_enabled = mesh_resource.dp_resource is not None tp_size = 1 if mesh_resource.tp_resource is not None: @@ -62,45 +61,28 @@ def generate_collectives_count_ref( # for loss and dbias return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0) - def generate_inputs(self, shape, mesh_resource, with_bias, attn_mask_type, dtype): - batch, seqlen, _, heads, _ = shape - - qkv = random.normal(random.PRNGKey(1124), shape, dtype=dtype) - - bias = ( - random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype) - if with_bias - else None - ) - - mask = None - if attn_mask_type == AttnMaskType.PADDING_MASK: - mask = make_causal_mask(batch, seqlen) - elif attn_mask_type == AttnMaskType.CAUSAL_MASK: - mask = make_self_mask(batch, seqlen) - - qkv_pspec = PartitionSpec( - mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None - ) - bias_pspec = ( - PartitionSpec(None, mesh_resource.tp_resource, None, None) if with_bias else None - ) - mask_pspec = ( - PartitionSpec(mesh_resource.dp_resource, None, None, None) - if attn_mask_type != AttnMaskType.NO_MASK - else None - ) - - return (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) - @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) - @pytest.mark.parametrize("data_shape", [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]]) @pytest.mark.parametrize( - "attn_bias_type", - [AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS], + "data_shape", + [ + pytest.param((32, 512, 12, 64), id="32-512-12-64"), + pytest.param((32, 1024, 16, 128), id="32-1024-16-128"), + ], ) @pytest.mark.parametrize( - "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK] + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"), + ], + ) + @pytest.mark.parametrize( + "attn_mask_type", + [ + pytest.param(AttnMaskType.PADDING_MASK, id="PADDING_MASK"), + pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"), + ], ) @pytest.mark.parametrize("dtype", DTYPES) def test_self_attn( @@ -111,14 +93,14 @@ def test_self_attn( mesh_resource, data_shape, attn_bias_type, + bias_shape, attn_mask_type, dtype, ): dropout_prob = 0.0 is_training = True - scaling_factor = 1.0 - _, seqlen, _, num_head, hidden = data_shape + batch, seqlen, num_head, hidden = data_shape if not is_fused_attn_kernel_available( dtype, @@ -136,74 +118,36 @@ def test_self_attn( ): pytest.skip(f"No FusedAttn backend found") - def target_func(qkv, bias, mask): - return jnp.mean( - fused_attn( - (qkv,), - bias, - mask, - None, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=QKVLayout.BS3HD, - scaling_factor=scaling_factor, - dropout_probability=dropout_prob, - is_training=is_training, - ) - ) - - def ref_func(qkv, bias, mask): - query, key, value = jnp.split(qkv, [1, 2], axis=-3) - query = jnp.squeeze(query) - key = jnp.squeeze(key) - value = jnp.squeeze(value) - - output = dot_product_attention( - query, - key, - value, - bias=bias, - mask=mask, - deterministic=is_training, - dropout_rate=dropout_prob, - dropout_rng=None, - dtype=jnp.float32, - ) - - return jnp.mean(output).astype(dtype) - - with_bias = attn_bias_type != AttnBiasType.NO_BIAS - (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = self.generate_inputs( - data_shape, mesh_resource, with_bias, attn_mask_type, dtype + col_ref = self.generate_collectives_count_ref( + mesh_shape, + mesh_axes, + mesh_resource, + attn_bias_type != AttnBiasType.NO_BIAS, + data_shape, + dtype, ) - collective_count_ref = self.generate_collectives_count_ref( - mesh_shape, mesh_axes, mesh_resource, with_bias, data_shape, dtype + runner = FusedAttnRunner( + batch, + seqlen, + seqlen, + num_head, + num_head, + hidden, + attn_bias_type, + attn_mask_type, + dropout_prob, + dtype, + is_training, + QKVLayout.BS3HD, + bias_shape, + None, + number_of_devices=device_count, + mesh_shape=mesh_shape, + mesh_axes=mesh_axes, + mesh_resource=mesh_resource, + coll_count_ref=col_ref, ) - devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) - mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast(mesh_resource=mesh_resource): - qkv_ = jax.device_put(qkv, NamedSharding(mesh, qkv_pspec)) - bias_ = ( - jax.device_put(bias, NamedSharding(mesh, bias_pspec)) if bias is not None else bias - ) - mask_ = ( - jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask - ) - - grad_args = (0, 1) if with_bias else (0,) - out_grad_shardings = (qkv_pspec, bias_pspec) if with_bias else (qkv_pspec,) - - compare_ops( - target_func, - ref_func, - [qkv_, bias_, mask_], - collective_count_ref, - grad_args=grad_args, - metric_fwd_dtype=dtype, - metric_bwd_dtype=dtype, - in_shardings=(qkv_pspec, bias_pspec, mask_pspec), - out_shardings=(None, out_grad_shardings), - ) + runner.test_backward() class TestDistributedCrossAttn: @@ -213,31 +157,6 @@ def generate_collectives_count_ref(self): all_reduce_loss_bytes = 4 # 1 * FP32 return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0) - def generate_inputs(self, shape, mesh_resource, attn_mask_type, dtype): - batch, seqlen, heads, hidden = shape - - q = random.normal(random.PRNGKey(1124), shape, dtype=dtype) - kv = random.normal(random.PRNGKey(1125), (batch, seqlen, 2, heads, hidden), dtype=dtype) - - mask = None - if attn_mask_type == AttnMaskType.PADDING_MASK: - mask = make_causal_mask(batch, seqlen) - elif attn_mask_type == AttnMaskType.CAUSAL_MASK: - mask = make_self_mask(batch, seqlen) - - q_pspec = PartitionSpec(mesh_resource.dp_resource, None, mesh_resource.tp_resource, None) - - kv_pspec = PartitionSpec( - mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None - ) - mask_pspec = ( - PartitionSpec(mesh_resource.dp_resource, None, None, None) - if attn_mask_type != AttnMaskType.NO_MASK - else None - ) - - return (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) - @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @pytest.mark.parametrize("data_shape", [[32, 128, 12, 64], [32, 512, 16, 64]]) @pytest.mark.parametrize( @@ -248,11 +167,11 @@ def test_cross_attn( self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype ): attn_bias_type = AttnBiasType.NO_BIAS + bias_shape = None dropout_prob = 0.0 is_training = True - scaling_factor = 1.0 - _, seqlen, num_head, hidden = data_shape + batch, seqlen, num_head, hidden = data_shape if not is_fused_attn_kernel_available( dtype, @@ -270,67 +189,29 @@ def test_cross_attn( ): pytest.skip(f"No FusedAttn backend found") - def target_func(q, kv, mask): - return jnp.mean( - fused_attn( - (q, kv), - None, - mask, - None, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=QKVLayout.BSHD_BS2HD, - scaling_factor=scaling_factor, - dropout_probability=dropout_prob, - is_training=is_training, - ), - dtype=jnp.float32, - ) - - def ref_func(query, kv, mask): - key, value = jnp.split(kv, [1], axis=-3) - query = jnp.squeeze(query) - key = jnp.squeeze(key) - value = jnp.squeeze(value) - - output = dot_product_attention( - query, - key, - value, - bias=None, - mask=mask, - deterministic=is_training, - dropout_rate=dropout_prob, - dropout_rng=None, - dtype=jnp.float32, - ) - - return jnp.mean(output, dtype=jnp.float32) - - (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = self.generate_inputs( - data_shape, mesh_resource, attn_mask_type, dtype + col_ref = self.generate_collectives_count_ref() + runner = FusedAttnRunner( + batch, + seqlen, + seqlen, + num_head, + num_head, + hidden, + attn_bias_type, + attn_mask_type, + dropout_prob, + dtype, + is_training, + QKVLayout.BSHD_BS2HD, + bias_shape, + None, + number_of_devices=device_count, + mesh_shape=mesh_shape, + mesh_axes=mesh_axes, + mesh_resource=mesh_resource, + coll_count_ref=col_ref, ) - collective_count_ref = self.generate_collectives_count_ref() - devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) - mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast(mesh_resource=mesh_resource): - q_ = jax.device_put(q, NamedSharding(mesh, q_pspec)) - kv_ = jax.device_put(kv, NamedSharding(mesh, kv_pspec)) - mask_ = ( - jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask - ) - - compare_ops( - target_func, - ref_func, - [q_, kv_, mask_], - collective_count_ref, - grad_args=(0, 1), - metric_fwd_dtype=dtype, - metric_bwd_dtype=dtype, - in_shardings=(q_pspec, kv_pspec, mask_pspec), - out_shardings=(None, (q_pspec, kv_pspec)), - ) + runner.test_backward() @pytest.mark.parametrize( @@ -366,41 +247,6 @@ def ref_func(query, kv, mask): ) class TestDistributedContextParallelSelfAttn: - def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype): - batch, seqlen, heads, hidden = shape - kv_shape = (batch, seqlen, heads // kv_groups, hidden) - qkey, kkey, vkey = random.split(random.PRNGKey(1124), 3) - q = random.normal(qkey, shape, dtype=dtype) - k = random.normal(kkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype) - v = random.normal(vkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype) - - def gen_valid(bs, max_seqlen, pad_ratio): - pad_len = int(max_seqlen * pad_ratio) - valid_len = max_seqlen - pad_len - tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1) - return tokens, jnp.logical_not(tokens) - - from test_fused_attn import make_mask - - q_idx, _ = gen_valid(batch, seqlen, 0.0) - kv_idx, _ = gen_valid(batch, seqlen, 0.0) - mask = make_mask(q_idx, kv_idx, None, None, attn_mask_type) - - return q, k, v, mask - - def qkv_to_layout(self, q, k, v, qkv_layout): - qkv_args = () - match qkv_layout: - case QKVLayout.BSHD_BS2HD: - k, v = map(partial(jnp.expand_dims, axis=-3), [k, v]) - kv = jnp.concatenate((k, v), axis=-3) - qkv_args = (q, kv) - case QKVLayout.BSHD_BSHD_BSHD: - qkv_args = (q, k, v) - case _: - raise ValueError(f"Unsupported {qkv_layout=}") - return qkv_args - def impl_test_context_parallel_attn( self, device_count, @@ -416,6 +262,7 @@ def impl_test_context_parallel_attn( cp_strategy, ): attn_bias_type = AttnBiasType.NO_BIAS + bias_shape = None dropout_prob = 0.0 is_training = True dp_size, cp_size, tp_size = mesh_shape @@ -431,6 +278,29 @@ def impl_test_context_parallel_attn( num_kv_heads = num_head // kv_groups scaling_factor = 1.0 / np.sqrt(num_head) + runner = FusedAttnRunner( + batch, + seqlen, + seqlen, + num_head, + num_kv_heads, + hidden, + attn_bias_type, + attn_mask_type, + dropout_prob, + dtype, + is_training, + qkv_layout, + bias_shape, + None, + number_of_devices=device_count, + mesh_shape=mesh_shape, + mesh_axes=mesh_axes, + mesh_resource=mesh_resource, + cp_strategy=cp_strategy, + cp_load_balanced=load_balanced, + ) + def check_has_backend_for_mask(mask_type): return is_fused_attn_kernel_available( dtype, @@ -465,123 +335,7 @@ def check_has_backend_for_mask(mask_type): if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0: pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}") - def target_func(q, k, v, mask): - return fused_attn( - self.qkv_to_layout(q, k, v, qkv_layout), - None, # bias - mask, - None, # seed - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_prob, - is_training=is_training, - context_parallel_strategy=cp_strategy, - context_parallel_causal_load_balanced=load_balanced, - context_parallel_axis="cp", - ).astype(dtype) - - def ref_func(q, k, v, mask): - output = general_dot_product_attention( - q, - k, - v, - bias=None, - mask=mask, - deterministic=not is_training, - scale_factor=scaling_factor, - dropout_rate=dropout_prob, - dropout_rng=None, - dtype=jnp.float32, - ) - return output.astype(dtype) - - def grad_func(func, *args, **kwargs): - # Gradient is small, use a gradient multiplier to amplify the gradient - _, max_seq_len, num_heads, _ = data_shape - gradient_multiplier = max_seq_len * num_heads - if attn_mask_type.is_causal(): - gradient_multiplier /= 10 - ret_valid = func(*args, **kwargs) - return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(dtype) - - q, k, v, mask = self.generate_inputs(data_shape, kv_groups, attn_mask_type, dtype) - - diff_argnums = (0, 1, 2) - - # Single GPU (reference) - ref_func_jit = jax.jit( - jax.value_and_grad( - lambda q, k, v, mask: grad_func(ref_func, q, k, v, mask), argnums=diff_argnums - ) - ) - ref_fwd, ref_grads = ref_func_jit(q, k, v, mask) - - # Multi GPU (function under test) - devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) - mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast(mesh_resource=mesh_resource, enabled=False): - qkv_ps = PartitionSpec( - mesh_resource.dp_resource, - mesh_resource.cp_resource, - mesh_resource.tp_resource, - None, - ) - qkv_sharding = NamedSharding(mesh, qkv_ps) - - mask_ps = PartitionSpec( - mesh_resource.dp_resource, None, mesh_resource.cp_resource, None - ) - mask_sharding = NamedSharding(mesh, mask_ps) - - reorder = partial( - reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format - ) - inverse_reorder = partial( - inverse_reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format - ) - - if load_balanced: - q, k, v = jax.tree.map(reorder, (q, k, v)) - - q_, k_, v_ = map(partial(jax.device_put, device=qkv_sharding), [q, k, v]) - mask_ = jax.device_put(mask, device=mask_sharding) - - target_func_jit = jax.jit( - jax.value_and_grad( - lambda q, k, v, mask: grad_func(target_func, q, k, v, mask), - argnums=diff_argnums, - ), - in_shardings=[qkv_sharding, qkv_sharding, qkv_sharding, mask_sharding], - out_shardings=(None, (qkv_sharding, qkv_sharding, qkv_sharding)), - ) - - target_fwd, target_grads = target_func_jit(q_, k_, v_, mask_) - - if load_balanced: - target_dq, target_dk, target_dv = jax.tree.map(inverse_reorder, target_grads[0:3]) - target_grads = (target_dq, target_dk, target_dv, *target_grads[3:]) - - has_diffs = False - - print_debug_tensor_stats("target", target_fwd) - print_debug_tensor_stats("ref", ref_fwd) - print_debug_tensor_stats("diff", jnp.abs(target_fwd - ref_fwd)) - assert_allclose(target_fwd, ref_fwd, dtype=dtype) - - for i in range(len(target_grads)): - if ref_grads[i] is None or target_grads[i] is None: - # expect both none if one is - assert target_grads[i] is None and ref_grads[i] is None - else: - print_debug_tensor_stats(f"target_grad[{i}]", target_grads[i]) - print_debug_tensor_stats(f"ref_grad[{i}]", ref_grads[i]) - print_debug_tensor_stats( - f"diff_grad[{i}]", jnp.abs(target_grads[i] - ref_grads[i]) - ) - - assert_allclose(target_grads[i], ref_grads[i], dtype=dtype) + runner.test_backward() def test_context_parallel_allgather_attn( self, diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 5cbbec7b04..710ae1946d 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -3,10 +3,10 @@ # See LICENSE for license information. """Tests for fused attention""" from enum import Enum -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import partial from math import sqrt -from typing import Tuple, Optional +from typing import Tuple, Optional, Dict import random import jax @@ -19,16 +19,22 @@ from flax.linen.dtypes import promote_dtype from jax import Array from jax import value_and_grad, jit +from jax.sharding import Mesh, NamedSharding, PartitionSpec from jax.typing import ArrayLike, DTypeLike +from transformer_engine.jax import fp8_autocast +from transformer_engine.jax.sharding import MeshResource from transformer_engine.jax.attention import ( AttnBiasType, AttnMaskType, QKVLayout, QKVFormat, + reorder_causal_load_balancing, + inverse_reorder_causal_load_balancing, fused_attn, fused_attn_thd, make_swa_mask, + CPStrategy, ) from transformer_engine.jax.cpp_extensions import FusedAttnHelper from transformer_engine.transformer_engine_jax import ( @@ -36,7 +42,8 @@ get_cudnn_version, ) -from utils import assert_allclose +from distributed_test_base import assert_equal_collectives +from utils import assert_allclose, print_debug_tensor_stats @pytest.fixture(autouse=True, scope="module") @@ -304,6 +311,19 @@ class FusedAttnRunner: bias_shape: BiasShape window_size: Optional[Tuple[int, int]] = None + # Specifies sharding resources for distributed tests + number_of_devices: int = 1 + mesh_shape: tuple[int, ...] = (1, 1, 1) + mesh_axes: tuple[str, ...] = ("dp", "cp", "tp") + mesh_resource: MeshResource = field(default_factory=partial(MeshResource, "dp", "cp", "tp")) + + # Context parallel aux arguments + cp_strategy: CPStrategy = CPStrategy.DEFAULT + cp_load_balanced: bool = True + + # dictionary of expected collective comm bytes + coll_count_ref: Optional[Dict[str, int]] = None + # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases. def _get_max_segments_per_sequence(self): @@ -362,6 +382,14 @@ def _check_configs(self): def _setup_inputs(self): self._check_configs() + + # Create a mesh for distributed tests + self.devices = np.asarray(jax.devices()[: self.number_of_devices]).reshape(*self.mesh_shape) + self.mesh = Mesh(self.devices, self.mesh_axes) + self.dp_size = self.mesh.shape.get(self.mesh_resource.dp_resource, 1) + self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1) + self.tp_size = self.mesh.shape.get(self.mesh_resource.tp_resource, 1) + key = jax.random.PRNGKey(0) q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5) @@ -527,6 +555,66 @@ def generate_random_segment_ids( self.dropout_rng = dropout_key if self.dropout_prob > 0 else None self.scaling_factor = 1.0 / sqrt(self.head_dim) + # Setup distributed sharding specs + # Setup shardings for distributed tests + self.qkvo_psec = PartitionSpec( + self.mesh_resource.dp_resource, + self.mesh_resource.cp_resource, + self.mesh_resource.tp_resource, + None, + ) + self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec) + + self.mask_pspec = PartitionSpec( + self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None + ) + self.mask_sharding = NamedSharding(self.mesh, self.mask_pspec) + + if self.bias_shape == BiasShape._1HSS: + self.bias_pspec = PartitionSpec( + None, self.mesh_resource.tp_resource, self.mesh_resource.cp_resource, None + ) + elif self.bias_shape == BiasShape._B1SS: + self.bias_pspec = PartitionSpec( + self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None + ) + elif self.bias_shape == BiasShape._11SS: + self.bias_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None) + else: + self.bias_pspec = PartitionSpec() + self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec) + + self.dropout_rng_pspec = PartitionSpec( + None, + ) + self.dropout_rng_sharding = NamedSharding(self.mesh, self.dropout_rng_pspec) + + self.logit_scale_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None) + self.logit_scale_sharding = NamedSharding(self.mesh, self.logit_scale_pspec) + + # [batch][max_segments_per_batch] + # TODO(mgoldfarb-nvidia): Will need to handle CP cases of replicated or distributed length/offset. + self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None) + self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec) + + # Softmax aux sharding + + if self.cp_size > 1 and self.cp_load_balanced: + self.cp_reorder_fn = partial( + reorder_causal_load_balancing, + cp_size=self.cp_size, + tensor_format=self.qkv_layout.get_qkv_format(), + ) + self.cp_inverse_reorder_fn = partial( + inverse_reorder_causal_load_balancing, + cp_size=self.cp_size, + tensor_format=self.qkv_layout.get_qkv_format(), + ) + else: + # no-ops for non cp or non load balanced + self.cp_reorder_fn = lambda x: x + self.cp_inverse_reorder_fn = lambda x: x + def test_forward(self): """ Test forward without JIT @@ -534,17 +622,21 @@ def test_forward(self): self._setup_inputs() args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng] + customcall_args = [ - self.q, - self.k, - self.v, - self.bias, - self.mask_for_customcall, - self.seqlens_q, - self.seqlens_kv, - self.offsets_q, - self.offsets_kv, - self.dropout_rng, + # Put test data onto each GPU for distributed. + # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and + # THD params once we support those features on CP. + jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding), + jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding), + jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding), + jax.device_put(self.bias, self.bias_sharding), + jax.device_put(self.mask_for_customcall, self.mask_sharding), + jax.device_put(self.seqlens_q, self.seq_length_offset_sharding), + jax.device_put(self.seqlens_kv, self.seq_length_offset_sharding), + jax.device_put(self.offsets_q, self.seq_length_offset_sharding), + jax.device_put(self.offsets_kv, self.seq_length_offset_sharding), + jax.device_put(self.dropout_rng, self.dropout_rng_sharding), ] kwargs = { "attn_bias_type": self.attn_bias_type, @@ -555,10 +647,31 @@ def test_forward(self): "qkv_layout": self.qkv_layout, "max_segments_per_seq": self._get_max_segments_per_sequence(), "window_size": self.window_size, + "context_parallel_strategy": self.cp_strategy, + "context_parallel_causal_load_balanced": self.cp_load_balanced, } - # Convert the outputs to float32 for the elementwise comparison - primitive_out = customcall_fused_dpa(*customcall_args, **kwargs) + customcall_fused_dpa_jit = jit( + partial(customcall_fused_dpa, **kwargs), + static_argnames=kwargs.keys(), + in_shardings=[ + self.qkvo_sharding, + self.qkvo_sharding, + self.qkvo_sharding, + self.bias_sharding, + self.mask_sharding, + self.seq_length_offset_sharding, + self.seq_length_offset_sharding, + self.seq_length_offset_sharding, + self.seq_length_offset_sharding, + self.dropout_rng_sharding, + ], + ) + + with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + primitive_out = customcall_fused_dpa_jit(*customcall_args) + primitive_out = self.cp_inverse_reorder_fn(primitive_out) + reference_out = jax_dpa(*args, **kwargs) if self.is_training and self.dropout_prob > 0.0: @@ -571,9 +684,19 @@ def test_forward(self): assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype) assert_allclose(primitive_valid, reference_valid, dtype=self.dtype) + if self.coll_count_ref is not None: + with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + target_hlo = ( + customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text() + ) + assert_equal_collectives(target_hlo, self.coll_count_ref) + def test_backward(self): """ - Test value_and_grad with JIT, which includes both forward and backward + Test value_and_grad with JIT, which includes both forward and backward. + + If coll_count_ref is not None then the HLO of the backwrds function + HLO will be examined for the expected comms. """ self._setup_inputs() @@ -587,20 +710,24 @@ def grad_func(func, *args, **kwargs): ret_valid = jnp.where( self.pad_q[..., jnp.newaxis, jnp.newaxis], 0, func(*args, **kwargs) ) - return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(self.dtype) + return ( + jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier + ).astype(self.dtype) args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng] customcall_args = [ - self.q, - self.k, - self.v, - self.bias, - self.mask_for_customcall, - self.seqlens_q, - self.seqlens_kv, - self.offsets_q, - self.offsets_kv, - self.dropout_rng, + # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and + # THD params once we support those features on CP. + jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding), + jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding), + jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding), + jax.device_put(self.bias, self.bias_sharding), + jax.device_put(self.mask_for_customcall, self.mask_sharding), + jax.device_put(self.seqlens_q, self.seq_length_offset_sharding), + jax.device_put(self.seqlens_kv, self.seq_length_offset_sharding), + jax.device_put(self.offsets_q, self.seq_length_offset_sharding), + jax.device_put(self.offsets_kv, self.seq_length_offset_sharding), + jax.device_put(self.dropout_rng, self.dropout_rng_sharding), ] kwargs = { "attn_bias_type": self.attn_bias_type, @@ -611,10 +738,22 @@ def grad_func(func, *args, **kwargs): "qkv_layout": self.qkv_layout, "max_segments_per_seq": self._get_max_segments_per_sequence(), "window_size": self.window_size, + "context_parallel_strategy": self.cp_strategy, + "context_parallel_causal_load_balanced": self.cp_load_balanced, } # We can compute dBias only for the [1, h, s, s] layout - arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape._1HSS else (0, 1, 2) + if self.bias_shape == BiasShape._1HSS: + arg_nums = (0, 1, 2, 3) + grad_shardings = ( + self.qkvo_sharding, + self.qkvo_sharding, + self.qkvo_sharding, + self.bias_sharding, + ) + else: + arg_nums = (0, 1, 2) + grad_shardings = (self.qkvo_sharding, self.qkvo_sharding, self.qkvo_sharding) # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation jitted_primitive = jit( @@ -623,7 +762,20 @@ def grad_func(func, *args, **kwargs): customcall_fused_dpa, q, k, v, bias, *args, **kwargs ), arg_nums, - ) + ), + in_shardings=( + self.qkvo_sharding, + self.qkvo_sharding, + self.qkvo_sharding, + self.bias_sharding, + self.mask_sharding, + self.seq_length_offset_sharding, + self.seq_length_offset_sharding, + self.seq_length_offset_sharding, + self.seq_length_offset_sharding, + self.dropout_rng_sharding, + ), + out_shardings=(None, grad_shardings), ) jitted_reference = jit( value_and_grad( @@ -632,20 +784,31 @@ def grad_func(func, *args, **kwargs): ) ) - primitive_out, primitive_dgrad = jitted_primitive(*customcall_args) + with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + primitive_out, primitive_dgrad = jitted_primitive(*customcall_args) + reference_out, reference_dgrad = jitted_reference(*args) # Skip elementwise comparison when dropout enabled if self.dropout_prob > 0.0: return + print_debug_tensor_stats(f"primitive_out", primitive_out) + print_debug_tensor_stats(f"reference_grad_valid", reference_out) + print_debug_tensor_stats(f"diff_grad", jnp.abs(primitive_out - reference_out)) assert_allclose(primitive_out, reference_out, dtype=self.dtype) - def check_dqkv(primitive, reference, pad): + def check_dqkv(primitive, reference, pad, idx): primitive_valid, primitive_invalid, reference_valid, reference_invalid = ( _split_valid_and_invalid(primitive, reference, pad) ) + print_debug_tensor_stats(f"primitive_grad_valid[{idx}]", primitive_valid[idx]) + print_debug_tensor_stats(f"reference_grad_valid[{idx}]", reference_valid[idx]) + print_debug_tensor_stats( + f"diff_grad[{idx}]", jnp.abs(primitive_valid[idx] - reference_valid[idx]) + ) + assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype) assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype) assert_allclose(primitive_valid, reference_valid, dtype=self.dtype) @@ -653,11 +816,17 @@ def check_dqkv(primitive, reference, pad): primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3] reference_dq, reference_dk, reference_dv = reference_dgrad[:3] - check_dqkv(primitive_dq, reference_dq, self.pad_q) - check_dqkv(primitive_dk, reference_dk, self.pad_kv) - check_dqkv(primitive_dv, reference_dv, self.pad_kv) + primitive_dq = self.cp_inverse_reorder_fn(primitive_dq) + primitive_dk = self.cp_inverse_reorder_fn(primitive_dk) + primitive_dv = self.cp_inverse_reorder_fn(primitive_dv) + + check_dqkv(primitive_dq, reference_dq, self.pad_q, 0) + check_dqkv(primitive_dk, reference_dk, self.pad_kv, 1) + check_dqkv(primitive_dv, reference_dv, self.pad_kv, 2) if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS: + # TODO(mgoldfarb-nvidia): Inverse reorder bias once supported by a CP implementation. + primitive_dbias = primitive_dgrad[3] reference_dbias = reference_dgrad[3] @@ -685,6 +854,11 @@ def check_dqkv(primitive, reference, pad): dtype=self.dtype, ) + if self.coll_count_ref is not None: + with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text() + assert_equal_collectives(target_hlo, self.coll_count_ref) + @pytest.mark.parametrize( "attn_mask_type",