From d93e6533b7bb9ba32bd5c4f8cc976ec37b838de1 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 22 Aug 2022 10:18:52 -0700 Subject: [PATCH] Format bert or transformers code (#12646) (1) Modify some lines to fit line length limit 120 (2) Adjust parameter order of LaunchAttentionKernel (3) Format code with Clang-Format in VS Code (4) Fix spelling errors --- onnxruntime/contrib_ops/cpu/bert/attention.cc | 120 +++-- .../contrib_ops/cpu/bert/attention_base.h | 11 +- .../contrib_ops/cpu/bert/attention_cpu_base.h | 61 ++- .../contrib_ops/cpu/bert/attention_helper.h | 11 +- .../contrib_ops/cpu/bert/bias_gelu_helper.h | 2 +- .../contrib_ops/cpu/bert/embed_layer_norm.cc | 113 ++--- .../cpu/bert/embed_layer_norm_helper.cc | 2 +- .../contrib_ops/cpu/bert/ngram_repeat_block.h | 14 +- .../cpu/transformers/generation_shared.h | 1 - .../cpu/transformers/greedy_search.cc | 58 +-- .../cpu/transformers/greedy_search_impl_gpt.h | 28 +- .../transformers/greedy_search_parameters.cc | 3 +- .../cpu/transformers/subgraph_t5_decoder.h | 4 +- .../cpu/transformers/subgraph_t5_encoder.h | 4 +- .../contrib_ops/cuda/bert/attention.cc | 34 +- .../contrib_ops/cuda/bert/attention_concat.cu | 91 ++-- .../contrib_ops/cuda/bert/attention_impl.cu | 420 ++++++++++-------- .../contrib_ops/cuda/bert/attention_impl.h | 82 ++-- .../contrib_ops/cuda/bert/attention_softmax.h | 197 ++++++-- .../cuda/bert/attention_transpose.cu | 2 - .../cuda/bert/decoder_attention.cc | 25 +- .../contrib_ops/cuda/bert/embed_layer_norm.cc | 6 +- .../cuda/bert/embed_layer_norm_impl.cu | 10 +- .../cuda/bert/embed_layer_norm_impl.h | 34 +- .../contrib_ops/cuda/bert/fast_gelu.cc | 2 +- onnxruntime/contrib_ops/cuda/bert/fast_gelu.h | 2 +- .../contrib_ops/cuda/bert/fast_gelu_impl.cu | 25 +- .../contrib_ops/cuda/bert/fast_gelu_impl.h | 3 +- .../contrib_ops/cuda/bert/layer_norm.cuh | 16 +- .../cuda/bert/longformer_attention_impl.cu | 4 +- .../cuda/bert/longformer_attention_impl.h | 3 +- .../cuda/bert/longformer_attention_softmax.cu | 23 +- .../cuda/bert/ngram_repeat_block.cc | 2 +- .../cuda/bert/ngram_repeat_block.h | 1 + .../cuda/bert/ngram_repeat_block_impl.cu | 4 +- .../contrib_ops/cuda/bert/skip_layer_norm.cc | 2 +- .../cuda/bert/skip_layer_norm_impl.cu | 41 +- .../cuda/bert/skip_layer_norm_impl.h | 20 +- onnxruntime/contrib_ops/cuda/layer_norm.cc | 6 +- .../contrib_ops/cuda/layer_norm_impl.cu | 8 +- .../quantization/attention_quantization.cc | 21 +- .../transformers/generation_device_helper.cc | 24 +- .../contrib_ops/rocm/bert/attention.cc | 25 +- .../contrib_ops/rocm/bert/attention_impl.cu | 367 ++++++++------- .../contrib_ops/rocm/bert/attention_impl.h | 197 ++++---- .../contrib_ops/rocm/bert/attention_softmax.h | 187 ++++++-- 46 files changed, 1371 insertions(+), 945 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 1234fb019c930..7af0d9d7be2e8 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -117,7 +117,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, } if (hidden_size % num_heads_ != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "hidden_size should be divisiable by num_heads."); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "hidden_size should be divisible by num_heads."); } } else { int qkv_sizes = 0; @@ -129,12 +129,13 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, if (qkv_hidden_sizes_[0] != qkv_hidden_sizes_[1]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "qkv_hidden_sizes first element should be same as the second"); + "qkv_hidden_sizes first element should be same as the second"); } for (size_t i = 0; i < qkv_hidden_sizes_.size(); i++) { if (qkv_hidden_sizes_[i] % num_heads_ != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "hidden_size should be divisiable by num_heads:", qkv_hidden_sizes_[i]); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "hidden_size should be divisible by num_heads:", qkv_hidden_sizes_[i]); } qkv_sizes += static_cast(qkv_hidden_sizes_[i]); @@ -164,13 +165,16 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'past' dimension 0 shall have length of 2"); } if (static_cast(past_dims[1]) != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'past' dimension 1 shall have same length as dimension 0 of input 0"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Inputs 'past' dimension 1 shall have same length as dimension 0 of input 0"); } if (static_cast(past_dims[2]) != num_heads_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'past' dimension 2 shall have length of num_heads", num_heads_); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Inputs 'past' dimension 2 shall have length of num_heads", num_heads_); } if (static_cast(past_dims[4]) != hidden_size / num_heads_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'past' dimension 2 shall have length of ", hidden_size / num_heads_); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Inputs 'past' dimension 2 shall have length of ", hidden_size / num_heads_); } past_sequence_length = static_cast(past_dims[3]); } @@ -179,31 +183,50 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, const auto& mask_dims = mask_index->Shape().GetDims(); if (mask_dims.size() == 1) { if (static_cast(mask_dims[0]) != batch_size && static_cast(mask_dims[0]) != 2 * batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 1D data shall have length of batch_size or 2 * batch_size"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Inputs 'mask_index' with 1D data shall have length of batch_size or 2 * batch_size"); } } else if (mask_dims.size() == 2) { - if (static_cast(mask_dims[0]) != batch_size || static_cast(mask_dims[1]) != past_sequence_length + sequence_length) { + if (static_cast(mask_dims[0]) != batch_size || + static_cast(mask_dims[1]) != past_sequence_length + sequence_length) { // Add operator supports broadcasting. Here we handle a case with only one element in the 2nd dimension. - if ((static_cast(mask_dims[0]) == batch_size || static_cast(mask_dims[0]) == 1) && static_cast(mask_dims[1]) == 1) { - // Mask will have same value after propogation, which has same effect as no mask. + if ((static_cast(mask_dims[0]) == batch_size || static_cast(mask_dims[0]) == 1) && + static_cast(mask_dims[1]) == 1) { + // Mask will have same value after propagation, which has same effect as no mask. mask_index = nullptr; } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 2D data shall have shape batch_size x (past_sequence_length + sequence_length)"); + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Inputs 'mask_index' with 2D data shall have shape " + "batch_size x (past_sequence_length + sequence_length)"); } } } else if (mask_dims.size() == 3) { - if (static_cast(mask_dims[0]) != batch_size || mask_dims[1] != sequence_length || static_cast(mask_dims[2]) != past_sequence_length + sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 3D data shall have shape batch_size x sequence_length x (past_sequence_length + sequence_length)"); + if (static_cast(mask_dims[0]) != batch_size || + mask_dims[1] != sequence_length || + static_cast(mask_dims[2]) != past_sequence_length + sequence_length) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Inputs 'mask_index' with 3D data shall have shape " + "batch_size x sequence_length x (past_sequence_length + sequence_length)"); } } else if (mask_dims.size() == 4) { - if (static_cast(mask_dims[0]) != batch_size || mask_dims[1] != 1 || mask_dims[2] != mask_dims[3] || mask_dims[2] < static_cast(past_sequence_length) + sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 4D data shall have shape batch_size x 1 x max_sequence_length x max_sequence_length)"); + if (static_cast(mask_dims[0]) != batch_size || + mask_dims[1] != 1 || + mask_dims[2] != mask_dims[3] || + mask_dims[2] < static_cast(past_sequence_length) + sequence_length) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Inputs 'mask_index' with 4D data shall have shape " + "batch_size x 1 x max_sequence_length x max_sequence_length)"); } if (is_unidirectional_ == true) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 4D data shall have is_unidirectional_ set to false"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Inputs 'mask_index' with 4D data shall have is_unidirectional_ set to false"); } } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'mask_index' is expected to have 1, 2, 3 or 4 dimensions, got ", + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'mask_index' is expected to have 1, 2, 3 or 4 dimensions, got ", mask_dims.size()); } } @@ -212,24 +235,29 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, const auto& extra_add_qk_dims = extra_add_qk->Shape().GetDims(); if (extra_add_qk_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' is expected to have 4 dimensions, got ", + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'extra_add_qk' is expected to have 4 dimensions, got ", extra_add_qk_dims.size()); } if (extra_add_qk_dims[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' dimension 0 should be same as batch_size, got ", + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'extra_add_qk' dimension 0 should be same as batch_size, got ", extra_add_qk_dims[0]); } if (extra_add_qk_dims[1] != num_heads_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' dimension 1 should be same as number of heads, got ", + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'extra_add_qk' dimension 1 should be same as number of heads, got ", extra_add_qk_dims[1]); } if (extra_add_qk_dims[2] != sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' dimension 2 should be same as sequence_length, got ", + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'extra_add_qk' dimension 2 should be same as sequence_length, got ", extra_add_qk_dims[2]); } if (extra_add_qk_dims[3] != sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' dimension 3 should be same as sequence_length, got ", + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'extra_add_qk' dimension 3 should be same as sequence_length, got ", extra_add_qk_dims[3]); } } @@ -322,7 +350,6 @@ template Status Attention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { - /* The PrePack() massages the weights to speed up Compute(), there is an option to * use shared prepacked weights in which case prepacked_weights parameter would be non-null. * @@ -375,9 +402,14 @@ Status Attention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr const size_t qkv_head_size[3] = {q_hidden_size / num_heads_, k_hidden_size / num_heads_, v_hidden_size / num_heads_}; const size_t weight_matrix_col_size = q_hidden_size + k_hidden_size + v_hidden_size; - if (!IsPackWeightsSuccessful(0, alloc, qkv_head_size[0], input_hidden_size, weights_data, weight_matrix_col_size, prepacked_weights) || - !IsPackWeightsSuccessful(1, alloc, qkv_head_size[1], input_hidden_size, weights_data + (num_heads_ * qkv_head_size[0]), weight_matrix_col_size, prepacked_weights) || - !IsPackWeightsSuccessful(2, alloc, qkv_head_size[2], input_hidden_size, weights_data + (num_heads_ * (qkv_head_size[0] + qkv_head_size[1])), weight_matrix_col_size, prepacked_weights)) { + if (!IsPackWeightsSuccessful(0, alloc, qkv_head_size[0], input_hidden_size, + weights_data, weight_matrix_col_size, prepacked_weights) || + !IsPackWeightsSuccessful(1, alloc, qkv_head_size[1], input_hidden_size, + weights_data + (num_heads_ * qkv_head_size[0]), + weight_matrix_col_size, prepacked_weights) || + !IsPackWeightsSuccessful(2, alloc, qkv_head_size[2], input_hidden_size, + weights_data + (num_heads_ * (qkv_head_size[0] + qkv_head_size[1])), + weight_matrix_col_size, prepacked_weights)) { if (prepacked_weights == nullptr) { FreePackedWeights(packed_weights_, qkv_hidden_sizes_.size()); } @@ -469,7 +501,8 @@ Status Attention::Compute(OpKernelContext* context) const { // gemm_data(BS, NT) = input(BS, D) x weights(D, NT) + bias(NT) // D (input_hidden_size) is hidden dimension of input, where D could be larger than any of the hidden_sizes // (NH) when model is pruned. T = H1 + H2 + H3, where H1, H2, H3 are head sizes of Q, K, V respectively - auto gemm_data = allocator->Alloc(SafeInt(batch_size) * sequence_length * (q_hidden_size + k_hidden_size + v_hidden_size) * element_size); + int qkv_hidden_size = (q_hidden_size + k_hidden_size + v_hidden_size); + auto gemm_data = allocator->Alloc(SafeInt(batch_size) * sequence_length * qkv_hidden_size * element_size); BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(std::move(allocator))); auto Q = reinterpret_cast(gemm_data); @@ -523,12 +556,13 @@ Status Attention::Compute(OpKernelContext* context) const { // C: QKV[qkv_index] (BxNxSxT) (B.N.)S x T S x H if (is_prepack_) { uint8_t* packed_weight; - packed_weight = static_cast(packed_weights_[qkv_index].get()) + packed_weights_size_[qkv_index] * (weights_offset / head_size); + packed_weight = static_cast(packed_weights_[qkv_index].get()) + + packed_weights_size_[qkv_index] * (weights_offset / head_size); MlasGemm( CblasNoTrans, // TransA = no sequence_length, // M = S - head_size, // N = H + head_size, // N = H input_hidden_size, // K = D 1.0f, // alpha input_data + input_offset, // A @@ -540,20 +574,20 @@ Status Attention::Compute(OpKernelContext* context) const { nullptr); // use single-thread } else { math::GemmEx( - CblasNoTrans, // TransA = no - CblasNoTrans, // TransB = no - sequence_length, // M = S - head_size, // N = H - input_hidden_size, // K = D - 1.0f, // alpha - input_data + input_offset, // A - input_hidden_size, // lda = D - weights_data + weights_offset, // B - q_hidden_size + k_hidden_size + v_hidden_size,// ldb = NH1 + NH2 + NH3 - 1.0f, // beta - qkv_dest + qkv_offset, // C - head_size, // ldc - nullptr // use single-thread + CblasNoTrans, // TransA = no + CblasNoTrans, // TransB = no + sequence_length, // M = S + head_size, // N = H + input_hidden_size, // K = D + 1.0f, // alpha + input_data + input_offset, // A + input_hidden_size, // lda = D + weights_data + weights_offset, // B + q_hidden_size + k_hidden_size + v_hidden_size, // ldb = NH1 + NH2 + NH3 + 1.0f, // beta + qkv_dest + qkv_offset, // C + head_size, // ldc + nullptr // use single-thread ); } } diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h index 0855cb33773da..feee7e9d6d0fd 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h @@ -3,6 +3,7 @@ #pragma once +#include #include "core/common/common.h" #include "core/framework/op_kernel.h" @@ -17,7 +18,7 @@ class AttentionBase { const TensorShape& bias_shape, const Tensor*& mask_index, // For dummy mask with shape (1, 1) or (batch_size, 1), it will be updated to nullptr. const Tensor* past, - const Tensor *extra_add_qk, + const Tensor* extra_add_qk, const int max_threads_per_block) const; Tensor* GetPresent(OpKernelContext* context, @@ -45,11 +46,11 @@ class AttentionBase { const TensorShape& bias_shape, const Tensor*& mask_index, // For dummy mask with shape (1, 1) or (batch_size, 1), it will be updated to nullptr. const Tensor* past, - const Tensor *extra_add_qk) const; + const Tensor* extra_add_qk) const; - int num_heads_; // number of attention heads - bool is_unidirectional_; // whether every token can only attend to previous tokens. - std::vector qkv_hidden_sizes_; // Q, K, V path hidden layer sizes + int num_heads_; // number of attention heads + bool is_unidirectional_; // whether every token can only attend to previous tokens. + std::vector qkv_hidden_sizes_; // Q, K, V path hidden layer sizes }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index 69e2017601dd7..1d08638a7c0f3 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -9,12 +9,7 @@ #include "core/common/common.h" #include "core/common/safeint.h" #include "core/framework/op_kernel.h" -//TODO: fix the warnings -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(push) -// Chance of arithmetic overflow could be reduced -#pragma warning(disable : 26451) -#endif + namespace onnxruntime { namespace contrib { @@ -51,8 +46,8 @@ class AttentionCPUBase : public AttentionBase { // I. attention_probs(B, N, S, S*) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, S*, H -> B, N, H, S*) + // 1 x mask_data(B, N, S, S*) // II.attention_probs(B, N, S, S*) = Softmax(attention_probs) - size_t attention_probs_bytes = SafeInt(batch_size) * num_heads_ * sequence_length * all_sequence_length * sizeof(T); - auto attention_probs = allocator->Alloc(attention_probs_bytes); + size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * all_sequence_length * sizeof(T); + auto attention_probs = allocator->Alloc(bytes); BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator)); bool has_unidirectional = (is_unidirectional_ && sequence_length > 1); @@ -66,7 +61,9 @@ class AttentionCPUBase : public AttentionBase { BufferUniquePtr mask_data_buffer(mask_data, BufferDeleter(allocator)); const int32_t* mask_index_data = mask_index != nullptr ? mask_index->Data() : nullptr; - gsl::span mask_index_dims = mask_index != nullptr ? mask_index->Shape().GetDims() : gsl::span{}; + gsl::span mask_index_dims = mask_index != nullptr + ? mask_index->Shape().GetDims() + : gsl::span{}; const T* past_data = past != nullptr ? past->Data() : nullptr; T* present_data = present != nullptr ? present->MutableData() : nullptr; @@ -77,7 +74,8 @@ class AttentionCPUBase : public AttentionBase { ComputeAttentionProbs(static_cast(attention_probs), Q, K, mask_index_data, mask_index_dims, static_cast(mask_data), has_unidirectional, - batch_size, sequence_length, past_sequence_length, qk_head_size == 0 ? v_head_size : qk_head_size, + batch_size, sequence_length, past_sequence_length, + qk_head_size == 0 ? v_head_size : qk_head_size, past_data, present_data, tp, extra_add_qk_data); // Compute the attentionScore * Value. It does: out_tmp(B, N, S, H) = attention_probs(B, N, S, S*) x V(B, N, S*, H) @@ -85,7 +83,8 @@ class AttentionCPUBase : public AttentionBase { allocator->Alloc(SafeInt(batch_size) * num_heads_ * sequence_length * v_head_size * sizeof(T)); BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(std::move(allocator))); - ComputeVxAttentionScore(output->MutableData(), static_cast(out_tmp_data), static_cast(attention_probs), V, + ComputeVxAttentionScore(output->MutableData(), static_cast(out_tmp_data), + static_cast(attention_probs), V, batch_size, sequence_length, past_sequence_length, v_head_size, v_hidden_size, past_data, present_data, tp); @@ -98,21 +97,21 @@ class AttentionCPUBase : public AttentionBase { // 1 x mask_data(B, N, S, S*) // II.attention_probs(B, N, S, S*) = Softmax(attention_probs) template - void ComputeAttentionProbs(T* attention_probs, // output buffer for the attention probs. Its size is BxNxSxS* - const T* Q, // Q data. Its size is BxNxSxH - const T* K, // k data. Its size is BxNxSxH - const int32_t* mask_index, // mask index. nullptr if no mask or its size is B - gsl::span mask_index_dims, // mask index shape - T* mask_data, // buffer for mask data. It is nullptr if mask_index is nullptr and not unidirectional, otherwise its shape is BxSxS* - bool has_unidirectional, // has unidirectional mask - int batch_size, // batch size of self-attention - int sequence_length, // sequence length of self-attention - int past_sequence_length, // sequence length of past state - int head_size, // head size of self-attention - const T* past, // past state - T* present, // present state - ThreadPool* tp, // thread pool - const T* extra_add_qk_data // extra add matrix with shape BxNxSxS* + void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxS* + const T* Q, // Q data. Its size is BxNxSxH + const T* K, // k data. Its size is BxNxSxH + const int32_t* mask_index, // mask index. nullptr if no mask. + gsl::span mask_index_dims, // mask index shape + T* mask_data, // buffer for mask data. + bool has_unidirectional, // has unidirectional mask + int batch_size, // batch size of self-attention + int sequence_length, // sequence length of self-attention + int past_sequence_length, // sequence length of past state + int head_size, // head size of self-attention + const T* past, // past state + T* present, // present state + ThreadPool* tp, // thread pool + const T* extra_add_qk_data // extra add matrix with shape BxNxSxS* ) const { const int all_sequence_length = past_sequence_length + sequence_length; // S* = S' + S const size_t past_chunk_length = static_cast(past_sequence_length) * head_size; // S' x H @@ -120,10 +119,13 @@ class AttentionCPUBase : public AttentionBase { const size_t present_chunk_length = past_chunk_length + input_chunk_length; // S* x H { + // mask_data is nullptr when mask_index is nullptr and not unidirectional, otherwise its shape is BxSxS* if (mask_data != nullptr) { - PrepareMask(mask_index, mask_index_dims, mask_data, has_unidirectional, batch_size, sequence_length, past_sequence_length); + PrepareMask(mask_index, mask_index_dims, mask_data, + has_unidirectional, batch_size, sequence_length, past_sequence_length); } else { // no any mask - memset(attention_probs, 0, static_cast(batch_size) * num_heads_ * sequence_length * all_sequence_length * sizeof(T)); + size_t bytes = static_cast(batch_size) * num_heads_ * sequence_length * all_sequence_length * sizeof(T); + memset(attention_probs, 0, bytes); } const int loop_len = batch_size * num_heads_; @@ -247,6 +249,3 @@ class AttentionCPUBase : public AttentionBase { } // namespace contrib } // namespace onnxruntime -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(pop) -#endif diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index 6e21d82d6a972..9b05f30ce8f8d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -99,7 +99,9 @@ void PrepareMask(const int32_t* mask_index, } bool is_raw_attention_mask = (nullptr != mask_index && mask_index_dims.size() == 2); - bool has_mask_start_position = (nullptr != mask_index && mask_index_dims.size() == 1 && static_cast(mask_index_dims.at(0)) == 2 * batch_size); + bool has_mask_start_position = (nullptr != mask_index && + mask_index_dims.size() == 1 && + static_cast(mask_index_dims.at(0)) == 2 * batch_size); for (int b_i = 0; b_i < batch_size; b_i++) { // TODO: mask_index can be used in softmax to save some calculation. @@ -150,7 +152,12 @@ void PrepareMask(const int32_t* mask_index, // Concatenate a past state chunk S'xH with input state chunk SxH into present state chunk S*xH // Returns a pointer to the start of present state chunk. template -T* ConcatStateChunk(const T* past, const T* chunk, T* present, size_t past_chunk_length, size_t present_chunk_length, std::ptrdiff_t i) { +T* ConcatStateChunk(const T* past, + const T* chunk, + T* present, + size_t past_chunk_length, + size_t present_chunk_length, + std::ptrdiff_t i) { T* start = present + i * present_chunk_length; T* p = start; diff --git a/onnxruntime/contrib_ops/cpu/bert/bias_gelu_helper.h b/onnxruntime/contrib_ops/cpu/bert/bias_gelu_helper.h index 2500dd0f3592e..86552f888833e 100644 --- a/onnxruntime/contrib_ops/cpu/bert/bias_gelu_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/bias_gelu_helper.h @@ -12,6 +12,6 @@ namespace bias_gelu_helper { Status CheckInputs(const OpKernelContext* context); -} +} // namespace bias_gelu_helper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc b/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc index 44b972e89749a..fdc5573361ad9 100644 --- a/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc @@ -38,7 +38,6 @@ EmbedLayerNorm::EmbedLayerNorm(const OpKernelInfo& op_kernel_info) : EmbedLayerNormBase(op_kernel_info) { } - template Status EmbedLayerNorm::Compute(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(embed_layer_norm::CheckInputs(context)); @@ -49,8 +48,8 @@ Status EmbedLayerNorm::Compute(OpKernelContext* context) const { const Tensor* segment_embedding = context->Input(4); // optional. nullptr if it's distill-bert const Tensor* gamma = context->Input(5); const Tensor* beta = context->Input(6); - const Tensor* mask = context->Input(7); // optional. nullptr if not provided - const Tensor* position_ids = context->Input(8); // optional. nullptr if not provided + const Tensor* mask = context->Input(7); // optional. nullptr if not provided + const Tensor* position_ids = context->Input(8); // optional. nullptr if not provided const auto& input_dims = input_ids->Shape().GetDims(); int64_t hidden_size = word_embedding->Shape()[1]; @@ -86,58 +85,62 @@ Status EmbedLayerNorm::Compute(OpKernelContext* context) const { std::atomic_bool failed{false}; int n = batch_size * sequence_length; - concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), n, [=, &failed](ptrdiff_t index) { - int word_col_index = input_ids_data[index]; - if (word_col_index < 0 || word_col_index >= word_embedding_length) { - failed.store(true, std::memory_order_release); - return; - } - int position_col_index = (position_ids_data == nullptr) ? index % sequence_length : position_ids_data[index]; - if (position_col_index >= position_embedding_length) { - failed.store(true, std::memory_order_release); - return; - } - int segment_col_index = 0; - if (nullptr != segment_ids_data) { - segment_col_index = segment_ids_data[index]; - if (segment_col_index < 0 || segment_col_index >= segment_embedding_length) { - failed.store(true, std::memory_order_release); - return; - } - } - - T* y = output_data + index * hidden_size; - T* y1 = nullptr; - if (embedding_sum_data != nullptr) { - y1 = embedding_sum_data + index * hidden_size; - } - const T* input_word_embedding = word_embedding_data + word_col_index * hidden_size; - const T* input_position_embedding = position_embedding_data + position_col_index * hidden_size; - const T* input_segment_embedding = (nullptr == segment_embedding_data) ? nullptr : segment_embedding_data + segment_col_index * hidden_size; - - T sum = static_cast(0); - for (int i = 0; i < hidden_size; i++) { - T subtotal = input_word_embedding[i] + input_position_embedding[i]; - if (nullptr != segment_embedding_data) - subtotal += input_segment_embedding[i]; - y[i] = subtotal; - if (y1 != nullptr) { - y1[i] = subtotal; - } - sum += subtotal; - } - T mean = sum / hidden_size; - sum = 0; - for (int i = 0; i < hidden_size; i++) { - T a = y[i] - mean; - y[i] = a; - sum += a * a; - } - T e = sqrt(sum / hidden_size + static_cast(epsilon())); - for (int i = 0; i < hidden_size; i++) { - y[i] = y[i] / e * gamma_data[i] + beta_data[i]; - } - }, 0); + concurrency::ThreadPool::TryBatchParallelFor( + context->GetOperatorThreadPool(), n, [=, &failed](ptrdiff_t index) { + int word_col_index = input_ids_data[index]; + if (word_col_index < 0 || word_col_index >= word_embedding_length) { + failed.store(true, std::memory_order_release); + return; + } + int position_col_index = (position_ids_data == nullptr) ? index % sequence_length : position_ids_data[index]; + if (position_col_index >= position_embedding_length) { + failed.store(true, std::memory_order_release); + return; + } + int segment_col_index = 0; + if (nullptr != segment_ids_data) { + segment_col_index = segment_ids_data[index]; + if (segment_col_index < 0 || segment_col_index >= segment_embedding_length) { + failed.store(true, std::memory_order_release); + return; + } + } + + T* y = output_data + index * hidden_size; + T* y1 = nullptr; + if (embedding_sum_data != nullptr) { + y1 = embedding_sum_data + index * hidden_size; + } + const T* input_word_embedding = word_embedding_data + word_col_index * hidden_size; + const T* input_position_embedding = position_embedding_data + position_col_index * hidden_size; + const T* input_segment_embedding = (nullptr == segment_embedding_data) + ? nullptr + : segment_embedding_data + segment_col_index * hidden_size; + + T sum = static_cast(0); + for (int i = 0; i < hidden_size; i++) { + T subtotal = input_word_embedding[i] + input_position_embedding[i]; + if (nullptr != segment_embedding_data) + subtotal += input_segment_embedding[i]; + y[i] = subtotal; + if (y1 != nullptr) { + y1[i] = subtotal; + } + sum += subtotal; + } + T mean = sum / hidden_size; + sum = 0; + for (int i = 0; i < hidden_size; i++) { + T a = y[i] - mean; + y[i] = a; + sum += a * a; + } + T e = sqrt(sum / hidden_size + static_cast(epsilon())); + for (int i = 0; i < hidden_size; i++) { + y[i] = y[i] / e * gamma_data[i] + beta_data[i]; + } + }, + 0); if (failed.load(std::memory_order_acquire)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input index out of range"); diff --git a/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm_helper.cc b/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm_helper.cc index e5c49a6c0d5fb..b937912af3250 100644 --- a/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm_helper.cc +++ b/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm_helper.cc @@ -23,7 +23,7 @@ Status CheckInputs(const OpKernelContext* context, bool quantizedVersion) { const Tensor* mask = context->Input(7); // optional. nullptr if not provided if (!quantizedVersion) { - const Tensor* position_ids = context->Input(8); // optional. nullptr if not provided + const Tensor* position_ids = context->Input(8); // optional. nullptr if not provided if (nullptr != position_ids && input_ids->Shape() != position_ids->Shape()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/contrib_ops/cpu/bert/ngram_repeat_block.h b/onnxruntime/contrib_ops/cpu/bert/ngram_repeat_block.h index 509a35fc33b26..8c1f75e20a698 100644 --- a/onnxruntime/contrib_ops/cpu/bert/ngram_repeat_block.h +++ b/onnxruntime/contrib_ops/cpu/bert/ngram_repeat_block.h @@ -68,16 +68,16 @@ class NGramRepeatBlock : public OpKernel { concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); concurrency::ThreadPool::TryParallelFor( - tp, batch_size, static_cast(cur_len * ngram_size_), - [&lambda](ptrdiff_t first, ptrdiff_t last) { - for (auto b = static_cast(first), end = static_cast(last); b < end; ++b) { - lambda(b); - } - } - ); + tp, batch_size, static_cast(cur_len * ngram_size_), + [&lambda](ptrdiff_t first, ptrdiff_t last) { + for (auto b = static_cast(first), end = static_cast(last); b < end; ++b) { + lambda(b); + } + }); return Status::OK(); } + private: int64_t ngram_size_; }; diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index ce272b4886057..1f1b63485c365 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -51,7 +51,6 @@ struct IGreedySearchState { gsl::span sequence_lengths; // shape (batch_size) gsl::span next_positions; // shape (batch_size, num_beams). Next position value for position_ids. gsl::span eos_meet; // shape (batch_size) - gsl::span next_token_scores; // shape (batch_size, vocab_size) gsl::span next_tokens; // shape (batch_size) }; diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc index 440d2eed75334..088c59fac2722 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc @@ -127,43 +127,43 @@ Status GreedySearch::Compute(OpKernelContext* ctx) const { // make a copy since we will update the parameters based on inputs later GreedySearchParameters parameters = parameters_; -if (parameters_.model_type == 0) { // GPT-2 + if (parameters_.model_type == 0) { // GPT-2 // Subgraph has constraint that the output is either float or float16 if (!gpt_subgraph_->IsOutputFloat16()) { GreedySearchGpt impl{ - *ctx_internal, - *decoder_session_state, - *gpt_subgraph_, - thread_pool, - cuda_stream_, - dumper_, - parameters, - GenerationCpuDeviceHelper::CreateGptInputs, - add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds, - topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK, - process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::GreedySearchProcessLogits, - init_greedy_state_func_ ? init_greedy_state_func_ : GenerationCpuDeviceHelper::InitGreedyState, - device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy, - update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds}; + *ctx_internal, + *decoder_session_state, + *gpt_subgraph_, + thread_pool, + cuda_stream_, + dumper_, + parameters, + GenerationCpuDeviceHelper::CreateGptInputs, + add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds, + topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK, + process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::GreedySearchProcessLogits, + init_greedy_state_func_ ? init_greedy_state_func_ : GenerationCpuDeviceHelper::InitGreedyState, + device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy, + update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds}; ORT_RETURN_IF_ERROR(impl.Initialize()); return impl.Execute(*decoder_feeds_fetches_manager_); } else { GreedySearchGpt impl{ - *ctx_internal, - *decoder_session_state, - *gpt_subgraph_, - thread_pool, - cuda_stream_, - dumper_, - parameters, - GenerationCpuDeviceHelper::CreateGptInputs, - add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds, - topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK, - process_logits_fp16_func_, - init_greedy_state_fp16_func_, - device_copy_func_, - update_gpt_feeds_fp16_func_}; + *ctx_internal, + *decoder_session_state, + *gpt_subgraph_, + thread_pool, + cuda_stream_, + dumper_, + parameters, + GenerationCpuDeviceHelper::CreateGptInputs, + add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds, + topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK, + process_logits_fp16_func_, + init_greedy_state_fp16_func_, + device_copy_func_, + update_gpt_feeds_fp16_func_}; ORT_RETURN_IF_ERROR(impl.Initialize()); return impl.Execute(*decoder_feeds_fetches_manager_); diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h index 3c17c190435c1..a98df8c7e2210 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h @@ -16,19 +16,19 @@ template class GreedySearchGpt : public GreedySearchBase { public: GreedySearchGpt(OpKernelContextInternal& context, - const SessionState& decoder_session_state, - GptSubgraph& gpt_subgraph, - concurrency::ThreadPool* thread_pool, - void* cuda_stream, - IConsoleDumper* cuda_dumper, - GreedySearchParameters& params, - const GenerationDeviceHelper::CreateGptInputsFunc& create_inputs_func, - const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func, - const GenerationDeviceHelper::TopkFunc& topk_func, - const GenerationDeviceHelper::GreedySearchProcessLogitsFunc& process_logits_func, - const GenerationDeviceHelper::InitGreedyStateFunc& init_greedy_state_func, - const GenerationDeviceHelper::DeviceCopyFunc& device_copy_func, - const GenerationDeviceHelper::UpdateGptFeedsFunc& update_feeds_func) + const SessionState& decoder_session_state, + GptSubgraph& gpt_subgraph, + concurrency::ThreadPool* thread_pool, + void* cuda_stream, + IConsoleDumper* cuda_dumper, + GreedySearchParameters& params, + const GenerationDeviceHelper::CreateGptInputsFunc& create_inputs_func, + const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func, + const GenerationDeviceHelper::TopkFunc& topk_func, + const GenerationDeviceHelper::GreedySearchProcessLogitsFunc& process_logits_func, + const GenerationDeviceHelper::InitGreedyStateFunc& init_greedy_state_func, + const GenerationDeviceHelper::DeviceCopyFunc& device_copy_func, + const GenerationDeviceHelper::UpdateGptFeedsFunc& update_feeds_func) : GreedySearchBase(context, decoder_session_state, thread_pool, @@ -225,7 +225,7 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager& feeds_fetches_mana // Copy the sequences to output gsl::span output = output_sequences->MutableDataAsSpan(); for (int batch_id = 0; batch_id < parameters->batch_size; ++batch_id) { - auto batch_output = output.subspan(batch_id * parameters->max_length, parameters->max_length); + auto batch_output = output.subspan(batch_id * parameters->max_length, parameters->max_length); gsl::span sequence_source = greedy_state.sequences.GetSequence(batch_id); gsl::copy(sequence_source, batch_output); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_parameters.cc index 5496511de561c..fe42335201715 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_parameters.cc @@ -37,8 +37,7 @@ void GreedySearchParameters::ParseFromInputs(OpKernelContext* context) { num_beams = static_cast(1); auto* repetition_penalty_tensor = context->Input(3); - repetition_penalty = repetition_penalty_tensor ? - static_cast(*repetition_penalty_tensor->Data()) : 1.0f; + repetition_penalty = repetition_penalty_tensor ? static_cast(*repetition_penalty_tensor->Data()) : 1.0f; ORT_ENFORCE(repetition_penalty > 0.0f, "repetition_penalty shall be greater than 0, got ", repetition_penalty); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h index 18edfc3740b83..cdda69d04db9a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h @@ -17,8 +17,8 @@ class T5DecoderSubgraph : public Subgraph { const std::string& attribute_name, const GraphViewer& subgraph_in) : Subgraph(node_in, attribute_name, subgraph_in), has_hidden_state_(false) { - first_present_output_index_ = 1; - } + first_present_output_index_ = 1; + } // Create inputs for first inference of decoder subgraph. Status CreateInitialFeeds( diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h index 0651c5d946138..7d2e44f952106 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h @@ -16,8 +16,8 @@ class T5EncoderSubgraph : public Subgraph { const onnxruntime::Node& node_in, const std::string& attribute_name, const GraphViewer& subgraph_in) : Subgraph(node_in, attribute_name, subgraph_in) { - first_present_output_index_ = 2; - } + first_present_output_index_ = 2; + } // Create inputs for first inference of subgraph. Status CreateInitialFeeds( diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 580e6b22e34c9..2be250fc0d328 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -41,7 +41,13 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { const Tensor* extra_add_qk = context->Input(5); auto& device_prop = GetDeviceProp(); - ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask_index, past, extra_add_qk, device_prop.maxThreadsPerBlock)); + ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), + weights->Shape(), + bias->Shape(), + mask_index, + past, + extra_add_qk, + device_prop.maxThreadsPerBlock)); // input shape (batch_size, sequence_length, input_hidden_size) const auto& shape = input->Shape(); @@ -92,26 +98,32 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { reinterpret_cast(input->Data()), k, &one, reinterpret_cast(gemm_buffer.get()), n, device_prop)); - size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, head_size, sequence_length, past_sequence_length); - auto temp_buffer = GetScratchBuffer(workSpaceSize); + size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, + batch_size, + num_heads_, + head_size, + sequence_length, + past_sequence_length); + + auto work_space = GetScratchBuffer(workSpaceSize); if (!LaunchAttentionKernel( device_prop, Stream(), - reinterpret_cast(gemm_buffer.get()), - nullptr == mask_index ? nullptr : mask_index->Data(), - nullptr == mask_index ? gsl::span() : mask_index->Shape().GetDims(), - output->MutableData(), + cublas, + element_size, batch_size, sequence_length, num_heads_, head_size, - temp_buffer.get(), - cublas, - element_size, - is_unidirectional_, past_sequence_length, + is_unidirectional_, + reinterpret_cast(gemm_buffer.get()), + nullptr == mask_index ? nullptr : mask_index->Data(), + nullptr == mask_index ? gsl::span() : mask_index->Shape().GetDims(), nullptr == past ? nullptr : past->Data(), nullptr == extra_add_qk ? nullptr : extra_add_qk->Data(), + work_space.get(), + output->MutableData(), nullptr == present ? nullptr : present->MutableData())) { // Get last error to reset it to cudaSuccess. CUDA_CALL(cudaGetLastError()); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu b/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu index 3f9ed50e8995a..06d20c91274e0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu @@ -92,7 +92,6 @@ __global__ void ConcatTensorToTensorLarge(const int tensor_add_sequence_length, } } - bool LaunchConcatTensorToTensor(cudaStream_t stream, const int all_sequence_length, const int sequence_length, @@ -109,10 +108,17 @@ bool LaunchConcatTensorToTensor(cudaStream_t stream, const int H = head_size / 2; if (H * num_heads <= max_threads_per_block) { const dim3 block(H, num_heads, 1); - ConcatTensorToTensor<<>>(sequence_length, reinterpret_cast(tensor_in), reinterpret_cast(tensor_add), reinterpret_cast(tensor_out)); + ConcatTensorToTensor<<>>(sequence_length, + reinterpret_cast(tensor_in), + reinterpret_cast(tensor_add), + reinterpret_cast(tensor_out)); } else { const dim3 block(max_threads_per_block / num_heads, num_heads, 1); - ConcatTensorToTensorLarge<<>>(sequence_length, H, reinterpret_cast(tensor_in), reinterpret_cast(tensor_add), reinterpret_cast(tensor_out)); + ConcatTensorToTensorLarge<<>>(sequence_length, + H, + reinterpret_cast(tensor_in), + reinterpret_cast(tensor_add), + reinterpret_cast(tensor_out)); } } else { if (head_size * num_heads <= max_threads_per_block) { @@ -120,9 +126,12 @@ bool LaunchConcatTensorToTensor(cudaStream_t stream, ConcatTensorToTensor<<>>(sequence_length, tensor_in, tensor_add, tensor_out); } else { const dim3 block(max_threads_per_block / num_heads, num_heads, 1); - ConcatTensorToTensorLarge<<>>(sequence_length, head_size, tensor_in, tensor_add, tensor_out); + ConcatTensorToTensorLarge<<>>(sequence_length, + head_size, + tensor_in, + tensor_add, + tensor_out); } - } return CUDA_CALL(cudaPeekAtLastError()); } @@ -143,19 +152,33 @@ bool LaunchConcatTensorToTensor(cudaStream_t stream, const int H = head_size / 4; if (H * num_heads <= max_threads_per_block) { const dim3 block(H, num_heads, 1); - ConcatTensorToTensor<<>>(sequence_length, reinterpret_cast(tensor_in), reinterpret_cast(tensor_add), reinterpret_cast(tensor_out)); + ConcatTensorToTensor<<>>(sequence_length, + reinterpret_cast(tensor_in), + reinterpret_cast(tensor_add), + reinterpret_cast(tensor_out)); } else { const dim3 block(max_threads_per_block / num_heads, num_heads, 1); - ConcatTensorToTensorLarge<<>>(sequence_length, H, reinterpret_cast(tensor_in), reinterpret_cast(tensor_add), reinterpret_cast(tensor_out)); + ConcatTensorToTensorLarge<<>>(sequence_length, + H, + reinterpret_cast(tensor_in), + reinterpret_cast(tensor_add), + reinterpret_cast(tensor_out)); } } else if (0 == (head_size & 1)) { const int H = head_size / 2; if (H * num_heads <= max_threads_per_block) { const dim3 block(H, num_heads, 1); - ConcatTensorToTensor<<>>(sequence_length, reinterpret_cast(tensor_in), reinterpret_cast(tensor_add), reinterpret_cast(tensor_out)); + ConcatTensorToTensor<<>>(sequence_length, + reinterpret_cast(tensor_in), + reinterpret_cast(tensor_add), + reinterpret_cast(tensor_out)); } else { const dim3 block(max_threads_per_block / num_heads, num_heads, 1); - ConcatTensorToTensorLarge<<>>(sequence_length, H, reinterpret_cast(tensor_in), reinterpret_cast(tensor_add), reinterpret_cast(tensor_out)); + ConcatTensorToTensorLarge<<>>(sequence_length, + H, + reinterpret_cast(tensor_in), + reinterpret_cast(tensor_add), + reinterpret_cast(tensor_out)); } } else { // this should be an "odd" case. probably not worth catching it in the half2 kernel. if (head_size * num_heads <= max_threads_per_block) { @@ -163,7 +186,11 @@ bool LaunchConcatTensorToTensor(cudaStream_t stream, ConcatTensorToTensor<<>>(sequence_length, tensor_in, tensor_add, tensor_out); } else { const dim3 block(max_threads_per_block / num_heads, num_heads, 1); - ConcatTensorToTensorLarge<<>>(sequence_length, head_size, tensor_in, tensor_add, tensor_out); + ConcatTensorToTensorLarge<<>>(sequence_length, + head_size, + tensor_in, + tensor_add, + tensor_out); } } return CUDA_CALL(cudaPeekAtLastError()); @@ -180,17 +207,17 @@ bool LaunchConcatPastToPresent(cudaStream_t stream, const float* k_v, float* present) { return LaunchConcatTensorToTensor( - stream, - all_sequence_length, - sequence_length, - batch_size, - head_size, - num_heads, - max_threads_per_block, - 2, - past, - k_v, - present); + stream, + all_sequence_length, + sequence_length, + batch_size, + head_size, + num_heads, + max_threads_per_block, + 2, + past, + k_v, + present); } bool LaunchConcatPastToPresent(cudaStream_t stream, @@ -204,17 +231,17 @@ bool LaunchConcatPastToPresent(cudaStream_t stream, const half* k_v, half* present) { return LaunchConcatTensorToTensor( - stream, - all_sequence_length, - sequence_length, - batch_size, - head_size, - num_heads, - max_threads_per_block, - 2, - past, - k_v, - present); + stream, + all_sequence_length, + sequence_length, + batch_size, + head_size, + num_heads, + max_threads_per_block, + 2, + past, + k_v, + present); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 6343a6bad1b81..ff8f4e45ff305 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -46,11 +46,11 @@ static size_t AlignTo(size_t a, size_t b) { } size_t GetAttentionScratchSize( - size_t element_size, - size_t batch_size, - size_t num_heads, - size_t sequence_length, - size_t all_sequence_length) { + size_t element_size, + size_t batch_size, + size_t num_heads, + size_t sequence_length, + size_t all_sequence_length) { const size_t bytes = element_size * batch_size * num_heads * sequence_length * all_sequence_length; const size_t alignment = 256; @@ -66,18 +66,34 @@ size_t GetAttentionWorkspaceSize( size_t sequence_length, size_t past_sequence_length) { size_t qkv_size = element_size * 3 * batch_size * sequence_length * num_heads * head_size; - return qkv_size + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, past_sequence_length + sequence_length); + return qkv_size + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, + past_sequence_length + sequence_length); } template bool QkvToContext( - const cudaDeviceProp& prop, cublasHandle_t& cublas, cudaStream_t stream, - const int batch_size, const int sequence_length, const int num_heads, const int head_size, const size_t element_size, - const T* input, T* output, T* workspace, - const int* mask_index, gsl::span mask_index_dims, - bool is_unidirectional, int past_sequence_length, const T* past, const T* extra_add_qk, T* present, bool use_persistent_softmax) { + const cudaDeviceProp& prop, + cublasHandle_t& cublas, + cudaStream_t stream, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const size_t element_size, + const T* input, + T* output, + T* workspace, + const int* mask_index, + gsl::span mask_index_dims, + bool is_unidirectional, + int past_sequence_length, + const T* past, + const T* extra_add_qk, + T* present, + bool use_persistent_softmax) { const int all_sequence_length = past_sequence_length + sequence_length; - const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, all_sequence_length); + const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, + sequence_length, all_sequence_length); T* scratch1 = workspace; T* scratch2 = scratch1 + (bytes / element_size); T* scratch3 = scratch2 + (bytes / element_size); @@ -85,7 +101,8 @@ bool QkvToContext( const int max_threads_per_block = prop.maxThreadsPerBlock; // input should be BxSx3xNxH => scratch3: 3xBxNxSxH - if (!LaunchTransQkv(stream, 3, sequence_length, batch_size, head_size, num_heads, max_threads_per_block, false, input, scratch3)) { + if (!LaunchTransQkv(stream, 3, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, false, input, scratch3)) { return false; } @@ -105,7 +122,8 @@ bool QkvToContext( // past_v (BxNxS'xH) + v (BxNxSxH) => present_v (BxNxS*xH) const int present_size_per_batch = all_sequence_length * head_size; if (nullptr != present) { - if (!LaunchConcatPastToPresent(stream, all_sequence_length, sequence_length, batch_size, head_size, num_heads, max_threads_per_block, past, k, present)) { + if (!LaunchConcatPastToPresent(stream, all_sequence_length, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, past, k, present)) { return false; } @@ -128,19 +146,23 @@ bool QkvToContext( float alpha = use_raw_attention_mask ? one : rsqrt_head_size; if (!CUBLAS_CALL(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_T, CUBLAS_OP_N, all_sequence_length, sequence_length, head_size, &alpha, k, head_size, present_size_per_batch, - q, head_size, size_per_batch, &zero, scratch1, all_sequence_length, temp_matrix_size, batches, prop))) { + cublas, CUBLAS_OP_T, CUBLAS_OP_N, + all_sequence_length, sequence_length, head_size, + &alpha, k, head_size, present_size_per_batch, + q, head_size, size_per_batch, + &zero, scratch1, all_sequence_length, temp_matrix_size, batches, prop))) { return false; } // apply softmax and store result P to scratch2: BxNxSxS* if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask const int mask_dimension = static_cast(mask_index_dims.size()); - const int64_t max_sequence_length = mask_dimension == 4 ? mask_index_dims.at(3) : 0; + const int max_sequence_length = mask_dimension == 4 ? static_cast(mask_index_dims.at(3)) : 0; - T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score if persistent softmax is selected. - if (!ComputeSoftmaxWithRawMask(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, nullptr, extra_add_qk, scratch1, scratch2, - is_unidirectional, rsqrt_head_size, mask_dimension, static_cast(max_sequence_length), + T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score for persistent softmax. + if (!ComputeSoftmaxWithRawMask(stream, all_sequence_length, sequence_length, batch_size, num_heads, + mask_index, nullptr, extra_add_qk, scratch1, scratch2, + is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax, persistent_softmax_workspace)) { return false; } @@ -148,94 +170,109 @@ bool QkvToContext( ORT_ENFORCE(mask_index_dims.size() == 1); // mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions. const int* mask_start = (mask_index_dims.at(0) > batch_size) ? mask_index + batch_size : nullptr; - if (!ComputeSoftmaxWithMask1D(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, mask_start, extra_add_qk, scratch1, scratch2, is_unidirectional)) { + if (!ComputeSoftmaxWithMask1D(stream, all_sequence_length, sequence_length, batch_size, num_heads, + mask_index, mask_start, extra_add_qk, scratch1, scratch2, is_unidirectional)) { return false; } } else { // no mask - if (!ComputeSoftmax(stream, all_sequence_length, sequence_length, batch_size, num_heads, extra_add_qk, scratch1, scratch2, is_unidirectional)) { + if (!ComputeSoftmax(stream, all_sequence_length, sequence_length, batch_size, num_heads, extra_add_qk, + scratch1, scratch2, is_unidirectional)) { return false; } } // compute P*V (as V*P), and store in scratch3: BxNxSxH if (!CUBLAS_CALL(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_N, CUBLAS_OP_N, head_size, sequence_length, all_sequence_length, &one, v, head_size, present_size_per_batch, - scratch2, all_sequence_length, temp_matrix_size, &zero, scratch3, head_size, size_per_batch, batches, prop))) { + cublas, CUBLAS_OP_N, CUBLAS_OP_N, + head_size, sequence_length, all_sequence_length, + &one, v, head_size, present_size_per_batch, + scratch2, all_sequence_length, temp_matrix_size, + &zero, scratch3, head_size, size_per_batch, batches, prop))) { return false; } // scratch3 is BxNxSxH, transpose to output BxSxNxH - return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, max_threads_per_block, false, scratch3, output); + return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, false, scratch3, output); } bool LaunchAttentionKernel( const cudaDeviceProp& prop, cudaStream_t stream, - const void* input, - const int* mask_index, - gsl::span mask_index_dims, - void* output, - const int batch_size, - const int sequence_length, - const int num_heads, - const int head_size, - void* workspace, cublasHandle_t& cublas, const size_t element_size, - bool is_unidirectional, + int batch_size, + int sequence_length, + int num_heads, + int head_size, int past_sequence_length, + bool is_unidirectional, + const void* input, + const int* mask_index, + gsl::span mask_index_dims, const void* past, const void* extra_add_qk, + void* workspace, + void* output, void* present) { - - // For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 could enable persistent softmax + // For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 could enable persistent softmax used in Torch. const TransformerOptions* options = TransformerOptions::GetInstance(); bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax(); if (element_size == 2) { - return QkvToContext(prop, cublas, stream, - batch_size, sequence_length, num_heads, head_size, element_size, - reinterpret_cast(input), reinterpret_cast(output), reinterpret_cast(workspace), - mask_index, mask_index_dims, is_unidirectional, - past_sequence_length, reinterpret_cast(past), reinterpret_cast(extra_add_qk), - reinterpret_cast(present), use_persistent_softmax); + return QkvToContext(prop, cublas, stream, batch_size, sequence_length, num_heads, head_size, element_size, + reinterpret_cast(input), + reinterpret_cast(output), + reinterpret_cast(workspace), + mask_index, + mask_index_dims, + is_unidirectional, + past_sequence_length, + reinterpret_cast(past), + reinterpret_cast(extra_add_qk), + reinterpret_cast(present), + use_persistent_softmax); } else { - return QkvToContext(prop, cublas, stream, - batch_size, sequence_length, num_heads, head_size, element_size, - reinterpret_cast(input), reinterpret_cast(output), reinterpret_cast(workspace), - mask_index, mask_index_dims, is_unidirectional, - past_sequence_length, reinterpret_cast(past), reinterpret_cast(extra_add_qk), - reinterpret_cast(present), use_persistent_softmax); + return QkvToContext(prop, cublas, stream, batch_size, sequence_length, num_heads, head_size, element_size, + reinterpret_cast(input), + reinterpret_cast(output), + reinterpret_cast(workspace), + mask_index, + mask_index_dims, + is_unidirectional, + past_sequence_length, + reinterpret_cast(past), + reinterpret_cast(extra_add_qk), + reinterpret_cast(present), + use_persistent_softmax); } } - template bool DecoderQkvToContext( - const cudaDeviceProp& prop, - cudaStream_t stream, - cublasHandle_t& cublas, - const size_t element_size, - const int batch_size, - const int sequence_length, - const int kv_sequence_length, - const int num_heads, - const int head_size, - const bool static_kv, - const bool use_past, - const bool has_layer_state, - const bool has_key_padding_mask, - const T* gemm_query_buffer, - const T* gemm_kv_buffer, - const bool* key_padding_mask, - const T* key_cache, - const T* value_cache, - T* qkv_buffer, - T* workspace_buffer, - T* output, - T* new_key_cache, - T* new_value_cache) -{ + const cudaDeviceProp& prop, + cudaStream_t stream, + cublasHandle_t& cublas, + const size_t element_size, + const int batch_size, + const int sequence_length, + const int kv_sequence_length, + const int num_heads, + const int head_size, + const bool static_kv, + const bool use_past, + const bool has_layer_state, + const bool has_key_padding_mask, + const T* gemm_query_buffer, + const T* gemm_kv_buffer, + const bool* key_padding_mask, + const T* key_cache, + const T* value_cache, + T* qkv_buffer, + T* workspace_buffer, + T* output, + T* new_key_cache, + T* new_value_cache) { const int max_threads_per_block = prop.maxThreadsPerBlock; const int BN = batch_size * num_heads; const int BHN = BN * head_size; @@ -246,8 +283,9 @@ bool DecoderQkvToContext( T* temp_qkv_buffer = workspace_buffer; const T* q = qkv_buffer; - //transpose q and copy them to qkv_buffer - if (!LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, num_heads, max_threads_per_block, true, gemm_query_buffer, qkv_buffer)) { + // transpose q and copy them to qkv_buffer + if (!LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_query_buffer, qkv_buffer)) { return false; } @@ -255,30 +293,41 @@ bool DecoderQkvToContext( const T* v = qkv_buffer + v_buffer_offset; if (!has_layer_state || !use_past) { if (!static_kv) { - //transpose kv and copy them to qkv_buffer - if (!LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)) { + // transpose kv and copy them to qkv_buffer + if (!LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)) { return false; } } else { - //transpose kv and copy them to qkv_buffer - if (!LaunchTransQkv(stream, 2, kv_sequence_length, batch_size, head_size, num_heads, max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)) { + // transpose kv and copy them to qkv_buffer + if (!LaunchTransQkv(stream, 2, kv_sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)) { return false; } } } else { if (!static_kv) { - //transpose kv and copy them to temp_buffer - if (!LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer)) { + // transpose kv and copy them to temp_buffer + if (!LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer)) { return false; } // concat cache-k with k and copy to qkv_buffer - if (nullptr != key_cache && !LaunchConcatTensorToTensor(stream, kv_sequence_length, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, 1, key_cache, temp_qkv_buffer, qkv_buffer + k_buffer_offset)) { + if (nullptr != key_cache && !LaunchConcatTensorToTensor(stream, kv_sequence_length, + sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, 1, + key_cache, + temp_qkv_buffer, + qkv_buffer + k_buffer_offset)) { return false; } // concat cache-v with v and copy to qkv_buffer - if (nullptr != value_cache && !LaunchConcatTensorToTensor(stream, kv_sequence_length, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, 1, value_cache, temp_qkv_buffer + k_buffer_offset, qkv_buffer + v_buffer_offset)) { + if (nullptr != value_cache && !LaunchConcatTensorToTensor(stream, kv_sequence_length, + sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, 1, + value_cache, + temp_qkv_buffer + k_buffer_offset, + qkv_buffer + v_buffer_offset)) { return false; } } @@ -286,11 +335,15 @@ bool DecoderQkvToContext( if (has_layer_state) { if (use_past && static_kv) { - CHECK_CUDA(cudaMemcpyAsync(new_key_cache, key_cache, kv_sequence_length * BHN * sizeof(T), cudaMemcpyDeviceToDevice, stream)); - CHECK_CUDA(cudaMemcpyAsync(new_value_cache, value_cache, kv_sequence_length * BHN * sizeof(T), cudaMemcpyDeviceToDevice, stream)); + CHECK_CUDA(cudaMemcpyAsync(new_key_cache, key_cache, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + CHECK_CUDA(cudaMemcpyAsync(new_value_cache, value_cache, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); } else { - CHECK_CUDA(cudaMemcpyAsync(new_key_cache, k, kv_sequence_length * BHN * sizeof(T), cudaMemcpyDeviceToDevice, stream)); - CHECK_CUDA(cudaMemcpyAsync(new_value_cache, v, kv_sequence_length * BHN * sizeof(T), cudaMemcpyDeviceToDevice, stream)); + CHECK_CUDA(cudaMemcpyAsync(new_key_cache, k, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + CHECK_CUDA(cudaMemcpyAsync(new_value_cache, v, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); } } @@ -313,25 +366,38 @@ bool DecoderQkvToContext( const int strideB = sequence_length * head_size; if (use_past && static_kv) { if (!CUBLAS_CALL(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_T, CUBLAS_OP_N, kv_sequence_length, sequence_length, head_size, &alpha, key_cache, head_size, strideA, - q, head_size, strideB, &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, prop))) { + cublas, CUBLAS_OP_T, CUBLAS_OP_N, + kv_sequence_length, sequence_length, head_size, + &alpha, key_cache, head_size, strideA, + q, head_size, strideB, + &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, prop))) { return false; } } else { if (!CUBLAS_CALL(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_T, CUBLAS_OP_N, kv_sequence_length, sequence_length, head_size, &alpha, k, head_size, strideA, - q, head_size, strideB, &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, prop))) { + cublas, CUBLAS_OP_T, CUBLAS_OP_N, + kv_sequence_length, sequence_length, head_size, + &alpha, k, head_size, strideA, + q, head_size, strideB, + &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, prop))) { return false; } } + constexpr bool is_unidirectional = false; + const T* add_before_softmax = nullptr; if (has_key_padding_mask) { - if (!ComputeSoftmaxWithRawMask(stream, kv_sequence_length, sequence_length, batch_size, num_heads, nullptr, key_padding_mask, nullptr, scratch1, scratch2, - false, 1, 2, static_cast(0), false, nullptr)) { + constexpr int mask_dimension = 2; + constexpr int max_sequence_length = 0; + if (!ComputeSoftmaxWithRawMask(stream, kv_sequence_length, sequence_length, batch_size, num_heads, + nullptr, key_padding_mask, add_before_softmax, scratch1, scratch2, + is_unidirectional, 1.0f, mask_dimension, max_sequence_length, + false, nullptr)) { return false; } } else { - if (!ComputeSoftmax(stream, kv_sequence_length, sequence_length, batch_size, num_heads, nullptr, scratch1, scratch2, false)) { + if (!ComputeSoftmax(stream, kv_sequence_length, sequence_length, batch_size, num_heads, + add_before_softmax, scratch1, scratch2, is_unidirectional)) { return false; } } @@ -339,101 +405,103 @@ bool DecoderQkvToContext( // compute P*V (as V*P), and store in scratch3: BxNxSxH if (use_past && static_kv) { if (!CUBLAS_CALL(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_N, CUBLAS_OP_N, head_size, sequence_length, kv_sequence_length, &one, value_cache, head_size, strideA, - scratch2, kv_sequence_length, temp_matrix_size, &zero, scratch3, head_size, strideB, BN, prop))) { + cublas, CUBLAS_OP_N, CUBLAS_OP_N, + head_size, sequence_length, kv_sequence_length, + &one, value_cache, head_size, strideA, + scratch2, kv_sequence_length, temp_matrix_size, + &zero, scratch3, head_size, strideB, BN, prop))) { return false; } } else { if (!CUBLAS_CALL(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_N, CUBLAS_OP_N, head_size, sequence_length, kv_sequence_length, &one, v, head_size, strideA, - scratch2, kv_sequence_length, temp_matrix_size, &zero, scratch3, head_size, strideB, BN, prop))) { + cublas, CUBLAS_OP_N, CUBLAS_OP_N, + head_size, sequence_length, kv_sequence_length, + &one, v, head_size, strideA, + scratch2, kv_sequence_length, temp_matrix_size, + &zero, scratch3, head_size, strideB, BN, prop))) { return false; } } // scratch3 is BxNxSxH, transpose to output SxBxNxH - return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, max_threads_per_block, true, scratch3, output); + return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, scratch3, output); } - bool LaunchDecoderAttentionKernel( - const cudaDeviceProp& prop, - cudaStream_t stream, - cublasHandle_t& cublas, - const size_t element_size, - const int batch_size, - const int sequence_length, - const int kv_sequence_length, - const int num_heads, - const int head_size, - const bool static_kv, - const bool use_past, - const bool has_layer_state, - const bool has_key_padding_mask, - const void* gemm_query_buffer, - const void* gemm_kv_buffer, - const bool* key_padding_mask, - const void* key_cache, - const void* value_cache, - void* qkv_buffer, - void* workspace_buffer, - void* output, - void* new_key_cache, - void* new_value_cache) -{ - + const cudaDeviceProp& prop, + cudaStream_t stream, + cublasHandle_t& cublas, + const size_t element_size, + const int batch_size, + const int sequence_length, + const int kv_sequence_length, + const int num_heads, + const int head_size, + const bool static_kv, + const bool use_past, + const bool has_layer_state, + const bool has_key_padding_mask, + const void* gemm_query_buffer, + const void* gemm_kv_buffer, + const bool* key_padding_mask, + const void* key_cache, + const void* value_cache, + void* qkv_buffer, + void* workspace_buffer, + void* output, + void* new_key_cache, + void* new_value_cache) { if (element_size == 2) { return DecoderQkvToContext( - prop, - stream, - cublas, - element_size, - batch_size, - sequence_length, - kv_sequence_length, - num_heads, - head_size, - static_kv, - use_past, - has_layer_state, - has_key_padding_mask, - reinterpret_cast(gemm_query_buffer), - reinterpret_cast(gemm_kv_buffer), - key_padding_mask, - reinterpret_cast(key_cache), - reinterpret_cast(value_cache), - reinterpret_cast(qkv_buffer), - reinterpret_cast(workspace_buffer), - reinterpret_cast(output), - reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache) - ); + prop, + stream, + cublas, + element_size, + batch_size, + sequence_length, + kv_sequence_length, + num_heads, + head_size, + static_kv, + use_past, + has_layer_state, + has_key_padding_mask, + reinterpret_cast(gemm_query_buffer), + reinterpret_cast(gemm_kv_buffer), + key_padding_mask, + reinterpret_cast(key_cache), + reinterpret_cast(value_cache), + reinterpret_cast(qkv_buffer), + reinterpret_cast(workspace_buffer), + reinterpret_cast(output), + reinterpret_cast(new_key_cache), + reinterpret_cast(new_value_cache)); } else { return DecoderQkvToContext( - prop, - stream, - cublas, - element_size, - batch_size, - sequence_length, - kv_sequence_length, - num_heads, - head_size, - static_kv, - use_past, - has_layer_state, - has_key_padding_mask, - reinterpret_cast(gemm_query_buffer), - reinterpret_cast(gemm_kv_buffer), - key_padding_mask, - reinterpret_cast(key_cache), - reinterpret_cast(value_cache), - reinterpret_cast(qkv_buffer), - reinterpret_cast(workspace_buffer), - reinterpret_cast(output), - reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache) - ); + prop, + stream, + cublas, + element_size, + batch_size, + sequence_length, + kv_sequence_length, + num_heads, + head_size, + static_kv, + use_past, + has_layer_state, + has_key_padding_mask, + reinterpret_cast(gemm_query_buffer), + reinterpret_cast(gemm_kv_buffer), + key_padding_mask, + reinterpret_cast(key_cache), + reinterpret_cast(value_cache), + reinterpret_cast(qkv_buffer), + reinterpret_cast(workspace_buffer), + reinterpret_cast(output), + reinterpret_cast(new_key_cache), + reinterpret_cast(new_value_cache)); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index e515f323af962..09532670cef65 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -26,50 +26,50 @@ size_t GetAttentionWorkspaceSize( size_t past_sequence_length); bool LaunchAttentionKernel( - const cudaDeviceProp& prop, // Device Properties - cudaStream_t stream, // cuda stream - const void* input, // Input tensor - const int* mask_index, // Attention mask raw data or index (end position of each sequence, or end positions and start positions). NULL means no mask. - gsl::span mask_index_dims, // Mask index shape - void* output, // Output tensor - int batch_size, // Batch size (B) - int sequence_length, // Sequence length (S) - int num_heads, // Number of attention heads (N) - int head_size, // Hidden layer size per head (H) - void* workspace, // Temporary buffer - cublasHandle_t& cublas, // Cublas handle - const size_t element_size, // Element size of input tensor - bool is_unidirectional, // Whether there is unidirecitonal mask. - int past_sequence_length, // Sequence length in past state - const void* past, // Past state input - const void* extra_add_qk, // Additional Add - void* present // Present state output + const cudaDeviceProp& prop, // Device Properties + cudaStream_t stream, // cuda stream + cublasHandle_t& cublas, // Cublas handle + const size_t element_size, // Element size of input tensor + int batch_size, // Batch size (B) + int sequence_length, // Sequence length (S) + int num_heads, // Number of attention heads (N) + int head_size, // Hidden layer size per head (H) + int past_sequence_length, // Sequence length in past state + bool is_unidirectional, // Whether there is unidirecitonal mask. + const void* input, // Input tensor + const int* mask_index, // Attention mask raw data or index. NULL means no mask. + gsl::span mask_index_dims, // Mask index shape + const void* past, // Past state input + const void* extra_add_qk, // Additional Add + void* workspace, // Temporary buffer + void* output, // Output tensor + void* present // Present state output ); bool LaunchDecoderAttentionKernel( - const cudaDeviceProp& prop, // Device Properties - cudaStream_t stream, // Cuda stream - cublasHandle_t& cublas, // Cublas handle - const size_t element_size, // Element size of input tensor - const int batch_size, // Batch size (B) - const int sequence_length, // Sequence length (S) - const int kv_sequence_length, // Key/Value/Cache sequence length - const int num_heads, // Number of attention heads (N) - const int head_size, // Hidden layer size per head (H) - const bool static_kv, // Whether cross attention or not - const bool use_past, // Whether use cache or not - const bool has_layer_state, // Whether output cache or not - const bool has_key_padding_mask, // Whether use key_padding_mask or not - const void* gemm_query_buffer, // Query buffer - const void* gemm_kv_buffer, // Key and value buffer - const bool* key_padding_mask, // Key padding mask - const void* key_cache, // Input key cache - const void* value_cache, // Input value cache - void* qkv_buffer, // Temporary buffer - void* workspace_buffer, // Temporary buffer - void* output, // Output tensor - void* new_key_cache, // New_key_cache tensor - void* new_value_cache // New_value_cache tensor + const cudaDeviceProp& prop, // Device Properties + cudaStream_t stream, // Cuda stream + cublasHandle_t& cublas, // Cublas handle + const size_t element_size, // Element size of input tensor + const int batch_size, // Batch size (B) + const int sequence_length, // Sequence length (S) + const int kv_sequence_length, // Key/Value/Cache sequence length + const int num_heads, // Number of attention heads (N) + const int head_size, // Hidden layer size per head (H) + const bool static_kv, // Whether cross attention or not + const bool use_past, // Whether use cache or not + const bool has_layer_state, // Whether output cache or not + const bool has_key_padding_mask, // Whether use key_padding_mask or not + const void* gemm_query_buffer, // Query buffer + const void* gemm_kv_buffer, // Key and value buffer + const bool* key_padding_mask, // Key padding mask + const void* key_cache, // Input key cache + const void* value_cache, // Input value cache + void* qkv_buffer, // Temporary buffer + void* workspace_buffer, // Temporary buffer + void* output, // Output tensor + void* new_key_cache, // New_key_cache tensor + void* new_value_cache // New_value_cache tensor ); bool LaunchTransCtx(cudaStream_t stream, diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h index d5f908b9220a6..671e26c885442 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h @@ -50,6 +50,8 @@ __device__ inline void Softmax(const int all_sequence_length, float thread_data_max(-CUDART_INF_F); + const bool no_add = (add_before_softmax == nullptr); + // e^x is represented as infinity if x is large enough, like 100.f. // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. // a math transform as below is leveraged to get a stable softmax: @@ -58,7 +60,7 @@ __device__ inline void Softmax(const int all_sequence_length, for (int i = threadIdx.x; i < valid_end; i += TPB) { if (i >= valid_start) { const int index = offset + i; - float input_at_idx = add_before_softmax == nullptr ? float(input[index]) : float(input[index] + add_before_softmax[index]); + float input_at_idx = no_add ? float(input[index]) : float(input[index] + add_before_softmax[index]); if (thread_data_max < input_at_idx) { thread_data_max = input_at_idx; } @@ -77,7 +79,7 @@ __device__ inline void Softmax(const int all_sequence_length, for (int i = threadIdx.x; i < valid_end; i += TPB) { if (i >= valid_start) { const int index = offset + i; - float val = add_before_softmax == nullptr ? input[index] : input[index] + add_before_softmax[index]; + float val = no_add ? input[index] : input[index] + add_before_softmax[index]; thread_data_sum += expf(val - max_block); } } @@ -90,7 +92,7 @@ __device__ inline void Softmax(const int all_sequence_length, for (int i = threadIdx.x; i < all_sequence_length; i += TPB) { const int index = offset + i; - float input_at_idx = add_before_softmax == nullptr ? float(input[index]) : float(input[index] + add_before_softmax[index]); + float input_at_idx = no_add ? float(input[index]) : float(input[index] + add_before_softmax[index]); const float val = (i >= valid_start && i < valid_end) ? expf(input_at_idx - max_block) * sum_reverse_block : 0.f; output[index] = T(val); } @@ -122,7 +124,8 @@ __device__ inline void SoftmaxSmall(const int all_sequence_length, if (is_unidirectional) { int end_unid = all_sequence_length - sequence_length + (blockIdx.x % sequence_length) + 1; if (end_unid <= valid_start) { - // In this situation, mask of [0, end_unid) and [valid_start, valid_end) has -10000, and [end_unid, valid_start) and [valid_end, all_seq_len) has -20000. + // In this situation, mask of [0, end_unid) and [valid_start, valid_end) has -10000, + // and [end_unid, valid_start) and [valid_end, all_seq_len) has -20000. // So [0, end_unid) will also have value after softmax. is_valid = threadIdx.x < end_unid; } else { @@ -136,7 +139,8 @@ __device__ inline void SoftmaxSmall(const int all_sequence_length, // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. // a math transform as below is leveraged to get a stable softmax: // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - float input_data = add_before_softmax == nullptr ? float(input[index]) : float(input[index] + add_before_softmax[index]); + const bool no_add = (add_before_softmax == nullptr); + float input_data = no_add ? float(input[index]) : float(input[index] + add_before_softmax[index]); float thread_data_max = is_valid ? input_data : float(-CUDART_INF_F); const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), end); @@ -197,7 +201,7 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, const int sequence_index = blockIdx.x % sequence_length; if (is_unidirectional) { - int from_index = all_sequence_length - sequence_length + sequence_index; // offset of from token in all sequence length. + int from_index = all_sequence_length - sequence_length + sequence_index; // offset in all sequence length. if (threadIdx.x > from_index) { thread_data = -10000.0f; } @@ -210,7 +214,8 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, } else if (mask_dimension == 3) { mask_offset = (batch_index * sequence_length + sequence_index) * all_sequence_length + threadIdx.x; } else if (mask_dimension == 4) { - mask_offset = (batch_index * max_sequence_length + all_sequence_length - sequence_length + sequence_index) * max_sequence_length + threadIdx.x; + int from_index = all_sequence_length - sequence_length + sequence_index; + mask_offset = (batch_index * max_sequence_length + from_index) * max_sequence_length + threadIdx.x; } if (nullptr == key_padding_mask) { @@ -255,41 +260,59 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, } template -__global__ void SoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, const T* add_before_softmax, const T* input, T* output, bool is_unidirectional) { - SoftmaxSmall(all_sequence_length, sequence_length, all_sequence_length, 0, add_before_softmax, input, output, is_unidirectional); +__global__ void SoftmaxKernelSmall(const int all_sequence_length, + const int sequence_length, + const T* add_before_softmax, + const T* input, + T* output, + bool is_unidirectional) { + SoftmaxSmall(all_sequence_length, sequence_length, all_sequence_length, 0, + add_before_softmax, input, output, is_unidirectional); } template -__global__ void SoftmaxKernel(const int all_sequence_length, const int sequence_length, const T* add_before_softmax, const T* input, T* output) { - Softmax(all_sequence_length, sequence_length, all_sequence_length, 0, add_before_softmax, input, output); +__global__ void SoftmaxKernel(const int all_sequence_length, + const int sequence_length, + const T* add_before_softmax, + const T* input, + T* output) { + Softmax(all_sequence_length, sequence_length, all_sequence_length, 0, + add_before_softmax, input, output); } template -bool ComputeSoftmax( - cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, - const T* add_before_softmax, const T* input, T* output, bool is_unidirectional) { +bool ComputeSoftmax(cudaStream_t stream, const int all_sequence_length, const int sequence_length, + const int batch_size, const int num_heads, + const T* add_before_softmax, const T* input, T* output, bool is_unidirectional) { const dim3 grid(sequence_length * num_heads, batch_size, 1); if (all_sequence_length <= 32) { const int blockSize = 32; - SoftmaxKernelSmall<<>>(all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); + SoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 64) { const int blockSize = 64; - SoftmaxKernelSmall<<>>(all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); + SoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 128) { const int blockSize = 128; - SoftmaxKernelSmall<<>>(all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); + SoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 256) { const int blockSize = 256; - SoftmaxKernelSmall<<>>(all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); + SoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 512) { const int blockSize = 512; - SoftmaxKernelSmall<<>>(all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); + SoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 1024) { const int blockSize = 1024; - SoftmaxKernelSmall<<>>(all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); + SoftmaxKernelSmall<<>>( + all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); } else if (!is_unidirectional) { const int blockSize = 1024; - SoftmaxKernel<<>>(all_sequence_length, sequence_length, add_before_softmax, input, output); + SoftmaxKernel<<>>( + all_sequence_length, sequence_length, add_before_softmax, input, output); } else { ORT_THROW("Attention CUDA operator does not support total sequence length > 1024."); } @@ -298,7 +321,14 @@ bool ComputeSoftmax( } template -__global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, const T* add_before_softmax, const T* input, T* output, bool is_unidirectional) { +__global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, + const int sequence_length, + const int* mask_end, + const int* mask_start, + const T* add_before_softmax, + const T* input, + T* output, + bool is_unidirectional) { __shared__ int start_position; __shared__ int end_position; @@ -315,11 +345,17 @@ __global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, const in } __syncthreads(); - SoftmaxSmall(all_sequence_length, sequence_length, end_position, start_position, add_before_softmax, input, output, is_unidirectional); + SoftmaxSmall(all_sequence_length, sequence_length, end_position, start_position, + add_before_softmax, input, output, is_unidirectional); } template -__global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, const T* add_before_softmax, const T* input, T* output) { +__global__ void MaskedSoftmaxKernel(const int all_sequence_length, + const int sequence_length, + const int* mask_end, + const int* mask_start, + const T* add_before_softmax, + const T* input, T* output) { __shared__ int start_position; __shared__ int end_position; @@ -336,49 +372,79 @@ __global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int seq } __syncthreads(); - Softmax(all_sequence_length, sequence_length, end_position, start_position, add_before_softmax, input, output); + Softmax(all_sequence_length, sequence_length, end_position, start_position, + add_before_softmax, input, output); } template -__global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length, const int sequence_length, const int* attention_mask, const bool* key_padding_mask, const T* add_before_softmax, const T* input, - T* output, const bool is_unidirectional, const float rsqrt_head_size, const int mask_dimension, const int max_sequence_length, +__global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length, + const int sequence_length, + const int* attention_mask, + const bool* key_padding_mask, + const T* add_before_softmax, + const T* input, + T* output, + const bool is_unidirectional, + const float rsqrt_head_size, + const int mask_dimension, + const int max_sequence_length, const bool skip_softmax) { - SoftmaxWithRawMaskSmall(all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, output, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, skip_softmax); + SoftmaxWithRawMaskSmall( + all_sequence_length, sequence_length, + attention_mask, key_padding_mask, add_before_softmax, input, output, + is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, + skip_softmax); } template -bool ComputeSoftmaxWithMask1D(cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, - const int* mask_index, const int* mask_start, const T* add_before_softmax, const T* input, T* output, const bool is_unidirectional) { +bool ComputeSoftmaxWithMask1D(cudaStream_t stream, + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int num_heads, + const int* mask_index, + const int* mask_start, + const T* add_before_softmax, + const T* input, + T* output, + const bool is_unidirectional) { const dim3 grid(sequence_length * num_heads, batch_size, 1); if (all_sequence_length <= 32) { const int blockSize = 32; MaskedSoftmaxKernelSmall - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional); + <<>>(all_sequence_length, sequence_length, mask_index, mask_start, + add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 64) { const int blockSize = 64; MaskedSoftmaxKernelSmall - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional); + <<>>(all_sequence_length, sequence_length, mask_index, mask_start, + add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 128) { const int blockSize = 128; MaskedSoftmaxKernelSmall - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional); + <<>>(all_sequence_length, sequence_length, mask_index, mask_start, + add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 256) { const int blockSize = 256; MaskedSoftmaxKernelSmall - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional); + <<>>(all_sequence_length, sequence_length, mask_index, mask_start, + add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 512) { const int blockSize = 512; MaskedSoftmaxKernelSmall - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional); + <<>>(all_sequence_length, sequence_length, mask_index, mask_start, + add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 1024) { const int blockSize = 1024; MaskedSoftmaxKernelSmall - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional); + <<>>(all_sequence_length, sequence_length, mask_index, mask_start, + add_before_softmax, input, output, is_unidirectional); } else if (!is_unidirectional) { const int blockSize = 1024; MaskedSoftmaxKernel - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output); + <<>>(all_sequence_length, sequence_length, mask_index, mask_start, + add_before_softmax, input, output); } else { ORT_THROW("Attention CUDA operator does not support total sequence length > 1024."); } @@ -387,48 +453,83 @@ bool ComputeSoftmaxWithMask1D(cudaStream_t stream, const int all_sequence_length } template -bool ComputeSoftmaxWithRawMask(cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, - const int* attention_mask, const bool* key_padding_mask, const T* add_before_softmax, const T* input, T* output, const bool is_unidirectional, - const float rsqrt_head_size, const int mask_dimension, const int max_sequence_length, const bool use_persistent_softmax, T* persistent_softmax_workspace) { +bool ComputeSoftmaxWithRawMask(cudaStream_t stream, + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int num_heads, + const int* attention_mask, + const bool* key_padding_mask, + const T* add_before_softmax, + const T* input, + T* output, + const bool is_unidirectional, + const float rsqrt_head_size, + const int mask_dimension, + const int max_sequence_length, + const bool use_persistent_softmax, + T* persistent_softmax_workspace) { const dim3 grid(sequence_length * num_heads, batch_size, 1); T* out = use_persistent_softmax ? persistent_softmax_workspace : output; if (all_sequence_length <= 32) { const int blockSize = 32; SoftmaxWithRawMaskSmallKernel - <<>>(all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax); + <<>>(all_sequence_length, sequence_length, + attention_mask, key_padding_mask, add_before_softmax, input, out, + is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax); } else if (all_sequence_length <= 64) { const int blockSize = 64; SoftmaxWithRawMaskSmallKernel - <<>>(all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax); + <<>>(all_sequence_length, sequence_length, + attention_mask, key_padding_mask, add_before_softmax, input, out, + is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax); } else if (all_sequence_length <= 128) { const int blockSize = 128; SoftmaxWithRawMaskSmallKernel - <<>>(all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax); + <<>>(all_sequence_length, sequence_length, + attention_mask, key_padding_mask, add_before_softmax, input, out, + is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax); } else if (all_sequence_length <= 256) { const int blockSize = 256; SoftmaxWithRawMaskSmallKernel - <<>>(all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax); + <<>>(all_sequence_length, sequence_length, + attention_mask, key_padding_mask, add_before_softmax, input, out, + is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax); } else if (all_sequence_length <= 512) { const int blockSize = 512; SoftmaxWithRawMaskSmallKernel - <<>>(all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax); + <<>>(all_sequence_length, sequence_length, + attention_mask, key_padding_mask, add_before_softmax, input, out, + is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax); } else if (all_sequence_length <= 1024) { const int blockSize = 1024; SoftmaxWithRawMaskSmallKernel - <<>>(all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax); + <<>>(all_sequence_length, sequence_length, + attention_mask, key_padding_mask, add_before_softmax, input, out, + is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax); } else { ORT_THROW("Attention CUDA operator does not support total sequence length > 1024."); } if (use_persistent_softmax) { - dispatch_warpwise_softmax_forward(stream, output, persistent_softmax_workspace, all_sequence_length, all_sequence_length, batch_size * num_heads * sequence_length); + dispatch_warpwise_softmax_forward(stream, + output, + persistent_softmax_workspace, + all_sequence_length, + all_sequence_length, + batch_size * num_heads * sequence_length); } return CUDA_CALL(cudaPeekAtLastError()); } - } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu index ce203ac8d8d70..1a4fd80b78d12 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu @@ -252,7 +252,6 @@ bool LaunchTransQkv(cudaStream_t stream, const int matrix_num, const dim3 block(max_threads_per_block / num_heads, num_heads, 1); TransposeQKVLarge<<>>(head_size, reversed_bs, input, output); } - } return CUDA_CALL(cudaPeekAtLastError()); } @@ -295,7 +294,6 @@ bool LaunchTransQkv(cudaStream_t stream, const int matrix_num, return CUDA_CALL(cudaPeekAtLastError()); } - } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc index 411f3ff6f2730..95e2caefc79f5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc @@ -104,7 +104,8 @@ Status CheckInputs(const TensorShape& query_shape, const auto& kp_mask_dims = key_padding_mask->Shape().GetDims(); if (kp_mask_dims.size() != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key_padding_mask' is expected to have 2 dimension, got ", + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key_padding_mask' is expected to have 2 dimension, got ", kp_mask_dims.size()); } @@ -123,7 +124,8 @@ Status CheckInputs(const TensorShape& query_shape, } if (kp_mask_dims[1] != key_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "key_padding_mask shall have same sequence length as generated key"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "key_padding_mask shall have same sequence length as generated key"); } } @@ -188,10 +190,14 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { // Copy static_kv, use_past and has_layer_state to CPU auto pinned_buffer = AllocateBufferOnCPUPinned(4 * sizeof(bool)); bool* kernel_state_pinned = reinterpret_cast(pinned_buffer.get()); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(kernel_state_pinned, static_kv->Data(), sizeof(bool), cudaMemcpyDeviceToHost, stream)); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(kernel_state_pinned + 1, use_past->Data(), sizeof(bool), cudaMemcpyDeviceToHost, stream)); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(kernel_state_pinned + 2, has_layer_state->Data(), sizeof(bool), cudaMemcpyDeviceToHost, stream)); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(kernel_state_pinned + 3, has_key_padding_mask->Data(), sizeof(bool), cudaMemcpyDeviceToHost, stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(kernel_state_pinned, static_kv->Data(), sizeof(bool), + cudaMemcpyDeviceToHost, stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(kernel_state_pinned + 1, use_past->Data(), sizeof(bool), + cudaMemcpyDeviceToHost, stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(kernel_state_pinned + 2, has_layer_state->Data(), sizeof(bool), + cudaMemcpyDeviceToHost, stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(kernel_state_pinned + 3, has_key_padding_mask->Data(), sizeof(bool), + cudaMemcpyDeviceToHost, stream)); // Create an event to make sure the async copy is finished before reading the data. AutoDestoryCudaEvent new_event; @@ -342,8 +348,11 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { } } - auto qkv_buffer_p = GetScratchBuffer(batch_size * (sequence_length + 2 * kv_sequence_length) * hidden_size * element_size); - auto workspace_p = GetScratchBuffer(2 * batch_size * sequence_length * num_heads_ * element_size * (2 * head_size + kv_sequence_length)); + size_t bytes = element_size * batch_size * (sequence_length + 2 * kv_sequence_length) * hidden_size; + auto qkv_buffer_p = GetScratchBuffer(bytes); + + bytes = element_size * 2 * batch_size * sequence_length * num_heads_ * (2 * head_size + kv_sequence_length); + auto workspace_p = GetScratchBuffer(bytes); Tensor* output(context->Output(0, query_shape)); TensorShape new_cache_shape({batch_size, num_heads_, kv_sequence_length, head_size}); diff --git a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc index 1dfca4f459995..4e279f389f0fa 100644 --- a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc @@ -43,8 +43,8 @@ Status EmbedLayerNorm::ComputeInternal(OpKernelContext* context) const { const Tensor* segment_embedding = context->Input(4); // optional. nullptr if it's distill-bert const Tensor* gamma = context->Input(5); const Tensor* beta = context->Input(6); - const Tensor* mask = context->Input(7); // optional. nullptr if not provided - const Tensor* position_ids = context->Input(8); // optional. nullptr if not provided + const Tensor* mask = context->Input(7); // optional. nullptr if not provided + const Tensor* position_ids = context->Input(8); // optional. nullptr if not provided const auto& input_dims = input_ids->Shape().GetDims(); int64_t hidden_size = word_embedding->Shape()[1]; @@ -88,6 +88,6 @@ Status EmbedLayerNorm::ComputeInternal(OpKernelContext* context) const { return Status::OK(); } -} //namespace cuda +} // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu index f718bc643e2b5..27ccca6c245d3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu @@ -85,7 +85,11 @@ __global__ void MaskIndexKernel(int sequence_length, const int* mask, int* mask_ } } -inline bool ComputeMaskIndex(cudaStream_t stream, const int sequence_length, const int batch_size, const int* mask, int* mask_index) { +inline bool ComputeMaskIndex(cudaStream_t stream, + const int sequence_length, + const int batch_size, + const int* mask, + int* mask_index) { // Mask idx is of length batch_size and assumes the valid region is contiguous starting // from the beginning of the sequence @@ -178,7 +182,9 @@ bool EmbedSkipLayerNorm( const dim3 block(tpb, 1, 1); EmbedLayerNormKernel - <<>>(hidden_size, input_ids, segment_ids, beta, gamma, word_embedding, position_embedding, segment_embedding, epsilon, output, embedding_sum, position_ids); + <<>>(hidden_size, input_ids, segment_ids, beta, gamma, + word_embedding, position_embedding, segment_embedding, + epsilon, output, embedding_sum, position_ids); return CUDA_CALL(cudaPeekAtLastError()); } diff --git a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.h b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.h index 2358bef875943..64f8be8fd3f1d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.h @@ -7,23 +7,23 @@ namespace contrib { namespace cuda { bool LaunchEmbedLayerNormKernel(cudaStream_t stream, - void* output, // output tensor - void* mask_index, // output mask index - const int* input_ids, // input word IDs - const int* segment_ids, // input segment IDs - const int* input_mask, // input mask - const void* gamma, // weight for layer normalization - const void* beta, // bias for layer normalization - const void* word_embedding, // weights for word embeddings - const void* position_embedding, // weights for position embeddings - const void* segment_embedding, // weights for segment (like sentence) embeddings - float epsilon, // epsilon for layer normalization - const int hidden_size, // hidden size (that is head_size * num_heads) - int batch_size, // batch size - int sequence_length, // sequence length - const size_t element_size, // size of element in output tensor. 2 for half, 4 for float. - void* embedding_sum, // Optional output of sum of embeddings - const int* position_ids); // Optional input of position ids + void* output, // output tensor + void* mask_index, // output mask index + const int* input_ids, // input word IDs + const int* segment_ids, // input segment IDs + const int* input_mask, // input mask + const void* gamma, // weight for layer normalization + const void* beta, // bias for layer normalization + const void* word_embedding, // weights for word embeddings + const void* position_embedding, // weights for position embeddings + const void* segment_embedding, // weights for segment (like sentence) embeddings + float epsilon, // epsilon for layer normalization + const int hidden_size, // hidden size (that is head_size * num_heads) + int batch_size, // batch size + int sequence_length, // sequence length + const size_t element_size, // size of output element: 2 for half, 4 for float. + void* embedding_sum, // Optional output of sum of embeddings + const int* position_ids); // Optional input of position ids } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc index 3b7b47bed3416..1b795d7682235 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc @@ -65,6 +65,6 @@ Status FastGelu::ComputeInternal(OpKernelContext* context) const { return Status::OK(); } -} //namespace cuda +} // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h index 8155ece5b1c38..3e642a70afef5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h @@ -16,7 +16,7 @@ class FastGelu final : public CudaKernel { public: FastGelu(const OpKernelInfo& op_kernel_info); Status ComputeInternal(OpKernelContext* ctx) const override; - + private: bool use_half2_; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu index 8bfb7972d4103..ebcf61dff486b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu @@ -1,7 +1,7 @@ /* The implementation of this file is based on gelu plugin in TensorRT demo: https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ - + Copyright 2019 NVIDIA Corporation Licensed under the Apache License, Version 2.0 (the "License"); @@ -40,7 +40,8 @@ constexpr float B = 0.7978845608028654f; // sqrt(2.0/M_PI) constexpr float C = 0.035677408136300125f; // 0.044715 * sqrt(2.0/M_PI) template -__global__ void FastGeluKernel(const T a, const T b, const T c, int input_length, int bias_length, const T* input, const T* bias, T* output) { +__global__ void FastGeluKernel(const T a, const T b, const T c, int input_length, int bias_length, + const T* input, const T* bias, T* output) { const int idx = blockIdx.x * TPB + threadIdx.x; if (idx < input_length) { @@ -52,7 +53,8 @@ __global__ void FastGeluKernel(const T a, const T b, const T c, int input_length } template -__global__ void FastGeluKernel2(const half2 a, const half2 b, const half2 c, int input_length, int bias_length, const half2* input, const half2* bias, half2* output) { +__global__ void FastGeluKernel2(const half2 a, const half2 b, const half2 c, int input_length, int bias_length, + const half2* input, const half2* bias, half2* output) { const int idx = blockIdx.x * TPB + threadIdx.x; if (idx < input_length) { const half2 x = input[idx]; @@ -63,16 +65,19 @@ __global__ void FastGeluKernel2(const half2 a, const half2 b, const half2 c, int } template <> -bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, const float* input, const float* bias, float* output, bool /*use_half2*/) { +bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, + const float* input, const float* bias, float* output, bool /*use_half2*/) { constexpr int blockSize = 256; const int gridSize = (input_length + blockSize - 1) / blockSize; - FastGeluKernel<<>>(A, B, C, input_length, bias_length, input, bias, output); + FastGeluKernel<<>>(A, B, C, input_length, bias_length, + input, bias, output); return CUDA_CALL(cudaPeekAtLastError()); } template <> -bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, const half* input, const half* bias, half* output, bool use_half2) { +bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, + const half* input, const half* bias, half* output, bool use_half2) { constexpr int blockSize = 256; if (use_half2 && 0 == (bias_length & 1) && prop.major >= 7) { const int n = input_length / 2; @@ -83,10 +88,12 @@ bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int i const half2* input2 = reinterpret_cast(input); const half2* bias2 = reinterpret_cast(bias); half2* output2 = reinterpret_cast(output); - FastGeluKernel2<<>>(A2, B2, C2, n, bias_length / 2, input2, bias2, output2); + FastGeluKernel2<<>>(A2, B2, C2, n, bias_length / 2, + input2, bias2, output2); } else { const int gridSize = (input_length + blockSize - 1) / blockSize; - FastGeluKernel<<>>(A, B, C, input_length, bias_length, input, bias, output); + FastGeluKernel<<>>(A, B, C, input_length, bias_length, + input, bias, output); } return CUDA_CALL(cudaPeekAtLastError()); @@ -98,7 +105,7 @@ bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int i constexpr int blockSize = 256; // remove nv_bfloat162 implementation for now to fix build issue - // we can decide whether to add it back if there's perf concern + // we can decide whether to add it back if there's perf concern const int gridSize = (input_length + blockSize - 1) / blockSize; FastGeluKernel <<>>(A, B, C, input_length, bias_length, input, bias, output); diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.h b/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.h index 9618c52bb3247..ed17eb0f194d2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.h @@ -8,7 +8,8 @@ namespace contrib { namespace cuda { template -bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, const T* input, const T* bias, T* output, bool use_half2); +bool LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, + const T* input, const T* bias, T* output, bool use_half2); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh b/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh index 6d29e67e66ee6..0f1dbe5ab6b36 100644 --- a/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh @@ -61,18 +61,21 @@ __device__ inline half2 AddHalf2(const half2 a, const half2 b) { } struct KeyValuePairSum { - __device__ inline cub::KeyValuePair operator()(const cub::KeyValuePair& a, const cub::KeyValuePair& b) { + __device__ inline cub::KeyValuePair operator()(const cub::KeyValuePair& a, + const cub::KeyValuePair& b) { return cub::KeyValuePair(a.key + b.key, a.value + b.value); } - __device__ inline cub::KeyValuePair operator()(const cub::KeyValuePair& a, const cub::KeyValuePair& b) { + __device__ inline cub::KeyValuePair operator()(const cub::KeyValuePair& a, + const cub::KeyValuePair& b) { const half2 a2 = __halves2half2(a.key, a.value); const half2 b2 = __halves2half2(b.key, b.value); const half2 res = AddHalf2(a2, b2); return cub::KeyValuePair(__low2half(res), __high2half(res)); } - __device__ inline cub::KeyValuePair operator()(const cub::KeyValuePair& a, const cub::KeyValuePair& b) { + __device__ inline cub::KeyValuePair operator()(const cub::KeyValuePair& a, + const cub::KeyValuePair& b) { return cub::KeyValuePair(AddHalf2(a.key, b.key), AddHalf2(a.value, b.value)); } }; @@ -139,10 +142,11 @@ __device__ inline void LayerNormSmall(const T* input_v, const cub::KeyValuePair< __syncthreads(); if (ILP * threadIdx.x < ld) { - #pragma unroll +#pragma unroll for (int i = 0; i < ILP; i++) { - output_v[i] = (beta != nullptr) ? gamma_v[i] * (input_v[i] - mu) * rsigma + beta_v[i] : - gamma_v[i] * (input_v[i] - mu) * rsigma; + output_v[i] = (beta != nullptr) + ? gamma_v[i] * (input_v[i] - mu) * rsigma + beta_v[i] + : gamma_v[i] * (input_v[i] - mu) * rsigma; } *(reinterpret_cast(&output[idx])) = *output_val; } diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu index de50d8cabebaa..3c9be2d6bf394 100644 --- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu @@ -63,7 +63,7 @@ namespace cuda { // [scratch1: BxNxSxS] [scratch2: BxNxSxS] static size_t Align(size_t a) { - const size_t alignment = 128; // Align on a 16-byte boundary to avoid "misaligned address" error. + const size_t alignment = 128; // Align on a 16-byte boundary to avoid "misaligned address" error. return CeilDiv(a, alignment) * alignment; } @@ -88,7 +88,7 @@ size_t GetLongformerSoftmaxWorkspaceSize( size_t scratch2_size = GetScratch2Size(); return Align(scratch1_size + scratch2_size); } else { - return static_cast(2) * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, sequence_length); + return 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, sequence_length); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.h index d58eaa07ac7c9..57ea44de351fc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.h @@ -45,8 +45,7 @@ bool LaunchLongformerAttentionKernel( const size_t element_size, // Element size of input tensor, bool disable_compact_memory, // Disable compact memory kernel bool use_merged_qkv_weights, - bool use_half4 -); + bool use_half4); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_softmax.cu b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_softmax.cu index 96f386ac98723..8b7b66c96e54e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_softmax.cu +++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_softmax.cu @@ -309,10 +309,11 @@ bool LaunchLongformerSoftmaxSimpleKernel( } else { // sequence_length > 2 * w for (int i = 0; i < batch_size; ++i) { for (int j = 0; j < num_heads; ++j) { - const void* q_head = reinterpret_cast(q) + \ + const void* q_head = reinterpret_cast(q) + (i * x_offset + j * sequence_length * head_size + w * head_size) * element_size; - const void* k_head = reinterpret_cast(k) + (i * x_offset + j * sequence_length * head_size) * element_size; - void* qk_head = reinterpret_cast(scratch1) + \ + const void* k_head = reinterpret_cast(k) + + (i * x_offset + j * sequence_length * head_size) * element_size; + void* qk_head = reinterpret_cast(scratch1) + (i * y_offset + j * sequence_length * sequence_length + w * sequence_length) * element_size; int count = (sequence_length - 2 * w) / w; CHECK(cublasGemmStridedBatchedEx(cublas, @@ -367,7 +368,7 @@ bool LaunchLongformerSoftmaxSimpleKernel( const void* q_head = reinterpret_cast(q) + (last_block * w * head_size) * element_size; const void* k_head = reinterpret_cast(k) + ((last_block - 1) * w * head_size) * element_size; - void* qk_head = reinterpret_cast(scratch1) + \ + void* qk_head = reinterpret_cast(scratch1) + (last_block * w * sequence_length + (last_block - 1) * w) * element_size; CHECK(cublasGemmStridedBatchedEx(cublas, CUBLAS_OP_T, @@ -426,7 +427,7 @@ bool LaunchLongformerSoftmaxSimpleKernel( resultType, algo)); - const void* global_q_batch = reinterpret_cast(global_q) + \ + const void* global_q_batch = reinterpret_cast(global_q) + (i * num_heads * sequence_length * head_size) * element_size; const void* global_k_batch = reinterpret_cast(global_k) + (i * x_offset) * element_size; int strideB_global = sequence_length * head_size; @@ -517,11 +518,11 @@ bool LaunchLongformerSoftmaxSimpleKernel( } else { // sequence_length > 2 * w for (int i = 0; i < batch_size; ++i) { for (int j = 0; j < num_heads; ++j) { - const void* v_head = reinterpret_cast(v) + \ + const void* v_head = reinterpret_cast(v) + (i * x_offset + j * head_size * sequence_length) * element_size; - const void* prob_head = reinterpret_cast(softmax_out) + \ - (i * y_offset + j * sequence_length * sequence_length + w * sequence_length) * element_size; - void* out_head = reinterpret_cast(output) + \ + size_t offset = (i * y_offset + j * sequence_length * sequence_length + w * sequence_length) * element_size; + const void* prob_head = reinterpret_cast(softmax_out) + offset; + void* out_head = reinterpret_cast(output) + (i * x_offset + j * head_size * sequence_length + w * head_size) * element_size; int count = (sequence_length - 2 * w) / w; CHECK(cublasGemmStridedBatchedEx(cublas, @@ -575,7 +576,7 @@ bool LaunchLongformerSoftmaxSimpleKernel( algo)); const void* v_head = reinterpret_cast(v) + (last_block - 1) * w * head_size * element_size; - const void* prob_head = reinterpret_cast(softmax_out) + \ + const void* prob_head = reinterpret_cast(softmax_out) + (sequence_length * last_block * w + (last_block - 1) * w) * element_size; void* out_head = reinterpret_cast(output) + last_block * w * head_size * element_size; @@ -609,7 +610,7 @@ bool LaunchLongformerSoftmaxSimpleKernel( int glob_longdim_mm = (last_block - 1) * w; const void* v_head = reinterpret_cast(v) + (i * x_offset) * element_size; - const void* prob_head = reinterpret_cast(softmax_out) + \ + const void* prob_head = reinterpret_cast(softmax_out) + (i * y_offset + 2 * w * sequence_length) * element_size; void* out_head = reinterpret_cast(output) + (i * x_offset + 2 * w * head_size) * element_size; diff --git a/onnxruntime/contrib_ops/cuda/bert/ngram_repeat_block.cc b/onnxruntime/contrib_ops/cuda/bert/ngram_repeat_block.cc index 53fb3ffa9d923..b9a8613c852a3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/ngram_repeat_block.cc +++ b/onnxruntime/contrib_ops/cuda/bert/ngram_repeat_block.cc @@ -66,6 +66,6 @@ Status NGramRepeatBlock::ComputeInternal(OpKernelContext* context) const { return Status::OK(); } -} //namespace cuda +} // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/ngram_repeat_block.h b/onnxruntime/contrib_ops/cuda/bert/ngram_repeat_block.h index eb219916ba590..139447d37b7b4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/ngram_repeat_block.h +++ b/onnxruntime/contrib_ops/cuda/bert/ngram_repeat_block.h @@ -15,6 +15,7 @@ class NGramRepeatBlock final : public CudaKernel { public: NGramRepeatBlock(const OpKernelInfo& op_kernel_info); Status ComputeInternal(OpKernelContext* ctx) const override; + private: int64_t ngram_size_; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/ngram_repeat_block_impl.cu b/onnxruntime/contrib_ops/cuda/bert/ngram_repeat_block_impl.cu index 887f32721c5bc..8a04ede231a27 100644 --- a/onnxruntime/contrib_ops/cuda/bert/ngram_repeat_block_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/ngram_repeat_block_impl.cu @@ -33,8 +33,8 @@ __global__ void banRepeatedTokens(const int64_t* __restrict__ tokens, extern __shared__ int64_t tokens_shm[]; tokens_shm[col] = tokens[start]; if (col == blockDim.x - 1) { - for (int i=1; i::ComputeInternal(OpKernelContext* ctx) const { return Status::OK(); } -} //namespace cuda +} // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu index 49bebfec0752f..6766a4908c3e4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu @@ -35,15 +35,15 @@ namespace onnxruntime { namespace contrib { namespace cuda { -template +template T maybe2half(float x); -template<> +template <> float maybe2half(float x) { return x; } -template<> +template <> half maybe2half(float x) { return __float2half_rn(x); } @@ -98,9 +98,9 @@ __global__ void SkipLayerNormKernelSmall( if (ILP * threadIdx.x < ld) { T rldval_sum = T(0.f); T rldvalsq_sum = T(0.f); - #pragma unroll +#pragma unroll for (int i = 0; i < ILP; i++) { - input_v[i] += hasBias ? skip_v[i] + bias_v[i]: skip_v[i]; + input_v[i] += hasBias ? skip_v[i] + bias_v[i] : skip_v[i]; const T rldval = rld * input_v[i]; rldval_sum += rldval; rldvalsq_sum += rldval * input_v[i]; @@ -115,7 +115,6 @@ bool LaunchSkipLayerNormKernel( cudaStream_t stream, T* output, const T* input, const T* skip, const T* gamma, const T* beta, const T* bias, float epsilon, const int ld, const int element_count, size_t element_size) { - // this must be true because n is the total size of the tensor assert(element_count % ld == 0); bool hasBias = (bias == nullptr) ? false : true; @@ -125,37 +124,37 @@ bool LaunchSkipLayerNormKernel( constexpr int block_size = 32; SkipLayerNormKernelSmall <<>>(ld, input, skip, beta, gamma, bias, - maybe2half(epsilon), output, hasBias); + maybe2half(epsilon), output, hasBias); } else if (ld <= 64) { constexpr int block_size = 64 / 2; SkipLayerNormKernelSmall <<>>(ld, input, skip, beta, gamma, bias, - maybe2half(epsilon), output, hasBias); + maybe2half(epsilon), output, hasBias); } else if (ld <= 128) { constexpr int block_size = 128 / 4; SkipLayerNormKernelSmall <<>>(ld, input, skip, beta, gamma, bias, - maybe2half(epsilon), output, hasBias); + maybe2half(epsilon), output, hasBias); } else if (ld <= 384) { constexpr int block_size = 384 / 4; SkipLayerNormKernelSmall <<>>(ld, input, skip, beta, gamma, bias, - maybe2half(epsilon), output, hasBias); + maybe2half(epsilon), output, hasBias); } else if (ld <= 768) { constexpr int block_size = 768 / 4; SkipLayerNormKernelSmall <<>>(ld, input, skip, beta, gamma, bias, - maybe2half(epsilon), output, hasBias); + maybe2half(epsilon), output, hasBias); } else if (ld <= 1024) { constexpr int block_size = 1024 / 4; SkipLayerNormKernelSmall <<>>(ld, input, skip, beta, gamma, bias, - maybe2half(epsilon), output, hasBias); + maybe2half(epsilon), output, hasBias); } else { constexpr int block_size = 256; SkipLayerNormKernel <<>>(ld, input, skip, beta, gamma, bias, - maybe2half(epsilon), output); + maybe2half(epsilon), output); } } else { const int grid_size = element_count / ld; @@ -163,27 +162,27 @@ bool LaunchSkipLayerNormKernel( constexpr int block_size = 32; SkipLayerNormKernelSmall <<>>(ld, input, skip, beta, gamma, bias, - maybe2half(epsilon), output, hasBias); + maybe2half(epsilon), output, hasBias); } else if (ld <= 64) { constexpr int block_size = 64; SkipLayerNormKernelSmall <<>>(ld, input, skip, beta, gamma, bias, - maybe2half(epsilon), output, hasBias); + maybe2half(epsilon), output, hasBias); } else if (ld <= 128) { constexpr int block_size = 128; SkipLayerNormKernelSmall <<>>(ld, input, skip, beta, gamma, bias, - maybe2half(epsilon), output, hasBias); + maybe2half(epsilon), output, hasBias); } else if (ld == 384) { constexpr int block_size = 384; SkipLayerNormKernelSmall <<>>(ld, input, skip, beta, gamma, bias, - maybe2half(epsilon), output, hasBias); + maybe2half(epsilon), output, hasBias); } else { constexpr int block_size = 256; SkipLayerNormKernel <<>>(ld, input, skip, beta, gamma, bias, - maybe2half(epsilon), output); + maybe2half(epsilon), output); } } return CUDA_CALL(cudaPeekAtLastError()); @@ -195,9 +194,9 @@ template bool LaunchSkipLayerNormKernel(cudaStream_t stream, float* outpu const int element_count, size_t element_size); template bool LaunchSkipLayerNormKernel(cudaStream_t stream, half* output, const half* input, - const half* skip, const half* gamma, const half* beta, - const half* bias, float epsilon, const int ld, - const int element_count, size_t element_size); + const half* skip, const half* gamma, const half* beta, + const half* bias, float epsilon, const int ld, + const int element_count, size_t element_size); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h index ace02b63fddf5..c9f11d456f3da 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h @@ -10,19 +10,17 @@ namespace cuda { template bool LaunchSkipLayerNormKernel( cudaStream_t stream, - T* output, // output tensor - const T* input, // input tensor - const T* skip, // skip tensor - const T* gamma, // Layer normalization gamma tensor - const T* beta, // Layer normalization beta tensor - const T* bias, // Layer normalization beta tensor + T* output, // output tensor + const T* input, // input tensor + const T* skip, // skip tensor + const T* gamma, // Layer normalization gamma tensor + const T* beta, // Layer normalization beta tensor + const T* bias, // Layer normalization beta tensor float epsilon, // Layer normalization epsilon - int hidden_size, // hidden size, it is the leading dimension (ld) - int element_count, // number of elements in input tensor - size_t element_size -); + int hidden_size, // hidden size, it is the leading dimension (ld) + int element_count, // number of elements in input tensor + size_t element_size); } // namespace cuda } // namespace contrib } // namespace onnxruntime - diff --git a/onnxruntime/contrib_ops/cuda/layer_norm.cc b/onnxruntime/contrib_ops/cuda/layer_norm.cc index f1b70a295acee..6a276830018f8 100644 --- a/onnxruntime/contrib_ops/cuda/layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/layer_norm.cc @@ -43,7 +43,7 @@ Status LayerNorm::ComputeInternal(OpKernelContext* ctx) con typedef typename ToCudaType::MappedType CudaT; typedef typename ToCudaType::MappedType CudaU; typedef typename ToCudaType::MappedType CudaV; - //Inputs + // Inputs const Tensor* X = ctx->Input(0); const Tensor* scale = ctx->Input(1); const Tensor* bias = ctx->Input(2); @@ -64,7 +64,7 @@ Status LayerNorm::ComputeInternal(OpKernelContext* ctx) con Tensor* Y = ctx->Output(0, x_shape); auto Y_data = reinterpret_cast(Y->MutableData()); - //Mean and variance + // Mean and variance std::vector mean_inv_std_var_dim; for (int i = 0; i < static_cast(x_shape.NumDimensions()); ++i) { if (i < axis) { @@ -98,6 +98,6 @@ Status LayerNorm::ComputeInternal(OpKernelContext* ctx) con return Status::OK(); } -} //namespace cuda +} // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/layer_norm_impl.cu index 748f51953965f..b1ebb0b3a5eba 100644 --- a/onnxruntime/contrib_ops/cuda/layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/layer_norm_impl.cu @@ -114,8 +114,8 @@ __device__ void cuWelfordMuSigma2( U curr = static_cast(lvals[l]); cuWelfordOnlineSum(curr, mu, sigma2, count); } - // intra-warp reductions - #pragma unroll +// intra-warp reductions +#pragma unroll for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) { U muB = WARP_SHFL_DOWN(mu, stride); U countB = WARP_SHFL_DOWN(count, stride); @@ -209,8 +209,8 @@ __device__ void cuWelfordMuSigma2( float curr = static_cast(lvals[l]); cuWelfordOnlineSum(curr, mu, sigma2, count); } - // intra-warp reductions - #pragma unroll +// intra-warp reductions +#pragma unroll for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) { float muB = WARP_SHFL_DOWN(mu, stride); float countB = WARP_SHFL_DOWN(count, stride); diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index 1336aa7aac08c..dec76be78e23a 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -164,25 +164,26 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { Tensor* present_tensor = GetPresent(context, past_tensor, batch_size, head_size, sequence_length, past_sequence_length); size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, head_size, sequence_length, past_sequence_length); - auto temp_buffer = GetScratchBuffer(workSpaceSize); + + auto work_space = GetScratchBuffer(workSpaceSize); if (!LaunchAttentionKernel( GetDeviceProp(), Stream(), - reinterpret_cast(gemm_buffer.get()), - nullptr == mask_index ? nullptr : mask_index->Data(), - nullptr == mask_index ? gsl::span() : mask_index->Shape().GetDims(), - output->MutableData(), + cublas, + element_size, batch_size, sequence_length, num_heads_, head_size, - temp_buffer.get(), - cublas, - element_size, - is_unidirectional_, past_sequence_length, + is_unidirectional_, + reinterpret_cast(gemm_buffer.get()), + nullptr == mask_index ? nullptr : mask_index->Data(), + nullptr == mask_index ? gsl::span() : mask_index->Shape().GetDims(), nullptr == past_tensor ? nullptr : past_tensor->Data(), - nullptr, // TODO: support add_qk in quantized attention + nullptr, // TODO: support add_qk in quantized attention + work_space.get(), + output->MutableData(), nullptr == present_tensor ? nullptr : present_tensor->MutableData())) { // Get last error to reset it to cudaSuccess. CUDA_CALL(cudaGetLastError()); diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 0432b1a3450a7..eb0c3b528345f 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -409,16 +409,16 @@ Status ProcessLogits(const OrtValue& logits, // template Status GreedySearchProcessLogits( - const OrtValue& logits, // logits output of subgraph - transformers::IGreedySearchState* greedy_state, // state - transformers::ISequences* sequences, // sequences - AllocatorPtr& allocator, // default allocator - onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) - transformers::ILogitsProcessorList* logits_processors, // logits processors - const transformers::IBeamSearchParameters* parameters, // parameters - int step, // iteration counter - void* stream, // cuda stream (for CUDA only) - const transformers::IConsoleDumper* dumper) { // tensor dumper + const OrtValue& logits, // logits output of subgraph + transformers::IGreedySearchState* greedy_state, // state + transformers::ISequences* sequences, // sequences + AllocatorPtr& allocator, // default allocator + onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) + transformers::ILogitsProcessorList* logits_processors, // logits processors + const transformers::IBeamSearchParameters* parameters, // parameters + int step, // iteration counter + void* stream, // cuda stream (for CUDA only) + const transformers::IConsoleDumper* dumper) { // tensor dumper ORT_UNUSED_PARAMETER(logits_processors); #ifndef DEBUG_GENERATION @@ -520,8 +520,8 @@ Status GreedySearchProcessLogits( *topk_scores, *topk_indices)); #ifdef DEBUG_GENERATION - dumper->Print("topk_scores", *(topk_scores.get())); - dumper->Print("topk_indices", *(topk_indices.get())); + dumper->Print("topk_scores", topk_scores); + dumper->Print("topk_indices", topk_indices); #endif const int64_t* next_token_indices = topk_indices->Data(); diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.cc b/onnxruntime/contrib_ops/rocm/bert/attention.cc index 8258bf78c61f0..027d1c5e68ebd 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention.cc +++ b/onnxruntime/contrib_ops/rocm/bert/attention.cc @@ -41,7 +41,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { const Tensor* extra_add_qk = context->Input(5); auto& device_prop = GetDeviceProp(); - ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask_index, past, extra_add_qk, device_prop.maxThreadsPerBlock)); + ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), + mask_index, past, extra_add_qk, device_prop.maxThreadsPerBlock)); // input shape (batch_size, sequence_length, input_hidden_size) const auto& shape = input->Shape(); @@ -92,26 +93,28 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { reinterpret_cast(input->Data()), k, &one, reinterpret_cast(gemm_buffer.get()), n)); - size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, head_size, sequence_length, past_sequence_length); - auto temp_buffer = GetScratchBuffer(workSpaceSize); + size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, head_size, + sequence_length, past_sequence_length); + + auto work_space = GetScratchBuffer(workSpaceSize); if (!LaunchAttentionKernel( device_prop, Stream(), - reinterpret_cast(gemm_buffer.get()), - nullptr == mask_index ? nullptr : mask_index->Data(), - nullptr == mask_index ? gsl::span() : mask_index->Shape().GetDims(), - output->MutableData(), + rocblas, + element_size, batch_size, sequence_length, num_heads_, head_size, - temp_buffer.get(), - rocblas, - element_size, - is_unidirectional_, past_sequence_length, + is_unidirectional_, + reinterpret_cast(gemm_buffer.get()), + nullptr == mask_index ? nullptr : mask_index->Data(), + nullptr == mask_index ? gsl::span() : mask_index->Shape().GetDims(), nullptr == past ? nullptr : past->Data(), nullptr == extra_add_qk ? nullptr : extra_add_qk->Data(), + work_space.get(), + output->MutableData(), nullptr == present ? nullptr : present->MutableData())) { // Get last error to reset it to hipSuccess. HIP_CALL(hipGetLastError()); diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu index c87c12a1f9ea0..dd1861f06667e 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu @@ -32,9 +32,9 @@ limitations under the License. using namespace onnxruntime::rocm; using namespace hipcub; -#define CHECK_ROCM(expr) \ +#define CHECK_ROCM(expr) \ if (!HIP_CALL(expr)) { \ - return false; \ + return false; \ } namespace onnxruntime { @@ -45,9 +45,12 @@ static size_t AlignTo(size_t a, size_t b) { return CeilDiv(a, b) * b; } -size_t GetAttentionScratchSize(size_t element_size, int batch_size, int num_heads, int sequence_length, int all_sequence_length) { - const size_t len = batch_size * num_heads * sequence_length * all_sequence_length; - const size_t bytes = len * element_size; +size_t GetAttentionScratchSize(size_t element_size, + int batch_size, + int num_heads, + int sequence_length, + int all_sequence_length) { + const size_t bytes = element_size * batch_size * num_heads * sequence_length * all_sequence_length; const size_t alignment = 256; const size_t bytesAligned = AlignTo(bytes, alignment); @@ -61,19 +64,35 @@ size_t GetAttentionWorkspaceSize( int head_size, int sequence_length, int past_sequence_length) { - size_t qkv_size = 3 * batch_size * sequence_length * num_heads * head_size * element_size; - return qkv_size + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, past_sequence_length + sequence_length); + size_t qkv_size = element_size * 3 * batch_size * sequence_length * num_heads * head_size; + return qkv_size + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, + sequence_length, past_sequence_length + sequence_length); } template bool QkvToContext( - const hipDeviceProp_t& prop, rocblas_handle& rocblas, hipStream_t stream, - const int batch_size, const int sequence_length, const int num_heads, const int head_size, const size_t element_size, - const T* input, T* output, T* workspace, - const int* mask_index, gsl::span mask_index_dims, - bool is_unidirectional, int past_sequence_length, const T* past, const T* extra_add_qk, T* present, bool use_persistent_softmax) { + const hipDeviceProp_t& prop, + rocblas_handle& rocblas, + hipStream_t stream, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const size_t element_size, + const T* input, + T* output, + T* workspace, + const int* mask_index, + gsl::span mask_index_dims, + bool is_unidirectional, + int past_sequence_length, + const T* past, + const T* extra_add_qk, + T* present, + bool use_persistent_softmax) { const int all_sequence_length = past_sequence_length + sequence_length; - const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, all_sequence_length); + const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, + sequence_length, all_sequence_length); T* scratch1 = workspace; T* scratch2 = scratch1 + (bytes / element_size); T* scratch3 = scratch2 + (bytes / element_size); @@ -81,7 +100,8 @@ bool QkvToContext( const int max_threads_per_block = prop.maxThreadsPerBlock; // input should be BxSx3xNxH => scratch3: 3xBxNxSxH - if (!LaunchTransQkv(stream, 3, sequence_length, batch_size, head_size, num_heads, max_threads_per_block, false, input, scratch3)) { + if (!LaunchTransQkv(stream, 3, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, false, input, scratch3)) { return false; } @@ -101,7 +121,8 @@ bool QkvToContext( // past_v (BxNxS'xH) + v (BxNxSxH) => present_v (BxNxS*xH) const int present_size_per_batch = all_sequence_length * head_size; if (nullptr != present) { - if (!LaunchConcatPastToPresent(stream, all_sequence_length, sequence_length, batch_size, head_size, num_heads, max_threads_per_block, past, k, present)) { + if (!LaunchConcatPastToPresent(stream, all_sequence_length, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, past, k, present)) { return false; } @@ -120,29 +141,33 @@ bool QkvToContext( typedef typename ToHipType::MappedType HipT; - //float one = 1.0f; - //float zero = 0.f; + // float one = 1.0f; + // float zero = 0.f; const HipT one = ToHipType::FromFloat(1.0f); const HipT zero = ToHipType::FromFloat(0.f); // For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation. - //float temp_alpha = use_raw_attention_mask ? one : rsqrt_head_size; + // float temp_alpha = use_raw_attention_mask ? one : rsqrt_head_size; const HipT alpha = use_raw_attention_mask ? one : ToHipType::FromFloat(rsqrt_head_size); if (!ROCBLAS_CALL(rocblasGemmStridedBatchedHelper( - rocblas, rocblas_operation_transpose, rocblas_operation_none, all_sequence_length, sequence_length, head_size, &alpha, k, head_size, present_size_per_batch, - q, head_size, size_per_batch, &zero, scratch1, all_sequence_length, temp_matrix_size, batches))) { + rocblas, rocblas_operation_transpose, rocblas_operation_none, + all_sequence_length, sequence_length, head_size, + &alpha, k, head_size, present_size_per_batch, + q, head_size, size_per_batch, + &zero, scratch1, all_sequence_length, temp_matrix_size, batches))) { return false; } // apply softmax and store result P to scratch2: BxNxSxS* if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask const int mask_dimension = static_cast(mask_index_dims.size()); - const int64_t max_sequence_length = mask_dimension == 4 ? mask_index_dims.at(3) : 0; + const int max_sequence_length = mask_dimension == 4 ? static_cast(mask_index_dims.at(3)) : 0; - T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score if persistent softmax is selected. - if (!ComputeSoftmaxWithRawMask(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, nullptr, extra_add_qk, scratch1, scratch2, - is_unidirectional, rsqrt_head_size, mask_dimension, static_cast(max_sequence_length), + T* persistent_softmax_workspace = scratch1; // replace Q*K' in place if persistent softmax is selected. + if (!ComputeSoftmaxWithRawMask(stream, all_sequence_length, sequence_length, batch_size, num_heads, + mask_index, nullptr, extra_add_qk, scratch1, scratch2, + is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax, persistent_softmax_workspace)) { return false; } @@ -150,91 +175,110 @@ bool QkvToContext( ORT_ENFORCE(mask_index_dims.size() == 1); // mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions. const int* mask_start = (mask_index_dims.at(0) > batch_size) ? mask_index + batch_size : nullptr; - if (!ComputeSoftmaxWithMask1D(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, mask_start, extra_add_qk, scratch1, scratch2, is_unidirectional)) { + if (!ComputeSoftmaxWithMask1D(stream, all_sequence_length, sequence_length, batch_size, num_heads, + mask_index, mask_start, extra_add_qk, scratch1, scratch2, is_unidirectional)) { return false; } } else { // no mask - if (!ComputeSoftmax(stream, all_sequence_length, sequence_length, batch_size, num_heads, extra_add_qk, scratch1, scratch2, is_unidirectional)) { + if (!ComputeSoftmax(stream, all_sequence_length, sequence_length, batch_size, num_heads, + extra_add_qk, scratch1, scratch2, is_unidirectional)) { return false; } } // compute P*V (as V*P), and store in scratch3: BxNxSxH if (!ROCBLAS_CALL(rocblasGemmStridedBatchedHelper( - rocblas, rocblas_operation_none, rocblas_operation_none, head_size, sequence_length, all_sequence_length, &one, v, head_size, present_size_per_batch, - scratch2, all_sequence_length, temp_matrix_size, &zero, scratch3, head_size, size_per_batch, batches))) { + rocblas, rocblas_operation_none, rocblas_operation_none, + head_size, sequence_length, all_sequence_length, + &one, v, head_size, present_size_per_batch, + scratch2, all_sequence_length, temp_matrix_size, + &zero, scratch3, head_size, size_per_batch, batches))) { return false; } // scratch3 is BxNxSxH, transpose to output BxSxNxH - return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, max_threads_per_block, false, scratch3, output); + return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, false, scratch3, output); } bool LaunchAttentionKernel( const hipDeviceProp_t& prop, hipStream_t stream, - const void* input, - const int* mask_index, - gsl::span mask_index_dims, - void* output, - const int batch_size, - const int sequence_length, - const int num_heads, - const int head_size, - void* workspace, rocblas_handle& rocblas, const size_t element_size, - bool is_unidirectional, + int batch_size, + int sequence_length, + int num_heads, + int head_size, int past_sequence_length, + bool is_unidirectional, + const void* input, + const int* mask_index, + gsl::span mask_index_dims, const void* past, const void* extra_add_qk, + void* workspace, + void* output, void* present) { // For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 could enable persistent softmax const TransformerOptions* options = TransformerOptions::GetInstance(); bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax(); if (element_size == 2) { - return QkvToContext(prop, rocblas, stream, - batch_size, sequence_length, num_heads, head_size, element_size, - reinterpret_cast(input), reinterpret_cast<__half*>(output), reinterpret_cast<__half*>(workspace), - mask_index, mask_index_dims, is_unidirectional, - past_sequence_length, reinterpret_cast(past), reinterpret_cast(extra_add_qk), - reinterpret_cast<__half*>(present), use_persistent_softmax); + return QkvToContext( + prop, rocblas, stream, batch_size, sequence_length, num_heads, head_size, element_size, + reinterpret_cast(input), + reinterpret_cast<__half*>(output), + reinterpret_cast<__half*>(workspace), + mask_index, + mask_index_dims, + is_unidirectional, + past_sequence_length, + reinterpret_cast(past), + reinterpret_cast(extra_add_qk), + reinterpret_cast<__half*>(present), + use_persistent_softmax); } else { - return QkvToContext(prop, rocblas, stream, - batch_size, sequence_length, num_heads, head_size, element_size, - reinterpret_cast(input), reinterpret_cast(output), reinterpret_cast(workspace), - mask_index, mask_index_dims, is_unidirectional, - past_sequence_length, reinterpret_cast(past), reinterpret_cast(extra_add_qk), - reinterpret_cast(present), use_persistent_softmax); + return QkvToContext( + prop, rocblas, stream, batch_size, sequence_length, num_heads, head_size, element_size, + reinterpret_cast(input), + reinterpret_cast(output), + reinterpret_cast(workspace), + mask_index, + mask_index_dims, + is_unidirectional, + past_sequence_length, + reinterpret_cast(past), + reinterpret_cast(extra_add_qk), + reinterpret_cast(present), + use_persistent_softmax); } } - template bool DecoderQkvToContext( - const hipDeviceProp_t& prop, - hipStream_t stream, - rocblas_handle& rocblas, - const size_t element_size, - const int batch_size, - const int sequence_length, - const int kv_sequence_length, - const int num_heads, - const int head_size, - const bool static_kv, - const bool use_past, - const bool has_layer_state, - const bool has_key_padding_mask, - const T* gemm_query_buffer, - const T* gemm_kv_buffer, - const bool* key_padding_mask, - const T* key_cache, - const T* value_cache, - T* qkv_buffer, - T* workspace_buffer, - T* output, - T* new_key_cache, - T* new_value_cache) { + const hipDeviceProp_t& prop, + hipStream_t stream, + rocblas_handle& rocblas, + const size_t element_size, + const int batch_size, + const int sequence_length, + const int kv_sequence_length, + const int num_heads, + const int head_size, + const bool static_kv, + const bool use_past, + const bool has_layer_state, + const bool has_key_padding_mask, + const T* gemm_query_buffer, + const T* gemm_kv_buffer, + const bool* key_padding_mask, + const T* key_cache, + const T* value_cache, + T* qkv_buffer, + T* workspace_buffer, + T* output, + T* new_key_cache, + T* new_value_cache) { const int max_threads_per_block = prop.maxThreadsPerBlock; const int BN = batch_size * num_heads; const int BHN = BN * head_size; @@ -325,16 +369,20 @@ bool DecoderQkvToContext( const int strideB = sequence_length * head_size; if (use_past && static_kv) { if (!ROCBLAS_CALL(rocblasGemmStridedBatchedHelper( - rocblas, rocblas_operation_transpose, rocblas_operation_none, - kv_sequence_length, sequence_length, head_size, &alpha, key_cache, head_size, strideA, - q, head_size, strideB, &zero, scratch1, kv_sequence_length, temp_matrix_size, BN))) { + rocblas, rocblas_operation_transpose, rocblas_operation_none, + kv_sequence_length, sequence_length, head_size, + &alpha, key_cache, head_size, strideA, + q, head_size, strideB, + &zero, scratch1, kv_sequence_length, temp_matrix_size, BN))) { return false; } } else { if (!ROCBLAS_CALL(rocblasGemmStridedBatchedHelper( - rocblas, rocblas_operation_transpose, rocblas_operation_none, - kv_sequence_length, sequence_length, head_size, &alpha, k, head_size, strideA, - q, head_size, strideB, &zero, scratch1, kv_sequence_length, temp_matrix_size, BN))) { + rocblas, rocblas_operation_transpose, rocblas_operation_none, + kv_sequence_length, sequence_length, head_size, + &alpha, k, head_size, strideA, + q, head_size, strideB, + &zero, scratch1, kv_sequence_length, temp_matrix_size, BN))) { return false; } } @@ -342,7 +390,7 @@ bool DecoderQkvToContext( if (has_key_padding_mask) { if (!ComputeSoftmaxWithRawMask(stream, kv_sequence_length, sequence_length, batch_size, num_heads, nullptr, key_padding_mask, nullptr, scratch1, scratch2, - false, 1, 2, static_cast(0), false, nullptr)) { + false, 1, 2, static_cast(0), false, nullptr)) { return false; } } else { @@ -355,16 +403,20 @@ bool DecoderQkvToContext( // compute P*V (as V*P), and store in scratch3: BxNxSxH if (use_past && static_kv) { if (!ROCBLAS_CALL(rocblasGemmStridedBatchedHelper( - rocblas, rocblas_operation_none, rocblas_operation_none, - head_size, sequence_length, kv_sequence_length, &one, value_cache, head_size, strideA, - scratch2, kv_sequence_length, temp_matrix_size, &zero, scratch3, head_size, strideB, BN))) { + rocblas, rocblas_operation_none, rocblas_operation_none, + head_size, sequence_length, kv_sequence_length, + &one, value_cache, head_size, strideA, + scratch2, kv_sequence_length, temp_matrix_size, + &zero, scratch3, head_size, strideB, BN))) { return false; } } else { if (!ROCBLAS_CALL(rocblasGemmStridedBatchedHelper( - rocblas, rocblas_operation_none, rocblas_operation_none, - head_size, sequence_length, kv_sequence_length, &one, v, head_size, strideA, - scratch2, kv_sequence_length, temp_matrix_size, &zero, scratch3, head_size, strideB, BN))) { + rocblas, rocblas_operation_none, rocblas_operation_none, + head_size, sequence_length, kv_sequence_length, + &one, v, head_size, strideA, + scratch2, kv_sequence_length, temp_matrix_size, + &zero, scratch3, head_size, strideB, BN))) { return false; } } @@ -374,81 +426,80 @@ bool DecoderQkvToContext( num_heads, max_threads_per_block, true, scratch3, output); } - bool LaunchDecoderAttentionKernel( - const hipDeviceProp_t& prop, - hipStream_t stream, - rocblas_handle& rocblas, - const size_t element_size, - const int batch_size, - const int sequence_length, - const int kv_sequence_length, - const int num_heads, - const int head_size, - const bool static_kv, - const bool use_past, - const bool has_layer_state, - const bool has_key_padding_mask, - const void* gemm_query_buffer, - const void* gemm_kv_buffer, - const bool* key_padding_mask, - const void* key_cache, - const void* value_cache, - void* qkv_buffer, - void* workspace_buffer, - void* output, - void* new_key_cache, - void* new_value_cache) { + const hipDeviceProp_t& prop, + hipStream_t stream, + rocblas_handle& rocblas, + const size_t element_size, + const int batch_size, + const int sequence_length, + const int kv_sequence_length, + const int num_heads, + const int head_size, + const bool static_kv, + const bool use_past, + const bool has_layer_state, + const bool has_key_padding_mask, + const void* gemm_query_buffer, + const void* gemm_kv_buffer, + const bool* key_padding_mask, + const void* key_cache, + const void* value_cache, + void* qkv_buffer, + void* workspace_buffer, + void* output, + void* new_key_cache, + void* new_value_cache) { if (element_size == 2) { return DecoderQkvToContext( - prop, - stream, - rocblas, - element_size, - batch_size, - sequence_length, - kv_sequence_length, - num_heads, - head_size, - static_kv, - use_past, - has_layer_state, - has_key_padding_mask, - reinterpret_cast(gemm_query_buffer), - reinterpret_cast(gemm_kv_buffer), - key_padding_mask, - reinterpret_cast(key_cache), - reinterpret_cast(value_cache), - reinterpret_cast(qkv_buffer), - reinterpret_cast(workspace_buffer), - reinterpret_cast(output), - reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); + prop, + stream, + rocblas, + element_size, + batch_size, + sequence_length, + kv_sequence_length, + num_heads, + head_size, + static_kv, + use_past, + has_layer_state, + has_key_padding_mask, + reinterpret_cast(gemm_query_buffer), + reinterpret_cast(gemm_kv_buffer), + key_padding_mask, + reinterpret_cast(key_cache), + reinterpret_cast(value_cache), + reinterpret_cast(qkv_buffer), + reinterpret_cast(workspace_buffer), + reinterpret_cast(output), + reinterpret_cast(new_key_cache), + reinterpret_cast(new_value_cache)); } else { return DecoderQkvToContext( - prop, - stream, - rocblas, - element_size, - batch_size, - sequence_length, - kv_sequence_length, - num_heads, - head_size, - static_kv, - use_past, - has_layer_state, - has_key_padding_mask, - reinterpret_cast(gemm_query_buffer), - reinterpret_cast(gemm_kv_buffer), - key_padding_mask, - reinterpret_cast(key_cache), - reinterpret_cast(value_cache), - reinterpret_cast(qkv_buffer), - reinterpret_cast(workspace_buffer), - reinterpret_cast(output), - reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); + prop, + stream, + rocblas, + element_size, + batch_size, + sequence_length, + kv_sequence_length, + num_heads, + head_size, + static_kv, + use_past, + has_layer_state, + has_key_padding_mask, + reinterpret_cast(gemm_query_buffer), + reinterpret_cast(gemm_kv_buffer), + key_padding_mask, + reinterpret_cast(key_cache), + reinterpret_cast(value_cache), + reinterpret_cast(qkv_buffer), + reinterpret_cast(workspace_buffer), + reinterpret_cast(output), + reinterpret_cast(new_key_cache), + reinterpret_cast(new_value_cache)); } } diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h index 7a17c2c38b60f..cce12fad74749 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -10,65 +10,66 @@ namespace onnxruntime { namespace contrib { namespace rocm { -size_t GetAttentionScratchSize(size_t element_size, int batch_size, - int num_heads, int sequence_length, int all_sequence_length); +size_t GetAttentionScratchSize( + size_t element_size, + int batch_size, + int num_heads, + int sequence_length, + int all_sequence_length); size_t GetAttentionWorkspaceSize( size_t element_size, - int batchsize, + int batch_size, int num_heads, int head_size, int sequence_length, int past_sequence_length); bool LaunchAttentionKernel( - const hipDeviceProp_t& prop, // Device Properties - hipStream_t stream, // rocm stream - const void* input, // Input tensor - const int* mask_index, // Attention mask raw data or index - // (end position of each sequence, - // or end positions and start positions). - // NULL means no mask. - gsl::span mask_index_dims, // Mask index shape - void* output, // Output tensor - int batch_size, // Batch size (B) - int sequence_length, // Sequence length (S) - int num_heads, // Number of attention heads (N) - int head_size, // Hidden layer size per head (H) - void* workspace, // Temporary buffer - rocblas_handle& rocblas, // Rocblas handle - const size_t element_size, // Element size of input tensor - bool is_unidirectional, // Whether there is unidirecitonal mask. - int past_sequence_length, // Sequence length in past state - const void* past, // Past state input - const void* extra_add_qk, // Additional Add - void* present // Present state output + const hipDeviceProp_t& prop, // Device Properties + hipStream_t stream, // cuda stream + rocblas_handle& rocblas, // Rocblas handle + const size_t element_size, // Element size of input tensor + int batch_size, // Batch size (B) + int sequence_length, // Sequence length (S) + int num_heads, // Number of attention heads (N) + int head_size, // Hidden layer size per head (H) + int past_sequence_length, // Sequence length in past state + bool is_unidirectional, // Whether there is unidirectional mask. + const void* input, // Input tensor + const int* mask_index, // Attention mask raw data or index. NULL means no mask. + gsl::span mask_index_dims, // Mask index shape + const void* past, // Past state input + const void* extra_add_qk, // Additional Add + void* workspace, // Temporary buffer + void* output, // Output tensor + void* present // Present state output ); bool LaunchDecoderAttentionKernel( - const hipDeviceProp_t& prop, // Device Properties - hipStream_t stream, // Cuda stream - rocblas_handle& rocblas, // Rocblas handle - const size_t element_size, // Element size of input tensor - const int batch_size, // Batch size (B) - const int sequence_length, // Sequence length (S) - const int kv_sequence_length, // Key/Value/Cache sequence length - const int num_heads, // Number of attention heads (N) - const int head_size, // Hidden layer size per head (H) - const bool static_kv, // Whether cross attention or not - const bool use_past, // Whether use cache or not - const bool has_layer_state, // Whether output cache or not - const bool has_key_padding_mask, // Whether use key_padding_mask or not - const void* gemm_query_buffer, // Query buffer - const void* gemm_kv_buffer, // Key and value buffer - const bool* key_padding_mask, // Key padding mask - const void* key_cache, // Input key cache - const void* value_cache, // Input value cache - void* qkv_buffer, // Temporary buffer - void* workspace_buffer, // Temporary buffer - void* output, // Output tensor - void* new_key_cache, // New_key_cache tensor - void* new_value_cache // New_value_cache tensor + const hipDeviceProp_t& prop, // Device Properties + hipStream_t stream, // Cuda stream + rocblas_handle& rocblas, // Rocblas handle + const size_t element_size, // Element size of input tensor + const int batch_size, // Batch size (B) + const int sequence_length, // Sequence length (S) + const int kv_sequence_length, // Key/Value/Cache sequence length + const int num_heads, // Number of attention heads (N) + const int head_size, // Hidden layer size per head (H) + const bool static_kv, // Whether cross attention or not + const bool use_past, // Whether use cache or not + const bool has_layer_state, // Whether output cache or not + const bool has_key_padding_mask, // Whether use key_padding_mask or not + const void* gemm_query_buffer, // Query buffer + const void* gemm_kv_buffer, // Key and value buffer + const bool* key_padding_mask, // Key padding mask + const void* key_cache, // Input key cache + const void* value_cache, // Input value cache + void* qkv_buffer, // Temporary buffer + void* workspace_buffer, // Temporary buffer + void* output, // Output tensor + void* new_key_cache, // New_key_cache tensor + void* new_value_cache // New_value_cache tensor ); bool LaunchTransCtx(hipStream_t stream, @@ -133,57 +134,57 @@ bool LaunchConcatPastToPresent(hipStream_t stream, const half* k_v, half* present); -inline rocblas_status _compat_rocblas_gemm_strided_batched_ex(rocblas_handle handle, - rocblas_operation transa, - rocblas_operation transb, - int m, - int n, - int k, - const void* alpha, - const void* A, - rocblas_datatype a_type, - rocblas_int lda, - rocblas_stride stride_A, - const void* b, - rocblas_datatype b_type, - rocblas_int ldb, - rocblas_stride stride_b, - const void* beta, - void* c, - rocblas_datatype c_type, - rocblas_int ldc, - rocblas_stride stride_c, - rocblas_int batch_count, - rocblas_datatype compute_type, - rocblas_gemm_algo algo) { - return rocblas_gemm_strided_batched_ex(handle, - transa, - transb, - m, // m - n, // n - k, // k - alpha, // alpha - A, // A - a_type, // A type - lda, // lda - stride_A, // strideA - b, // B - b_type, // B type - ldb, // ldb - stride_b, // strideB - beta, // beta - c, // C - c_type, // C type - ldc, // ldc - stride_c, // strideC - c, // D = C - c_type, // D type = C type - ldc, // ldd = ldc - stride_c, // strideD = strideC - batch_count, // batch count - compute_type, - algo, - 0, 0); +inline rocblas_status _compat_rocblas_gemm_strided_batched_ex(rocblas_handle handle, + rocblas_operation transa, + rocblas_operation transb, + int m, + int n, + int k, + const void* alpha, + const void* A, + rocblas_datatype a_type, + rocblas_int lda, + rocblas_stride stride_A, + const void* b, + rocblas_datatype b_type, + rocblas_int ldb, + rocblas_stride stride_b, + const void* beta, + void* c, + rocblas_datatype c_type, + rocblas_int ldc, + rocblas_stride stride_c, + rocblas_int batch_count, + rocblas_datatype compute_type, + rocblas_gemm_algo algo) { + return rocblas_gemm_strided_batched_ex(handle, + transa, + transb, + m, // m + n, // n + k, // k + alpha, // alpha + A, // A + a_type, // A type + lda, // lda + stride_A, // strideA + b, // B + b_type, // B type + ldb, // ldb + stride_b, // strideB + beta, // beta + c, // C + c_type, // C type + ldc, // ldc + stride_c, // strideC + c, // D = C + c_type, // D type = C type + ldc, // ldd = ldc + stride_c, // strideD = strideC + batch_count, // batch count + compute_type, + algo, + 0, 0); } // Compatible for CublasMathModeSetter diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h index cbd7472a17dc1..7ff41737c0c52 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h @@ -60,7 +60,9 @@ __device__ inline void Softmax(const int all_sequence_length, for (int i = threadIdx.x; i < valid_end; i += TPB) { if (i >= valid_start) { const int index = offset + i; - float input_at_idx = add_before_softmax == nullptr ? float(input[index]) : float(input[index] + add_before_softmax[index]); + float input_at_idx = add_before_softmax == nullptr + ? static_cast(input[index]) + : static_cast(input[index] + add_before_softmax[index]); if (thread_data_max < input_at_idx) { thread_data_max = input_at_idx; } @@ -92,7 +94,9 @@ __device__ inline void Softmax(const int all_sequence_length, for (int i = threadIdx.x; i < all_sequence_length; i += TPB) { const int index = offset + i; - float input_at_idx = add_before_softmax == nullptr ? float(input[index]) : float(input[index] + add_before_softmax[index]); + float input_at_idx = add_before_softmax == nullptr + ? static_cast(input[index]) + : static_cast(input[index] + add_before_softmax[index]); const float val = (i >= valid_start && i < valid_end) ? expf(input_at_idx - max_block) * sum_reverse_block : 0.f; output[index] = T(val); } @@ -124,7 +128,8 @@ __device__ inline void SoftmaxSmall(const int all_sequence_length, if (is_unidirectional) { int end_unid = all_sequence_length - sequence_length + (blockIdx.x % sequence_length) + 1; if (end_unid <= valid_start) { - // In this situation, mask of [0, end_unid) and [valid_start, valid_end) has -10000, and [end_unid, valid_start) and [valid_end, all_seq_len) has -20000. + // In this situation, mask of [0, end_unid) and [valid_start, valid_end) has -10000, + // and [end_unid, valid_start) and [valid_end, all_seq_len) has -20000. // So [0, end_unid) will also have value after softmax. is_valid = threadIdx.x < end_unid; } else { @@ -138,7 +143,9 @@ __device__ inline void SoftmaxSmall(const int all_sequence_length, // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. // a math transform as below is leveraged to get a stable softmax: // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - float input_data = add_before_softmax == nullptr ? float(input[index]) : float(input[index] + add_before_softmax[index]); + float input_data = add_before_softmax == nullptr + ? static_cast(input[index]) + : static_cast(input[index] + add_before_softmax[index]); float thread_data_max = is_valid ? input_data : float(-ROCMRT_INF_F); const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, hipcub::Max(), end); @@ -199,7 +206,7 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, const int sequence_index = blockIdx.x % sequence_length; if (is_unidirectional) { - int from_index = all_sequence_length - sequence_length + sequence_index; // offset of from token in all sequence length. + int from_index = all_sequence_length - sequence_length + sequence_index; // offset in all sequence length. if (threadIdx.x > from_index) { thread_data = -10000.0f; } @@ -212,7 +219,8 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, } else if (mask_dimension == 3) { mask_offset = (batch_index * sequence_length + sequence_index) * all_sequence_length + threadIdx.x; } else if (mask_dimension == 4) { - mask_offset = (batch_index * max_sequence_length + all_sequence_length - sequence_length + sequence_index) * max_sequence_length + threadIdx.x; + int from_index = all_sequence_length - sequence_length + sequence_index; + mask_offset = (batch_index * max_sequence_length + from_index) * max_sequence_length + threadIdx.x; } if (nullptr == key_padding_mask) { @@ -257,41 +265,53 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, } template -__global__ void SoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, const T* add_before_softmax, const T* input, T* output, bool is_unidirectional) { - SoftmaxSmall(all_sequence_length, sequence_length, all_sequence_length, 0, add_before_softmax, input, output, is_unidirectional); +__global__ void SoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, + const T* add_before_softmax, const T* input, T* output, bool is_unidirectional) { + SoftmaxSmall(all_sequence_length, sequence_length, all_sequence_length, 0, + add_before_softmax, input, output, is_unidirectional); } template -__global__ void SoftmaxKernel(const int all_sequence_length, const int sequence_length, const T* add_before_softmax, const T* input, T* output) { - Softmax(all_sequence_length, sequence_length, all_sequence_length, 0, add_before_softmax, input, output); +__global__ void SoftmaxKernel(const int all_sequence_length, const int sequence_length, + const T* add_before_softmax, const T* input, T* output) { + Softmax(all_sequence_length, sequence_length, all_sequence_length, 0, + add_before_softmax, input, output); } template bool ComputeSoftmax( - hipStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, + hipStream_t stream, + const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, const T* add_before_softmax, const T* input, T* output, bool is_unidirectional) { const dim3 grid(sequence_length * num_heads, batch_size, 1); if (all_sequence_length <= 32) { const int blockSize = 32; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall), grid, blockSize, 0, stream, + all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 64) { const int blockSize = 64; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall), grid, blockSize, 0, stream, + all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 128) { const int blockSize = 128; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall), grid, blockSize, 0, stream, + all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 256) { const int blockSize = 256; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall), grid, blockSize, 0, stream, + all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 512) { const int blockSize = 512; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall), grid, blockSize, 0, stream, + all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 1024) { const int blockSize = 1024; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall), grid, blockSize, 0, stream, + all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); } else if (!is_unidirectional) { const int blockSize = 1024; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernel), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output); + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernel), grid, blockSize, 0, stream, + all_sequence_length, sequence_length, add_before_softmax, input, output); } else { ORT_THROW("Attention ROCM operator does not support total sequence length > 1024."); } @@ -300,7 +320,10 @@ bool ComputeSoftmax( } template -__global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, const T* add_before_softmax, const T* input, T* output, bool is_unidirectional) { +__global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, + const int* mask_end, const int* mask_start, + const T* add_before_softmax, const T* input, T* output, + bool is_unidirectional) { __shared__ int start_position; __shared__ int end_position; @@ -317,11 +340,14 @@ __global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, const in } __syncthreads(); - SoftmaxSmall(all_sequence_length, sequence_length, end_position, start_position, add_before_softmax, input, output, is_unidirectional); + SoftmaxSmall(all_sequence_length, sequence_length, end_position, start_position, + add_before_softmax, input, output, is_unidirectional); } template -__global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, const T* add_before_softmax, const T* input, T* output) { +__global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int sequence_length, + const int* mask_end, const int* mask_start, + const T* add_before_softmax, const T* input, T* output) { __shared__ int start_position; __shared__ int end_position; @@ -338,42 +364,72 @@ __global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int seq } __syncthreads(); - Softmax(all_sequence_length, sequence_length, end_position, start_position, add_before_softmax, input, output); + Softmax(all_sequence_length, sequence_length, end_position, start_position, + add_before_softmax, input, output); } template -__global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length, const int sequence_length, const int* attention_mask, const bool* key_padding_mask, const T* add_before_softmax, const T* input, - T* output, const bool is_unidirectional, const float rsqrt_head_size, const int mask_dimension, const int max_sequence_length, +__global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length, + const int sequence_length, + const int* attention_mask, + const bool* key_padding_mask, + const T* add_before_softmax, + const T* input, T* output, + const bool is_unidirectional, + const float rsqrt_head_size, + const int mask_dimension, + const int max_sequence_length, const bool skip_softmax) { - SoftmaxWithRawMaskSmall(all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, output, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, skip_softmax); + SoftmaxWithRawMaskSmall( + all_sequence_length, sequence_length, + attention_mask, key_padding_mask, add_before_softmax, input, output, + is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, + skip_softmax); } template -bool ComputeSoftmaxWithMask1D(hipStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, - const int* mask_index, const int* mask_start, const T* add_before_softmax, const T* input, T* output, const bool is_unidirectional) { +bool ComputeSoftmaxWithMask1D( + hipStream_t stream, + const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, + const int* mask_index, const int* mask_start, + const T* add_before_softmax, const T* input, T* output, const bool is_unidirectional) { const dim3 grid(sequence_length * num_heads, batch_size, 1); if (all_sequence_length <= 32) { const int blockSize = 32; - hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional); + hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall), grid, blockSize, 0, stream, + all_sequence_length, sequence_length, mask_index, mask_start, + add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 64) { const int blockSize = 64; - hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional); + hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall), grid, blockSize, 0, stream, + all_sequence_length, sequence_length, mask_index, mask_start, + add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 128) { const int blockSize = 128; - hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional); + hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall), grid, blockSize, 0, stream, + all_sequence_length, sequence_length, mask_index, mask_start, + add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 256) { const int blockSize = 256; - hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional); + hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall), grid, blockSize, 0, stream, + all_sequence_length, sequence_length, mask_index, mask_start, + add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 512) { const int blockSize = 512; - hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional); + hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall), grid, blockSize, 0, stream, + all_sequence_length, sequence_length, mask_index, mask_start, + add_before_softmax, input, output, is_unidirectional); } else if (all_sequence_length <= 1024) { const int blockSize = 1024; - hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional); + hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall), grid, blockSize, 0, stream, + all_sequence_length, sequence_length, mask_index, mask_start, + add_before_softmax, input, output, is_unidirectional); } else if (!is_unidirectional) { const int blockSize = 1024; - hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernel), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output); + hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernel), grid, blockSize, 0, stream, + all_sequence_length, sequence_length, mask_index, mask_start, + add_before_softmax, input, output); } else { ORT_THROW("Attention ROCM operator does not support total sequence length > 1024."); } @@ -382,42 +438,83 @@ bool ComputeSoftmaxWithMask1D(hipStream_t stream, const int all_sequence_length, } template -bool ComputeSoftmaxWithRawMask(hipStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, - const int* attention_mask, const bool* key_padding_mask, const T* add_before_softmax, const T* input, T* output, const bool is_unidirectional, - const float rsqrt_head_size, const int mask_dimension, const int max_sequence_length, const bool use_persistent_softmax, T* persistent_softmax_workspace) { +bool ComputeSoftmaxWithRawMask(hipStream_t stream, + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int num_heads, + const int* attention_mask, + const bool* key_padding_mask, + const T* add_before_softmax, + const T* input, + T* output, + const bool is_unidirectional, + const float rsqrt_head_size, + const int mask_dimension, + const int max_sequence_length, + const bool use_persistent_softmax, + T* persistent_softmax_workspace) { const dim3 grid(sequence_length * num_heads, batch_size, 1); T* out = use_persistent_softmax ? persistent_softmax_workspace : output; if (all_sequence_length <= 32) { const int blockSize = 32; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel), grid, blockSize, 0, stream, all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax); + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel), grid, blockSize, 0, stream, + all_sequence_length, sequence_length, + attention_mask, key_padding_mask, add_before_softmax, input, out, + is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax); } else if (all_sequence_length <= 64) { const int blockSize = 64; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel), grid, blockSize, 0, stream, all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax); + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel), grid, blockSize, 0, stream, + all_sequence_length, sequence_length, + attention_mask, key_padding_mask, add_before_softmax, input, out, + is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax); } else if (all_sequence_length <= 128) { const int blockSize = 128; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel), grid, blockSize, 0, stream, all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax); + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel), grid, blockSize, 0, stream, + all_sequence_length, sequence_length, + attention_mask, key_padding_mask, add_before_softmax, input, out, + is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax); } else if (all_sequence_length <= 256) { const int blockSize = 256; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel), grid, blockSize, 0, stream, all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax); + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel), grid, blockSize, 0, stream, + all_sequence_length, sequence_length, + attention_mask, key_padding_mask, add_before_softmax, input, out, + is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax); } else if (all_sequence_length <= 512) { const int blockSize = 512; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel), grid, blockSize, 0, stream, all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax); + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel), grid, blockSize, 0, stream, + all_sequence_length, sequence_length, + attention_mask, key_padding_mask, add_before_softmax, input, out, + is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax); } else if (all_sequence_length <= 1024) { const int blockSize = 1024; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel), grid, blockSize, 0, stream, all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax); + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel), grid, blockSize, 0, stream, + all_sequence_length, sequence_length, + attention_mask, key_padding_mask, add_before_softmax, input, out, + is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, + use_persistent_softmax); } else { ORT_THROW("Attention ROCM operator does not support total sequence length > 1024."); } if (use_persistent_softmax) { - dispatch_warpwise_softmax_forward(stream, output, persistent_softmax_workspace, all_sequence_length, all_sequence_length, batch_size * num_heads * sequence_length); + dispatch_warpwise_softmax_forward(stream, + output, + persistent_softmax_workspace, + all_sequence_length, + all_sequence_length, + batch_size * num_heads * sequence_length); } return HIP_CALL(hipPeekAtLastError()); } - } // namespace rocm } // namespace contrib } // namespace onnxruntime