From e08b6d442ca883cf708be905aebfe8ede23ea372 Mon Sep 17 00:00:00 2001 From: Jay Shah Date: Fri, 23 Aug 2024 15:57:50 -0700 Subject: [PATCH] add descaling factors to _flash_attn_forward --- hopper/flash_attn_interface.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 13ddff4bb..0998dece3 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -14,7 +14,7 @@ def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x -def _flash_attn_forward(q, k, v, softmax_scale, causal): +def _flash_attn_forward(q, k, v, softmax_scale, causal, descale_q = 1.0, descale_k = 1.0, descale_v = 1.0): q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd( q, @@ -22,6 +22,9 @@ def _flash_attn_forward(q, k, v, softmax_scale, causal): v, None, softmax_scale, + descale_q, + descale_k, + descale_v, causal, ) return out, q, k, v, out_padded, softmax_lse, S_dmask