diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 5fee1b7191..abd8f33ccd 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -11,7 +11,7 @@ Fixes # (issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Infra/Build change -- [ ] Code refractor +- [ ] Code refactoring ## Changes diff --git a/setup.py b/setup.py index 16e988aa88..643dd7a908 100644 --- a/setup.py +++ b/setup.py @@ -64,6 +64,9 @@ def setup_common_extension() -> CMakeExtension: ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" cmake_flags.append("-DNVTE_UB_WITH_MPI=ON") + if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))): + cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON") + # Project directory root root_path = Path(__file__).resolve().parent diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index c2d7039a53..d0ace8263f 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -18,14 +18,22 @@ def generate_configs(): configs = [] if is_devices_enough(2): - configs.append([2, (2,), "dp", MeshResource(dp_resource="dp")]) - configs.append([2, (2,), "tp", MeshResource(tp_resource="tp")]) + configs.append( + pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1") + ) + configs.append( + pytest.param(2, (2,), ("tp",), MeshResource(tp_resource="tp"), id="n2_dp1_tp2") + ) if is_devices_enough(4): - TP_size = 2 - DP_size = 2 configs.append( - [4, (DP_size, TP_size), ("dp", "tp"), MeshResource(dp_resource="dp", tp_resource="tp")] + pytest.param( + 4, + (2, 2), + ("dp", "tp"), + MeshResource(dp_resource="dp", tp_resource="tp"), + id=f"n4_dp2_tp2", + ) ) return configs @@ -33,7 +41,8 @@ def generate_configs(): def generate_context_parallel_configs(): configs = [] - + mr = MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp") + axes = ("dp", "cp", "tp") DP_sizes = (1, 2) CP_sizes = (1, 2, 4, 8) TP_sizes = (1, 2) @@ -41,13 +50,7 @@ def generate_context_parallel_configs(): ndev = cp * tp * dp if is_devices_enough(ndev): configs.append( - pytest.param( - ndev, - (dp, cp, tp), - ("dp", "cp", "tp"), - MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp"), - id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}", - ) + pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}") ) return configs diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 5a41911691..2e15dd4d5d 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -37,8 +37,7 @@ ) from transformer_engine.jax.sharding import MeshResource -# We will use the golden reference model from our non distributed attention test fixture. -from test_fused_attn import general_dot_product_attention, make_mask +from test_fused_attn import FusedAttnRunner, BiasShape, general_dot_product_attention, make_mask DTYPES = [jnp.float16, jnp.bfloat16] @@ -49,7 +48,7 @@ def generate_collectives_count_ref( self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype ): jax_dtype = jax.dtypes.canonicalize_dtype(dtype) - _, seqlen, _, heads, _ = shape + _, seqlen, heads, _ = shape is_dp_enabled = mesh_resource.dp_resource is not None tp_size = 1 if mesh_resource.tp_resource is not None: @@ -62,45 +61,28 @@ def generate_collectives_count_ref( # for loss and dbias return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0) - def generate_inputs(self, shape, mesh_resource, with_bias, attn_mask_type, dtype): - batch, seqlen, _, heads, _ = shape - - qkv = random.normal(random.PRNGKey(1124), shape, dtype=dtype) - - bias = ( - random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype) - if with_bias - else None - ) - - mask = None - if attn_mask_type == AttnMaskType.PADDING_MASK: - mask = make_causal_mask(batch, seqlen) - elif attn_mask_type == AttnMaskType.CAUSAL_MASK: - mask = make_self_mask(batch, seqlen) - - qkv_pspec = PartitionSpec( - mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None - ) - bias_pspec = ( - PartitionSpec(None, mesh_resource.tp_resource, None, None) if with_bias else None - ) - mask_pspec = ( - PartitionSpec(mesh_resource.dp_resource, None, None, None) - if attn_mask_type != AttnMaskType.NO_MASK - else None - ) - - return (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) - @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) - @pytest.mark.parametrize("data_shape", [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]]) @pytest.mark.parametrize( - "attn_bias_type", - [AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS], + "data_shape", + [ + pytest.param((32, 512, 12, 64), id="32-512-12-64"), + pytest.param((32, 1024, 16, 128), id="32-1024-16-128"), + ], ) @pytest.mark.parametrize( - "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK] + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"), + ], + ) + @pytest.mark.parametrize( + "attn_mask_type", + [ + pytest.param(AttnMaskType.PADDING_MASK, id="PADDING_MASK"), + pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"), + ], ) @pytest.mark.parametrize("dtype", DTYPES) def test_self_attn( @@ -111,14 +93,14 @@ def test_self_attn( mesh_resource, data_shape, attn_bias_type, + bias_shape, attn_mask_type, dtype, ): dropout_prob = 0.0 is_training = True - scaling_factor = 1.0 - _, seqlen, _, num_head, hidden = data_shape + batch, seqlen, num_head, hidden = data_shape if not is_fused_attn_kernel_available( dtype, @@ -136,74 +118,36 @@ def test_self_attn( ): pytest.skip(f"No FusedAttn backend found") - def target_func(qkv, bias, mask): - return jnp.mean( - fused_attn( - (qkv,), - bias, - mask, - None, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=QKVLayout.BS3HD, - scaling_factor=scaling_factor, - dropout_probability=dropout_prob, - is_training=is_training, - ) - ) - - def ref_func(qkv, bias, mask): - query, key, value = jnp.split(qkv, [1, 2], axis=-3) - query = jnp.squeeze(query) - key = jnp.squeeze(key) - value = jnp.squeeze(value) - - output = dot_product_attention( - query, - key, - value, - bias=bias, - mask=mask, - deterministic=is_training, - dropout_rate=dropout_prob, - dropout_rng=None, - dtype=jnp.float32, - ) - - return jnp.mean(output).astype(dtype) - - with_bias = attn_bias_type != AttnBiasType.NO_BIAS - (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = self.generate_inputs( - data_shape, mesh_resource, with_bias, attn_mask_type, dtype + col_ref = self.generate_collectives_count_ref( + mesh_shape, + mesh_axes, + mesh_resource, + attn_bias_type != AttnBiasType.NO_BIAS, + data_shape, + dtype, ) - collective_count_ref = self.generate_collectives_count_ref( - mesh_shape, mesh_axes, mesh_resource, with_bias, data_shape, dtype + runner = FusedAttnRunner( + batch, + seqlen, + seqlen, + num_head, + num_head, + hidden, + attn_bias_type, + attn_mask_type, + dropout_prob, + dtype, + is_training, + QKVLayout.BS3HD, + bias_shape, + None, + number_of_devices=device_count, + mesh_shape=mesh_shape, + mesh_axes=mesh_axes, + mesh_resource=mesh_resource, + coll_count_ref=col_ref, ) - devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) - mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast(mesh_resource=mesh_resource): - qkv_ = jax.device_put(qkv, NamedSharding(mesh, qkv_pspec)) - bias_ = ( - jax.device_put(bias, NamedSharding(mesh, bias_pspec)) if bias is not None else bias - ) - mask_ = ( - jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask - ) - - grad_args = (0, 1) if with_bias else (0,) - out_grad_shardings = (qkv_pspec, bias_pspec) if with_bias else (qkv_pspec,) - - compare_ops( - target_func, - ref_func, - [qkv_, bias_, mask_], - collective_count_ref, - grad_args=grad_args, - metric_fwd_dtype=dtype, - metric_bwd_dtype=dtype, - in_shardings=(qkv_pspec, bias_pspec, mask_pspec), - out_shardings=(None, out_grad_shardings), - ) + runner.test_backward() class TestDistributedCrossAttn: @@ -213,31 +157,6 @@ def generate_collectives_count_ref(self): all_reduce_loss_bytes = 4 # 1 * FP32 return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0) - def generate_inputs(self, shape, mesh_resource, attn_mask_type, dtype): - batch, seqlen, heads, hidden = shape - - q = random.normal(random.PRNGKey(1124), shape, dtype=dtype) - kv = random.normal(random.PRNGKey(1125), (batch, seqlen, 2, heads, hidden), dtype=dtype) - - mask = None - if attn_mask_type == AttnMaskType.PADDING_MASK: - mask = make_causal_mask(batch, seqlen) - elif attn_mask_type == AttnMaskType.CAUSAL_MASK: - mask = make_self_mask(batch, seqlen) - - q_pspec = PartitionSpec(mesh_resource.dp_resource, None, mesh_resource.tp_resource, None) - - kv_pspec = PartitionSpec( - mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None - ) - mask_pspec = ( - PartitionSpec(mesh_resource.dp_resource, None, None, None) - if attn_mask_type != AttnMaskType.NO_MASK - else None - ) - - return (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) - @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @pytest.mark.parametrize("data_shape", [[32, 128, 12, 64], [32, 512, 16, 64]]) @pytest.mark.parametrize( @@ -248,11 +167,11 @@ def test_cross_attn( self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype ): attn_bias_type = AttnBiasType.NO_BIAS + bias_shape = None dropout_prob = 0.0 is_training = True - scaling_factor = 1.0 - _, seqlen, num_head, hidden = data_shape + batch, seqlen, num_head, hidden = data_shape if not is_fused_attn_kernel_available( dtype, @@ -270,67 +189,29 @@ def test_cross_attn( ): pytest.skip(f"No FusedAttn backend found") - def target_func(q, kv, mask): - return jnp.mean( - fused_attn( - (q, kv), - None, - mask, - None, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=QKVLayout.BSHD_BS2HD, - scaling_factor=scaling_factor, - dropout_probability=dropout_prob, - is_training=is_training, - ), - dtype=jnp.float32, - ) - - def ref_func(query, kv, mask): - key, value = jnp.split(kv, [1], axis=-3) - query = jnp.squeeze(query) - key = jnp.squeeze(key) - value = jnp.squeeze(value) - - output = dot_product_attention( - query, - key, - value, - bias=None, - mask=mask, - deterministic=is_training, - dropout_rate=dropout_prob, - dropout_rng=None, - dtype=jnp.float32, - ) - - return jnp.mean(output, dtype=jnp.float32) - - (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = self.generate_inputs( - data_shape, mesh_resource, attn_mask_type, dtype + col_ref = self.generate_collectives_count_ref() + runner = FusedAttnRunner( + batch, + seqlen, + seqlen, + num_head, + num_head, + hidden, + attn_bias_type, + attn_mask_type, + dropout_prob, + dtype, + is_training, + QKVLayout.BSHD_BS2HD, + bias_shape, + None, + number_of_devices=device_count, + mesh_shape=mesh_shape, + mesh_axes=mesh_axes, + mesh_resource=mesh_resource, + coll_count_ref=col_ref, ) - collective_count_ref = self.generate_collectives_count_ref() - devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) - mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast(mesh_resource=mesh_resource): - q_ = jax.device_put(q, NamedSharding(mesh, q_pspec)) - kv_ = jax.device_put(kv, NamedSharding(mesh, kv_pspec)) - mask_ = ( - jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask - ) - - compare_ops( - target_func, - ref_func, - [q_, kv_, mask_], - collective_count_ref, - grad_args=(0, 1), - metric_fwd_dtype=dtype, - metric_bwd_dtype=dtype, - in_shardings=(q_pspec, kv_pspec, mask_pspec), - out_shardings=(None, (q_pspec, kv_pspec)), - ) + runner.test_backward() @pytest.mark.parametrize( @@ -366,41 +247,6 @@ def ref_func(query, kv, mask): ) class TestDistributedContextParallelSelfAttn: - def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype): - batch, seqlen, heads, hidden = shape - kv_shape = (batch, seqlen, heads // kv_groups, hidden) - qkey, kkey, vkey = random.split(random.PRNGKey(1124), 3) - q = random.normal(qkey, shape, dtype=dtype) - k = random.normal(kkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype) - v = random.normal(vkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype) - - def gen_valid(bs, max_seqlen, pad_ratio): - pad_len = int(max_seqlen * pad_ratio) - valid_len = max_seqlen - pad_len - tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1) - return tokens, jnp.logical_not(tokens) - - from test_fused_attn import make_mask - - q_idx, _ = gen_valid(batch, seqlen, 0.0) - kv_idx, _ = gen_valid(batch, seqlen, 0.0) - mask = make_mask(q_idx, kv_idx, None, None, attn_mask_type) - - return q, k, v, mask - - def qkv_to_layout(self, q, k, v, qkv_layout): - qkv_args = () - match qkv_layout: - case QKVLayout.BSHD_BS2HD: - k, v = map(partial(jnp.expand_dims, axis=-3), [k, v]) - kv = jnp.concatenate((k, v), axis=-3) - qkv_args = (q, kv) - case QKVLayout.BSHD_BSHD_BSHD: - qkv_args = (q, k, v) - case _: - raise ValueError(f"Unsupported {qkv_layout=}") - return qkv_args - def impl_test_context_parallel_attn( self, device_count, @@ -416,6 +262,7 @@ def impl_test_context_parallel_attn( cp_strategy, ): attn_bias_type = AttnBiasType.NO_BIAS + bias_shape = None dropout_prob = 0.0 is_training = True dp_size, cp_size, tp_size = mesh_shape @@ -431,6 +278,29 @@ def impl_test_context_parallel_attn( num_kv_heads = num_head // kv_groups scaling_factor = 1.0 / np.sqrt(num_head) + runner = FusedAttnRunner( + batch, + seqlen, + seqlen, + num_head, + num_kv_heads, + hidden, + attn_bias_type, + attn_mask_type, + dropout_prob, + dtype, + is_training, + qkv_layout, + bias_shape, + None, + number_of_devices=device_count, + mesh_shape=mesh_shape, + mesh_axes=mesh_axes, + mesh_resource=mesh_resource, + cp_strategy=cp_strategy, + cp_load_balanced=load_balanced, + ) + def check_has_backend_for_mask(mask_type): return is_fused_attn_kernel_available( dtype, @@ -465,123 +335,7 @@ def check_has_backend_for_mask(mask_type): if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0: pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}") - def target_func(q, k, v, mask): - return fused_attn( - self.qkv_to_layout(q, k, v, qkv_layout), - None, # bias - mask, - None, # seed - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_prob, - is_training=is_training, - context_parallel_strategy=cp_strategy, - context_parallel_causal_load_balanced=load_balanced, - context_parallel_axis="cp", - ).astype(dtype) - - def ref_func(q, k, v, mask): - output = general_dot_product_attention( - q, - k, - v, - bias=None, - mask=mask, - deterministic=not is_training, - scale_factor=scaling_factor, - dropout_rate=dropout_prob, - dropout_rng=None, - dtype=jnp.float32, - ) - return output.astype(dtype) - - def grad_func(func, *args, **kwargs): - # Gradient is small, use a gradient multiplier to amplify the gradient - _, max_seq_len, num_heads, _ = data_shape - gradient_multiplier = max_seq_len * num_heads - if attn_mask_type.is_causal(): - gradient_multiplier /= 10 - ret_valid = func(*args, **kwargs) - return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(dtype) - - q, k, v, mask = self.generate_inputs(data_shape, kv_groups, attn_mask_type, dtype) - - diff_argnums = (0, 1, 2) - - # Single GPU (reference) - ref_func_jit = jax.jit( - jax.value_and_grad( - lambda q, k, v, mask: grad_func(ref_func, q, k, v, mask), argnums=diff_argnums - ) - ) - ref_fwd, ref_grads = ref_func_jit(q, k, v, mask) - - # Multi GPU (function under test) - devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) - mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast(mesh_resource=mesh_resource, enabled=False): - qkv_ps = PartitionSpec( - mesh_resource.dp_resource, - mesh_resource.cp_resource, - mesh_resource.tp_resource, - None, - ) - qkv_sharding = NamedSharding(mesh, qkv_ps) - - mask_ps = PartitionSpec( - mesh_resource.dp_resource, None, mesh_resource.cp_resource, None - ) - mask_sharding = NamedSharding(mesh, mask_ps) - - reorder = partial( - reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format - ) - inverse_reorder = partial( - inverse_reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format - ) - - if load_balanced: - q, k, v = jax.tree.map(reorder, (q, k, v)) - - q_, k_, v_ = map(partial(jax.device_put, device=qkv_sharding), [q, k, v]) - mask_ = jax.device_put(mask, device=mask_sharding) - - target_func_jit = jax.jit( - jax.value_and_grad( - lambda q, k, v, mask: grad_func(target_func, q, k, v, mask), - argnums=diff_argnums, - ), - in_shardings=[qkv_sharding, qkv_sharding, qkv_sharding, mask_sharding], - out_shardings=(None, (qkv_sharding, qkv_sharding, qkv_sharding)), - ) - - target_fwd, target_grads = target_func_jit(q_, k_, v_, mask_) - - if load_balanced: - target_dq, target_dk, target_dv = jax.tree.map(inverse_reorder, target_grads[0:3]) - target_grads = (target_dq, target_dk, target_dv, *target_grads[3:]) - - has_diffs = False - - print_debug_tensor_stats("target", target_fwd) - print_debug_tensor_stats("ref", ref_fwd) - print_debug_tensor_stats("diff", jnp.abs(target_fwd - ref_fwd)) - assert_allclose(target_fwd, ref_fwd, dtype=dtype) - - for i in range(len(target_grads)): - if ref_grads[i] is None or target_grads[i] is None: - # expect both none if one is - assert target_grads[i] is None and ref_grads[i] is None - else: - print_debug_tensor_stats(f"target_grad[{i}]", target_grads[i]) - print_debug_tensor_stats(f"ref_grad[{i}]", ref_grads[i]) - print_debug_tensor_stats( - f"diff_grad[{i}]", jnp.abs(target_grads[i] - ref_grads[i]) - ) - - assert_allclose(target_grads[i], ref_grads[i], dtype=dtype) + runner.test_backward() def test_context_parallel_allgather_attn( self, diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 5cbbec7b04..710ae1946d 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -3,10 +3,10 @@ # See LICENSE for license information. """Tests for fused attention""" from enum import Enum -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import partial from math import sqrt -from typing import Tuple, Optional +from typing import Tuple, Optional, Dict import random import jax @@ -19,16 +19,22 @@ from flax.linen.dtypes import promote_dtype from jax import Array from jax import value_and_grad, jit +from jax.sharding import Mesh, NamedSharding, PartitionSpec from jax.typing import ArrayLike, DTypeLike +from transformer_engine.jax import fp8_autocast +from transformer_engine.jax.sharding import MeshResource from transformer_engine.jax.attention import ( AttnBiasType, AttnMaskType, QKVLayout, QKVFormat, + reorder_causal_load_balancing, + inverse_reorder_causal_load_balancing, fused_attn, fused_attn_thd, make_swa_mask, + CPStrategy, ) from transformer_engine.jax.cpp_extensions import FusedAttnHelper from transformer_engine.transformer_engine_jax import ( @@ -36,7 +42,8 @@ get_cudnn_version, ) -from utils import assert_allclose +from distributed_test_base import assert_equal_collectives +from utils import assert_allclose, print_debug_tensor_stats @pytest.fixture(autouse=True, scope="module") @@ -304,6 +311,19 @@ class FusedAttnRunner: bias_shape: BiasShape window_size: Optional[Tuple[int, int]] = None + # Specifies sharding resources for distributed tests + number_of_devices: int = 1 + mesh_shape: tuple[int, ...] = (1, 1, 1) + mesh_axes: tuple[str, ...] = ("dp", "cp", "tp") + mesh_resource: MeshResource = field(default_factory=partial(MeshResource, "dp", "cp", "tp")) + + # Context parallel aux arguments + cp_strategy: CPStrategy = CPStrategy.DEFAULT + cp_load_balanced: bool = True + + # dictionary of expected collective comm bytes + coll_count_ref: Optional[Dict[str, int]] = None + # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases. def _get_max_segments_per_sequence(self): @@ -362,6 +382,14 @@ def _check_configs(self): def _setup_inputs(self): self._check_configs() + + # Create a mesh for distributed tests + self.devices = np.asarray(jax.devices()[: self.number_of_devices]).reshape(*self.mesh_shape) + self.mesh = Mesh(self.devices, self.mesh_axes) + self.dp_size = self.mesh.shape.get(self.mesh_resource.dp_resource, 1) + self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1) + self.tp_size = self.mesh.shape.get(self.mesh_resource.tp_resource, 1) + key = jax.random.PRNGKey(0) q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5) @@ -527,6 +555,66 @@ def generate_random_segment_ids( self.dropout_rng = dropout_key if self.dropout_prob > 0 else None self.scaling_factor = 1.0 / sqrt(self.head_dim) + # Setup distributed sharding specs + # Setup shardings for distributed tests + self.qkvo_psec = PartitionSpec( + self.mesh_resource.dp_resource, + self.mesh_resource.cp_resource, + self.mesh_resource.tp_resource, + None, + ) + self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec) + + self.mask_pspec = PartitionSpec( + self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None + ) + self.mask_sharding = NamedSharding(self.mesh, self.mask_pspec) + + if self.bias_shape == BiasShape._1HSS: + self.bias_pspec = PartitionSpec( + None, self.mesh_resource.tp_resource, self.mesh_resource.cp_resource, None + ) + elif self.bias_shape == BiasShape._B1SS: + self.bias_pspec = PartitionSpec( + self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None + ) + elif self.bias_shape == BiasShape._11SS: + self.bias_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None) + else: + self.bias_pspec = PartitionSpec() + self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec) + + self.dropout_rng_pspec = PartitionSpec( + None, + ) + self.dropout_rng_sharding = NamedSharding(self.mesh, self.dropout_rng_pspec) + + self.logit_scale_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None) + self.logit_scale_sharding = NamedSharding(self.mesh, self.logit_scale_pspec) + + # [batch][max_segments_per_batch] + # TODO(mgoldfarb-nvidia): Will need to handle CP cases of replicated or distributed length/offset. + self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None) + self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec) + + # Softmax aux sharding + + if self.cp_size > 1 and self.cp_load_balanced: + self.cp_reorder_fn = partial( + reorder_causal_load_balancing, + cp_size=self.cp_size, + tensor_format=self.qkv_layout.get_qkv_format(), + ) + self.cp_inverse_reorder_fn = partial( + inverse_reorder_causal_load_balancing, + cp_size=self.cp_size, + tensor_format=self.qkv_layout.get_qkv_format(), + ) + else: + # no-ops for non cp or non load balanced + self.cp_reorder_fn = lambda x: x + self.cp_inverse_reorder_fn = lambda x: x + def test_forward(self): """ Test forward without JIT @@ -534,17 +622,21 @@ def test_forward(self): self._setup_inputs() args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng] + customcall_args = [ - self.q, - self.k, - self.v, - self.bias, - self.mask_for_customcall, - self.seqlens_q, - self.seqlens_kv, - self.offsets_q, - self.offsets_kv, - self.dropout_rng, + # Put test data onto each GPU for distributed. + # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and + # THD params once we support those features on CP. + jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding), + jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding), + jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding), + jax.device_put(self.bias, self.bias_sharding), + jax.device_put(self.mask_for_customcall, self.mask_sharding), + jax.device_put(self.seqlens_q, self.seq_length_offset_sharding), + jax.device_put(self.seqlens_kv, self.seq_length_offset_sharding), + jax.device_put(self.offsets_q, self.seq_length_offset_sharding), + jax.device_put(self.offsets_kv, self.seq_length_offset_sharding), + jax.device_put(self.dropout_rng, self.dropout_rng_sharding), ] kwargs = { "attn_bias_type": self.attn_bias_type, @@ -555,10 +647,31 @@ def test_forward(self): "qkv_layout": self.qkv_layout, "max_segments_per_seq": self._get_max_segments_per_sequence(), "window_size": self.window_size, + "context_parallel_strategy": self.cp_strategy, + "context_parallel_causal_load_balanced": self.cp_load_balanced, } - # Convert the outputs to float32 for the elementwise comparison - primitive_out = customcall_fused_dpa(*customcall_args, **kwargs) + customcall_fused_dpa_jit = jit( + partial(customcall_fused_dpa, **kwargs), + static_argnames=kwargs.keys(), + in_shardings=[ + self.qkvo_sharding, + self.qkvo_sharding, + self.qkvo_sharding, + self.bias_sharding, + self.mask_sharding, + self.seq_length_offset_sharding, + self.seq_length_offset_sharding, + self.seq_length_offset_sharding, + self.seq_length_offset_sharding, + self.dropout_rng_sharding, + ], + ) + + with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + primitive_out = customcall_fused_dpa_jit(*customcall_args) + primitive_out = self.cp_inverse_reorder_fn(primitive_out) + reference_out = jax_dpa(*args, **kwargs) if self.is_training and self.dropout_prob > 0.0: @@ -571,9 +684,19 @@ def test_forward(self): assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype) assert_allclose(primitive_valid, reference_valid, dtype=self.dtype) + if self.coll_count_ref is not None: + with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + target_hlo = ( + customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text() + ) + assert_equal_collectives(target_hlo, self.coll_count_ref) + def test_backward(self): """ - Test value_and_grad with JIT, which includes both forward and backward + Test value_and_grad with JIT, which includes both forward and backward. + + If coll_count_ref is not None then the HLO of the backwrds function + HLO will be examined for the expected comms. """ self._setup_inputs() @@ -587,20 +710,24 @@ def grad_func(func, *args, **kwargs): ret_valid = jnp.where( self.pad_q[..., jnp.newaxis, jnp.newaxis], 0, func(*args, **kwargs) ) - return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(self.dtype) + return ( + jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier + ).astype(self.dtype) args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng] customcall_args = [ - self.q, - self.k, - self.v, - self.bias, - self.mask_for_customcall, - self.seqlens_q, - self.seqlens_kv, - self.offsets_q, - self.offsets_kv, - self.dropout_rng, + # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and + # THD params once we support those features on CP. + jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding), + jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding), + jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding), + jax.device_put(self.bias, self.bias_sharding), + jax.device_put(self.mask_for_customcall, self.mask_sharding), + jax.device_put(self.seqlens_q, self.seq_length_offset_sharding), + jax.device_put(self.seqlens_kv, self.seq_length_offset_sharding), + jax.device_put(self.offsets_q, self.seq_length_offset_sharding), + jax.device_put(self.offsets_kv, self.seq_length_offset_sharding), + jax.device_put(self.dropout_rng, self.dropout_rng_sharding), ] kwargs = { "attn_bias_type": self.attn_bias_type, @@ -611,10 +738,22 @@ def grad_func(func, *args, **kwargs): "qkv_layout": self.qkv_layout, "max_segments_per_seq": self._get_max_segments_per_sequence(), "window_size": self.window_size, + "context_parallel_strategy": self.cp_strategy, + "context_parallel_causal_load_balanced": self.cp_load_balanced, } # We can compute dBias only for the [1, h, s, s] layout - arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape._1HSS else (0, 1, 2) + if self.bias_shape == BiasShape._1HSS: + arg_nums = (0, 1, 2, 3) + grad_shardings = ( + self.qkvo_sharding, + self.qkvo_sharding, + self.qkvo_sharding, + self.bias_sharding, + ) + else: + arg_nums = (0, 1, 2) + grad_shardings = (self.qkvo_sharding, self.qkvo_sharding, self.qkvo_sharding) # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation jitted_primitive = jit( @@ -623,7 +762,20 @@ def grad_func(func, *args, **kwargs): customcall_fused_dpa, q, k, v, bias, *args, **kwargs ), arg_nums, - ) + ), + in_shardings=( + self.qkvo_sharding, + self.qkvo_sharding, + self.qkvo_sharding, + self.bias_sharding, + self.mask_sharding, + self.seq_length_offset_sharding, + self.seq_length_offset_sharding, + self.seq_length_offset_sharding, + self.seq_length_offset_sharding, + self.dropout_rng_sharding, + ), + out_shardings=(None, grad_shardings), ) jitted_reference = jit( value_and_grad( @@ -632,20 +784,31 @@ def grad_func(func, *args, **kwargs): ) ) - primitive_out, primitive_dgrad = jitted_primitive(*customcall_args) + with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + primitive_out, primitive_dgrad = jitted_primitive(*customcall_args) + reference_out, reference_dgrad = jitted_reference(*args) # Skip elementwise comparison when dropout enabled if self.dropout_prob > 0.0: return + print_debug_tensor_stats(f"primitive_out", primitive_out) + print_debug_tensor_stats(f"reference_grad_valid", reference_out) + print_debug_tensor_stats(f"diff_grad", jnp.abs(primitive_out - reference_out)) assert_allclose(primitive_out, reference_out, dtype=self.dtype) - def check_dqkv(primitive, reference, pad): + def check_dqkv(primitive, reference, pad, idx): primitive_valid, primitive_invalid, reference_valid, reference_invalid = ( _split_valid_and_invalid(primitive, reference, pad) ) + print_debug_tensor_stats(f"primitive_grad_valid[{idx}]", primitive_valid[idx]) + print_debug_tensor_stats(f"reference_grad_valid[{idx}]", reference_valid[idx]) + print_debug_tensor_stats( + f"diff_grad[{idx}]", jnp.abs(primitive_valid[idx] - reference_valid[idx]) + ) + assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype) assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype) assert_allclose(primitive_valid, reference_valid, dtype=self.dtype) @@ -653,11 +816,17 @@ def check_dqkv(primitive, reference, pad): primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3] reference_dq, reference_dk, reference_dv = reference_dgrad[:3] - check_dqkv(primitive_dq, reference_dq, self.pad_q) - check_dqkv(primitive_dk, reference_dk, self.pad_kv) - check_dqkv(primitive_dv, reference_dv, self.pad_kv) + primitive_dq = self.cp_inverse_reorder_fn(primitive_dq) + primitive_dk = self.cp_inverse_reorder_fn(primitive_dk) + primitive_dv = self.cp_inverse_reorder_fn(primitive_dv) + + check_dqkv(primitive_dq, reference_dq, self.pad_q, 0) + check_dqkv(primitive_dk, reference_dk, self.pad_kv, 1) + check_dqkv(primitive_dv, reference_dv, self.pad_kv, 2) if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS: + # TODO(mgoldfarb-nvidia): Inverse reorder bias once supported by a CP implementation. + primitive_dbias = primitive_dgrad[3] reference_dbias = reference_dgrad[3] @@ -685,6 +854,11 @@ def check_dqkv(primitive, reference, pad): dtype=self.dtype, ) + if self.coll_count_ref is not None: + with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text() + assert_equal_collectives(target_hlo, self.coll_count_ref) + @pytest.mark.parametrize( "attn_mask_type", diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index e49174c24f..5a67bd616a 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -51,15 +51,23 @@ def _get_layer_args(config, tp_group, tp_size, reference=False): kwargs["ub_overlap_ag"] = not reference if config.layer_type is te.Linear: - input_shape[2] = hidden_size // tp_size - args.append(hidden_size) - kwargs["parallel_mode"] = "row" - kwargs["ub_overlap_rs"] = not reference - kwargs["ub_name"] = "proj" + if config.linear_parallel_mode == "row": + input_shape[2] = hidden_size // tp_size + args.append(hidden_size) + kwargs["ub_overlap_rs"] = not reference + elif config.linear_parallel_mode == "column": + input_shape[0] = config.seq_length // tp_size + args.append(3 * hidden_size) + kwargs["ub_overlap_rs"] = config.overlap_rs_dgrad and not reference + kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference + kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference + kwargs["parallel_mode"] = config.linear_parallel_mode + kwargs["ub_name"] = "proj" if config.linear_parallel_mode == "row" else "qkv" else: input_shape[0] = config.seq_length // tp_size - kwargs["ub_bulk_wgrad"] = not reference - kwargs["ub_bulk_dgrad"] = not reference + kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference + kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference + kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference if config.layer_type is te.LayerNormLinear: args.append(3 * hidden_size) kwargs["parallel_mode"] = "column" @@ -125,6 +133,19 @@ def _parse_args(argv=None, namespace=None): parser.add_argument( "--use-cuda-graphs", action="store_true", default=False, help="Use CUDA Graphs." ) + parser.add_argument( + "--linear-parallel-mode", + type=str.lower, + default="row", + choices=["row", "column"], + help="Parallel mode for te.Linear.", + ) + parser.add_argument( + "--overlap-rs-dgrad", + action="store_true", + default=False, + help="Overlap reduce-scatter with DGRAD in the backward pass instead of bulk overlaps.", + ) parser.add_argument( "--debug", action="store_true", @@ -230,12 +251,19 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs") # Intialize userbuffers + ub_cfgs = None + if opts.overlap_rs_dgrad: + ub_cfgs = { + "proj_dgrad": {"method": "ring_exchange"}, + "qkv_dgrad": {"method": "ring_exchange"}, + } te.module.base.initialize_ub( [opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim], WORLD_SIZE, use_fp8=opts.fp8, dtype=torch.bfloat16, bootstrap_backend=opts.bootstrap_backend, + ub_cfgs=ub_cfgs, ) # Initialize the Transformer Engine layer with overlap @@ -314,27 +342,29 @@ def run_fwd_bwd(model, x): ref_grads.append(ref_param.grad) # Make sure we have the same number of gradients - numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda") + num_grads_failed = torch.tensor([0], dtype=torch.uint8, device="cuda") if len(test_grads) != len(ref_grads): - numerics_failed[0] = 1 + num_grads_failed[0] = 1 numerics_info = ( "NUMERICAL CHECK FAILED: Incorrect number of gradients, " + f"expected {len(ref_grads)} but got {len(test_grads)}." ) dist_print(numerics_info, src=WORLD_RANK, error=True) - dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) + dist.all_reduce(num_grads_failed, dist.ReduceOp.MAX, nccl_world) # Now validate accuracy - if not bool(numerics_failed.item()): + numerics_failed = torch.zeros(len(test_grads), dtype=torch.uint8, device="cuda") + if not bool(num_grads_failed.item()): for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)): rtol = 0.125 if opts.fp8 else 0.025 atol = 0.0625 if opts.fp8 else 0.00125 grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol) dist_print(grad_info, src=WORLD_RANK, error=grad_failed) - numerics_failed[0] = int(grad_failed) - dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) - if bool(numerics_failed.item()): - break + numerics_failed[i] = int(grad_failed) + return_code = torch.max(numerics_failed) + dist.all_reduce(return_code, dist.ReduceOp.MAX, nccl_world) + else: + return_code = num_grads_failed te.module.base.destroy_ub() dist_print("Destroying Userbuffers objects...", debug=True) @@ -344,7 +374,7 @@ def run_fwd_bwd(model, x): if opts.debug and WORLD_RANK == 0: print("Exiting...\n", end="", flush=True) - return numerics_failed[0].item() + return return_code.item() if __name__ == "__main__": diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index 240e396534..c285da7fbd 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -21,8 +21,10 @@ BATCH_SIZE: int = 2 NUM_HEADS: int = 12 HEAD_DIM: int = 64 + +# NOTE: te.Linear is intentionally omitted here and manually added later for testing both +# row and column parallel layouts. TE_LAYERS = [ - te.Linear, te.LayerNormLinear, te.LayerNormMLP, te.MultiheadAttention, @@ -86,7 +88,7 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggreg raise AssertionError(result.stderr.decode()) -def _run_layer_with_overlap(layer_type, fp8, fp8_init): +def _run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init): test_path = TEST_ROOT / "run_layer_with_overlap.py" test_cmd = LAUNCH_CMD + [ str(test_path), @@ -97,6 +99,8 @@ def _run_layer_with_overlap(layer_type, fp8, fp8_init): f"--head-dim={HEAD_DIM}", f"--layer-type={layer_type}", ] + if layer_type == te.Linear.__name__: + test_cmd.append(f"--linear-parallel-mode={linear_parallel_mode}") if fp8: if not fp8_available: @@ -245,9 +249,15 @@ def test_bulk_overlaps(comm_type, fp8, connections): @pytest.mark.parametrize( - "layer_type", - [layer.__name__ for layer in TE_LAYERS], - ids=[(" " + layer.__name__ + " ") for layer in TE_LAYERS], + "layer_type,linear_parallel_mode", + ( + [(te.Linear.__name__, "row"), (te.Linear.__name__, "column")] + + list(zip([layer.__name__ for layer in TE_LAYERS], [None for _ in range(len(TE_LAYERS))])) + ), + ids=( + [f" {te.Linear.__name__} (row-parallel) ", f" {te.Linear.__name__} (column-parallel) "] + + [(" " + layer.__name__ + " ") for layer in TE_LAYERS] + ), ) @pytest.mark.parametrize( "fp8,fp8_init", @@ -262,8 +272,8 @@ def test_bulk_overlaps(comm_type, fp8, connections): " FP8 GEMM - FP8 PARAMS ", ], ) -def test_layers_with_overlap(layer_type, fp8, fp8_init): +def test_layers_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init): """ Test Transformer Engine layers with comm+GEMM overlap. """ - _run_layer_with_overlap(layer_type, fp8, fp8_init) + _run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 3efe116105..3afddcc48d 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -147,6 +147,14 @@ set_source_files_properties(fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu PROPERTIES COMPILE_OPTIONS "--use_fast_math") +option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF) +if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) + set_source_files_properties(activation/gelu.cu + activation/relu.cu + activation/swiglu.cu + PROPERTIES + COMPILE_OPTIONS "--use_fast_math") +endif() set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 5fd4dd2fc9..5893c4ea3c 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -3,6 +3,8 @@ # See LICENSE for license information. """Linear API""" +from functools import reduce +from operator import mul as multiply_op from typing import Any, Callable, Dict, Optional, Tuple, Union import torch @@ -43,7 +45,7 @@ fp8_cast_transpose_fused, cast_to_fp8, ) -from ..constants import GemmParallelModes, dist_group_type +from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..jit import no_torch_dynamo from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor @@ -80,8 +82,12 @@ def forward( activation_dtype: torch.dtype, parallel_mode: Union[str, None], is_grad_enabled: bool, - ub_overlap_rs: bool, - ub_overlap_ag: bool, + ub_overlap_rs_fprop: bool, + ub_overlap_ag_dgrad: bool, + ub_overlap_ag_fprop: bool, + ub_overlap_rs_dgrad: bool, + ub_bulk_dgrad: bool, + ub_bulk_wgrad: bool, ub_name: str, fp8_output: bool, fsdp_group: Union[dist_group_type, None], @@ -99,7 +105,8 @@ def forward( assert_dim_for_fp8_exec(weight) tp_world_size = get_distributed_world_size(tp_group) - ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs + ub_overlap_ag_fprop = False if tp_world_size == 1 else ub_overlap_ag_fprop + ub_overlap_rs_fprop = False if tp_world_size == 1 else ub_overlap_rs_fprop # Cast input to expected dtype inputmat = cast_if_needed(inputmat, activation_dtype) @@ -150,10 +157,11 @@ def forward( inputmat_scale_inv.fill_(inputmat_scale_inv.item()) # Column Parallel Linear - if parallel_mode == "column" and sequence_parallel: + if parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop: inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) else: inputmat_total = inputmat + if fp8: bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype bias = cast_if_needed(bias, bias_dtype) if use_bias else bias @@ -165,75 +173,92 @@ def forward( assert isinstance(weight_fp8, Float8Tensor) if fp8_output: - proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( + out_index, meta_tensor, out_tedtype, out_pttype = ( tex.FP8FwdTensors.GEMM1_OUTPUT, fp8_meta["scaling_fwd"], fp8_dtype_forward, torch.uint8, ) else: - proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( + out_index, meta_tensor, out_tedtype, out_pttype = ( None, None, None, activation_dtype, ) + ub_obj = None ub_algo = None rs_out = None - if ub_overlap_rs: - ub_obj_projout = get_ub(ub_name + "_fprop") - out = ub_obj_projout.get_ubuf_output(1) + inputmat_data = ( + inputmat_total._data if isinstance(inputmat_total, Float8Tensor) else inputmat_total + ) + if ub_overlap_rs_fprop: + ub_obj = get_ub(ub_name + "_fprop") + out = ub_obj.get_ubuf_output(1) dim_size = list(inputmat_total.size()) dim_size[0] = dim_size[0] // tp_world_size dim_size[1] = out_features rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) - if ub_obj_projout.is_p2p_overlap(): - if ub_obj_projout.is_atomic_gemm(): + if ub_obj.is_p2p_overlap(): + if ub_obj.is_atomic_gemm(): ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P else: ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: - if ub_obj_projout.is_atomic_gemm(): + if ub_obj.is_atomic_gemm(): ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS else: ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - if ub_obj_projout.is_fp8_ubuf(): - proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT + if ub_obj.is_fp8_ubuf(): + out_index = tex.FP8FwdTensors.GEMM1_OUTPUT meta_tensor = fp8_meta["scaling_fwd"] - proj_out_tetype = fp8_dtype_forward - proj_out_pttype = torch.uint8 - ub_obj_projout.set_ubuf_scale_inv(meta_tensor.scale_inv[proj_out_index]) + out_tedtype = fp8_dtype_forward + out_pttype = torch.uint8 + ub_obj.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) + + elif ub_overlap_ag_fprop: + ub_obj = get_ub(ub_name + "_fprop") + assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM requires FP8 buffer." + ub_obj.copy_input_to_ubuf(inputmat_data, True) + ub_obj.set_ubuf_scale_inv(inputmat_scale_inv) + if ub_obj.is_atomic_gemm(): + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P + else: + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P + out_tedtype = TE_DType[activation_dtype] + out_pttype = activation_dtype + dim_size = list(inputmat_total.size()) + dim_size[0] *= tp_size + dim_size[1] = out_features + out = torch.empty(dim_size, dtype=out_pttype, device=inputmat_total.device) + else: dim_size = list(inputmat_total.size()) dim_size[1] = out_features - out = torch.empty(dim_size, dtype=proj_out_pttype, device=inputmat_total.device) + out = torch.empty(dim_size, dtype=out_pttype, device=inputmat_total.device) _ = fp8_gemm( weight_fp8._data, weight_fp8._scale_inv, 0, weight_fp8._fp8_dtype, - ( - inputmat_total._data - if isinstance(inputmat_total, Float8Tensor) - else inputmat_total - ), + inputmat_data, inputmat_scale_inv, 0, fp8_dtype_forward, - proj_out_pttype, + out_pttype, get_workspace(), bias=bias, use_bias=use_bias, use_split_accumulator=_2X_ACC_FPROP, out=out, - ub_algo=ub_algo if ub_overlap_rs else None, - ub=ub_obj_projout if ub_overlap_rs else None, - extra_output_tensor=rs_out if ub_overlap_rs else None, - out_index=proj_out_index, + ub_algo=ub_algo, + ub=ub_obj, + extra_output_tensor=rs_out, + out_index=out_index, fp8_meta_tensor=meta_tensor, - D_dtype=proj_out_tetype, + D_dtype=out_tedtype, ) if fp8_output: out = Float8Tensor( @@ -261,17 +286,30 @@ def forward( -amin, amax ).float() - if ub_overlap_rs: - ub_obj_projout = get_ub(ub_name + "_fprop") - out = ub_obj_projout.get_ubuf_output(1) + ub_obj = None + ub_algo = None + rs_out = None + if ub_overlap_rs_fprop: + ub_obj = get_ub(ub_name + "_fprop") + out = ub_obj.get_ubuf_output(1) dim_size = list(inputmat_total.size()) - dim_size[0] = dim_size[0] // get_distributed_world_size(tp_group) + dim_size[0] = dim_size[0] // tp_world_size dim_size[1] = out_features rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) - if ub_obj_projout.is_p2p_overlap(): + if ub_obj.is_p2p_overlap(): ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS + + elif ub_overlap_ag_fprop: + ub_obj = get_ub(ub_name + "_fprop") + ub_obj.copy_input_to_ubuf(inputmat_total, True) + dim_size = list(inputmat_total.size()) + dim_size[0] *= tp_size # all-gathered sequence length + dim_size[1] = out_features + out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P + else: dim_size = list(inputmat_total.size()) dim_size[1] = out_features @@ -285,9 +323,9 @@ def forward( bias=bias, use_bias=use_bias, out=out, - ub_algo=ub_algo if ub_overlap_rs else None, - ub=ub_obj_projout if ub_overlap_rs else None, - extra_output_tensor=rs_out if ub_overlap_rs else None, + ub_algo=ub_algo, + ub=ub_obj, + extra_output_tensor=rs_out, ) if is_grad_enabled: @@ -343,7 +381,10 @@ def forward( ctx.inp_shape = inp_shape ctx.parallel_mode = parallel_mode ctx.tp_group = tp_group - ctx.ub_overlap_ag = ub_overlap_ag + ctx.ub_overlap_ag = ub_overlap_ag_dgrad + ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad + ctx.ub_bulk_dgrad = ub_bulk_dgrad + ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_name = ub_name ctx.tp_size = tp_size ctx.requires_dgrad = inp.requires_grad @@ -356,12 +397,13 @@ def forward( FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module # Row Parallel Linear - if ub_overlap_rs: - out = rs_out - elif parallel_mode == "row" and sequence_parallel: - out, _ = reduce_scatter_along_first_dim(out, tp_group) - elif parallel_mode == "row" and tensor_parallel: - out, _ = allreduce(out, tp_group) + if parallel_mode == "row": + if ub_overlap_rs_fprop: + out = rs_out + elif sequence_parallel: + out, _ = reduce_scatter_along_first_dim(out, tp_group) + elif tensor_parallel: + out, _ = allreduce(out, tp_group) # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp_shape[1:-1], out_features) @@ -401,15 +443,68 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], tp_world_size = get_distributed_world_size(ctx.tp_group) ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag - ub_algo = None + ctx.ub_overlap_rs_dgrad = False if tp_world_size == 1 else ctx.ub_overlap_rs_dgrad + ctx.ub_bulk_dgrad = False if tp_world_size == 1 else ctx.ub_bulk_dgrad + ctx.ub_bulk_wgrad = False if tp_world_size == 1 else ctx.ub_bulk_wgrad + + ctx.ub_obj_gradout = None + ub_obj_wgrad = None + ub_algo_wgrad = None + ub_algo_dgrad = None + rs_out = None + dgrad = None + dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: - dim_size = list(grad_output.size()) - dim_size[0] = dim_size[0] * tp_world_size + # Overlap grad_output all-gather with dgrad compute ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") if ctx.ub_obj_gradout.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P + ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P + ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P + dgrad = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device + ) + + elif ctx.ub_overlap_rs_dgrad: + # Overlap dgrad reduce-scatter with dgrad compute + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + dgrad = ctx.ub_obj_gradout.get_ubuf_output(1) + if ctx.ub_obj_gradout.is_p2p_overlap(): + if ctx.ub_obj_gradout.is_atomic_gemm(): + ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P + else: + ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else: + if ctx.ub_obj_gradout.is_atomic_gemm(): + ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_RS + else: + ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS + rs_out = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device + ) + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + + else: + if ctx.ub_bulk_dgrad: + # Overlap inputmat all-gather with dgrad compute + ub_algo_dgrad = tex.CommOverlapAlgo.BULK_OVERLAP_AG + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + inputmat_data = ( + inputmat._data if isinstance(inputmat, Float8Tensor) else inputmat + ) + ctx.ub_obj_gradout.copy_input_to_ubuf(inputmat_data, True) + inputmat_ubuf = ctx.ub_obj_gradout.get_ubuf_output(1) + if isinstance(inputmat, Float8Tensor): + inputmat._data = inputmat_ubuf + else: + inputmat = inputmat_ubuf + + if ctx.ub_bulk_wgrad: + # Overlap dgrad reduce-scatter with wgrad compute + ub_algo_wgrad = tex.CommOverlapAlgo.BULK_OVERLAP_RS + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") + dgrad = ub_obj_wgrad.get_ubuf_output(1) ( grad_output, @@ -420,13 +515,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ctx, grad_output, ctx.parallel_mode == "row" ) - # Column Parallel Linear - # Overlap input AG with dgrad + # Overlap inputmat AG with dgrad via NCCL async comms (no TP overlap via Userbuffers) inputmat_total = None inputmat_t_total = None - handle = None - if weight.requires_grad and ctx.parallel_mode == "column" and ctx.sequence_parallel: - inputmat_total, handle = gather_along_first_dim( + inputmat_gather_handle = None + if ( + weight.requires_grad + and ctx.parallel_mode == "column" + and ctx.sequence_parallel + and not ctx.ub_bulk_dgrad + ): + inputmat_total, inputmat_gather_handle = gather_along_first_dim( inputmat, ctx.tp_group, async_op=ctx.requires_dgrad ) else: @@ -444,15 +543,20 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) + output_dtype = ctx.activation_dtype if ctx.requires_dgrad: if ctx.fp8: - if ctx.is_input_fp8: + if ctx.is_input_fp8 or ( + ctx.ub_overlap_rs_dgrad and ctx.ub_obj_gradout.is_fp8_ubuf() + ): out_index, meta_tensor, output_te_dtype, output_dtype = ( tex.FP8BwdTensors.GRAD_INPUT1, ctx.fp8_meta["scaling_bwd"], fp8_dtype_backward, torch.uint8, ) + if ctx.ub_overlap_rs_dgrad and ctx.ub_obj_gradout.is_fp8_ubuf(): + ctx.ub_obj_gradout.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) else: out_index, meta_tensor, output_te_dtype, output_dtype = ( None, @@ -460,7 +564,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, ctx.activation_dtype, ) - dgrad, _ = fp8_gemm( + + if dgrad is None: + if ctx.parallel_mode == "column" and ctx.sequence_parallel: + dgrad_shape[0] = dgrad_shape[0] * tp_world_size + dgrad = torch.empty(dgrad_shape, dtype=output_dtype, device=grad_output.device) + + if ctx.requires_dgrad: + if ctx.fp8: + _ = fp8_gemm( weight_fp8.transpose_2d(), weight_fp8._scale_inv, 0, @@ -472,13 +584,18 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], output_dtype, get_workspace(), use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=ub_algo if ctx.ub_overlap_ag else None, - ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, + ub_algo=ub_algo_dgrad, + ub=ctx.ub_obj_gradout, + out=dgrad, out_index=out_index, fp8_meta_tensor=meta_tensor, D_dtype=output_te_dtype, + extra_output_tensor=rs_out, ) - if output_dtype == torch.uint8: + + if ctx.ub_overlap_rs_dgrad: + dgrad = rs_out + elif output_dtype == torch.uint8: dgrad = Float8Tensor( data=dgrad, fp8_meta=ctx.fp8_meta, @@ -488,30 +605,34 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, ) else: - dgrad, _, _ = gemm( + _ = gemm( weight, grad_output, ctx.activation_dtype, get_workspace(), layout="NN", grad=True, - ub_algo=( - tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P - if ctx.ub_overlap_ag - else None - ), - ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, + ub_algo=ub_algo_dgrad, + ub=ctx.ub_obj_gradout, + out=dgrad, + extra_output_tensor=rs_out, ) - # Overlap dgrad-RS/AR with wgrad - if ctx.parallel_mode == "column" and ctx.sequence_parallel: - if handle is not None: - handle.wait() - dgrad, handle = reduce_scatter_along_first_dim( + if ctx.ub_overlap_rs_dgrad: + dgrad = rs_out + + if inputmat_gather_handle is not None: + inputmat_gather_handle.wait() + + # Overlap dgrad RS/AR with wgrad via NCCL async comms (no TP overlap via Userbuffers) + dgrad_reduce_handle = None + if ctx.requires_dgrad and ctx.parallel_mode == "column": + if ctx.sequence_parallel and not (ctx.ub_overlap_rs_dgrad or ctx.ub_bulk_wgrad): + dgrad, dgrad_reduce_handle = reduce_scatter_along_first_dim( dgrad, ctx.tp_group, async_op=True ) - elif ctx.parallel_mode == "column" and ctx.tensor_parallel: - dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) + elif ctx.tensor_parallel and not ctx.sequence_parallel: + dgrad, dgrad_reduce_handle = allreduce(dgrad, ctx.tp_group, async_op=True) wgrad = None if weight.requires_grad: @@ -548,6 +669,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=_2X_ACC_WGRAD, + ub=ub_obj_wgrad, + ub_algo=ub_algo_wgrad, ) else: wgrad, _, _ = gemm( @@ -559,6 +682,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad=True, accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, + ub=ub_obj_wgrad, + ub_algo=ub_algo_wgrad, ) else: # WGRAD @@ -572,15 +697,20 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_bias=ctx.use_bias, accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, + ub=ub_obj_wgrad, + ub_algo=ub_algo_wgrad, ) + if ctx.ub_bulk_wgrad: + dgrad = ub_obj_wgrad.get_ubuf_output(0) + # Deallocate input tensor clear_tensor_data(inputmat_total) clear_tensor_data(inputmat_t_total) - # Column Parallel Linear - if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None: - handle.wait() + # Wait for dgrad reduce-scatter or all-reduce + if dgrad_reduce_handle is not None: + dgrad_reduce_handle.wait() if not ctx.use_bias: grad_bias = None @@ -634,8 +764,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, # activation_dtype None, # parallel_mode None, # is_grad_enabled - None, # ub_overlap_rs - None, # ub_overlap_ag + None, # ub_overlap_rs_fprop + None, # ub_overlap_ag_dgrad + None, # ub_overlap_ag_fprop + None, # ub_overlap_rs_dgrad + None, # ub_bulk_dgrad + None, # ub_bulk_wgrad None, # ub_name None, # fp8_output None, # fsdp_group @@ -729,8 +863,10 @@ def __init__( parallel_mode: Optional[str] = None, parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, device: Union[torch.device, str] = "cuda", - ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, + ub_overlap_rs: bool = False, + ub_bulk_dgrad: bool = False, + ub_bulk_wgrad: bool = False, ub_name: Optional[str] = None, ) -> None: super().__init__() @@ -742,13 +878,6 @@ def __init__( self.use_bias = bias self.return_bias = return_bias self.apply_bias = bias and not return_bias - self.ub_overlap_rs = ub_overlap_rs - self.ub_overlap_ag = ub_overlap_ag - if ub_overlap_rs or ub_overlap_ag: - assert ub_name is not None, "Userbuffer name [string] is not set." - self.ub_name = ub_name - self.get_rng_state_tracker = get_rng_state_tracker - self.rng_tracker_name = rng_tracker_name if device == "meta": assert parameters_split is None, "Cannot split module parameters on 'meta' device." @@ -773,6 +902,45 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel + # Column parallel TP overlap options + self.ub_overlap_ag_fprop = parallel_mode == "column" and sequence_parallel and ub_overlap_ag + self.ub_overlap_rs_dgrad = parallel_mode == "column" and sequence_parallel and ub_overlap_rs + self.ub_bulk_dgrad = parallel_mode == "column" and sequence_parallel and ub_bulk_dgrad + self.ub_bulk_wgrad = parallel_mode == "column" and sequence_parallel and ub_bulk_wgrad + if self.ub_overlap_rs_dgrad: + self.ub_bulk_dgrad = False + self.ub_bulk_wgrad = False + + # Row parallel TP overlap options + self.ub_overlap_rs_fprop = parallel_mode == "row" and sequence_parallel and ub_overlap_rs + self.ub_overlap_ag_dgrad = parallel_mode == "row" and sequence_parallel and ub_overlap_ag + + if any( + [ + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, + ] + ): + assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized." + self.ub_name = ub_name + + assert not ( + self.ub_overlap_rs_fprop and self.ub_overlap_ag_fprop + ), "Cannot enable AG+GEMM and GEMM+RS overlaps at the same time." + assert not ( + self.ub_overlap_rs_dgrad and self.ub_bulk_dgrad + ), "Cannot enable DGRAD+RS and bulk DGRAD overlaps at the same time." + assert not ( + self.ub_overlap_ag_dgrad and (self.ub_overlap_rs_dgrad or self.ub_bulk_dgrad) + ), "Cannot enable AG+DGRAD and DGRAD+RS or bulk DGRAD overlaps at the same time." + + self.get_rng_state_tracker = get_rng_state_tracker + self.rng_tracker_name = rng_tracker_name + # Initialize params in FP8 with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() @@ -1017,8 +1185,12 @@ def forward( self.activation_dtype, self.parallel_mode, torch.is_grad_enabled(), - self.ub_overlap_rs, - self.ub_overlap_ag, + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, self.ub_name, fp8_output, self.fsdp_group,