Skip to content
818 changes: 818 additions & 0 deletions benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton_sep.py

Large diffs are not rendered by default.

14 changes: 12 additions & 2 deletions python/sglang/srt/layers/attention/flashattention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,14 +855,24 @@ def forward_extend(
)
else:
# MHA for extend part of sequence without attending prefix kv cache
cu_seqlens_k = (
metadata.cu_seqlens_q
if not forward_batch.mha_one_shot
else metadata.cu_seqlens_k
)
max_seqlen_k = (
metadata.max_seq_len_q
if not forward_batch.mha_one_shot
else metadata.max_seq_len_k
)
output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=metadata.cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=metadata.max_seq_len_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=layer.scaling,
causal=True,
return_softmax_lse=forward_batch.mha_return_lse,
Expand Down
28 changes: 18 additions & 10 deletions python/sglang/srt/layers/attention/flashinfer_mla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(

# Buffers and wrappers
self.qo_indptr = attn_backend.qo_indptr
self.kv_indptr = attn_backend.kv_indptr
self.workspace_buffer = attn_backend.workspace_buffer
self.fmha_backend = attn_backend.fmha_backend

Expand Down Expand Up @@ -130,9 +131,14 @@ def update_wrapper(
)
# ragged prefill
if not disable_flashinfer_ragged:
kv_indptr = (
qo_indptr
if not forward_batch.mha_one_shot
else self.kv_indptr[: bs + 1]
)
self.ragged_wrapper.begin_forward(
qo_indptr=qo_indptr,
kv_indptr=qo_indptr,
kv_indptr=kv_indptr,
num_qo_heads=self.num_local_heads,
num_kv_heads=self.num_local_heads,
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
Expand All @@ -154,7 +160,7 @@ def forward(
chunk_idx = forward_batch.prefix_chunk_idx
assert chunk_idx >= 0
wrapper = self.chunk_ragged_wrappers[chunk_idx]
o1, s1 = wrapper.forward_return_lse(
o = wrapper.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
Expand All @@ -163,16 +169,20 @@ def forward(
logits_soft_cap=logits_soft_cap,
)
else:
o1, s1 = self.ragged_wrapper.forward_return_lse(
forward = (
self.ragged_wrapper.forward_return_lse
if forward_batch.mha_return_lse
else self.ragged_wrapper.forward
)
o = forward(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)

return o1, s1
return o


class FlashInferMLAAttnBackend(AttentionBackend):
Expand Down Expand Up @@ -510,15 +520,13 @@ def forward_extend(
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
):
if (
forward_batch.attn_attend_prefix_cache is not None
and forward_batch.mha_return_lse
if forward_batch.attn_attend_prefix_cache is not None and any(
forward_batch.extend_prefix_lens_cpu
): # MHA Chunk
assert self.enable_chunk_kv
assert q_rope is None
assert k_rope is None
o1, s1 = self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch)
return o1, s1
return self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch)

cache_loc = forward_batch.out_cache_loc
logits_soft_cap = layer.logit_cap
Expand Down
78 changes: 78 additions & 0 deletions python/sglang/srt/layers/attention/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import triton
import triton.language as tl

Expand Down Expand Up @@ -97,3 +98,80 @@ def create_flashmla_kv_indices_triton(
data // PAGED_SIZE,
mask=mask_out,
)


@triton.jit
def concat_and_cast_mha_k_kernel(
k_ptr,
k_nope_ptr,
k_rope_ptr,
head_cnt: tl.constexpr,
k_stride0: tl.constexpr,
k_stride1: tl.constexpr,
nope_stride0: tl.constexpr,
nope_stride1: tl.constexpr,
rope_stride0: tl.constexpr,
nope_dim: tl.constexpr,
rope_dim: tl.constexpr,
):
pid_loc = tl.program_id(0)
head_range = tl.arange(0, head_cnt)

k_head_ptr = k_ptr + pid_loc * k_stride0 + head_range[:, None] * k_stride1

nope_offs = tl.arange(0, nope_dim)

src_nope_ptr = (
k_nope_ptr
+ pid_loc * nope_stride0
+ head_range[:, None] * nope_stride1
+ nope_offs[None, :]
)
dst_nope_ptr = k_head_ptr + nope_offs[None, :]

src_nope = tl.load(src_nope_ptr)
tl.store(dst_nope_ptr, src_nope)

rope_offs = tl.arange(0, rope_dim)
src_rope_ptr = k_rope_ptr + pid_loc * rope_stride0 + rope_offs[None, :]
dst_rope_ptr = k_head_ptr + nope_dim + rope_offs[None, :]
src_rope = tl.load(src_rope_ptr)
tl.store(dst_rope_ptr, src_rope)


def concat_and_cast_mha_k_triton(
k: torch.Tensor,
k_nope: torch.Tensor,
k_rope: torch.Tensor,
):
# The source data type will be implicitly converted to the target data type.
assert (
len(k.shape) == 3 and len(k_nope.shape) == 3 and len(k_rope.shape) == 3
), f"shape should be 3d, but got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}"
assert (
k.shape[0] == k_nope.shape[0] and k.shape[0] == k_rope.shape[0]
), f"invalid shape, got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}"
assert (
k.shape[1] == k_nope.shape[1] and 1 == k_rope.shape[1]
), f"invalid shape, got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}"
assert (
k.shape[-1] == k_nope.shape[-1] + k_rope.shape[-1]
), f"invalid shape, got {k.shape=}, {k_nope.shape=}, {k_rope.shape=}"

nope_dim = k_nope.shape[-1]
rope_dim = k_rope.shape[-1]
grid = (k.shape[0],)

concat_and_cast_mha_k_kernel[grid](
k,
k_nope,
k_rope,
k.shape[1],
k.stride(0),
k.stride(1),
k_nope.stride(0),
k_nope.stride(1),
k_rope.stride(0),
nope_dim,
rope_dim,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
}
}
Loading