Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Update MLA to use triton attention and include MI300X MoE config file #479

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 34 additions & 11 deletions vllm/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
AttentionMetadata,
MLAAttentionImpl, T)
from vllm.attention.backends.utils import get_flash_attn_version
from vllm.attention.ops.triton_flash_attention import triton_attention
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Expand Down Expand Up @@ -163,6 +164,11 @@ def __init__(
kv_b_proj: ColumnParallelLinear,
o_proj: RowParallelLinear,
) -> None:
if envs.VLLM_USE_TRITON_FLASH_ATTN:
assert not alibi_slopes, "Triton MLA doesn't support alibi now!"
assert not envs.VLLM_USE_ROCM_FP8_FLASH_ATTN, (
"Triton MLA doesn't support FP8 flash attention now!")

self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
Expand Down Expand Up @@ -498,17 +504,34 @@ def _forward_prefill_flash(
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)

attn_output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=seq_start_loc,
cu_seqlens_k=seq_start_loc,
max_seqlen_q=max_prefill_seq_len,
max_seqlen_k=max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
if envs.VLLM_USE_TRITON_FLASH_ATTN:
attn_output, _ = triton_attention(
q,
k,
v_padded,
None,
seq_start_loc,
seq_start_loc,
max_prefill_seq_len,
max_prefill_seq_len,
True,
self.scale,
None, # attn_mask is None unless applying ALiBi mask
None, # fp8 scales need additional work to integrate
)
else:
attn_output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=seq_start_loc,
cu_seqlens_k=seq_start_loc,
max_seqlen_q=max_prefill_seq_len,
max_seqlen_k=max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)

attn_output = attn_output\
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
.reshape(-1, self.num_heads * v.shape[-1])
Expand Down
14 changes: 12 additions & 2 deletions vllm/attention/ops/triton_decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ def _decode_att_m_fwd(
logit_cap,
):
BLOCK = 64
if is_hip_:
BLOCK = 8
NUM_KV_SPLITS = num_kv_splits
Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1]
Expand All @@ -189,6 +191,12 @@ def _decode_att_m_fwd(
kv_group_num = q.shape[1] // k_buffer.shape[-2]

num_warps = 4 if kv_group_num == 1 else 2
if kv_group_num == 1:
num_warps = 4
else:
num_warps = 2
if is_hip_:
num_warps = 1

BLOCK_DMODEL = triton.next_power_of_2(Lk)
BLOCK_DV = triton.next_power_of_2(Lv)
Expand Down Expand Up @@ -417,15 +425,17 @@ def _decode_grouped_att_m_fwd(
NUM_KV_SPLITS,
)

num_stages = 2
extra_kargs = {}
if is_hip_:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {
"waves_per_eu": 4,
"waves_per_eu": 1,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
num_stages = 1

_fwd_grouped_kernel_stage1[grid](
q,
Expand Down Expand Up @@ -456,7 +466,7 @@ def _decode_grouped_att_m_fwd(
PAGE_SIZE=page_size,
logit_cap=logit_cap,
num_warps=4,
num_stages=2,
num_stages=num_stages,
Lk=Lk,
Lv=Lv,
**extra_kargs,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
}
}
53 changes: 27 additions & 26 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,19 @@ def fused_moe_kernel_gptq_awq(
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_bse,
stride_bsk,
stride_bsn,
stride_bze,
stride_bzk,
stride_bzn,
stride_am: tl.int64,
stride_ak: tl.int64,
stride_be: tl.int64,
stride_bk: tl.int64,
stride_bn: tl.int64,
stride_cm: tl.int64,
stride_cn: tl.int64,
stride_bse: tl.int64,
stride_bsk: tl.int64,
stride_bsn: tl.int64,
stride_bze: tl.int64,
stride_bzk: tl.int64,
stride_bzn: tl.int64,
block_k_diviable: tl.constexpr,
group_size: tl.constexpr,
# Meta-parameters
Expand Down Expand Up @@ -245,18 +245,18 @@ def fused_moe_kernel(
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
stride_am: tl.int64,
stride_ak: tl.int64,
stride_be: tl.int64,
stride_bk: tl.int64,
stride_bn: tl.int64,
stride_cm: tl.int64,
stride_cn: tl.int64,
stride_asm: tl.int64,
stride_ask: tl.int64,
stride_bse: tl.int64,
stride_bsk: tl.int64,
stride_bsn: tl.int64,
# Block size for block-wise quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
Expand Down Expand Up @@ -1220,7 +1220,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
# so the cache size and config are already set correctly and
# do not need to be adjusted.
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk *
topk_ids.shape[1]]
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
config = get_config_func(tokens_in_chunk)

Expand Down