Skip to content

Commit

Permalink
[Ascend] Wx/fix_varlen_flash_attention_on_ascend (#1158)
Browse files Browse the repository at this point in the history
* fix varlen flash attention bug
  • Loading branch information
POI-WX authored Apr 30, 2024
1 parent 132a28c commit b869ad3
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 10 deletions.
3 changes: 2 additions & 1 deletion diopi_test/python/configs/diopi_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8985,7 +8985,8 @@
p_dropout=[0, 0, 0, 0],
is_causal=[True, True, False, True],
softmax_scale=[None, 0.0883, None, 0.125],
max_seqlen=[32, 32, 128, 64],
max_seqlen_q=[32, 32, 128, 64],
max_seqlen_kv=[32, 32, 128, 64],
cu_seqlens_q=[[0, 32], [0, 16, 48, 64], [0, 32, 64, 128, 256], [0, 16, 48, 64, 128]],
cu_seqlens_kv=[[0, 32], [0, 16, 48, 64], [0, 32, 64, 128, 256], [0, 16, 48, 64, 128]],
),
Expand Down
3 changes: 2 additions & 1 deletion diopi_test/python/conformance/customized_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,10 +477,11 @@ def flash_attention_v3(q, k, v, p_dropout, softmax_scale, is_causal):
return output

def flash_attention_varlen(
q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen, p_dropout, softmax_scale, is_causal
q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, p_dropout, softmax_scale, is_causal
):
# Currently, only equality between cu_seqlens_q and cu_seqlens_kv is supported here
cu_seqlens = cu_seqlens_q
max_seqlen = max_seqlen_q
# In order to compare the accuracy with the baseline value, dropout is not used during testing.
batch_size = len(cu_seqlens) - 1
_, head_num, head_dim = q.size()
Expand Down
50 changes: 49 additions & 1 deletion diopi_test/python/conformance/diopi_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5457,7 +5457,7 @@ def flash_attention_v3_backward(q, k, v, out, grad_outputs, p_dropout, softmax_s
check_returncode(ret)
return {'q': grad_q, 'k': grad_k, 'v': grad_v}

def flash_attention_varlen(q, k, v, max_seqlen, cu_seqlens_q, cu_seqlens_kv, p_dropout, softmax_scale, is_causal):
def flash_attention_varlen(q, k, v, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, p_dropout, softmax_scale, is_causal):
call = "diopiFlashAttentionVarLen"
func = check_function(call)
q_size = list(q.size().data)
Expand Down Expand Up @@ -5500,11 +5500,59 @@ def flash_attention_varlen(q, k, v, max_seqlen, cu_seqlens_q, cu_seqlens_kv, p_d
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
p_dropout,
softmax_scale,
is_causal,
)
check_returncode(ret)
GLOBAL_STATE["flash_attention_varlen_attention_mask"] = attention_mask
GLOBAL_STATE["flash_attention_varlen_dropout_mask"] = dropout_mask
GLOBAL_STATE["flash_attention_varlen_softmax_max"] = softmax_max
GLOBAL_STATE["flash_attention_varlen_softmax_sum"] = softmax_sum
GLOBAL_STATE["flash_attention_varlen_softmax_out"] = softmax_out
return out

def flash_attention_varlen_backward(q, k, v, out, grad_outputs, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, p_dropout, softmax_scale, is_causal):
call = "diopiFlashAttentionVarLenBackward"
func = check_function(call)
assert p_dropout >=0 and p_dropout <=1, "The p_dropout value must be in range of [0, 1]"
head_dim = q.shape().data[-1]
softmax_scale = 1.0 / math.sqrt(head_dim) if not softmax_scale else softmax_scale
cu_seqlens_q = Sizes(cu_seqlens_q[1:])
cu_seqlens_kv = Sizes(cu_seqlens_kv[1:])
grad_q = raw_like(q)
grad_k = raw_like(k)
grad_v = raw_like(v)
attention_mask = GLOBAL_STATE.pop("flash_attention_varlen_attention_mask")
dropout_mask = GLOBAL_STATE.pop("flash_attention_varlen_dropout_mask")
softmax_max = GLOBAL_STATE.pop("flash_attention_varlen_softmax_max")
softmax_sum = GLOBAL_STATE.pop("flash_attention_varlen_softmax_sum")
softmax_out = GLOBAL_STATE.pop("flash_attention_varlen_softmax_out")
ret = func(
q.context(),
grad_q,
grad_k,
grad_v,
grad_outputs[0],
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
out,
attention_mask,
dropout_mask,
softmax_max,
softmax_sum,
softmax_out,
max_seqlen_q,
max_seqlen_kv,
p_dropout,
softmax_scale,
)
check_returncode(ret)
return out

def scaled_masked_softmax(input, mask, scale, fixed_triu_mask):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ const int64_t uInt8BitNumber = 8;
diopiError_t diopiFlashAttentionVarLen(diopiContextHandle_t ctx, diopiTensorHandle_t attentionOut, diopiTensorHandle_t* attentionMask,
diopiTensorHandle_t* dropoutMask, diopiTensorHandle_t* softmaxMax, diopiTensorHandle_t* softmaxSum,
diopiTensorHandle_t* softmaxOut, diopiGeneratorHandle_t gen, diopiConstTensorHandle_t q, diopiConstTensorHandle_t k,
diopiConstTensorHandle_t v, diopiSize_t cumSeqQ, diopiSize_t cumSeqKV, double pDropout, double softmaxScale,
bool isCausal) {
diopiConstTensorHandle_t v, diopiSize_t cumSeqQ, diopiSize_t cumSeqKV, int64_t maxSeqLenQ, int64_t maxSeqLenKV,
double pDropout, double softmaxScale, bool isCausal) {
BEGIN_CALL_ACL_OP(q, k, v, cumSeqQ, cumSeqKV, gen, attentionOut);

DIOPI_CHECK(qAt.dim() == 3, "The shapes of the input query should be 3-dimensional");
Expand Down Expand Up @@ -68,7 +68,7 @@ diopiError_t diopiFlashAttentionVarLen(diopiContextHandle_t ctx, diopiTensorHand

at::Tensor attentionMaskAt = at::Tensor();
if (isCausal) {
attentionMaskAt = npu_preparation::apply_tensor_without_format({t, t}, qAt.options().dtype(at::kBool));
attentionMaskAt = npu_preparation::apply_tensor_without_format({maxSeqLenQ, maxSeqLenKV}, qAt.options().dtype(at::kBool));
EXEC_NPU_CMD(aclnnInplaceOne, attentionMaskAt);
int64_t diagonal = 1;
EXEC_NPU_CMD(aclnnInplaceTriu, attentionMaskAt, diagonal);
Expand Down Expand Up @@ -131,7 +131,7 @@ diopiError_t diopiFlashAttentionVarLenBackward(diopiContextHandle_t ctx, diopiTe
diopiConstTensorHandle_t v, diopiSize_t cumSeqQ, diopiSize_t cumSeqKV, diopiConstTensorHandle_t attentionOut,
diopiConstTensorHandle_t attentionMask, diopiConstTensorHandle_t dropoutMask,
diopiConstTensorHandle_t softmaxMax, diopiConstTensorHandle_t softmaxSum, diopiConstTensorHandle_t softmaxOut,
double pDropout, double softmaxScale) {
int64_t maxSeqLenQ, int64_t maxSeqLenKV, double pDropout, double softmaxScale) {
BEGIN_CALL_ACL_OP(q, k, v, cumSeqQ, cumSeqKV, attentionOut, softmaxMax, softmaxSum, softmaxOut, gradQ, gradK, gradV, gradOut);

at::Tensor dropoutMaskAt;
Expand Down
10 changes: 7 additions & 3 deletions proto/include/diopi/functions_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ DIOPI_API diopiError_t diopiFlashAttentionV3Backward(diopiContextHandle_t ctx, d
* @param[in] v Value tensor. shape = [total_v, head_num, head_dim, where total_v = total number of value tokens in the batch. type = [bfloat16, float16].
* @param[in] cum_seq_q The cumulative sequence lengths of the sequences in the batch for query. shape = [batch_size+1].
* @param[in] cum_seq_kv The cumulative sequence lengths of the sequences in the batch for key and value. shape = [batch_size+1].
* @param[in] max_seqlen_q Maximum sequence length for query.
* @param[in] max_seqlen_kv Maximum sequence length for key and value.
* @param[in] p_dropout Dropout probability.
* @param[in] softmax_scale The scaling of qk^T before applying softmax. By default, softmax\_scale=\frac{1}{\sqrt{d_k}}
* @param[in] is_causal Whether to apply causal attention mask.
Expand All @@ -343,7 +345,7 @@ DIOPI_API diopiError_t diopiFlashAttentionVarLen(diopiContextHandle_t ctx, diopi
diopiTensorHandle_t* dropout_mask, diopiTensorHandle_t* softmax_max, diopiTensorHandle_t* softmax_sum,
diopiTensorHandle_t* softmax_out, diopiGeneratorHandle_t gen, diopiConstTensorHandle_t q,
diopiConstTensorHandle_t k, diopiConstTensorHandle_t v, diopiSize_t cum_seq_q, diopiSize_t cum_seq_kv,
double p_dropout, double softmax_scale, bool is_causal);
int64_t max_seqlen_q, int64_t max_seqlen_kv, double p_dropout, double softmax_scale, bool is_causal);

/**
* @brief Compute the backward pass for the variable length version of Flash Attention.
Expand All @@ -360,6 +362,8 @@ DIOPI_API diopiError_t diopiFlashAttentionVarLen(diopiContextHandle_t ctx, diopi
* @param[in] softmax_max Tensor representing the intermediate calculation result of softmax op from the forward pass. type = [float32].
* @param[in] softmax_sum Tensor representing the intermediate calculation result of softmax op from the forward pass. type = [float32].
* @param[in] softmax_out Tensor representing the intermediate calculation result of softmax op from the forward pass. type =[float32].
* @param[in] max_seqlen_q Maximum sequence length for query.
* @param[in] max_seqlen_kv Maximum sequence length for key and value.
* @param[in] p_dropout Dropout probability.
* @param[in] softmax_scale The scaling of qk^T before applying softmax. By default, softmax\_scale=\frac{1}{\sqrt{d_k}}
* @param[out] grad_q The gradient of the query tensor. shape = [total_q, head_num, head_dim], where total_q = total number of query tokens in the batch. type =
Expand All @@ -374,8 +378,8 @@ DIOPI_API diopiError_t diopiFlashAttentionVarLenBackward(diopiContextHandle_t ct
diopiConstTensorHandle_t k, diopiConstTensorHandle_t v, diopiSize_t cum_seq_q, diopiSize_t cum_seq_kv,
diopiConstTensorHandle_t attention_out, diopiConstTensorHandle_t attention_mask,
diopiConstTensorHandle_t dropout_mask, diopiConstTensorHandle_t softmax_max,
diopiConstTensorHandle_t softmax_sum, diopiConstTensorHandle_t softmax_out, double p_dropout,
double softmax_scale);
diopiConstTensorHandle_t softmax_sum, diopiConstTensorHandle_t softmax_out, int64_t max_seqlen_q,
int64_t max_seqlen_kv, double p_dropout, double softmax_scale);

// This interface is temporarily designed for ascend, please do not use it with other devices.
/**
Expand Down

0 comments on commit b869ad3

Please sign in to comment.