diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 695cc2200fe7a..187f1bb37edc5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -417,10 +417,10 @@ Status QkvToContext( const bool past_present_share_buffer = parameters.past_present_share_buffer; const float mask_filter_value = parameters.mask_filter_value; void* fused_runner = data.fused_runner; - bool use_memory_efficient_attention = data.use_memory_efficient_attention; // At most one fused kernel is enabled. - assert(int(use_memory_efficient_attention) + int(fused_runner != nullptr) + int(data.fused_cross_attention_kernel != nullptr) <= 1); + assert(int(data.use_memory_efficient_attention) + int(fused_runner != nullptr) + + int(data.fused_cross_attention_kernel != nullptr) <= 1); const int batches = batch_size * num_heads; const int size_per_batch_q = sequence_length * qk_head_size; @@ -469,7 +469,7 @@ Status QkvToContext( assert(data.fused_cross_attention_kernel == nullptr); assert(!use_fused_kernel); assert(data.gemm_buffer != nullptr); - assert(!use_memory_efficient_attention); + assert(!data.use_memory_efficient_attention); if (data.present != data.past) { // For easy testing. Production should better avoid this path. @@ -564,7 +564,7 @@ Status QkvToContext( } #if USE_FLASH_ATTENTION - if (use_memory_efficient_attention) { + if (data.use_memory_efficient_attention) { // We only enable fused cross attention when there is no key padding mask. // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. assert(data.mask_index == nullptr);