Skip to content

Commit

Permalink
Merge branch 'NVIDIA:main' into tp_rs_bf16
Browse files Browse the repository at this point in the history
  • Loading branch information
erhoo82 authored Nov 6, 2024
2 parents c4dcd95 + 095b27d commit aad8899
Show file tree
Hide file tree
Showing 40 changed files with 5,095 additions and 1,184 deletions.
1 change: 1 addition & 0 deletions qa/L1_pytorch_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
44 changes: 26 additions & 18 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def test_self_attn(
seqlen,
hidden,
None, # no window
False, # not context parallel
):
pytest.skip(f"No FusedAttn backend found")

Expand Down Expand Up @@ -268,7 +267,6 @@ def test_cross_attn(
seqlen,
hidden,
None, # no window
False, # not context parallel
):
pytest.skip(f"No FusedAttn backend found")

Expand Down Expand Up @@ -425,22 +423,32 @@ def test_contex_parallel_self_attn(
num_kv_heads = num_head // kv_groups
scaling_factor = 1.0 / np.sqrt(num_head)

if not is_fused_attn_kernel_available(
dtype,
dtype,
qkv_layout,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_kv_heads,
seqlen,
seqlen,
hidden,
None, # no window
cp_size > 1,
):
pytest.skip(f"No FusedAttn backend found")
def check_has_backend_for_mask(mask_type):
return is_fused_attn_kernel_available(
dtype,
dtype,
qkv_layout,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_kv_heads,
seqlen,
seqlen,
hidden,
None,
) # no SWA for CP

# For causal masking we depend on having bottom right support also.
# The API does not check this and instead we rely on lower level checks to raise
# and exception if the step backend is not supported. This was a deliberate API
# decision to keep the CP size or flag out of the function.
has_backend = check_has_backend_for_mask(attn_mask_type)
if cp_size > 1 and attn_mask_type == AttnMaskType.CAUSAL_MASK:
has_backend &= check_has_backend_for_mask(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK)

if not has_backend:
pytest.skip(f"No FusedAttn backend found {cp_size=} {attn_mask_type=}.")

if dp_size > 1 and batch % dp_size != 0:
pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}")
Expand Down
Loading

0 comments on commit aad8899

Please sign in to comment.