Skip to content

Commit

Permalink
[ascend] Optimise diopiFlashAttentionVarLen (DeepLink-org#1266)
Browse files Browse the repository at this point in the history
* support big attentionMask

* diopiFlashAttentionVarLenBackward support change sparseMode
  • Loading branch information
hellozmz authored Jun 20, 2024
1 parent 8896484 commit 390d7dd
Showing 1 changed file with 13 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,16 @@ diopiError_t diopiFlashAttentionVarLen(diopiContextHandle_t ctx, diopiTensorHand
}
}

int64_t sparseMode = 0;
at::Tensor attentionMaskAt = at::Tensor();
if (isCausal) {
// According to Huawei documentation, when the attentionMask shape is greater than 2048 * 2048, sparseMode=2 can be adjusted to reduce the memory usage:
// https://www.hiascend.com/document/detail/zh/Pytorch/60RC1/apiref/apilist/ptaoplist_000742.html
if (maxSeqLenQ > 2048 && maxSeqLenKV > 2048) {
maxSeqLenQ = 2048;
maxSeqLenKV = 2048;
sparseMode = 2;
}
attentionMaskAt = npu_preparation::apply_tensor_without_format({maxSeqLenQ, maxSeqLenKV}, qAt.options().dtype(at::kBool));
EXEC_NPU_CMD(aclnnInplaceOne, attentionMaskAt);
int64_t diagonal = 1;
Expand All @@ -77,7 +85,6 @@ diopiError_t diopiFlashAttentionVarLen(diopiContextHandle_t ctx, diopiTensorHand
int64_t preTokens = kAt.size(0);
int64_t nextTokens = 0;
int64_t innerPrecise = 0;
int64_t sparseMode = 0;

at::Tensor softmaxMaxAt;
at::Tensor softmaxSumAt;
Expand Down Expand Up @@ -167,6 +174,11 @@ diopiError_t diopiFlashAttentionVarLenBackward(diopiContextHandle_t ctx, diopiTe
int64_t nextTokens = 0;
int64_t innerPrecise = 0;
int64_t sparseMode = 0;
// According to Huawei documentation, when the attentionMask shape is greater than 2048 * 2048, sparseMode=2 can be adjusted to reduce the memory usage:
// https://www.hiascend.com/document/detail/zh/Pytorch/60RC1/apiref/apilist/ptaoplist_000742.html
if (maxSeqLenQ > 2048 && maxSeqLenKV > 2048 && attentionMaskAt.defined() && attentionMaskAt.size(0) == 2048 && attentionMaskAt.size(1) == 2048) {
sparseMode = 2;
}

EXEC_NPU_NO_FORMAT_CHECK_CMD(aclnnFlashAttentionUnpaddingScoreGrad,
qAt,
Expand Down

0 comments on commit 390d7dd

Please sign in to comment.