Skip to content

Commit 28afee6

Browse files
committed
Feat: Clip dot products
Ampere FNA only for now. Allows optionally clipping dot products according to some floating point range (min, max). - Issue: #249
1 parent aa73afc commit 28afee6

File tree

25 files changed

+1529
-337
lines changed

25 files changed

+1529
-337
lines changed

csrc/autogen/include/natten_autogen/cuda/reference/kernels.h

Lines changed: 252 additions & 84 deletions
Large diffs are not rendered by default.

csrc/autogen/src/cuda/reference/source_0.cu

Lines changed: 126 additions & 42 deletions
Large diffs are not rendered by default.

csrc/autogen/src/cuda/reference/source_1.cu

Lines changed: 126 additions & 42 deletions
Large diffs are not rendered by default.

csrc/autogen/src/cuda/reference/source_2.cu

Lines changed: 126 additions & 42 deletions
Large diffs are not rendered by default.

csrc/autogen/src/cuda/reference/source_3.cu

Lines changed: 126 additions & 42 deletions
Large diffs are not rendered by default.

csrc/include/natten/cuda/fna/fna_backward.cuh

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,11 @@ void fna_backward_generic(
8989
float attn_scale,
9090
IntTuple query_tile_shape,
9191
IntTuple key_tile_shape,
92-
IntTuple num_splits_key) {
92+
IntTuple num_splits_key,
93+
bool has_dot_product_min,
94+
bool has_dot_product_max,
95+
float dot_product_min,
96+
float dot_product_max) {
9397
static constexpr auto kRank =
9498
std::tuple_size<decltype(spatial_extent)>::value;
9599
using Dim = typename GetDim<kRank>::type;
@@ -157,6 +161,18 @@ void fna_backward_generic(
157161

158162
p.num_splits_key = tuple_to_na_dim<Dim>(num_splits_key);
159163

164+
// Optional dot product clipping
165+
p.has_dot_product_clip = has_dot_product_min || has_dot_product_max;
166+
p.has_dot_product_min = has_dot_product_min;
167+
p.has_dot_product_max = has_dot_product_max;
168+
if (has_dot_product_min) {
169+
p.dot_product_min = dot_product_min;
170+
}
171+
if (has_dot_product_max) {
172+
p.dot_product_max = dot_product_max;
173+
}
174+
//
175+
160176
int64_t size_bytes = p.workspace_size();
161177
if (size_bytes) {
162178
void* workspace_ptr = nullptr;

csrc/include/natten/cuda/fna/fna_forward.cuh

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,11 @@ void fna_forward_generic(
7979
float attn_scale,
8080
void* logsumexp_ptr,
8181
IntTuple query_tile_shape,
82-
IntTuple key_tile_shape) {
82+
IntTuple key_tile_shape,
83+
bool has_dot_product_min,
84+
bool has_dot_product_max,
85+
float dot_product_min,
86+
float dot_product_max) {
8387
static constexpr auto kRank =
8488
std::tuple_size<decltype(spatial_extent)>::value;
8589
using Dim = typename GetDim<kRank>::type;
@@ -167,6 +171,18 @@ void fna_forward_generic(
167171
p.query_tile_shape = tuple_to_na_dim<Dim>(query_tile_shape);
168172
p.key_tile_shape = tuple_to_na_dim<Dim>(key_tile_shape);
169173

174+
// Optional dot product clipping
175+
p.has_dot_product_clip = has_dot_product_min || has_dot_product_max;
176+
p.has_dot_product_min = has_dot_product_min;
177+
p.has_dot_product_max = has_dot_product_max;
178+
if (has_dot_product_min) {
179+
p.dot_product_min = dot_product_min;
180+
}
181+
if (has_dot_product_max) {
182+
p.dot_product_max = dot_product_max;
183+
}
184+
//
185+
170186
if (smem_bytes > 0xc000) {
171187
auto err = cudaFuncSetAttribute(
172188
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);

csrc/include/natten/cuda/fna/kernel_backward.h

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,17 @@ struct FusedNeighborhoodAttentionBackwardKernel {
692692
bool is_fully_block_sparse = false;
693693
bool has_q_padding = false;
694694

695+
// Optional dot product clipping -- all must be set explicitly for avoiding
696+
// comparisons.
697+
bool has_dot_product_clip = false;
698+
bool has_dot_product_min = false;
699+
bool has_dot_product_max = false;
700+
accum_t dot_product_min =
701+
-cutlass::platform::numeric_limits<accum_t>::infinity();
702+
accum_t dot_product_max =
703+
cutlass::platform::numeric_limits<accum_t>::infinity();
704+
//
705+
695706
// Dimensions/strides
696707
int32_t head_dim = -1;
697708
int32_t head_dim_value = -1;
@@ -1615,7 +1626,6 @@ struct FusedNeighborhoodAttentionBackwardKernel {
16151626
mma.set_prologue_done(kPrologueQK);
16161627
mma.set_zero_outside_bounds(/*!skipBoundsChecks*/ true);
16171628
mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
1618-
accum = cutlass::multiplies<typename Mma::FragmentC>()(scale, accum);
16191629

16201630
// Epilogue: add LSE + exp and store that to our shared memory buffer
16211631
// shmem <- (matmul_result -
@@ -1629,6 +1639,30 @@ struct FusedNeighborhoodAttentionBackwardKernel {
16291639
auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset(
16301640
lane_id, warp_id, output_tile_coords);
16311641

1642+
// (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING &
1643+
// SCALING.
1644+
if (p.has_dot_product_clip) {
1645+
if (not p.has_dot_product_max) {
1646+
for (int i = 0; i < Mma::FragmentC::kElements; ++i) {
1647+
accum[i] = cutlass::fast_max(accum[i], p.dot_product_min);
1648+
}
1649+
} else if (not p.has_dot_product_min) {
1650+
for (int i = 0; i < Mma::FragmentC::kElements; ++i) {
1651+
accum[i] = cutlass::fast_min(accum[i], p.dot_product_max);
1652+
}
1653+
} else {
1654+
// assert(p.has_dot_product_min && p.has_dot_product_max);
1655+
for (int i = 0; i < Mma::FragmentC::kElements; ++i) {
1656+
accum[i] = cutlass::fast_max(
1657+
cutlass::fast_min(accum[i], p.dot_product_max),
1658+
p.dot_product_min);
1659+
}
1660+
}
1661+
}
1662+
1663+
// Dot product scale
1664+
accum = cutlass::multiplies<typename Mma::FragmentC>()(scale, accum);
1665+
16321666
if (not p.is_fully_block_sparse) {
16331667
// Neighborhood Attention masking
16341668
Dim first_col, query_bound, row_idx;

csrc/include/natten/cuda/fna/kernel_forward.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,17 @@ struct FusedNeighborhoodAttentionKernel {
200200
bool is_fully_block_sparse = false;
201201
bool has_kv_padding = false;
202202

203+
// Optional dot product clipping -- all must be set explicitly for avoiding
204+
// comparisons.
205+
bool has_dot_product_clip = false;
206+
bool has_dot_product_min = false;
207+
bool has_dot_product_max = false;
208+
accum_t dot_product_min =
209+
-cutlass::platform::numeric_limits<accum_t>::infinity();
210+
accum_t dot_product_max =
211+
cutlass::platform::numeric_limits<accum_t>::infinity();
212+
//
213+
203214
// Moves pointers to what we should process
204215
// Returns "false" if there is no work to do
205216
CUTLASS_DEVICE bool advance_to_block() {
@@ -734,6 +745,27 @@ struct FusedNeighborhoodAttentionKernel {
734745
MM1::Mma::drain_cp_asyncs();
735746
}
736747

748+
// (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING &
749+
// SCALING.
750+
if (p.has_dot_product_clip) {
751+
if (not p.has_dot_product_max) {
752+
for (int i = 0; i < MM0::Mma::FragmentC::kElements; ++i) {
753+
accum[i] = cutlass::fast_max(accum[i], p.dot_product_min);
754+
}
755+
} else if (not p.has_dot_product_min) {
756+
for (int i = 0; i < MM0::Mma::FragmentC::kElements; ++i) {
757+
accum[i] = cutlass::fast_min(accum[i], p.dot_product_max);
758+
}
759+
} else {
760+
// assert(p.has_dot_product_min && p.has_dot_product_max);
761+
for (int i = 0; i < MM0::Mma::FragmentC::kElements; ++i) {
762+
accum[i] = cutlass::fast_max(
763+
cutlass::fast_min(accum[i], p.dot_product_max),
764+
p.dot_product_min);
765+
}
766+
}
767+
}
768+
737769
if (not p.is_fully_block_sparse) {
738770
// Neighborhood Attention masking
739771
Dim first_col, key_bound, row_idx;

csrc/include/natten/cuda/reference/fna_reference_backward.hpp

Lines changed: 107 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,11 @@ void __global__ fna_bwd_reference_dQ_kernel(
7474
Causal is_causal,
7575
QKVLayout qkv_layout,
7676
float attn_scale,
77-
int num_additional_kv) {
77+
int num_additional_kv,
78+
bool has_dot_product_min,
79+
bool has_dot_product_max,
80+
float dot_product_min,
81+
float dot_product_max) {
7882
using namespace cute;
7983

8084
auto attention_mask =
@@ -111,6 +115,20 @@ void __global__ fna_bwd_reference_dQ_kernel(
111115
acc_dov += mDO(idx_Q, idx_D1, idx_L) * mV(idx_K, idx_D1, idx_L);
112116
acc_doo += mDO(idx_Q, idx_D1, idx_L) * mO(idx_Q, idx_D1, idx_L);
113117
} // for idx_D1
118+
119+
// (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING &
120+
// SCALING.
121+
if (has_dot_product_min || has_dot_product_max) {
122+
if (not has_dot_product_max) {
123+
acc_qk = cutlass::fast_max(acc_qk, dot_product_min);
124+
} else if (not has_dot_product_min) {
125+
acc_qk = cutlass::fast_min(acc_qk, dot_product_max);
126+
} else {
127+
acc_qk = cutlass::fast_max(
128+
cutlass::fast_min(acc_qk, dot_product_max), dot_product_min);
129+
}
130+
}
131+
114132
acc_qk *= attn_scale;
115133
acc_dov *= attn_scale;
116134
acc_doo *= attn_scale;
@@ -186,7 +204,11 @@ void __global__ fna_bwd_reference_dK_kernel(
186204
Causal is_causal,
187205
QKVLayout qkv_layout,
188206
float attn_scale,
189-
int num_additional_kv) {
207+
int num_additional_kv,
208+
bool has_dot_product_min,
209+
bool has_dot_product_max,
210+
float dot_product_min,
211+
float dot_product_max) {
190212
using namespace cute;
191213

192214
auto attention_mask =
@@ -223,6 +245,20 @@ void __global__ fna_bwd_reference_dK_kernel(
223245
acc_dov += mDO(idx_Q, idx_D1, idx_L) * mV(idx_K, idx_D1, idx_L);
224246
acc_doo += mDO(idx_Q, idx_D1, idx_L) * mO(idx_Q, idx_D1, idx_L);
225247
} // for idx_D1
248+
249+
// (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING &
250+
// SCALING.
251+
if (has_dot_product_min || has_dot_product_max) {
252+
if (not has_dot_product_max) {
253+
acc_qk = cutlass::fast_max(acc_qk, dot_product_min);
254+
} else if (not has_dot_product_min) {
255+
acc_qk = cutlass::fast_min(acc_qk, dot_product_max);
256+
} else {
257+
acc_qk = cutlass::fast_max(
258+
cutlass::fast_min(acc_qk, dot_product_max), dot_product_min);
259+
}
260+
}
261+
226262
acc_qk *= attn_scale;
227263
acc_dov *= attn_scale;
228264
acc_doo *= attn_scale;
@@ -299,7 +335,11 @@ void __global__ fna_bwd_reference_dV_kernel(
299335
Causal is_causal,
300336
QKVLayout qkv_layout,
301337
float attn_scale,
302-
int num_additional_kv) {
338+
int num_additional_kv,
339+
bool has_dot_product_min,
340+
bool has_dot_product_max,
341+
float dot_product_min,
342+
float dot_product_max) {
303343
using namespace cute;
304344

305345
auto attention_mask =
@@ -333,6 +373,20 @@ void __global__ fna_bwd_reference_dV_kernel(
333373
ElementAccumulator rK = mK(idx_K, idx_D0, idx_L);
334374
acc_qk += rQ * rK;
335375
} // for idx_D0
376+
377+
// (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING &
378+
// SCALING.
379+
if (has_dot_product_min || has_dot_product_max) {
380+
if (not has_dot_product_max) {
381+
acc_qk = cutlass::fast_max(acc_qk, dot_product_min);
382+
} else if (not has_dot_product_min) {
383+
acc_qk = cutlass::fast_min(acc_qk, dot_product_max);
384+
} else {
385+
acc_qk = cutlass::fast_max(
386+
cutlass::fast_min(acc_qk, dot_product_max), dot_product_min);
387+
}
388+
}
389+
336390
acc_qk *= attn_scale;
337391

338392
auto id = make_identity_tensor(make_shape(1, 1));
@@ -408,7 +462,11 @@ void fna_bwd_reference_dQ(
408462
QKVLayout qkv_layout,
409463
float attn_scale,
410464
int num_additional_kv,
411-
cudaStream_t stream) {
465+
cudaStream_t stream,
466+
bool has_dot_product_min,
467+
bool has_dot_product_max,
468+
float dot_product_min,
469+
float dot_product_max) {
412470
using namespace cute;
413471

414472
// Only so that we don't oversubscribe shmem when seqlen is large.
@@ -447,7 +505,11 @@ void fna_bwd_reference_dQ(
447505
Causal{},
448506
qkv_layout,
449507
attn_scale,
450-
num_additional_kv);
508+
num_additional_kv,
509+
has_dot_product_min,
510+
has_dot_product_max,
511+
dot_product_min,
512+
dot_product_max);
451513
}
452514

453515
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -480,7 +542,11 @@ void fna_bwd_reference_dK(
480542
QKVLayout qkv_layout,
481543
float attn_scale,
482544
int num_additional_kv,
483-
cudaStream_t stream) {
545+
cudaStream_t stream,
546+
bool has_dot_product_min,
547+
bool has_dot_product_max,
548+
float dot_product_min,
549+
float dot_product_max) {
484550
using namespace cute;
485551

486552
// Only so that we don't oversubscribe shmem when seqlen is large.
@@ -519,7 +585,11 @@ void fna_bwd_reference_dK(
519585
Causal{},
520586
qkv_layout,
521587
attn_scale,
522-
num_additional_kv);
588+
num_additional_kv,
589+
has_dot_product_min,
590+
has_dot_product_max,
591+
dot_product_min,
592+
dot_product_max);
523593
}
524594

525595
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -552,7 +622,11 @@ void fna_bwd_reference_dV(
552622
QKVLayout qkv_layout,
553623
float attn_scale,
554624
int num_additional_kv,
555-
cudaStream_t stream) {
625+
cudaStream_t stream,
626+
bool has_dot_product_min,
627+
bool has_dot_product_max,
628+
float dot_product_min,
629+
float dot_product_max) {
556630
using namespace cute;
557631

558632
// Only so that we don't oversubscribe shmem when seqlen is large.
@@ -591,7 +665,11 @@ void fna_bwd_reference_dV(
591665
Causal{},
592666
qkv_layout,
593667
attn_scale,
594-
num_additional_kv);
668+
num_additional_kv,
669+
has_dot_product_min,
670+
has_dot_product_max,
671+
dot_product_min,
672+
dot_product_max);
595673
}
596674

597675
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -619,7 +697,11 @@ void fna_reference_backward(
619697
NADim dilation,
620698
Causal is_causal,
621699
float attn_scale,
622-
cudaStream_t stream) {
700+
cudaStream_t stream,
701+
bool has_dot_product_min,
702+
bool has_dot_product_max,
703+
float dot_product_min,
704+
float dot_product_max) {
623705
using namespace cute;
624706

625707
// No GQA/MQA for now
@@ -694,7 +776,11 @@ void fna_reference_backward(
694776
qkv_layout,
695777
attn_scale,
696778
num_additional_kv,
697-
stream);
779+
stream,
780+
has_dot_product_min,
781+
has_dot_product_max,
782+
dot_product_min,
783+
dot_product_max);
698784
fna_bwd_reference_dK(
699785
problem_shape,
700786
mQ,
@@ -711,7 +797,11 @@ void fna_reference_backward(
711797
qkv_layout,
712798
attn_scale,
713799
num_additional_kv,
714-
stream);
800+
stream,
801+
has_dot_product_min,
802+
has_dot_product_max,
803+
dot_product_min,
804+
dot_product_max);
715805
fna_bwd_reference_dV(
716806
problem_shape,
717807
mQ,
@@ -728,7 +818,11 @@ void fna_reference_backward(
728818
qkv_layout,
729819
attn_scale,
730820
num_additional_kv,
731-
stream);
821+
stream,
822+
has_dot_product_min,
823+
has_dot_product_max,
824+
dot_product_min,
825+
dot_product_max);
732826
}
733827

734828
} // namespace reference

0 commit comments

Comments
 (0)