From 3b61d7c56c0b40c354d4c5b15b2a39e225d17645 Mon Sep 17 00:00:00 2001 From: Ali Hassani Date: Tue, 12 Aug 2025 17:35:37 -0400 Subject: [PATCH 1/2] Feat: Clip dot products Ampere FNA only for now. Allows optionally clipping dot products according to some floating point range (min, max). - Issue: #249 --- csrc/include/natten/cuda/fna/fna_backward.cuh | 18 ++- csrc/include/natten/cuda/fna/fna_forward.cuh | 18 ++- .../include/natten/cuda/fna/kernel_backward.h | 36 +++++- csrc/include/natten/cuda/fna/kernel_forward.h | 32 +++++ .../cuda/reference/fna_reference_backward.hpp | 120 ++++++++++++++++-- .../cuda/reference/fna_reference_forward.hpp | 32 ++++- csrc/include/natten/fna.h | 36 +++++- csrc/include/natten/reference.h | 36 +++++- csrc/src/fna_backward.cu | 48 +++++-- csrc/src/fna_forward.cu | 48 +++++-- csrc/src/reference_backward.cu | 48 +++++-- csrc/src/reference_forward.cu | 48 +++++-- scripts/autogen_reference_fna.py | 24 +++- src/natten/backends/__init__.py | 28 +++- src/natten/backends/configs/checks.py | 74 ++++++++++- src/natten/backends/fna.py | 45 ++++++- src/natten/backends/reference.py | 45 ++++++- src/natten/functional.py | 45 ++++++- tests/test_fna.py | 58 +++++++++ tests/utils.py | 19 +++ 20 files changed, 773 insertions(+), 85 deletions(-) diff --git a/csrc/include/natten/cuda/fna/fna_backward.cuh b/csrc/include/natten/cuda/fna/fna_backward.cuh index 4d1a5821..29340970 100644 --- a/csrc/include/natten/cuda/fna/fna_backward.cuh +++ b/csrc/include/natten/cuda/fna/fna_backward.cuh @@ -89,7 +89,11 @@ void fna_backward_generic( float attn_scale, IntTuple query_tile_shape, IntTuple key_tile_shape, - IntTuple num_splits_key) { + IntTuple num_splits_key, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { static constexpr auto kRank = std::tuple_size::value; using Dim = typename GetDim::type; @@ -157,6 +161,18 @@ void fna_backward_generic( p.num_splits_key = tuple_to_na_dim(num_splits_key); + // Optional dot product clipping + p.has_dot_product_clip = has_dot_product_min || has_dot_product_max; + p.has_dot_product_min = has_dot_product_min; + p.has_dot_product_max = has_dot_product_max; + if (has_dot_product_min) { + p.dot_product_min = dot_product_min; + } + if (has_dot_product_max) { + p.dot_product_max = dot_product_max; + } + // + int64_t size_bytes = p.workspace_size(); if (size_bytes) { void* workspace_ptr = nullptr; diff --git a/csrc/include/natten/cuda/fna/fna_forward.cuh b/csrc/include/natten/cuda/fna/fna_forward.cuh index d6010cb9..f81c95d5 100644 --- a/csrc/include/natten/cuda/fna/fna_forward.cuh +++ b/csrc/include/natten/cuda/fna/fna_forward.cuh @@ -79,7 +79,11 @@ void fna_forward_generic( float attn_scale, void* logsumexp_ptr, IntTuple query_tile_shape, - IntTuple key_tile_shape) { + IntTuple key_tile_shape, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { static constexpr auto kRank = std::tuple_size::value; using Dim = typename GetDim::type; @@ -167,6 +171,18 @@ void fna_forward_generic( p.query_tile_shape = tuple_to_na_dim(query_tile_shape); p.key_tile_shape = tuple_to_na_dim(key_tile_shape); + // Optional dot product clipping + p.has_dot_product_clip = has_dot_product_min || has_dot_product_max; + p.has_dot_product_min = has_dot_product_min; + p.has_dot_product_max = has_dot_product_max; + if (has_dot_product_min) { + p.dot_product_min = dot_product_min; + } + if (has_dot_product_max) { + p.dot_product_max = dot_product_max; + } + // + if (smem_bytes > 0xc000) { auto err = cudaFuncSetAttribute( kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); diff --git a/csrc/include/natten/cuda/fna/kernel_backward.h b/csrc/include/natten/cuda/fna/kernel_backward.h index ab7f640c..1ff7df1c 100644 --- a/csrc/include/natten/cuda/fna/kernel_backward.h +++ b/csrc/include/natten/cuda/fna/kernel_backward.h @@ -692,6 +692,17 @@ struct FusedNeighborhoodAttentionBackwardKernel { bool is_fully_block_sparse = false; bool has_q_padding = false; + // Optional dot product clipping -- all must be set explicitly for avoiding + // comparisons. + bool has_dot_product_clip = false; + bool has_dot_product_min = false; + bool has_dot_product_max = false; + accum_t dot_product_min = + -cutlass::platform::numeric_limits::infinity(); + accum_t dot_product_max = + cutlass::platform::numeric_limits::infinity(); + // + // Dimensions/strides int32_t head_dim = -1; int32_t head_dim_value = -1; @@ -1615,7 +1626,6 @@ struct FusedNeighborhoodAttentionBackwardKernel { mma.set_prologue_done(kPrologueQK); mma.set_zero_outside_bounds(/*!skipBoundsChecks*/ true); mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); - accum = cutlass::multiplies()(scale, accum); // Epilogue: add LSE + exp and store that to our shared memory buffer // shmem <- (matmul_result - @@ -1629,6 +1639,30 @@ struct FusedNeighborhoodAttentionBackwardKernel { auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset( lane_id, warp_id, output_tile_coords); + // (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING & + // SCALING. + if (p.has_dot_product_clip) { + if (not p.has_dot_product_max) { + for (int i = 0; i < Mma::FragmentC::kElements; ++i) { + accum[i] = cutlass::fast_max(accum[i], p.dot_product_min); + } + } else if (not p.has_dot_product_min) { + for (int i = 0; i < Mma::FragmentC::kElements; ++i) { + accum[i] = cutlass::fast_min(accum[i], p.dot_product_max); + } + } else { + // assert(p.has_dot_product_min && p.has_dot_product_max); + for (int i = 0; i < Mma::FragmentC::kElements; ++i) { + accum[i] = cutlass::fast_max( + cutlass::fast_min(accum[i], p.dot_product_max), + p.dot_product_min); + } + } + } + + // Dot product scale + accum = cutlass::multiplies()(scale, accum); + if (not p.is_fully_block_sparse) { // Neighborhood Attention masking Dim first_col, query_bound, row_idx; diff --git a/csrc/include/natten/cuda/fna/kernel_forward.h b/csrc/include/natten/cuda/fna/kernel_forward.h index 0d53c9dc..29780321 100644 --- a/csrc/include/natten/cuda/fna/kernel_forward.h +++ b/csrc/include/natten/cuda/fna/kernel_forward.h @@ -200,6 +200,17 @@ struct FusedNeighborhoodAttentionKernel { bool is_fully_block_sparse = false; bool has_kv_padding = false; + // Optional dot product clipping -- all must be set explicitly for avoiding + // comparisons. + bool has_dot_product_clip = false; + bool has_dot_product_min = false; + bool has_dot_product_max = false; + accum_t dot_product_min = + -cutlass::platform::numeric_limits::infinity(); + accum_t dot_product_max = + cutlass::platform::numeric_limits::infinity(); + // + // Moves pointers to what we should process // Returns "false" if there is no work to do CUTLASS_DEVICE bool advance_to_block() { @@ -734,6 +745,27 @@ struct FusedNeighborhoodAttentionKernel { MM1::Mma::drain_cp_asyncs(); } + // (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING & + // SCALING. + if (p.has_dot_product_clip) { + if (not p.has_dot_product_max) { + for (int i = 0; i < MM0::Mma::FragmentC::kElements; ++i) { + accum[i] = cutlass::fast_max(accum[i], p.dot_product_min); + } + } else if (not p.has_dot_product_min) { + for (int i = 0; i < MM0::Mma::FragmentC::kElements; ++i) { + accum[i] = cutlass::fast_min(accum[i], p.dot_product_max); + } + } else { + // assert(p.has_dot_product_min && p.has_dot_product_max); + for (int i = 0; i < MM0::Mma::FragmentC::kElements; ++i) { + accum[i] = cutlass::fast_max( + cutlass::fast_min(accum[i], p.dot_product_max), + p.dot_product_min); + } + } + } + if (not p.is_fully_block_sparse) { // Neighborhood Attention masking Dim first_col, key_bound, row_idx; diff --git a/csrc/include/natten/cuda/reference/fna_reference_backward.hpp b/csrc/include/natten/cuda/reference/fna_reference_backward.hpp index 1544c1e1..1ea4c4a0 100644 --- a/csrc/include/natten/cuda/reference/fna_reference_backward.hpp +++ b/csrc/include/natten/cuda/reference/fna_reference_backward.hpp @@ -74,7 +74,11 @@ void __global__ fna_bwd_reference_dQ_kernel( Causal is_causal, QKVLayout qkv_layout, float attn_scale, - int num_additional_kv) { + int num_additional_kv, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { using namespace cute; auto attention_mask = @@ -111,6 +115,20 @@ void __global__ fna_bwd_reference_dQ_kernel( acc_dov += mDO(idx_Q, idx_D1, idx_L) * mV(idx_K, idx_D1, idx_L); acc_doo += mDO(idx_Q, idx_D1, idx_L) * mO(idx_Q, idx_D1, idx_L); } // for idx_D1 + + // (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING & + // SCALING. + if (has_dot_product_min || has_dot_product_max) { + if (not has_dot_product_max) { + acc_qk = cutlass::fast_max(acc_qk, dot_product_min); + } else if (not has_dot_product_min) { + acc_qk = cutlass::fast_min(acc_qk, dot_product_max); + } else { + acc_qk = cutlass::fast_max( + cutlass::fast_min(acc_qk, dot_product_max), dot_product_min); + } + } + acc_qk *= attn_scale; acc_dov *= attn_scale; acc_doo *= attn_scale; @@ -186,7 +204,11 @@ void __global__ fna_bwd_reference_dK_kernel( Causal is_causal, QKVLayout qkv_layout, float attn_scale, - int num_additional_kv) { + int num_additional_kv, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { using namespace cute; auto attention_mask = @@ -223,6 +245,20 @@ void __global__ fna_bwd_reference_dK_kernel( acc_dov += mDO(idx_Q, idx_D1, idx_L) * mV(idx_K, idx_D1, idx_L); acc_doo += mDO(idx_Q, idx_D1, idx_L) * mO(idx_Q, idx_D1, idx_L); } // for idx_D1 + + // (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING & + // SCALING. + if (has_dot_product_min || has_dot_product_max) { + if (not has_dot_product_max) { + acc_qk = cutlass::fast_max(acc_qk, dot_product_min); + } else if (not has_dot_product_min) { + acc_qk = cutlass::fast_min(acc_qk, dot_product_max); + } else { + acc_qk = cutlass::fast_max( + cutlass::fast_min(acc_qk, dot_product_max), dot_product_min); + } + } + acc_qk *= attn_scale; acc_dov *= attn_scale; acc_doo *= attn_scale; @@ -299,7 +335,11 @@ void __global__ fna_bwd_reference_dV_kernel( Causal is_causal, QKVLayout qkv_layout, float attn_scale, - int num_additional_kv) { + int num_additional_kv, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { using namespace cute; auto attention_mask = @@ -333,6 +373,20 @@ void __global__ fna_bwd_reference_dV_kernel( ElementAccumulator rK = mK(idx_K, idx_D0, idx_L); acc_qk += rQ * rK; } // for idx_D0 + + // (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING & + // SCALING. + if (has_dot_product_min || has_dot_product_max) { + if (not has_dot_product_max) { + acc_qk = cutlass::fast_max(acc_qk, dot_product_min); + } else if (not has_dot_product_min) { + acc_qk = cutlass::fast_min(acc_qk, dot_product_max); + } else { + acc_qk = cutlass::fast_max( + cutlass::fast_min(acc_qk, dot_product_max), dot_product_min); + } + } + acc_qk *= attn_scale; auto id = make_identity_tensor(make_shape(1, 1)); @@ -408,7 +462,11 @@ void fna_bwd_reference_dQ( QKVLayout qkv_layout, float attn_scale, int num_additional_kv, - cudaStream_t stream) { + cudaStream_t stream, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { using namespace cute; // Only so that we don't oversubscribe shmem when seqlen is large. @@ -447,7 +505,11 @@ void fna_bwd_reference_dQ( Causal{}, qkv_layout, attn_scale, - num_additional_kv); + num_additional_kv, + has_dot_product_min, + has_dot_product_max, + dot_product_min, + dot_product_max); } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -480,7 +542,11 @@ void fna_bwd_reference_dK( QKVLayout qkv_layout, float attn_scale, int num_additional_kv, - cudaStream_t stream) { + cudaStream_t stream, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { using namespace cute; // Only so that we don't oversubscribe shmem when seqlen is large. @@ -519,7 +585,11 @@ void fna_bwd_reference_dK( Causal{}, qkv_layout, attn_scale, - num_additional_kv); + num_additional_kv, + has_dot_product_min, + has_dot_product_max, + dot_product_min, + dot_product_max); } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -552,7 +622,11 @@ void fna_bwd_reference_dV( QKVLayout qkv_layout, float attn_scale, int num_additional_kv, - cudaStream_t stream) { + cudaStream_t stream, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { using namespace cute; // Only so that we don't oversubscribe shmem when seqlen is large. @@ -591,7 +665,11 @@ void fna_bwd_reference_dV( Causal{}, qkv_layout, attn_scale, - num_additional_kv); + num_additional_kv, + has_dot_product_min, + has_dot_product_max, + dot_product_min, + dot_product_max); } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -619,7 +697,11 @@ void fna_reference_backward( NADim dilation, Causal is_causal, float attn_scale, - cudaStream_t stream) { + cudaStream_t stream, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { using namespace cute; // No GQA/MQA for now @@ -694,7 +776,11 @@ void fna_reference_backward( qkv_layout, attn_scale, num_additional_kv, - stream); + stream, + has_dot_product_min, + has_dot_product_max, + dot_product_min, + dot_product_max); fna_bwd_reference_dK( problem_shape, mQ, @@ -711,7 +797,11 @@ void fna_reference_backward( qkv_layout, attn_scale, num_additional_kv, - stream); + stream, + has_dot_product_min, + has_dot_product_max, + dot_product_min, + dot_product_max); fna_bwd_reference_dV( problem_shape, mQ, @@ -728,7 +818,11 @@ void fna_reference_backward( qkv_layout, attn_scale, num_additional_kv, - stream); + stream, + has_dot_product_min, + has_dot_product_max, + dot_product_min, + dot_product_max); } } // namespace reference diff --git a/csrc/include/natten/cuda/reference/fna_reference_forward.hpp b/csrc/include/natten/cuda/reference/fna_reference_forward.hpp index d316961a..bf7c87e0 100644 --- a/csrc/include/natten/cuda/reference/fna_reference_forward.hpp +++ b/csrc/include/natten/cuda/reference/fna_reference_forward.hpp @@ -73,7 +73,11 @@ void __global__ fna_reference_kernel( Causal is_causal, QKVLayout qkv_layout, float attn_scale, - int num_additional_kv) { + int num_additional_kv, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { using namespace cute; auto attention_mask = @@ -133,6 +137,20 @@ void __global__ fna_reference_kernel( ElementAccumulator eK = mK(idx_K + offset_K, idx_D, idx_L); acc += eQ * eK; } + + // (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING & + // SCALING. + if (has_dot_product_min || has_dot_product_max) { + if (not has_dot_product_max) { + acc = cutlass::fast_max(acc, dot_product_min); + } else if (not has_dot_product_min) { + acc = cutlass::fast_min(acc, dot_product_max); + } else { + acc = cutlass::fast_max( + cutlass::fast_min(acc, dot_product_max), dot_product_min); + } + } + acc = acc * attn_scale; auto frag = make_tensor(Shape<_1, _1>{}); frag(0) = acc; @@ -231,7 +249,11 @@ void fna_reference_forward( NADim dilation, Causal is_causal, float attn_scale, - cudaStream_t stream) { + cudaStream_t stream, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { using namespace cute; // Only so that we don't oversubscribe shmem when seqlen is large. @@ -317,7 +339,11 @@ void fna_reference_forward( Causal{}, qkv_layout, attn_scale, - num_additional_kv); + num_additional_kv, + has_dot_product_min, + has_dot_product_max, + dot_product_min, + dot_product_max); NATTEN_CUDA_CHECK(cudaGetLastError()); } diff --git a/csrc/include/natten/fna.h b/csrc/include/natten/fna.h index 2900370b..a9774730 100644 --- a/csrc/include/natten/fna.h +++ b/csrc/include/natten/fna.h @@ -45,7 +45,11 @@ void na1d_forward( const std::tuple& is_causal, float attn_scale, const std::tuple& query_tile_size, - const std::tuple& key_tile_size); + const std::tuple& key_tile_size, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max); void na2d_forward( at::Tensor& out, @@ -59,7 +63,11 @@ void na2d_forward( const std::tuple& is_causal, float attn_scale, const std::tuple& query_tile_size, - const std::tuple& key_tile_size); + const std::tuple& key_tile_size, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max); void na3d_forward( at::Tensor& out, @@ -73,7 +81,11 @@ void na3d_forward( const std::tuple& is_causal, float attn_scale, const std::tuple& query_tile_size, - const std::tuple& key_tile_size); + const std::tuple& key_tile_size, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max); // Backward @@ -95,7 +107,11 @@ void na1d_backward( const std::tuple& query_tile_size, const std::tuple& key_tile_size, const std::tuple& num_splits_key, - bool compute_delta_with_torch); + bool compute_delta_with_torch, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max); void na2d_backward( at::Tensor& grad_query, @@ -115,7 +131,11 @@ void na2d_backward( const std::tuple& query_tile_size, const std::tuple& key_tile_size, const std::tuple& num_splits_key, - bool compute_delta_with_torch); + bool compute_delta_with_torch, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max); void na3d_backward( at::Tensor& grad_query, @@ -135,6 +155,10 @@ void na3d_backward( const std::tuple& query_tile_size, const std::tuple& key_tile_size, const std::tuple& num_splits_key, - bool compute_delta_with_torch); + bool compute_delta_with_torch, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max); } // namespace natten diff --git a/csrc/include/natten/reference.h b/csrc/include/natten/reference.h index d85477f0..02b710ae 100644 --- a/csrc/include/natten/reference.h +++ b/csrc/include/natten/reference.h @@ -45,7 +45,11 @@ void reference_na1d_forward( const std::tuple& is_causal, float attn_scale, const std::tuple& qkv_shape, - int num_extra_kv); + int num_extra_kv, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max); void reference_na2d_forward( at::Tensor& out, @@ -59,7 +63,11 @@ void reference_na2d_forward( const std::tuple& is_causal, float attn_scale, const std::tuple& qkv_shape, - int num_extra_kv); + int num_extra_kv, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max); void reference_na3d_forward( at::Tensor& out, @@ -73,7 +81,11 @@ void reference_na3d_forward( const std::tuple& is_causal, float attn_scale, const std::tuple& qkv_shape, - int num_extra_kv); + int num_extra_kv, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max); // Backward @@ -93,7 +105,11 @@ void reference_na1d_backward( const std::tuple& is_causal, float attn_scale, const std::tuple& qkv_shape, - int num_extra_kv); + int num_extra_kv, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max); void reference_na2d_backward( at::Tensor& grad_query, @@ -111,7 +127,11 @@ void reference_na2d_backward( const std::tuple& is_causal, float attn_scale, const std::tuple& qkv_shape, - int num_extra_kv); + int num_extra_kv, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max); void reference_na3d_backward( at::Tensor& grad_query, @@ -129,6 +149,10 @@ void reference_na3d_backward( const std::tuple& is_causal, float attn_scale, const std::tuple& qkv_shape, - int num_extra_kv); + int num_extra_kv, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max); } // namespace natten diff --git a/csrc/src/fna_backward.cu b/csrc/src/fna_backward.cu index 031da27e..d4a8bcdf 100644 --- a/csrc/src/fna_backward.cu +++ b/csrc/src/fna_backward.cu @@ -75,7 +75,11 @@ void fna_generic_backward( const StdNADim& query_tile_size, const StdNADim& key_tile_size, const StdNADim& num_splits_key, - bool compute_delta_with_torch) { + bool compute_delta_with_torch, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { static_assert( std::tuple_size_v > 0 && std::tuple_size_v < 4); static constexpr int kNADim = std::tuple_size_v; @@ -195,7 +199,11 @@ void fna_generic_backward( attn_scale, query_tile_size, key_tile_size, - num_splits_key); + num_splits_key, + has_dot_product_min, + has_dot_product_max, + dot_product_min, + dot_product_max); } else { NATTEN_FAILURE( "Fused kernels are only available on devices with " @@ -225,7 +233,11 @@ void na1d_backward( const std::tuple& query_tile_size, const std::tuple& key_tile_size, const std::tuple& num_splits_key, - bool compute_delta_with_torch) { + bool compute_delta_with_torch, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { TORCH_CHECK(query.dim() == 4, "Tensors must be 4-D."); fna_generic_backward( @@ -247,7 +259,11 @@ void na1d_backward( query_tile_size, key_tile_size, num_splits_key, - compute_delta_with_torch); + compute_delta_with_torch, + has_dot_product_min, + has_dot_product_max, + dot_product_min, + dot_product_max); } void na2d_backward( @@ -268,7 +284,11 @@ void na2d_backward( const std::tuple& query_tile_size, const std::tuple& key_tile_size, const std::tuple& num_splits_key, - bool compute_delta_with_torch) { + bool compute_delta_with_torch, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { TORCH_CHECK(query.dim() == 5, "Tensors must be 5-D."); fna_generic_backward( @@ -290,7 +310,11 @@ void na2d_backward( query_tile_size, key_tile_size, num_splits_key, - compute_delta_with_torch); + compute_delta_with_torch, + has_dot_product_min, + has_dot_product_max, + dot_product_min, + dot_product_max); } void na3d_backward( @@ -311,7 +335,11 @@ void na3d_backward( const std::tuple& query_tile_size, const std::tuple& key_tile_size, const std::tuple& num_splits_key, - bool compute_delta_with_torch) { + bool compute_delta_with_torch, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { TORCH_CHECK(query.dim() == 6, "Tensors must be 6-D."); fna_generic_backward( @@ -333,7 +361,11 @@ void na3d_backward( query_tile_size, key_tile_size, num_splits_key, - compute_delta_with_torch); + compute_delta_with_torch, + has_dot_product_min, + has_dot_product_max, + dot_product_min, + dot_product_max); } } // namespace natten diff --git a/csrc/src/fna_forward.cu b/csrc/src/fna_forward.cu index 3320a91a..ada9cff5 100644 --- a/csrc/src/fna_forward.cu +++ b/csrc/src/fna_forward.cu @@ -54,7 +54,11 @@ void fna_generic_forward( float attn_scale, const StdNADim& qkv_shape, const StdNADim& query_tile_size, - const StdNADim& key_tile_size) { + const StdNADim& key_tile_size, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { static_assert( std::tuple_size_v > 0 && std::tuple_size_v < 4); static constexpr int kNADim = std::tuple_size_v; @@ -130,7 +134,11 @@ void fna_generic_forward( logsumexp.has_value() ? static_cast(logsumexp.value().data_ptr()) : nullptr, query_tile_size, - key_tile_size); + key_tile_size, + has_dot_product_min, + has_dot_product_max, + dot_product_min, + dot_product_max); } else { NATTEN_FAILURE( "Fused kernels are only available on devices with " @@ -154,7 +162,11 @@ void na1d_forward( const std::tuple& is_causal, float attn_scale, const std::tuple& query_tile_size, - const std::tuple& key_tile_size) { + const std::tuple& key_tile_size, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { TORCH_CHECK(query.dim() == 4, "Tensors must be 4-D."); fna_generic_forward( @@ -170,7 +182,11 @@ void na1d_forward( attn_scale, {query.size(1)}, query_tile_size, - key_tile_size); + key_tile_size, + has_dot_product_min, + has_dot_product_max, + dot_product_min, + dot_product_max); } void na2d_forward( @@ -185,7 +201,11 @@ void na2d_forward( const std::tuple& is_causal, float attn_scale, const std::tuple& query_tile_size, - const std::tuple& key_tile_size) { + const std::tuple& key_tile_size, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { TORCH_CHECK(query.dim() == 5, "Tensors must be 5-D."); fna_generic_forward( @@ -201,7 +221,11 @@ void na2d_forward( attn_scale, {query.size(1), query.size(2)}, query_tile_size, - key_tile_size); + key_tile_size, + has_dot_product_min, + has_dot_product_max, + dot_product_min, + dot_product_max); } void na3d_forward( @@ -216,7 +240,11 @@ void na3d_forward( const std::tuple& is_causal, float attn_scale, const std::tuple& query_tile_size, - const std::tuple& key_tile_size) { + const std::tuple& key_tile_size, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { TORCH_CHECK(query.dim() == 6, "Tensors must be 6-D."); fna_generic_forward( @@ -232,7 +260,11 @@ void na3d_forward( attn_scale, {query.size(1), query.size(2), query.size(3)}, query_tile_size, - key_tile_size); + key_tile_size, + has_dot_product_min, + has_dot_product_max, + dot_product_min, + dot_product_max); } } // namespace natten diff --git a/csrc/src/reference_backward.cu b/csrc/src/reference_backward.cu index dde9304a..857e539c 100644 --- a/csrc/src/reference_backward.cu +++ b/csrc/src/reference_backward.cu @@ -72,7 +72,11 @@ void reference_na_generic_backward( const StdCausal& is_causal, float attn_scale, const StdNADim& qkv_shape, - int num_extra_kv) { + int num_extra_kv, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { static_assert( std::tuple_size_v > 0 && std::tuple_size_v < 4); static constexpr int kNADim = std::tuple_size_v; @@ -177,7 +181,11 @@ void reference_na_generic_backward( stride_, dilation_, attn_scale, - cuda_stream); + cuda_stream, + has_dot_product_min, + has_dot_product_max, + dot_product_min, + dot_product_max); #else TORCH_CHECK(false, "libnatten not compiled with CUTLASS."); #endif @@ -199,7 +207,11 @@ void reference_na1d_backward( const std::tuple& is_causal, float attn_scale, const std::tuple& qkv_shape, - int num_extra_kv) { + int num_extra_kv, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { TORCH_CHECK(query.dim() == 4, "Tensors must be 4-D."); reference_na_generic_backward( @@ -218,7 +230,11 @@ void reference_na1d_backward( is_causal, attn_scale, qkv_shape, - num_extra_kv); + num_extra_kv, + has_dot_product_min, + has_dot_product_max, + dot_product_min, + dot_product_max); } void reference_na2d_backward( @@ -237,7 +253,11 @@ void reference_na2d_backward( const std::tuple& is_causal, float attn_scale, const std::tuple& qkv_shape, - int num_extra_kv) { + int num_extra_kv, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { TORCH_CHECK(query.dim() == 4, "Tensors must be 4-D."); reference_na_generic_backward( @@ -256,7 +276,11 @@ void reference_na2d_backward( is_causal, attn_scale, qkv_shape, - num_extra_kv); + num_extra_kv, + has_dot_product_min, + has_dot_product_max, + dot_product_min, + dot_product_max); } void reference_na3d_backward( @@ -275,7 +299,11 @@ void reference_na3d_backward( const std::tuple& is_causal, float attn_scale, const std::tuple& qkv_shape, - int num_extra_kv) { + int num_extra_kv, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { TORCH_CHECK(query.dim() == 4, "Tensors must be 4-D."); reference_na_generic_backward( @@ -294,7 +322,11 @@ void reference_na3d_backward( is_causal, attn_scale, qkv_shape, - num_extra_kv); + num_extra_kv, + has_dot_product_min, + has_dot_product_max, + dot_product_min, + dot_product_max); } } // namespace natten diff --git a/csrc/src/reference_forward.cu b/csrc/src/reference_forward.cu index ce600da6..39218906 100644 --- a/csrc/src/reference_forward.cu +++ b/csrc/src/reference_forward.cu @@ -68,7 +68,11 @@ void reference_na_generic_forward( const StdCausal& is_causal, float attn_scale, const StdNADim& qkv_shape, - int num_extra_kv) { + int num_extra_kv, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { static_assert( std::tuple_size_v > 0 && std::tuple_size_v < 4); static constexpr int kNADim = std::tuple_size_v; @@ -152,7 +156,11 @@ void reference_na_generic_forward( stride_, dilation_, attn_scale, - cuda_stream); + cuda_stream, + has_dot_product_min, + has_dot_product_max, + dot_product_min, + dot_product_max); #else TORCH_CHECK(false, "libnatten not compiled with CUTLASS."); #endif @@ -170,7 +178,11 @@ void reference_na1d_forward( const std::tuple& is_causal, float attn_scale, const std::tuple& qkv_shape, - int num_extra_kv) { + int num_extra_kv, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { TORCH_CHECK(query.dim() == 4, "Tensors must be 4-D."); reference_na_generic_forward( @@ -185,7 +197,11 @@ void reference_na1d_forward( is_causal, attn_scale, qkv_shape, - num_extra_kv); + num_extra_kv, + has_dot_product_min, + has_dot_product_max, + dot_product_min, + dot_product_max); } void reference_na2d_forward( @@ -200,7 +216,11 @@ void reference_na2d_forward( const std::tuple& is_causal, float attn_scale, const std::tuple& qkv_shape, - int num_extra_kv) { + int num_extra_kv, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { TORCH_CHECK(query.dim() == 4, "Tensors must be 4-D."); reference_na_generic_forward( @@ -215,7 +235,11 @@ void reference_na2d_forward( is_causal, attn_scale, qkv_shape, - num_extra_kv); + num_extra_kv, + has_dot_product_min, + has_dot_product_max, + dot_product_min, + dot_product_max); } void reference_na3d_forward( @@ -230,7 +254,11 @@ void reference_na3d_forward( const std::tuple& is_causal, float attn_scale, const std::tuple& qkv_shape, - int num_extra_kv) { + int num_extra_kv, + bool has_dot_product_min, + bool has_dot_product_max, + float dot_product_min, + float dot_product_max) { TORCH_CHECK(query.dim() == 4, "Tensors must be 4-D."); reference_na_generic_forward( @@ -245,7 +273,11 @@ void reference_na3d_forward( is_causal, attn_scale, qkv_shape, - num_extra_kv); + num_extra_kv, + has_dot_product_min, + has_dot_product_max, + dot_product_min, + dot_product_max); } } // namespace natten diff --git a/scripts/autogen_reference_fna.py b/scripts/autogen_reference_fna.py index 5784d58f..530ad5f5 100644 --- a/scripts/autogen_reference_fna.py +++ b/scripts/autogen_reference_fna.py @@ -35,7 +35,9 @@ {DimType} stride, {DimType} dilation, float attn_scale, - cudaStream_t stream); + cudaStream_t stream, + bool has_dot_product_min, bool has_dot_product_max, + float dot_product_min, float dot_product_max); """ @@ -57,7 +59,9 @@ {DimType} stride, {DimType} dilation, float attn_scale, - cudaStream_t stream) {{ + cudaStream_t stream, + bool has_dot_product_min, bool has_dot_product_max, + float dot_product_min, float dot_product_max) {{ using Causal = {Causal}; @@ -79,7 +83,9 @@ dilation, Causal{{}}, attn_scale, - stream); + stream, + has_dot_product_min, has_dot_product_max, + dot_product_min, dot_product_max); }} """ @@ -106,7 +112,9 @@ {DimType} stride, {DimType} dilation, float attn_scale, - cudaStream_t stream); + cudaStream_t stream, + bool has_dot_product_min, bool has_dot_product_max, + float dot_product_min, float dot_product_max); """ @@ -132,7 +140,9 @@ {DimType} stride, {DimType} dilation, float attn_scale, - cudaStream_t stream) {{ + cudaStream_t stream, + bool has_dot_product_min, bool has_dot_product_max, + float dot_product_min, float dot_product_max) {{ using Causal = {Causal}; @@ -158,7 +168,9 @@ dilation, Causal{{}}, attn_scale, - stream); + stream, + has_dot_product_min, has_dot_product_max, + dot_product_min, dot_product_max); }} """ diff --git a/src/natten/backends/__init__.py b/src/natten/backends/__init__.py index c76b3caf..b65c2f34 100644 --- a/src/natten/backends/__init__.py +++ b/src/natten/backends/__init__.py @@ -22,7 +22,7 @@ ################################################################################################# -from typing import List +from typing import List, Optional from ..utils import log @@ -83,21 +83,37 @@ def choose_backend( - query: Tensor, key: Tensor, value: Tensor, torch_compile: bool + query: Tensor, + key: Tensor, + value: Tensor, + torch_compile: bool, + has_dot_product_clip: Optional[bool] = False, ) -> str: - if can_run_cutlass_blackwell_fna(query, key, value): + if can_run_cutlass_blackwell_fna( + query, key, value, has_dot_product_clip=has_dot_product_clip + ): logger.debug("Backend not set; picked Blackwell FNA kernel.") return "blackwell-fna" - if can_run_cutlass_hopper_fna(query, key, value): + if can_run_cutlass_hopper_fna( + query, key, value, has_dot_product_clip=has_dot_product_clip + ): logger.debug("Backend not set; picked Hopper FNA kernel.") return "hopper-fna" - if can_run_cutlass_fna(query, key, value): + if can_run_cutlass_fna( + query, key, value, has_dot_product_clip=has_dot_product_clip + ): logger.debug("Backend not set; picked CUTLASS (2.X) FNA kernel.") return "cutlass-fna" - if can_run_flex_attention(query, key, value, torch_compile=torch_compile): + if can_run_flex_attention( + query, + key, + value, + torch_compile=torch_compile, + has_dot_product_clip=has_dot_product_clip, + ): logger.debug("Backend not set; picked Flex Attention kernel.") return "flex-fna" diff --git a/src/natten/backends/configs/checks.py b/src/natten/backends/configs/checks.py index b9071626..a38a3d77 100644 --- a/src/natten/backends/configs/checks.py +++ b/src/natten/backends/configs/checks.py @@ -22,6 +22,7 @@ ################################################################################################# import functools import math +from typing import Optional import torch from torch import Tensor @@ -37,10 +38,20 @@ def can_run_cutlass_blackwell_fmha( - query: Tensor, key: Tensor, value: Tensor, raise_error: bool = False + query: Tensor, + key: Tensor, + value: Tensor, + raise_error: bool = False, + has_dot_product_clip: Optional[bool] = False, ) -> bool: target_fn = functools.partial(log_or_raise_error, raise_error=raise_error) + if has_dot_product_clip: + target_fn( + "Blackwell FMHA does not support clipping dot products (dot_product_{min,max})." + ) + return False + if not HAS_LIBNATTEN: target_fn("Can't run Blackwell FMHA; NATTEN was not built with libnatten.") return False @@ -112,10 +123,20 @@ def can_run_cutlass_blackwell_fmha( def can_run_cutlass_blackwell_fna( - query: Tensor, key: Tensor, value: Tensor, raise_error: bool = False + query: Tensor, + key: Tensor, + value: Tensor, + raise_error: bool = False, + has_dot_product_clip: Optional[bool] = False, ) -> bool: target_fn = functools.partial(log_or_raise_error, raise_error=raise_error) + if has_dot_product_clip: + target_fn( + "Blackwell FNA does not support clipping dot products (dot_product_{min,max})." + ) + return False + if not HAS_LIBNATTEN: target_fn("Can't run Blackwell FNA; NATTEN was not built with libnatten.") return False @@ -191,10 +212,20 @@ def can_run_cutlass_blackwell_fna( def can_run_cutlass_hopper_fmha( - query: Tensor, key: Tensor, value: Tensor, raise_error: bool = False + query: Tensor, + key: Tensor, + value: Tensor, + raise_error: bool = False, + has_dot_product_clip: Optional[bool] = False, ) -> bool: target_fn = functools.partial(log_or_raise_error, raise_error=raise_error) + if has_dot_product_clip: + target_fn( + "Hopper FMHA does not support clipping dot products (dot_product_{min,max})." + ) + return False + if not HAS_LIBNATTEN: target_fn("Can't run Hopper FMHA; NATTEN was not built with libnatten.") return False @@ -274,10 +305,20 @@ def can_run_cutlass_hopper_fmha( def can_run_cutlass_hopper_fna( - query: Tensor, key: Tensor, value: Tensor, raise_error: bool = False + query: Tensor, + key: Tensor, + value: Tensor, + raise_error: bool = False, + has_dot_product_clip: Optional[bool] = False, ) -> bool: target_fn = functools.partial(log_or_raise_error, raise_error=raise_error) + if has_dot_product_clip: + target_fn( + "Hopper FNA does not support clipping dot products (dot_product_{min,max})." + ) + return False + if not HAS_LIBNATTEN: target_fn("Can't run Hopper FNA; NATTEN was not built with libnatten.") return False @@ -361,10 +402,20 @@ def can_run_cutlass_hopper_fna( def can_run_cutlass_fmha( - query: Tensor, key: Tensor, value: Tensor, raise_error: bool = False + query: Tensor, + key: Tensor, + value: Tensor, + raise_error: bool = False, + has_dot_product_clip: Optional[bool] = False, ) -> bool: target_fn = functools.partial(log_or_raise_error, raise_error=raise_error) + if has_dot_product_clip: + target_fn( + "CUTLASS FMHA does not support clipping dot products (dot_product_{min,max})." + ) + return False + if not HAS_LIBNATTEN: target_fn("Can't run CUTLASS FMHA; NATTEN was not built with libnatten.") return False @@ -436,7 +487,11 @@ def can_run_cutlass_fmha( def can_run_cutlass_fna( - query: Tensor, key: Tensor, value: Tensor, raise_error: bool = False + query: Tensor, + key: Tensor, + value: Tensor, + raise_error: bool = False, + has_dot_product_clip: Optional[bool] = False, ) -> bool: target_fn = functools.partial(log_or_raise_error, raise_error=raise_error) @@ -524,9 +579,16 @@ def can_run_flex_attention( value: Tensor, torch_compile: bool, raise_error: bool = False, + has_dot_product_clip: Optional[bool] = False, ) -> bool: target_fn = functools.partial(log_or_raise_error, raise_error=raise_error) + if has_dot_product_clip: + target_fn( + "Flex FMHA/FNA does not support clipping dot products (dot_product_{min,max})." + ) + return False + if not _FLEX_SUPPORTED: target_fn("Can't run NATTEN with Flex Attention with torch < 2.7.") return False diff --git a/src/natten/backends/fna.py b/src/natten/backends/fna.py index bff2051a..e8bf4c03 100644 --- a/src/natten/backends/fna.py +++ b/src/natten/backends/fna.py @@ -100,6 +100,8 @@ def forward( scale: float, forward_config: CutlassFnaForwardConfigType, backward_config: CutlassFnaBackwardConfigType, + dot_product_min: Optional[float] = None, + dot_product_max: Optional[float] = None, ) -> Tuple[Tensor, Tensor]: kernel_size, stride, dilation, is_causal = check_all_args( na_dim, kernel_size, stride, dilation, is_causal @@ -135,6 +137,10 @@ def forward( scale, q_tile_shape, kv_tile_shape, + dot_product_min is not None, + dot_product_max is not None, + dot_product_min or 0.0, + dot_product_max or 0.0, ) ctx.save_for_backward(query, key, value, logsumexp, output) @@ -144,6 +150,8 @@ def forward( ctx.is_causal = is_causal ctx.scale = scale ctx.backward_config = backward_config + ctx.dot_product_min = dot_product_min + ctx.dot_product_max = dot_product_max return output, logsumexp @@ -160,6 +168,8 @@ def backward(ctx, grad_out: Tensor, grad_lse: Tensor) -> Tuple[ NoneType, NoneType, NoneType, + NoneType, + NoneType, ]: query, key, value, logsumexp, output = ctx.saved_tensors d_output = grad_out.contiguous() @@ -202,9 +212,26 @@ def backward(ctx, grad_out: Tensor, grad_lse: Tensor) -> Tuple[ k_tile_shape, num_kv_splits, compute_delta_with_pt, + ctx.dot_product_min is not None, + ctx.dot_product_max is not None, + ctx.dot_product_min or 0.0, + ctx.dot_product_max or 0.0, ) - return d_query, d_key, d_value, None, None, None, None, None, None, None + return ( + d_query, + d_key, + d_value, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) return CutlassFnaGenericAutogradFn @@ -237,6 +264,8 @@ def cutlass_fna_generic( backward_kv_splits: Optional[DimensionType] = None, backward_use_pt_reduction: bool = False, return_lse: bool = False, + dot_product_min: Optional[float] = None, + dot_product_max: Optional[float] = None, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: na_tensor_checks(query, key, value, must_match_head_dims=False) @@ -288,6 +317,8 @@ def cutlass_fna_generic( scale, forward_config, backward_config, + dot_product_min, + dot_product_max, ) if return_lse: @@ -312,6 +343,8 @@ def na1d_cutlass_fna( backward_kv_splits: Optional[Dimension1DType] = None, backward_use_pt_reduction: bool = False, return_lse: bool = False, + dot_product_min: Optional[float] = None, + dot_product_max: Optional[float] = None, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: return cutlass_fna_generic( query=query, @@ -329,6 +362,8 @@ def na1d_cutlass_fna( backward_kv_splits=backward_kv_splits, backward_use_pt_reduction=backward_use_pt_reduction, return_lse=return_lse, + dot_product_min=dot_product_min, + dot_product_max=dot_product_max, ) @@ -348,6 +383,8 @@ def na2d_cutlass_fna( backward_kv_splits: Optional[Dimension2DType] = None, backward_use_pt_reduction: bool = False, return_lse: bool = False, + dot_product_min: Optional[float] = None, + dot_product_max: Optional[float] = None, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: return cutlass_fna_generic( query=query, @@ -365,6 +402,8 @@ def na2d_cutlass_fna( backward_kv_splits=backward_kv_splits, backward_use_pt_reduction=backward_use_pt_reduction, return_lse=return_lse, + dot_product_min=dot_product_min, + dot_product_max=dot_product_max, ) @@ -384,6 +423,8 @@ def na3d_cutlass_fna( backward_kv_splits: Optional[Dimension3DType] = None, backward_use_pt_reduction: bool = False, return_lse: bool = False, + dot_product_min: Optional[float] = None, + dot_product_max: Optional[float] = None, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: return cutlass_fna_generic( query=query, @@ -401,4 +442,6 @@ def na3d_cutlass_fna( backward_kv_splits=backward_kv_splits, backward_use_pt_reduction=backward_use_pt_reduction, return_lse=return_lse, + dot_product_min=dot_product_min, + dot_product_max=dot_product_max, ) diff --git a/src/natten/backends/reference.py b/src/natten/backends/reference.py index 26aaa216..588d9b55 100644 --- a/src/natten/backends/reference.py +++ b/src/natten/backends/reference.py @@ -94,6 +94,8 @@ def forward( scale: float, qkv_shape: DimensionType, num_extra_kv: int, + dot_product_min: Optional[float] = None, + dot_product_max: Optional[float] = None, ) -> Tuple[Tensor, Tensor]: kernel_size, stride, dilation, is_causal = check_all_args( na_dim, kernel_size, stride, dilation, is_causal @@ -127,6 +129,10 @@ def forward( scale, qkv_shape, num_extra_kv, + dot_product_min is not None, + dot_product_max is not None, + dot_product_min or 0.0, + dot_product_max or 0.0, ) ctx.save_for_backward(query, key, value, logsumexp, output) @@ -137,6 +143,8 @@ def forward( ctx.scale = scale ctx.qkv_shape = qkv_shape ctx.num_extra_kv = num_extra_kv + ctx.dot_product_min = dot_product_min + ctx.dot_product_max = dot_product_max return output, logsumexp @@ -153,6 +161,8 @@ def backward(ctx, grad_out: Tensor, grad_lse: Tensor) -> Tuple[ NoneType, NoneType, NoneType, + NoneType, + NoneType, ]: query, key, value, logsumexp, output = ctx.saved_tensors d_output = grad_out.contiguous() @@ -177,9 +187,26 @@ def backward(ctx, grad_out: Tensor, grad_lse: Tensor) -> Tuple[ ctx.scale, ctx.qkv_shape, ctx.num_extra_kv, + ctx.dot_product_min is not None, + ctx.dot_product_max is not None, + ctx.dot_product_min or 0.0, + ctx.dot_product_max or 0.0, ) - return d_query, d_key, d_value, None, None, None, None, None, None, None + return ( + d_query, + d_key, + d_value, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) return ReferenceFnaGenericAutogradFn @@ -208,6 +235,8 @@ def reference_fna_generic( additional_keys: Optional[Tensor] = None, additional_values: Optional[Tensor] = None, return_lse: bool = False, + dot_product_min: Optional[float] = None, + dot_product_max: Optional[float] = None, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: na_tensor_checks(query, key, value, must_match_head_dims=False) @@ -261,6 +290,8 @@ def reference_fna_generic( scale, qkv_shape, num_extra_kv, + dot_product_min, + dot_product_max, ) output = output.reshape( query.shape[0], *qkv_shape, query.shape[-2], value.shape[-1] @@ -285,6 +316,8 @@ def na1d_reference( additional_keys: Optional[Tensor] = None, additional_values: Optional[Tensor] = None, return_lse: bool = False, + dot_product_min: Optional[float] = None, + dot_product_max: Optional[float] = None, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: return reference_fna_generic( query=query, @@ -298,6 +331,8 @@ def na1d_reference( additional_keys=additional_keys, additional_values=additional_values, return_lse=return_lse, + dot_product_min=dot_product_min, + dot_product_max=dot_product_max, ) @@ -313,6 +348,8 @@ def na2d_reference( additional_keys: Optional[Tensor] = None, additional_values: Optional[Tensor] = None, return_lse: bool = False, + dot_product_min: Optional[float] = None, + dot_product_max: Optional[float] = None, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: return reference_fna_generic( query=query, @@ -326,6 +363,8 @@ def na2d_reference( additional_keys=additional_keys, additional_values=additional_values, return_lse=return_lse, + dot_product_min=dot_product_min, + dot_product_max=dot_product_max, ) @@ -341,6 +380,8 @@ def na3d_reference( additional_keys: Optional[Tensor] = None, additional_values: Optional[Tensor] = None, return_lse: bool = False, + dot_product_min: Optional[float] = None, + dot_product_max: Optional[float] = None, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: return reference_fna_generic( query=query, @@ -354,4 +395,6 @@ def na3d_reference( additional_keys=additional_keys, additional_values=additional_values, return_lse=return_lse, + dot_product_min=dot_product_min, + dot_product_max=dot_product_max, ) diff --git a/src/natten/functional.py b/src/natten/functional.py index 30258ea2..19152d74 100644 --- a/src/natten/functional.py +++ b/src/natten/functional.py @@ -333,6 +333,8 @@ def neighborhood_attention_generic( run_persistent_kernel: bool = True, kernel_schedule: Optional[Union[str, KernelSchedule]] = None, torch_compile: bool = False, + dot_product_min: Optional[float] = None, + dot_product_max: Optional[float] = None, ) -> Tensor: na_tensor_checks(query, key, value) @@ -355,7 +357,23 @@ def neighborhood_attention_generic( is_causal=is_causal, ) - if is_self_attention(query, kernel_size=kernel_size, is_causal=is_causal): + has_dot_product_clip = dot_product_min is not None or dot_product_max is not None + + if has_dot_product_clip: + if dot_product_min is not None and not isinstance(dot_product_min, float): + raise ValueError( + f"`dot_product_min` must be a float or None, got {type(dot_product_min)=}." + ) + if dot_product_max is not None and not isinstance(dot_product_max, float): + raise ValueError( + f"`dot_product_max` must be a float or None, got {type(dot_product_max)=}." + ) + + # NOTE: FMHA backends don't support dot product clipping yet, so fast path must be disabled for + # those cases. + if not has_dot_product_clip and is_self_attention( + query, kernel_size=kernel_size, is_causal=is_causal + ): logger.debug( f"{query.shape=} with {kernel_size=} and {is_causal=} is self attention. " "Calling attention instead of neighborhood attention directly." @@ -384,13 +402,20 @@ def neighborhood_attention_generic( scale = scale or query.shape[-1] ** -0.5 - backend = backend or choose_backend(query, key, value, torch_compile=torch_compile) + backend = backend or choose_backend( + query, + key, + value, + torch_compile=torch_compile, + has_dot_product_clip=has_dot_product_clip, + ) has_additional_attention = ( additional_keys is not None and additional_values is not None ) if backend == "blackwell-fna": + assert not has_dot_product_clip outputs = cutlass_blackwell_fna_generic( query=query, key=key, @@ -409,6 +434,7 @@ def neighborhood_attention_generic( ) elif backend == "hopper-fna": + assert not has_dot_product_clip outputs = cutlass_hopper_fna_generic( query=query, key=key, @@ -443,9 +469,12 @@ def neighborhood_attention_generic( backward_kv_splits=backward_kv_splits, backward_use_pt_reduction=backward_use_pt_reduction, return_lse=has_additional_attention, + dot_product_min=dot_product_min, + dot_product_max=dot_product_max, ) elif backend == "flex-fna": + assert not has_dot_product_clip outputs = flex_fna_generic( query=query, key=key, @@ -513,6 +542,8 @@ def na1d( run_persistent_kernel: bool = True, kernel_schedule: Optional[Union[str, KernelSchedule]] = None, torch_compile: bool = False, + dot_product_min: Optional[float] = None, + dot_product_max: Optional[float] = None, ) -> Tensor: """Computes 1-D neighborhood attention. @@ -665,6 +696,8 @@ def na1d( run_persistent_kernel=run_persistent_kernel, kernel_schedule=kernel_schedule, torch_compile=torch_compile, + dot_product_min=dot_product_min, + dot_product_max=dot_product_max, ) @@ -690,6 +723,8 @@ def na2d( run_persistent_kernel: bool = True, kernel_schedule: Optional[Union[str, KernelSchedule]] = None, torch_compile: bool = False, + dot_product_min: Optional[float] = None, + dot_product_max: Optional[float] = None, ) -> Tensor: """Computes 2-D neighborhood attention. @@ -852,6 +887,8 @@ def na2d( run_persistent_kernel=run_persistent_kernel, kernel_schedule=kernel_schedule, torch_compile=torch_compile, + dot_product_min=dot_product_min, + dot_product_max=dot_product_max, ) @@ -877,6 +914,8 @@ def na3d( run_persistent_kernel: bool = True, kernel_schedule: Optional[Union[str, KernelSchedule]] = None, torch_compile: bool = False, + dot_product_min: Optional[float] = None, + dot_product_max: Optional[float] = None, ) -> Tensor: """Computes 3-D neighborhood attention. @@ -1039,4 +1078,6 @@ def na3d( run_persistent_kernel=run_persistent_kernel, kernel_schedule=kernel_schedule, torch_compile=torch_compile, + dot_product_min=dot_product_min, + dot_product_max=dot_product_max, ) diff --git a/tests/test_fna.py b/tests/test_fna.py index 5067d15b..0a9e0840 100644 --- a/tests/test_fna.py +++ b/tests/test_fna.py @@ -86,6 +86,8 @@ def _test_all_dtypes_against_reference( is_causal=None, additional_kv_length=0, configs_to_test=5, + dot_product_min=None, + dot_product_max=None, ): torch.set_default_device("cuda") @@ -108,6 +110,8 @@ def _test_all_dtypes_against_reference( reference_backend="reference", reference_fmha_backend="reference", dtype=torch.float32, + dot_product_min=dot_product_min, + dot_product_max=dot_product_max, ) # TODO: write note on why backprop eps is different when additional_kv_length > 0 @@ -226,6 +230,20 @@ def test_1d_against_reference(self): is_causal=is_causal, additional_kv_length=additional_kv_length, ) + self._test_all_dtypes_against_reference( + batch=batch, + heads=heads, + head_dim=head_dim, + head_dim_v=head_dim_v, + input_shape=input_shape, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + is_causal=is_causal, + additional_kv_length=additional_kv_length, + dot_product_min=-1.0, + dot_product_max=1.0, + ) @skip_if_libnatten_is_not_supported() def test_2d_against_reference(self): @@ -271,6 +289,20 @@ def test_2d_against_reference(self): is_causal=is_causal, additional_kv_length=additional_kv_length, ) + self._test_all_dtypes_against_reference( + batch=batch, + heads=heads, + head_dim=head_dim, + head_dim_v=head_dim_v, + input_shape=input_shape, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + is_causal=is_causal, + additional_kv_length=additional_kv_length, + dot_product_min=-1.0, + dot_product_max=1.0, + ) @skip_if_not_running_extended_tests() @skip_if_libnatten_is_not_supported() @@ -367,6 +399,20 @@ def test_3d_against_reference(self): is_causal=is_causal, additional_kv_length=additional_kv_length, ) + self._test_all_dtypes_against_reference( + batch=batch, + heads=heads, + head_dim=head_dim, + head_dim_v=head_dim_v, + input_shape=input_shape, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + is_causal=is_causal, + additional_kv_length=additional_kv_length, + dot_product_min=-1.0, + dot_product_max=1.0, + ) @skip_if_not_running_extended_tests() @skip_if_libnatten_is_not_supported() @@ -487,6 +533,16 @@ def _test_rand_sweep_against_reference( random.choice(range(8, 513, 8)) if ENABLE_ADDITIONAL_KV_TESTS else 0 ) + dot_product_min = random.choice([None, random.uniform(-2.0, 2.0)]) + dot_product_max = random.choice( + [ + None, + random.uniform( + (dot_product_min or -2.0) + 2.0, (dot_product_min or -2.0) + 4.0 + ), + ] + ) + self._test_all_dtypes_against_reference( batch=batch, heads=heads, @@ -499,6 +555,8 @@ def _test_rand_sweep_against_reference( is_causal=is_causal, additional_kv_length=additional_kv_length, configs_to_test=configs_to_test, + dot_product_min=dot_product_min, + dot_product_max=dot_product_max, ) @skip_if_libnatten_is_not_supported() diff --git a/tests/utils.py b/tests/utils.py index 65e0c1ad..11fe7ef3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -66,6 +66,8 @@ def __init__( reference_fmha_backend: str, dtype: torch.dtype, head_dim_v: Optional[int] = None, + dot_product_min: Optional[float] = None, + dot_product_max: Optional[float] = None, ): assert isinstance(input_shape, tuple) na_dim = len(input_shape) @@ -85,6 +87,12 @@ def __init__( self.reference_backend = reference_backend self.reference_fmha_backend = reference_fmha_backend + self.dot_product_min = dot_product_min + self.dot_product_max = dot_product_max + self.has_dot_product_clip = ( + dot_product_min is not None or dot_product_max is not None + ) + with torch.no_grad(): q_ref, k_ref, v_ref, d_out_ref = ( torch.randn( @@ -164,6 +172,8 @@ def __init__( additional_keys=additional_k_ref, additional_values=additional_v_ref, return_lse=False, + dot_product_min=dot_product_min, + dot_product_max=dot_product_max, ) else: @@ -181,6 +191,8 @@ def __init__( additional_values=additional_v_ref, backend=reference_backend, attention_kwargs={"backend": reference_fmha_backend}, + dot_product_min=dot_product_min, + dot_product_max=dot_product_max, ) self.out_ref = out_ref_.data.clone().float() # type: ignore[union-attr] @@ -246,6 +258,11 @@ def test( f"{kernel_size=}, {stride=}, {dilation=}, {is_causal=}, {additional_kv_length=},\n" f"{q_tile_shape=}, {kv_tile_shape=}, {run_persistent_kernel=}, {kernel_schedule=}, " f"{torch_compile=}" + + ( + f", {self.dot_product_min=}, {self.dot_product_max=}, " + if self.has_dot_product_clip + else "" + ) + ( f"\n{backward_q_tile_shape=}, {backward_kv_tile_shape=}, " f"{backward_kv_splits=}, {backward_use_pt_reduction=}." @@ -300,6 +317,8 @@ def test( kernel_schedule=kernel_schedule, torch_compile=torch_compile, attention_kwargs={"backend": target_fmha_backend}, + dot_product_min=self.dot_product_min, + dot_product_max=self.dot_product_max, ) out = out_.data.clone().float() From c402a65a21a2a57ba31af122451eb80447ea07d5 Mon Sep 17 00:00:00 2001 From: Ali Hassani Date: Wed, 20 Aug 2025 12:37:06 -0400 Subject: [PATCH 2/2] [EXP] Mask off out of bound dot products WARNING: does not successfully pass the entire unit test sweep. Does not affect existing functionality, and only fails to pass with the same error threshold as existing use cases (without clipping) --- .../fna/epilogue/epilogue_rescale_output.h | 8 +- .../include/natten/cuda/fna/kernel_backward.h | 24 +++--- csrc/include/natten/cuda/fna/kernel_forward.h | 32 +++++--- .../cuda/reference/fna_reference_backward.hpp | 77 ++++++++++++------- .../cuda/reference/fna_reference_forward.hpp | 24 +++--- 5 files changed, 106 insertions(+), 59 deletions(-) diff --git a/csrc/include/natten/cuda/fna/epilogue/epilogue_rescale_output.h b/csrc/include/natten/cuda/fna/epilogue/epilogue_rescale_output.h index 0499111d..a3241900 100644 --- a/csrc/include/natten/cuda/fna/epilogue/epilogue_rescale_output.h +++ b/csrc/include/natten/cuda/fna/epilogue/epilogue_rescale_output.h @@ -179,7 +179,9 @@ class MemoryEfficientAttentionNormalize { multiplies mul_add_source; multiply_add mul_add_accumulator; - ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1; + auto s_prime = s_prime_[row]; + auto scale = s_prime == 0 ? 0 : 1 / s_prime; + ElementCompute alpha = isLast ? scale : 1; ElementCompute beta = alpha * m_prime_[row]; intermediate = mul_add_source(beta, converted_source); // X = beta * C @@ -209,7 +211,9 @@ class MemoryEfficientAttentionNormalize { ComputeFragment intermediate; multiplies mul_accumulator; - ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1; + auto s_prime = s_prime_[row]; + auto scale = s_prime == 0 ? 0 : 1 / s_prime; + ElementCompute alpha = isLast ? scale : 1; intermediate = mul_accumulator( alpha, converted_accumulator); // X = alpha * C + uniform diff --git a/csrc/include/natten/cuda/fna/kernel_backward.h b/csrc/include/natten/cuda/fna/kernel_backward.h index 1ff7df1c..a1251581 100644 --- a/csrc/include/natten/cuda/fna/kernel_backward.h +++ b/csrc/include/natten/cuda/fna/kernel_backward.h @@ -1639,30 +1639,34 @@ struct FusedNeighborhoodAttentionBackwardKernel { auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset( lane_id, warp_id, output_tile_coords); - // (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING & - // SCALING. + // Dot product scale + accum = cutlass::multiplies()(scale, accum); + + // (Optional) clip dot products (mask off out of bound dot products) if (p.has_dot_product_clip) { if (not p.has_dot_product_max) { for (int i = 0; i < Mma::FragmentC::kElements; ++i) { - accum[i] = cutlass::fast_max(accum[i], p.dot_product_min); + accum[i] = accum[i] < p.dot_product_min + ? -cutlass::platform::numeric_limits::infinity() + : accum[i]; } } else if (not p.has_dot_product_min) { for (int i = 0; i < Mma::FragmentC::kElements; ++i) { - accum[i] = cutlass::fast_min(accum[i], p.dot_product_max); + accum[i] = accum[i] > p.dot_product_max + ? -cutlass::platform::numeric_limits::infinity() + : accum[i]; } } else { // assert(p.has_dot_product_min && p.has_dot_product_max); for (int i = 0; i < Mma::FragmentC::kElements; ++i) { - accum[i] = cutlass::fast_max( - cutlass::fast_min(accum[i], p.dot_product_max), - p.dot_product_min); + accum[i] = + (accum[i] < p.dot_product_min || accum[i] > p.dot_product_max) + ? -cutlass::platform::numeric_limits::infinity() + : accum[i]; } } } - // Dot product scale - accum = cutlass::multiplies()(scale, accum); - if (not p.is_fully_block_sparse) { // Neighborhood Attention masking Dim first_col, query_bound, row_idx; diff --git a/csrc/include/natten/cuda/fna/kernel_forward.h b/csrc/include/natten/cuda/fna/kernel_forward.h index 29780321..5fb05050 100644 --- a/csrc/include/natten/cuda/fna/kernel_forward.h +++ b/csrc/include/natten/cuda/fna/kernel_forward.h @@ -745,23 +745,30 @@ struct FusedNeighborhoodAttentionKernel { MM1::Mma::drain_cp_asyncs(); } - // (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING & - // SCALING. + // (Optional) clip dot products (mask off out of bound dot products) if (p.has_dot_product_clip) { if (not p.has_dot_product_max) { for (int i = 0; i < MM0::Mma::FragmentC::kElements; ++i) { - accum[i] = cutlass::fast_max(accum[i], p.dot_product_min); + accum[i] = accum[i] * p.scale; + accum[i] = accum[i] < p.dot_product_min + ? -cutlass::platform::numeric_limits::infinity() + : accum[i]; } } else if (not p.has_dot_product_min) { for (int i = 0; i < MM0::Mma::FragmentC::kElements; ++i) { - accum[i] = cutlass::fast_min(accum[i], p.dot_product_max); + accum[i] = accum[i] * p.scale; + accum[i] = accum[i] > p.dot_product_max + ? -cutlass::platform::numeric_limits::infinity() + : accum[i]; } } else { // assert(p.has_dot_product_min && p.has_dot_product_max); for (int i = 0; i < MM0::Mma::FragmentC::kElements; ++i) { - accum[i] = cutlass::fast_max( - cutlass::fast_min(accum[i], p.dot_product_max), - p.dot_product_min); + accum[i] = accum[i] * p.scale; + accum[i] = + (accum[i] < p.dot_product_min || accum[i] > p.dot_product_max) + ? -cutlass::platform::numeric_limits::infinity() + : accum[i]; } } } @@ -823,7 +830,7 @@ struct FusedNeighborhoodAttentionKernel { last_kv_col, is_first_kv_iter, iteratorC_tile_offset, - p.scale); + p.has_dot_product_clip ? 1.0 : p.scale); // Output results to shared-memory @@ -999,8 +1006,13 @@ struct FusedNeighborhoodAttentionKernel { map_index_to_coord((int32_t)thread_id(), problem_size_0_m); auto query_offset = (query_idx * p.lse_strideM).sum(); if (is_coord_within_upper_bound(query_idx, problem_size_0_m)) { - p.logsumexp_ptr[query_offset] = accum_t(mi[thread_id()] / kLog2e) + - cutlass::fast_log(accum_t(s_prime[thread_id()])); + if (mi[thread_id()] == + -cutlass::platform::numeric_limits::infinity()) { + p.logsumexp_ptr[query_offset] = 0.0f; + } else { + p.logsumexp_ptr[query_offset] = accum_t(mi[thread_id()] / kLog2e) + + cutlass::fast_log(accum_t(s_prime[thread_id()])); + } //} else if (query_offset < lse_dim) { // p.logsumexp_ptr[query_offset] = // cutlass::platform::numeric_limits::infinity(); diff --git a/csrc/include/natten/cuda/reference/fna_reference_backward.hpp b/csrc/include/natten/cuda/reference/fna_reference_backward.hpp index 1ea4c4a0..b7e7458c 100644 --- a/csrc/include/natten/cuda/reference/fna_reference_backward.hpp +++ b/csrc/include/natten/cuda/reference/fna_reference_backward.hpp @@ -116,23 +116,30 @@ void __global__ fna_bwd_reference_dQ_kernel( acc_doo += mDO(idx_Q, idx_D1, idx_L) * mO(idx_Q, idx_D1, idx_L); } // for idx_D1 - // (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING & - // SCALING. + acc_qk *= attn_scale; + acc_dov *= attn_scale; + acc_doo *= attn_scale; + + // (Optional) clip dot products (mask off out of bound dot products) if (has_dot_product_min || has_dot_product_max) { if (not has_dot_product_max) { - acc_qk = cutlass::fast_max(acc_qk, dot_product_min); + acc_qk = acc_qk < dot_product_min + ? -cutlass::platform::numeric_limits< + ElementAccumulator>::infinity() + : acc_qk; } else if (not has_dot_product_min) { - acc_qk = cutlass::fast_min(acc_qk, dot_product_max); + acc_qk = acc_qk > dot_product_max + ? -cutlass::platform::numeric_limits< + ElementAccumulator>::infinity() + : acc_qk; } else { - acc_qk = cutlass::fast_max( - cutlass::fast_min(acc_qk, dot_product_max), dot_product_min); + acc_qk = (acc_qk < dot_product_min || acc_qk > dot_product_max) + ? -cutlass::platform::numeric_limits< + ElementAccumulator>::infinity() + : acc_qk; } } - acc_qk *= attn_scale; - acc_dov *= attn_scale; - acc_doo *= attn_scale; - auto id = make_identity_tensor(make_shape(1, 1)); auto frag = make_tensor(Shape<_1, _1>{}); frag(0) = acc_qk; @@ -246,23 +253,30 @@ void __global__ fna_bwd_reference_dK_kernel( acc_doo += mDO(idx_Q, idx_D1, idx_L) * mO(idx_Q, idx_D1, idx_L); } // for idx_D1 - // (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING & - // SCALING. + acc_qk *= attn_scale; + acc_dov *= attn_scale; + acc_doo *= attn_scale; + + // (Optional) clip dot products (mask off out of bound dot products) if (has_dot_product_min || has_dot_product_max) { if (not has_dot_product_max) { - acc_qk = cutlass::fast_max(acc_qk, dot_product_min); + acc_qk = acc_qk < dot_product_min + ? -cutlass::platform::numeric_limits< + ElementAccumulator>::infinity() + : acc_qk; } else if (not has_dot_product_min) { - acc_qk = cutlass::fast_min(acc_qk, dot_product_max); + acc_qk = acc_qk > dot_product_max + ? -cutlass::platform::numeric_limits< + ElementAccumulator>::infinity() + : acc_qk; } else { - acc_qk = cutlass::fast_max( - cutlass::fast_min(acc_qk, dot_product_max), dot_product_min); + acc_qk = (acc_qk < dot_product_min || acc_qk > dot_product_max) + ? -cutlass::platform::numeric_limits< + ElementAccumulator>::infinity() + : acc_qk; } } - acc_qk *= attn_scale; - acc_dov *= attn_scale; - acc_doo *= attn_scale; - auto id = make_identity_tensor(make_shape(1, 1)); auto frag = make_tensor(Shape<_1, _1>{}); frag(0) = acc_qk; @@ -374,21 +388,28 @@ void __global__ fna_bwd_reference_dV_kernel( acc_qk += rQ * rK; } // for idx_D0 - // (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING & - // SCALING. + acc_qk *= attn_scale; + + // (Optional) clip dot products (mask off out of bound dot products) if (has_dot_product_min || has_dot_product_max) { if (not has_dot_product_max) { - acc_qk = cutlass::fast_max(acc_qk, dot_product_min); + acc_qk = acc_qk < dot_product_min + ? -cutlass::platform::numeric_limits< + ElementAccumulator>::infinity() + : acc_qk; } else if (not has_dot_product_min) { - acc_qk = cutlass::fast_min(acc_qk, dot_product_max); + acc_qk = acc_qk > dot_product_max + ? -cutlass::platform::numeric_limits< + ElementAccumulator>::infinity() + : acc_qk; } else { - acc_qk = cutlass::fast_max( - cutlass::fast_min(acc_qk, dot_product_max), dot_product_min); + acc_qk = (acc_qk < dot_product_min || acc_qk > dot_product_max) + ? -cutlass::platform::numeric_limits< + ElementAccumulator>::infinity() + : acc_qk; } } - acc_qk *= attn_scale; - auto id = make_identity_tensor(make_shape(1, 1)); auto frag = make_tensor(Shape<_1, _1>{}); frag(0) = acc_qk; diff --git a/csrc/include/natten/cuda/reference/fna_reference_forward.hpp b/csrc/include/natten/cuda/reference/fna_reference_forward.hpp index bf7c87e0..3747801c 100644 --- a/csrc/include/natten/cuda/reference/fna_reference_forward.hpp +++ b/csrc/include/natten/cuda/reference/fna_reference_forward.hpp @@ -138,20 +138,26 @@ void __global__ fna_reference_kernel( acc += eQ * eK; } - // (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING & - // SCALING. + acc = acc * attn_scale; + + // (Optional) clip dot products (mask off out of bound dot products) if (has_dot_product_min || has_dot_product_max) { if (not has_dot_product_max) { - acc = cutlass::fast_max(acc, dot_product_min); + acc = acc < dot_product_min ? -cutlass::platform::numeric_limits< + ElementAccumulator>::infinity() + : acc; } else if (not has_dot_product_min) { - acc = cutlass::fast_min(acc, dot_product_max); + acc = acc > dot_product_max ? -cutlass::platform::numeric_limits< + ElementAccumulator>::infinity() + : acc; } else { - acc = cutlass::fast_max( - cutlass::fast_min(acc, dot_product_max), dot_product_min); + acc = (acc < dot_product_min || acc > dot_product_max) + ? -cutlass::platform::numeric_limits< + ElementAccumulator>::infinity() + : acc; } } - acc = acc * attn_scale; auto frag = make_tensor(Shape<_1, _1>{}); frag(0) = acc; attention_mask.apply_mask( @@ -212,17 +218,17 @@ void __global__ fna_reference_kernel( __syncthreads(); } + ElementAccumulator scale = sum == 0.0f ? 0.0f : 1.0f / sum; for (int i = 0; i < DimPerThread; ++i) { int idx_D = threadIdx.x + i * blockDim.x; if (idx_D < size<1>(mO)) { - ElementAccumulator scale = 1.0f / sum; mO(idx_Q + offset_Q, idx_D, idx_L) = static_cast(final_acc[i] * scale); } } if (threadIdx.x == 0) { - mLSE(idx_Q + offset_Q, idx_L) = log(sum) + maxS; + mLSE(idx_Q + offset_Q, idx_L) = sum == 0.0f ? 0.0f : (log(sum) + maxS); } } }