@@ -74,7 +74,11 @@ void __global__ fna_bwd_reference_dQ_kernel(
7474 Causal is_causal,
7575 QKVLayout qkv_layout,
7676 float attn_scale,
77- int num_additional_kv) {
77+ int num_additional_kv,
78+ bool has_dot_product_min,
79+ bool has_dot_product_max,
80+ float dot_product_min,
81+ float dot_product_max) {
7882 using namespace cute ;
7983
8084 auto attention_mask =
@@ -111,6 +115,20 @@ void __global__ fna_bwd_reference_dQ_kernel(
111115 acc_dov += mDO (idx_Q, idx_D1, idx_L) * mV (idx_K, idx_D1, idx_L);
112116 acc_doo += mDO (idx_Q, idx_D1, idx_L) * mO (idx_Q, idx_D1, idx_L);
113117 } // for idx_D1
118+
119+ // (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING &
120+ // SCALING.
121+ if (has_dot_product_min || has_dot_product_max) {
122+ if (not has_dot_product_max) {
123+ acc_qk = cutlass::fast_max (acc_qk, dot_product_min);
124+ } else if (not has_dot_product_min) {
125+ acc_qk = cutlass::fast_min (acc_qk, dot_product_max);
126+ } else {
127+ acc_qk = cutlass::fast_max (
128+ cutlass::fast_min (acc_qk, dot_product_max), dot_product_min);
129+ }
130+ }
131+
114132 acc_qk *= attn_scale;
115133 acc_dov *= attn_scale;
116134 acc_doo *= attn_scale;
@@ -186,7 +204,11 @@ void __global__ fna_bwd_reference_dK_kernel(
186204 Causal is_causal,
187205 QKVLayout qkv_layout,
188206 float attn_scale,
189- int num_additional_kv) {
207+ int num_additional_kv,
208+ bool has_dot_product_min,
209+ bool has_dot_product_max,
210+ float dot_product_min,
211+ float dot_product_max) {
190212 using namespace cute ;
191213
192214 auto attention_mask =
@@ -223,6 +245,20 @@ void __global__ fna_bwd_reference_dK_kernel(
223245 acc_dov += mDO (idx_Q, idx_D1, idx_L) * mV (idx_K, idx_D1, idx_L);
224246 acc_doo += mDO (idx_Q, idx_D1, idx_L) * mO (idx_Q, idx_D1, idx_L);
225247 } // for idx_D1
248+
249+ // (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING &
250+ // SCALING.
251+ if (has_dot_product_min || has_dot_product_max) {
252+ if (not has_dot_product_max) {
253+ acc_qk = cutlass::fast_max (acc_qk, dot_product_min);
254+ } else if (not has_dot_product_min) {
255+ acc_qk = cutlass::fast_min (acc_qk, dot_product_max);
256+ } else {
257+ acc_qk = cutlass::fast_max (
258+ cutlass::fast_min (acc_qk, dot_product_max), dot_product_min);
259+ }
260+ }
261+
226262 acc_qk *= attn_scale;
227263 acc_dov *= attn_scale;
228264 acc_doo *= attn_scale;
@@ -299,7 +335,11 @@ void __global__ fna_bwd_reference_dV_kernel(
299335 Causal is_causal,
300336 QKVLayout qkv_layout,
301337 float attn_scale,
302- int num_additional_kv) {
338+ int num_additional_kv,
339+ bool has_dot_product_min,
340+ bool has_dot_product_max,
341+ float dot_product_min,
342+ float dot_product_max) {
303343 using namespace cute ;
304344
305345 auto attention_mask =
@@ -333,6 +373,20 @@ void __global__ fna_bwd_reference_dV_kernel(
333373 ElementAccumulator rK = mK (idx_K, idx_D0, idx_L);
334374 acc_qk += rQ * rK;
335375 } // for idx_D0
376+
377+ // (Optional) clip dot products -- MUST BE DONE PRIOR TO MASKING &
378+ // SCALING.
379+ if (has_dot_product_min || has_dot_product_max) {
380+ if (not has_dot_product_max) {
381+ acc_qk = cutlass::fast_max (acc_qk, dot_product_min);
382+ } else if (not has_dot_product_min) {
383+ acc_qk = cutlass::fast_min (acc_qk, dot_product_max);
384+ } else {
385+ acc_qk = cutlass::fast_max (
386+ cutlass::fast_min (acc_qk, dot_product_max), dot_product_min);
387+ }
388+ }
389+
336390 acc_qk *= attn_scale;
337391
338392 auto id = make_identity_tensor (make_shape (1 , 1 ));
@@ -408,7 +462,11 @@ void fna_bwd_reference_dQ(
408462 QKVLayout qkv_layout,
409463 float attn_scale,
410464 int num_additional_kv,
411- cudaStream_t stream) {
465+ cudaStream_t stream,
466+ bool has_dot_product_min,
467+ bool has_dot_product_max,
468+ float dot_product_min,
469+ float dot_product_max) {
412470 using namespace cute ;
413471
414472 // Only so that we don't oversubscribe shmem when seqlen is large.
@@ -447,7 +505,11 @@ void fna_bwd_reference_dQ(
447505 Causal{},
448506 qkv_layout,
449507 attn_scale,
450- num_additional_kv);
508+ num_additional_kv,
509+ has_dot_product_min,
510+ has_dot_product_max,
511+ dot_product_min,
512+ dot_product_max);
451513}
452514
453515// ///////////////////////////////////////////////////////////////////////////////////////////////
@@ -480,7 +542,11 @@ void fna_bwd_reference_dK(
480542 QKVLayout qkv_layout,
481543 float attn_scale,
482544 int num_additional_kv,
483- cudaStream_t stream) {
545+ cudaStream_t stream,
546+ bool has_dot_product_min,
547+ bool has_dot_product_max,
548+ float dot_product_min,
549+ float dot_product_max) {
484550 using namespace cute ;
485551
486552 // Only so that we don't oversubscribe shmem when seqlen is large.
@@ -519,7 +585,11 @@ void fna_bwd_reference_dK(
519585 Causal{},
520586 qkv_layout,
521587 attn_scale,
522- num_additional_kv);
588+ num_additional_kv,
589+ has_dot_product_min,
590+ has_dot_product_max,
591+ dot_product_min,
592+ dot_product_max);
523593}
524594
525595// ///////////////////////////////////////////////////////////////////////////////////////////////
@@ -552,7 +622,11 @@ void fna_bwd_reference_dV(
552622 QKVLayout qkv_layout,
553623 float attn_scale,
554624 int num_additional_kv,
555- cudaStream_t stream) {
625+ cudaStream_t stream,
626+ bool has_dot_product_min,
627+ bool has_dot_product_max,
628+ float dot_product_min,
629+ float dot_product_max) {
556630 using namespace cute ;
557631
558632 // Only so that we don't oversubscribe shmem when seqlen is large.
@@ -591,7 +665,11 @@ void fna_bwd_reference_dV(
591665 Causal{},
592666 qkv_layout,
593667 attn_scale,
594- num_additional_kv);
668+ num_additional_kv,
669+ has_dot_product_min,
670+ has_dot_product_max,
671+ dot_product_min,
672+ dot_product_max);
595673}
596674
597675// ///////////////////////////////////////////////////////////////////////////////////////////////
@@ -619,7 +697,11 @@ void fna_reference_backward(
619697 NADim dilation,
620698 Causal is_causal,
621699 float attn_scale,
622- cudaStream_t stream) {
700+ cudaStream_t stream,
701+ bool has_dot_product_min,
702+ bool has_dot_product_max,
703+ float dot_product_min,
704+ float dot_product_max) {
623705 using namespace cute ;
624706
625707 // No GQA/MQA for now
@@ -694,7 +776,11 @@ void fna_reference_backward(
694776 qkv_layout,
695777 attn_scale,
696778 num_additional_kv,
697- stream);
779+ stream,
780+ has_dot_product_min,
781+ has_dot_product_max,
782+ dot_product_min,
783+ dot_product_max);
698784 fna_bwd_reference_dK (
699785 problem_shape,
700786 mQ ,
@@ -711,7 +797,11 @@ void fna_reference_backward(
711797 qkv_layout,
712798 attn_scale,
713799 num_additional_kv,
714- stream);
800+ stream,
801+ has_dot_product_min,
802+ has_dot_product_max,
803+ dot_product_min,
804+ dot_product_max);
715805 fna_bwd_reference_dV (
716806 problem_shape,
717807 mQ ,
@@ -728,7 +818,11 @@ void fna_reference_backward(
728818 qkv_layout,
729819 attn_scale,
730820 num_additional_kv,
731- stream);
821+ stream,
822+ has_dot_product_min,
823+ has_dot_product_max,
824+ dot_product_min,
825+ dot_product_max);
732826}
733827
734828} // namespace reference
0 commit comments