Skip to content

question about fp8 version of context_flashattention_nopad.py #479

Open
@changyuanzhangchina

Description

@changyuanzhangchina

context_flashattention_nopad_fp16_fp8.txt

we have implemented a f8 version of context_flashattention_nopad.py. the v shape needs to be changed for performance improvement described in https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html. however, the current result is not correct, could you help us?

@triton.jit
def _fwd_kernel_fp8(
Q,
K,
V,
B_Loc,
sm_scale,
B_Start_Loc,
B_Seqlen,
B_Ctxlen,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
num_queries_per_kv: int,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, # head size
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
BLOCK_N: tl.constexpr,
SLIDING_WINDOW: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)

cur_kv_head = cur_head // num_queries_per_kv

cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)  #当前batch的seq len
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) #当前batch的start index
cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len

# start position inside of the query
# generally, N goes over kv, while M goes over query_len
block_start_loc = BLOCK_M * start_m

# initialize offsets
# [N]; starts at 0
offs_n = tl.arange(0, BLOCK_N)
# [D]; starts at 0
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
# [M]; starts at current position in query
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# [M,D]
off_q = (
    (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
    cur_head * stride_qh + offs_d[None, :] * stride_qd)

dim_mask = tl.where(
    offs_d < BLOCK_DMODEL, 1,
    0).to(tl.int1)  # [D]

#??? mask=dim_mask[None, :] &
q = tl.load(Q + off_q,
            mask=(offs_m[:, None] < cur_batch_query_len),
            other=0.0)  # [M,D]

# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")  # [M]
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)  # [M]
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED],
                dtype=tl.float32)  # [M,D]

#whether v is fp8
v_fp8  = True if V.dtype.element_ty == tl.float8e5 else False


off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
            offs_d[:, None] * stride_kd)


## about vshape refer to https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
if v_fp8:
    off_v = (offs_n[None, :] * stride_vbs + cur_kv_head * stride_vh +
            offs_d[:, None] * stride_vd)
else:
    off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
            offs_d[None, :] * stride_vd)
    
k_ptrs = K + off_k
v_ptrs = V + off_v

# block_mask is 0 when we're already past the current query length
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
block_end_loc = tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len)

# compute query against itself (with causal mask)
for start_n in range(0, block_mask * block_end_loc, BLOCK_N):
    start_n = tl.multiple_of(start_n, BLOCK_N)
    # -- compute qk ----
    k = tl.load(k_ptrs +
                (cur_batch_in_all_start_index + start_n) * stride_kbs,
                mask=((start_n + offs_n[None, :]) < block_end_loc),
                other=0.0)

    qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
    qk += tl.dot(q, k)
    qk *= sm_scale
    # apply causal mask
    qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
                    float("-inf"))
    if SLIDING_WINDOW > 0:
        qk = tl.where(
            offs_m[:, None] -
            (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, -10000)

    # -- compute m_ij, p, l_ij
    m_ij = tl.max(qk, 1)
    p = tl.exp(qk - m_ij[:, None])
    l_ij = tl.sum(p, 1)
    # -- update m_i and l_i
    m_i_new = tl.maximum(m_i, m_ij)
    alpha = tl.exp(m_i - m_i_new)
    beta = tl.exp(m_ij - m_i_new)
    l_i_new = alpha * l_i + beta * l_ij
    # -- update output accumulator --
    # scale p
    p_scale = beta / l_i_new
    p = p * p_scale[:, None]
    # scale acc
    acc_scale = l_i / l_i_new * alpha
    acc_scale = tl.where(offs_m >= start_n, acc_scale, 1.0)
    acc = acc * acc_scale[:, None]
    # update acc
    ## about vshape refer to https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
    if v_fp8:
        v = tl.load(v_ptrs +
                    (cur_batch_in_all_start_index + start_n) * stride_vbs,
                    mask=((start_n + offs_n[None, :]) < block_end_loc),
                    other=0.0)
    else:
        v = tl.load(v_ptrs +
                    (cur_batch_in_all_start_index + start_n) * stride_vbs,
                    mask=((start_n + offs_n[:, None]) < block_end_loc),
                    other=0.0)   

    p = p.to(v.dtype)
    acc += tl.dot(p, v)
    # update m_i and l_i
    l_i = l_i_new
    m_i = m_i_new
# initialize pointers to output
off_o = (
    (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
    cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
            acc.to(tl.float16),
            mask=(offs_m[:, None] < cur_batch_query_len))
            
return

@torch.inference_mode()
def context_attention_fwd_fp8(q,
k,
v,
o,
b_loc,
b_start_loc,
b_seq_len,
b_ctx_len,
max_input_len,
alibi_slopes=None,
sliding_window=None):

cap = current_platform.get_device_capability()
BLOCK = 128 if cap[0] >= 8 else 64

# need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory
if q.dtype is torch.float32:
    BLOCK = BLOCK // 2

# shape constraints head_size
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
# round up Lk to a power of 2 - this is required for Triton block size
Lk_padded = triton.next_power_of_2(Lk)
#print("Lk Lk_padded", Lk, Lk_padded)

sm_scale = 1.0 / (Lq**0.5)
#batch and num_query_head num_queries_per_kv
batch, head = b_seq_len.shape[0], q.shape[1]
num_queries_per_kv = q.shape[1] // k.shape[1]

grid = (batch, head, triton.cdiv(max_input_len, BLOCK))  # batch, num_query_head,

# 0 means "disable"
if sliding_window is None or sliding_window <= 0:
    sliding_window = 0

num_warps = 8 if Lk <= 64 else 8

#qkv to  fp8 
q = q.to(torch.float8_e5m2)  #e5m2
k = k.to(torch.float8_e5m2)
#[num_tokens, num_heads, head_size] to [num_tokens, num_heads, head_size]
#change v shape
v = v.permute(2, 1, 0).contiguous()
v = v.permute(2, 1, 0)
v = v.to(torch.float8_e5m2)

print("v.shape", v.shape)
print("v.stride", v.stride(0), v.stride(1), v.stride(2))

_fwd_kernel_fp8[grid](
    q,
    k,
    v,
    b_loc,
    sm_scale,
    b_start_loc,
    b_seq_len,
    b_ctx_len,
    o,
    b_loc.stride(0),
    b_loc.stride(1),
    q.stride(0),
    q.stride(1),
    q.stride(2),
    k.stride(0),
    k.stride(1),
    k.stride(2),
    v.stride(0),
    v.stride(1),
    v.stride(2),
    o.stride(0),
    o.stride(1),
    o.stride(2),
    num_queries_per_kv=num_queries_per_kv,
    BLOCK_M=BLOCK,
    BLOCK_DMODEL=Lk,
    BLOCK_DMODEL_PADDED=Lk_padded,
    BLOCK_N=BLOCK,
    SLIDING_WINDOW=sliding_window,
    num_warps=num_warps,
    num_stages=1,
)
return

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions