Skip to content

Commit de93f40

Browse files
authored
[CUDA] Lean Attention (microsoft#22352)
### Description Add [Lean Attention](https://arxiv.org/abs/2405.10480) and the integration with MultiHeadAttention operator for LLM in GPU. LeanAttention speeds up self-attention for the token-generation phase (decode-phase) of decoder-only transformer models, especially on long context lengths. - [x] Initial implementation of Lean Attention (by Srikant Bharadwaj) - [x] Integration with MultiHeadAttention operator - [x] Add parity tests - [x] Add benchmark #### Implementation Details (1) Lean Attention is enabled in build for Linux, and disabled for Windows (2) Lean Attention is disabled by default. Need enable it through cuda provider option sdpa_kernel, or use environment variable `ORT_ENABLE_LEAN_ATTENTION=1` (3) It only works for token-generation (sequence_length==1, past_sequence_length > 0). (4) Like flash attention, it only works in Ampere or newer GPU. We can revisit #1 and #2 after comparing with DecoderMaskedMultiHeadAttention and XQA kernels. #### Benchmark ``` cd onnxruntime/test/python/transformers /bin/bash benchmark_mha.sh lean ``` Example outputs in H100: Note that past and present does not share buffer for MHA for now, so we can see low tflops. The relative ratio will change after buffer sharing is enabled. But we expect that the order (kernel A is faster than B) will remain the same after buffer sharing is enabled. Note that common settings `sequence_length=1; causal=True;attn_bias=None;cuda_graph=False` are not shown in the below table. batch_size | past_sequence_length | num_heads | head_size | average_latency | tflops | kernel -- | -- | -- | -- | -- | -- | -- 1 | 512 | 16 | 64 | 0.000059 | 0.0178 | ort:flash 1 | 512 | 16 | 64 | 0.000068 | 0.0155 | ort:efficient 1 | 512 | 16 | 64 | 0.000065 | 0.0161 | ort:math 1 | 512 | 16 | 64 | 0.000060 | 0.0176 | ort:lean 1 | 512 | 32 | 128 | 0.000062 | 0.0674 | ort:flash 1 | 512 | 32 | 128 | 0.000064 | 0.0661 | ort:efficient 1 | 512 | 32 | 128 | 0.000067 | 0.0625 | ort:math 1 | 512 | 32 | 128 | 0.000062 | 0.0678 | ort:lean 1 | 1024 | 16 | 64 | 0.000061 | 0.0345 | ort:flash 1 | 1024 | 16 | 64 | 0.000086 | 0.0244 | ort:efficient 1 | 1024 | 16 | 64 | 0.000065 | 0.0322 | ort:math 1 | 1024 | 16 | 64 | 0.000063 | 0.0332 | ort:lean 1 | 1024 | 32 | 128 | 0.000075 | 0.1125 | ort:flash 1 | 1024 | 32 | 128 | 0.000088 | 0.0951 | ort:efficient 1 | 1024 | 32 | 128 | 0.000079 | 0.1068 | ort:math 1 | 1024 | 32 | 128 | 0.000072 | 0.1171 | ort:lean 1 | 2048 | 16 | 64 | 0.000069 | 0.0606 | ort:flash 1 | 2048 | 16 | 64 | 0.000125 | 0.0336 | ort:efficient 1 | 2048 | 16 | 64 | 0.000064 | 0.0655 | ort:lean 1 | 2048 | 32 | 128 | 0.000098 | 0.1720 | ort:flash 1 | 2048 | 32 | 128 | 0.000132 | 0.1270 | ort:efficient 1 | 2048 | 32 | 128 | 0.000092 | 0.1828 | ort:lean 1 | 4096 | 16 | 64 | 0.000076 | 0.1097 | ort:flash 1 | 4096 | 16 | 64 | 0.000207 | 0.0406 | ort:efficient 1 | 4096 | 16 | 64 | 0.000069 | 0.1209 | ort:lean 1 | 4096 | 32 | 128 | 0.000140 | 0.2394 | ort:flash 1 | 4096 | 32 | 128 | 0.000213 | 0.1575 | ort:efficient 1 | 4096 | 32 | 128 | 0.000139 | 0.2419 | ort:lean 1 | 8192 | 16 | 64 | 0.000104 | 0.1609 | ort:flash 1 | 8192 | 16 | 64 | 0.000392 | 0.0428 | ort:efficient 1 | 8192 | 16 | 64 | 0.000093 | 0.1809 | ort:lean 1 | 8192 | 32 | 128 | 0.000212 | 0.3160 | ort:flash 1 | 8192 | 32 | 128 | 0.000360 | 0.1866 | ort:efficient 1 | 8192 | 32 | 128 | 0.000212 | 0.3162 | ort:lean 1 | 16384 | 16 | 64 | 0.000139 | 0.2410 | ort:flash 1 | 16384 | 16 | 64 | 0.000731 | 0.0459 | ort:efficient 1 | 16384 | 16 | 64 | 0.000136 | 0.2465 | ort:lean 1 | 16384 | 32 | 128 | 0.000361 | 0.3722 | ort:flash 1 | 16384 | 32 | 128 | 0.000667 | 0.2014 | ort:efficient 1 | 16384 | 32 | 128 | 0.000357 | 0.3765 | ort:lean 1 | 32768 | 16 | 64 | 0.000210 | 0.3194 | ort:flash 1 | 32768 | 16 | 64 | 0.001428 | 0.0470 | ort:efficient 1 | 32768 | 16 | 64 | 0.000209 | 0.3211 | ort:lean 1 | 32768 | 32 | 128 | 0.000659 | 0.4074 | ort:flash 1 | 32768 | 32 | 128 | 0.001270 | 0.2114 | ort:efficient 1 | 32768 | 32 | 128 | 0.000651 | 0.4123 | ort:lean 1 | 65536 | 16 | 64 | 0.000355 | 0.3785 | ort:flash 1 | 65536 | 16 | 64 | 0.002736 | 0.0491 | ort:efficient 1 | 65536 | 16 | 64 | 0.000349 | 0.3845 | ort:lean 1 | 65536 | 32 | 128 | 0.001251 | 0.4290 | ort:flash 1 | 65536 | 32 | 128 | 0.002480 | 0.2165 | ort:efficient 1 | 65536 | 32 | 128 | 0.001239 | 0.4333 | ort:lean 4 | 512 | 16 | 64 | 0.000063 | 0.0665 | ort:flash 4 | 512 | 16 | 64 | 0.000069 | 0.0607 | ort:efficient 4 | 512 | 16 | 64 | 0.000066 | 0.0634 | ort:math 4 | 512 | 16 | 64 | 0.000062 | 0.0674 | ort:lean 4 | 512 | 32 | 128 | 0.000100 | 0.1677 | ort:flash 4 | 512 | 32 | 128 | 0.000099 | 0.1703 | ort:efficient 4 | 512 | 32 | 128 | 0.000108 | 0.1557 | ort:math 4 | 512 | 32 | 128 | 0.000092 | 0.1818 | ort:lean 4 | 1024 | 16 | 64 | 0.000077 | 0.1094 | ort:flash 4 | 1024 | 16 | 64 | 0.000099 | 0.0850 | ort:efficient 4 | 1024 | 16 | 64 | 0.000081 | 0.1038 | ort:math 4 | 1024 | 16 | 64 | 0.000072 | 0.1161 | ort:lean 4 | 1024 | 32 | 128 | 0.000143 | 0.2343 | ort:flash 4 | 1024 | 32 | 128 | 0.000137 | 0.2447 | ort:efficient 4 | 1024 | 32 | 128 | 0.000150 | 0.2245 | ort:math 4 | 1024 | 32 | 128 | 0.000135 | 0.2496 | ort:lean 4 | 2048 | 16 | 64 | 0.000096 | 0.1757 | ort:flash 4 | 2048 | 16 | 64 | 0.000156 | 0.1078 | ort:efficient 4 | 2048 | 16 | 64 | 0.000089 | 0.1892 | ort:lean 4 | 2048 | 32 | 128 | 0.000223 | 0.3010 | ort:flash 4 | 2048 | 32 | 128 | 0.000217 | 0.3101 | ort:efficient 4 | 2048 | 32 | 128 | 0.000209 | 0.3209 | ort:lean 4 | 4096 | 16 | 64 | 0.000137 | 0.2448 | ort:flash 4 | 4096 | 16 | 64 | 0.000256 | 0.1312 | ort:efficient 4 | 4096 | 16 | 64 | 0.000133 | 0.2530 | ort:lean 4 | 4096 | 32 | 128 | 0.000389 | 0.3450 | ort:flash 4 | 4096 | 32 | 128 | 0.000376 | 0.3574 | ort:efficient 4 | 4096 | 32 | 128 | 0.000354 | 0.3794 | ort:lean 4 | 8192 | 16 | 64 | 0.000210 | 0.3198 | ort:flash 4 | 8192 | 16 | 64 | 0.000453 | 0.1480 | ort:efficient 4 | 8192 | 16 | 64 | 0.000206 | 0.3260 | ort:lean 4 | 8192 | 32 | 128 | 0.000725 | 0.3705 | ort:flash 4 | 8192 | 32 | 128 | 0.000693 | 0.3874 | ort:efficient 4 | 8192 | 32 | 128 | 0.000653 | 0.4114 | ort:lean 4 | 16384 | 16 | 64 | 0.000355 | 0.3782 | ort:flash 4 | 16384 | 16 | 64 | 0.000849 | 0.1581 | ort:efficient 4 | 16384 | 16 | 64 | 0.000346 | 0.3874 | ort:lean 4 | 16384 | 32 | 128 | 0.001395 | 0.3848 | ort:flash 4 | 16384 | 32 | 128 | 0.001337 | 0.4017 | ort:efficient 4 | 16384 | 32 | 128 | 0.001252 | 0.4288 | ort:lean 4 | 32768 | 16 | 64 | 0.000647 | 0.4146 | ort:flash 4 | 32768 | 16 | 64 | 0.001649 | 0.1628 | ort:efficient 4 | 32768 | 16 | 64 | 0.000639 | 0.4204 | ort:lean 4 | 32768 | 32 | 128 | 0.002721 | 0.3947 | ort:flash 4 | 32768 | 32 | 128 | 0.002601 | 0.4128 | ort:efficient 4 | 32768 | 32 | 128 | 0.002434 | 0.4411 | ort:lean 4 | 65536 | 16 | 64 | 0.001231 | 0.4361 | ort:flash 4 | 65536 | 16 | 64 | 0.003238 | 0.1658 | ort:efficient 4 | 65536 | 16 | 64 | 0.001217 | 0.4412 | ort:lean 4 | 65536 | 32 | 128 | 0.005357 | 0.4009 | ort:flash 4 | 65536 | 32 | 128 | 0.005118 | 0.4196 | ort:efficient 4 | 65536 | 32 | 128 | 0.004781 | 0.4492 | ort:lean 16 | 512 | 16 | 64 | 0.000098 | 0.1724 | ort:flash 16 | 512 | 16 | 64 | 0.000104 | 0.1616 | ort:efficient 16 | 512 | 16 | 64 | 0.000118 | 0.1420 | ort:math 16 | 512 | 16 | 64 | 0.000087 | 0.1926 | ort:lean 16 | 512 | 32 | 128 | 0.000220 | 0.3062 | ort:flash 16 | 512 | 32 | 128 | 0.000208 | 0.3237 | ort:efficient 16 | 512 | 32 | 128 | 0.000237 | 0.2838 | ort:math 16 | 512 | 32 | 128 | 0.000209 | 0.3216 | ort:lean 16 | 1024 | 16 | 64 | 0.000136 | 0.2465 | ort:flash 16 | 1024 | 16 | 64 | 0.000150 | 0.2235 | ort:efficient 16 | 1024 | 16 | 64 | 0.000148 | 0.2266 | ort:math 16 | 1024 | 16 | 64 | 0.000129 | 0.2611 | ort:lean 16 | 1024 | 32 | 128 | 0.000367 | 0.3663 | ort:flash 16 | 1024 | 32 | 128 | 0.000351 | 0.3829 | ort:efficient 16 | 1024 | 32 | 128 | 0.000400 | 0.3357 | ort:math 16 | 1024 | 32 | 128 | 0.000349 | 0.3853 | ort:lean 16 | 2048 | 16 | 64 | 0.000209 | 0.3206 | ort:flash 16 | 2048 | 16 | 64 | 0.000243 | 0.2762 | ort:efficient 16 | 2048 | 16 | 64 | 0.000201 | 0.3338 | ort:lean 16 | 2048 | 32 | 128 | 0.000671 | 0.4002 | ort:flash 16 | 2048 | 32 | 128 | 0.000645 | 0.4163 | ort:efficient 16 | 2048 | 32 | 128 | 0.000642 | 0.4185 | ort:lean 16 | 4096 | 16 | 64 | 0.000360 | 0.3732 | ort:flash 16 | 4096 | 16 | 64 | 0.000425 | 0.3162 | ort:efficient 16 | 4096 | 16 | 64 | 0.000341 | 0.3933 | ort:lean 16 | 4096 | 32 | 128 | 0.001292 | 0.4156 | ort:flash 16 | 4096 | 32 | 128 | 0.001251 | 0.4291 | ort:efficient 16 | 4096 | 32 | 128 | 0.001241 | 0.4327 | ort:lean 16 | 8192 | 16 | 64 | 0.000666 | 0.4030 | ort:flash 16 | 8192 | 16 | 64 | 0.000804 | 0.3339 | ort:efficient 16 | 8192 | 16 | 64 | 0.000627 | 0.4283 | ort:lean 16 | 8192 | 32 | 128 | 0.002541 | 0.4226 | ort:flash 16 | 8192 | 32 | 128 | 0.002454 | 0.4376 | ort:efficient 16 | 8192 | 32 | 128 | 0.002438 | 0.4405 | ort:lean 16 | 16384 | 16 | 64 | 0.001292 | 0.4156 | ort:flash 16 | 16384 | 16 | 64 | 0.001571 | 0.3417 | ort:efficient 16 | 16384 | 16 | 64 | 0.001217 | 0.4411 | ort:lean 16 | 16384 | 32 | 128 | 0.005042 | 0.4260 | ort:flash 16 | 16384 | 32 | 128 | 0.004859 | 0.4420 | ort:efficient 16 | 16384 | 32 | 128 | 0.004827 | 0.4449 | ort:lean 16 | 32768 | 16 | 64 | 0.002537 | 0.4233 | ort:flash 16 | 32768 | 16 | 64 | 0.003103 | 0.3461 | ort:efficient 16 | 32768 | 16 | 64 | 0.002385 | 0.4501 | ort:lean 16 | 32768 | 32 | 128 | 0.009961 | 0.4312 | ort:flash 16 | 32768 | 32 | 128 | 0.009605 | 0.4472 | ort:efficient 16 | 32768 | 32 | 128 | 0.009524 | 0.4510 | ort:lean 16 | 65536 | 16 | 64 | 0.005019 | 0.4279 | ort:flash 16 | 65536 | 16 | 64 | 0.006133 | 0.3502 | ort:efficient 16 | 65536 | 16 | 64 | 0.004703 | 0.4566 | ort:lean 16 | 65536 | 32 | 128 | 0.019746 | 0.4350 | ort:flash 16 | 65536 | 32 | 128 | 0.019027 | 0.4515 | ort:efficient 16 | 65536 | 32 | 128 | 0.018864 | 0.4554 | ort:lean ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 87e8a5d commit de93f40

27 files changed

+3578
-68
lines changed

cmake/CMakeLists.txt

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF)
106106
option(onnxruntime_USE_VSINPU "Build with VSINPU support" OFF)
107107

108108
cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF)
109+
cmake_dependent_option(onnxruntime_USE_LEAN_ATTENTION "Build lean attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA; NOT WIN32" OFF)
109110
option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON)
110111

111112
option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF)
@@ -751,21 +752,30 @@ if (onnxruntime_USE_CUDA)
751752

752753
if (onnxruntime_DISABLE_CONTRIB_OPS)
753754
set(onnxruntime_USE_FLASH_ATTENTION OFF)
755+
set(onnxruntime_USE_LEAN_ATTENTION OFF)
754756
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
755757
endif()
758+
756759
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6)
757760
message( STATUS "Turn off flash attention since CUDA compiler version < 11.6")
758761
set(onnxruntime_USE_FLASH_ATTENTION OFF)
762+
set(onnxruntime_USE_LEAN_ATTENTION OFF)
759763
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
760764
elseif(WIN32 AND CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12)
761765
message( STATUS "Flash-Attention unsupported in Windows with CUDA compiler version < 12.0")
762766
set(onnxruntime_USE_FLASH_ATTENTION OFF)
763767
endif()
768+
764769
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4)
765770
message( FATAL_ERROR "Failed build due to CUDA compiler version < 11.4")
766771
endif()
772+
if (WIN32)
773+
message( STATUS "Lean Attention unsupported in Windows")
774+
set(onnxruntime_USE_LEAN_ATTENTION OFF)
775+
endif()
767776
else()
768777
set(onnxruntime_USE_FLASH_ATTENTION OFF)
778+
set(onnxruntime_USE_LEAN_ATTENTION OFF)
769779
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
770780
endif()
771781

@@ -779,6 +789,13 @@ if (onnxruntime_USE_CUDA)
779789
list(APPEND ORT_PROVIDER_FLAGS -DUSE_FLASH_ATTENTION=1)
780790
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_FLASH_ATTENTION=1)
781791
endif()
792+
793+
if (onnxruntime_USE_LEAN_ATTENTION)
794+
message( STATUS "Enable lean attention for CUDA EP")
795+
list(APPEND ORT_PROVIDER_FLAGS -DUSE_LEAN_ATTENTION=1)
796+
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_LEAN_ATTENTION=1)
797+
endif()
798+
782799
if (onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION)
783800
message( STATUS "Enable memory efficient attention for CUDA EP")
784801
list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1)

onnxruntime/contrib_ops/cpu/bert/attention_common.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ enum AttentionKernelType {
4848
AttentionKernel_CutlassMemoryEfficientAttention,
4949
AttentionKernel_FlashAttention,
5050
AttentionKernel_CudnnFlashAttention,
51+
AttentionKernel_LeanAttention,
5152
AttentionKernel_Default
5253
};
5354

@@ -65,7 +66,6 @@ struct AttentionParameters {
6566
int v_hidden_size; // hidden size of V
6667
int v_head_size; // hidden size per head of V
6768
int num_heads;
68-
int num_splits;
6969
int rotary_embedding;
7070
bool is_unidirectional;
7171
bool past_present_share_buffer;
@@ -208,10 +208,13 @@ enum class AttentionBackend : int {
208208
CUDNN_FLASH_ATTENTION = 8, // reserved for cuDNN flash attention.
209209
MATH = 16, // unfused kernel cannot be disabled right now.
210210

211-
// The following kernels might be deprecated in the future.
211+
// The following TRT kernels might be deprecated in the future.
212212
TRT_FLASH_ATTENTION = 32,
213213
TRT_CROSS_ATTENTION = 64,
214214
TRT_CAUSAL_ATTENTION = 128,
215+
216+
// Experimental kernels
217+
LEAN_ATTENTION = 256,
215218
};
216219

217220
// Environment variable to enable debug information of attention kernel to be printed. Default is 0 (disabled).
@@ -239,6 +242,9 @@ constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFF
239242
// Environment variable to enable or disable flash attention. Default is 0 (enabled).
240243
constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION";
241244

245+
// Environment variable to enable or disable lean attention. Default is 0 (disabled).
246+
constexpr const char* kEnableLeanAttention = "ORT_ENABLE_LEAN_ATTENTION";
247+
242248
// Minimum sequence length to perfer memory efficient attention when data type is float32
243249
constexpr const char* kMinSeqLenForEfficientAttentionFp32 = "ORT_MIN_SEQ_LEN_EFFICIENT_ATTENTION_FP32";
244250

onnxruntime/contrib_ops/cuda/bert/attention.cc

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
102102
const int sm = device_prop.major * 10 + device_prop.minor;
103103
const bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
104104

105+
typedef typename ToCudaType<T>::MappedType CudaT;
106+
AttentionData<CudaT> data;
107+
105108
#if USE_FLASH_ATTENTION
106109
bool use_flash_attention = !disable_flash_attention_ &&
107110
(nullptr == attention_bias) &&
@@ -118,21 +121,26 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
118121
use_flash_attention = false;
119122
}
120123
// Allocate buffers
124+
size_t softmax_lse_bytes = 0;
121125
size_t softmax_lse_accum_bytes = 0;
122126
size_t out_accum_bytes = 0;
123127
if (use_flash_attention) {
128+
softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(sequence_length, batch_size, parameters.num_heads);
129+
124130
using namespace std;
125131
auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes(
126132
parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, parameters.num_heads,
127133
parameters.head_size, device_prop.multiProcessorCount);
128-
parameters.num_splits = static_cast<int>(num_splits);
134+
data.num_splits = static_cast<int>(num_splits);
129135
softmax_lse_accum_bytes = slse_accum_bytes;
130136
out_accum_bytes = o_accum_bytes;
131137
}
138+
auto softmax_lse_buffer = GetScratchBuffer<void>(softmax_lse_bytes, context->GetComputeStream());
132139
auto softmax_lse_accum_buffer = GetScratchBuffer<void>(softmax_lse_accum_bytes, context->GetComputeStream());
133140
auto out_accum_buffer = GetScratchBuffer<void>(out_accum_bytes, context->GetComputeStream());
134141
#else
135142
constexpr bool use_flash_attention = false;
143+
auto softmax_lse_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
136144
auto softmax_lse_accum_buffer = GetScratchBuffer<void>(0, context->GetComputeStream()); // nullptr
137145
auto out_accum_buffer = GetScratchBuffer<void>(0, context->GetComputeStream()); // nullptr
138146
#endif
@@ -247,6 +255,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
247255
constexpr size_t element_size = sizeof(T);
248256
constexpr bool use_fused_cross_attention = false;
249257
constexpr bool use_cudnn_flash_attention = false;
258+
constexpr bool use_lean_attention = false;
250259
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
251260
parameters.batch_size,
252261
parameters.num_heads,
@@ -257,14 +266,13 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
257266
parameters.total_sequence_length,
258267
fused_runner,
259268
use_flash_attention,
269+
use_lean_attention,
260270
use_fused_cross_attention,
261271
use_memory_efficient_attention,
262272
use_cudnn_flash_attention,
263273
false);
264274
IAllocatorUniquePtr<void> work_space = IAllocator::MakeUniquePtr<void>(allocator, workSpaceSize, false, context->GetComputeStream());
265275

266-
typedef typename ToCudaType<T>::MappedType CudaT;
267-
AttentionData<CudaT> data;
268276
data.gemm_buffer = reinterpret_cast<CudaT*>(gemm_buffer.get());
269277
if (nullptr != bias) {
270278
data.bias = reinterpret_cast<const CudaT*>(bias->Data<T>());
@@ -289,6 +297,10 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
289297
data.fused_runner = reinterpret_cast<void*>(fused_runner);
290298
data.use_flash_attention = use_flash_attention;
291299
data.use_memory_efficient_attention = use_memory_efficient_attention;
300+
if (softmax_lse_buffer != nullptr) {
301+
data.softmax_lse = reinterpret_cast<CudaT*>(softmax_lse_buffer.get());
302+
}
303+
292304
if (softmax_lse_accum_buffer != nullptr) {
293305
data.softmax_lse_accum = reinterpret_cast<CudaT*>(softmax_lse_accum_buffer.get());
294306
}

onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ limitations under the License.
3939
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
4040
#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h"
4141
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
42+
#include "contrib_ops/cuda/bert/lean_attention/lean_api.h"
4243
#include "contrib_ops/cuda/bert/attention_impl.h"
4344

4445
using namespace onnxruntime::cuda;
@@ -108,6 +109,7 @@ size_t GetAttentionWorkspaceSize(
108109
size_t total_sequence_length,
109110
void* fused_runner,
110111
bool use_flash_attention,
112+
bool use_lean_attention,
111113
bool use_fused_cross_attention,
112114
bool use_memory_efficient_attention,
113115
bool use_cudnn_flash_attention,
@@ -119,12 +121,20 @@ size_t GetAttentionWorkspaceSize(
119121

120122
#if USE_FLASH_ATTENTION
121123
if (use_flash_attention) {
122-
return qkv_bytes + onnxruntime::flash::get_softmax_lse_size(sequence_length, batch_size, num_heads);
124+
return qkv_bytes;
123125
}
124126
#else
125127
ORT_UNUSED_PARAMETER(use_flash_attention);
126128
#endif
127129

130+
#if USE_LEAN_ATTENTION
131+
if (use_lean_attention) {
132+
return qkv_bytes;
133+
}
134+
#else
135+
ORT_UNUSED_PARAMETER(use_lean_attention);
136+
#endif
137+
128138
#if USE_MEMORY_EFFICIENT_ATTENTION
129139
if (use_memory_efficient_attention) {
130140
size_t fmha_buffer_bytes = 0;
@@ -301,10 +311,10 @@ Status FlashAttention(
301311

302312
constexpr bool is_bf16 = false;
303313
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd(
304-
device_prop, stream, data.q, data.k, data.v, data.output, reinterpret_cast<void*>(data.scratch),
314+
device_prop, stream, data.q, data.k, data.v, data.output, reinterpret_cast<void*>(data.softmax_lse),
305315
parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size,
306316
parameters.sequence_length, parameters.total_sequence_length, scale, 0.0, parameters.is_unidirectional, is_bf16,
307-
false, parameters.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum),
317+
false, data.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum),
308318
reinterpret_cast<void*>(data.out_accum), data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH));
309319

310320
return Status::OK();
@@ -326,6 +336,81 @@ Status FlashAttention(
326336
}
327337
#endif
328338

339+
#if USE_LEAN_ATTENTION
340+
template <typename T>
341+
Status LeanAttention(
342+
const cudaDeviceProp& device_prop,
343+
cudaStream_t stream,
344+
contrib::AttentionParameters& parameters,
345+
AttentionData<T>& data,
346+
float scale) {
347+
assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH ||
348+
data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH);
349+
assert(nullptr == data.mask_index);
350+
assert(nullptr == data.attention_bias);
351+
assert(parameters.head_size == parameters.v_head_size);
352+
353+
constexpr bool is_bf16 = false;
354+
355+
ORT_RETURN_IF_ERROR(onnxruntime::lean::mha_fwd_kvcache(
356+
device_prop, stream,
357+
data.q,
358+
data.k, // k_cache
359+
data.v, // v_cache
360+
nullptr, // new_k (we have appended new_k to k_cache)
361+
nullptr, // new_v (we have appended new_v to k_cache)
362+
data.output,
363+
reinterpret_cast<void*>(data.softmax_lse),
364+
nullptr, // seqlens_k
365+
nullptr, // cos_cache
366+
nullptr, // sin_cache
367+
nullptr, // block_table
368+
parameters.batch_size,
369+
parameters.num_heads,
370+
parameters.num_heads, // num_heads_k
371+
parameters.head_size,
372+
parameters.sequence_length, // seqlen_q
373+
parameters.total_sequence_length, // seqlen_k
374+
0, // seqlen_k_new
375+
0, // rotary_dim
376+
scale, // softmax_scale
377+
parameters.is_unidirectional,
378+
is_bf16,
379+
false, // past_bsnh
380+
data.num_splits,
381+
data.grid_dim_z,
382+
data.max_tiles_per_tb,
383+
data.high_load_tbs,
384+
data.tiles_per_head,
385+
reinterpret_cast<void*>(data.softmax_lse_accum),
386+
reinterpret_cast<void*>(data.out_accum),
387+
data.lean_sync_flag,
388+
-1, // local_window_size
389+
false, // is_rotary_interleaved
390+
false // is_packed_qkv
391+
));
392+
393+
return Status::OK();
394+
}
395+
396+
template <>
397+
Status LeanAttention(
398+
const cudaDeviceProp& device_prop,
399+
cudaStream_t stream,
400+
contrib::AttentionParameters& parameters,
401+
AttentionData<float>& data,
402+
float scale) {
403+
ORT_UNUSED_PARAMETER(device_prop);
404+
ORT_UNUSED_PARAMETER(stream);
405+
ORT_UNUSED_PARAMETER(parameters);
406+
ORT_UNUSED_PARAMETER(data);
407+
ORT_UNUSED_PARAMETER(scale);
408+
return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "lean attention does not support float tensor");
409+
}
410+
#endif
411+
412+
413+
329414
template <typename T>
330415
Status CudnnFlashAttention(
331416
cudnnHandle_t cudnn_handle,
@@ -641,6 +726,11 @@ Status QkvToContext(
641726
// For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation.
642727
const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(qk_head_size))
643728
: parameters.scale;
729+
#if USE_LEAN_ATTENTION
730+
if (data.use_lean_attention) {
731+
return LeanAttention(device_prop, stream, parameters, data, scale);
732+
}
733+
#endif
644734

645735
#if USE_FLASH_ATTENTION
646736
if (data.use_flash_attention) {

onnxruntime/contrib_ops/cuda/bert/attention_impl.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ size_t GetAttentionWorkspaceSize(
5353
size_t total_sequence_length,
5454
void* fused_runner,
5555
bool use_flash_attention,
56+
bool use_lean_attention,
5657
bool use_fused_cross_attention,
5758
bool use_memory_efficient_attention,
5859
bool use_cudnn_flash_attention,
@@ -102,6 +103,19 @@ struct AttentionData {
102103
T* softmax_lse_accum = nullptr;
103104
T* out_accum = nullptr;
104105

106+
// Flash Atttention and Lean Attention
107+
int num_splits;
108+
109+
// Lean Attention
110+
bool use_lean_attention = false;
111+
#if USE_LEAN_ATTENTION
112+
int grid_dim_z = 0;
113+
int max_tiles_per_tb = 0;
114+
int high_load_tbs = 0;
115+
int tiles_per_head = 0;
116+
int* lean_sync_flag = nullptr;
117+
#endif
118+
105119
// For Debugging
106120
size_t workspace_bytes = 0;
107121
bool allow_debug_info = false;
@@ -115,6 +129,7 @@ struct AttentionData {
115129

116130
void PrintDebugInfo() const {
117131
std::cout << "flash=" << use_flash_attention
132+
<< ", lean=" << use_lean_attention
118133
<< ", efficient=" << use_memory_efficient_attention
119134
<< ", fused_runner=" << (fused_runner != nullptr)
120135
<< ", fused_cross=" << (fused_cross_attention_kernel != nullptr)

onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ namespace onnxruntime {
1717
void AttentionKernelOptions::Initialize(int value, bool use_build_flag, bool check_cudnn_version) {
1818
if (value > 0) {
1919
use_flash_attention_ = (value & static_cast<int>(AttentionBackend::FLASH_ATTENTION)) > 0;
20+
#if USE_LEAN_ATTENTION
21+
use_lean_attention_ = (value & static_cast<int>(AttentionBackend::LEAN_ATTENTION)) > 0;
22+
#endif
2023
use_efficient_attention_ = (value & static_cast<int>(AttentionBackend::EFFICIENT_ATTENTION)) > 0;
2124
use_trt_fused_attention_ = (value & static_cast<int>(AttentionBackend::TRT_FUSED_ATTENTION)) > 0;
2225
use_cudnn_flash_attention_ = (value & static_cast<int>(AttentionBackend::CUDNN_FLASH_ATTENTION)) > 0;
@@ -26,6 +29,9 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag, bool che
2629
use_trt_causal_attention_ = (value & static_cast<int>(AttentionBackend::TRT_CAUSAL_ATTENTION)) > 0;
2730
} else {
2831
use_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFlashAttention, false);
32+
#if USE_LEAN_ATTENTION
33+
use_lean_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableLeanAttention, false);
34+
#endif
2935
use_efficient_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableMemoryEfficientAttention, false);
3036
use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedSelfAttention, false);
3137
use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableCudnnFlashAttention, false);
@@ -61,6 +67,10 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag, bool che
6167
use_flash_attention_ = false;
6268
#endif
6369

70+
#ifndef USE_LEAN_ATTENTION
71+
use_lean_attention_ = false;
72+
#endif
73+
6474
#ifndef USE_MEMORY_EFFICIENT_ATTENTION
6575
use_efficient_attention_ = false;
6676
#endif
@@ -81,6 +91,9 @@ void AttentionKernelOptions::Print() const {
8191
std::stringstream sstream;
8292
sstream << "AttentionKernelOptions:";
8393
sstream << " FLASH_ATTENTION=" << int(use_flash_attention_);
94+
#if USE_LEAN_ATTENTION
95+
sstream << " LEAN_ATTENTION=" << int(use_lean_attention_);
96+
#endif
8497
sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention_);
8598
sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention_);
8699
sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention_);
@@ -131,6 +144,10 @@ void AttentionKernelDebugInfo::Print(const char* operator_name,
131144
sstream << " SdpaKernel=";
132145
if (use_flash_attention.has_value() && use_flash_attention.value()) {
133146
sstream << "FLASH_ATTENTION";
147+
#if USE_LEAN_ATTENTION
148+
} else if (use_lean_attention.has_value() && use_lean_attention.value()) {
149+
sstream << "LEAN_ATTENTION";
150+
#endif
134151
} else if (use_efficient_attention.has_value() && use_efficient_attention.value()) {
135152
sstream << "EFFICIENT_ATTENTION";
136153
} else if (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) {

0 commit comments

Comments
 (0)