diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 13ddff4bb..0998dece3 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -14,7 +14,7 @@ def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x -def _flash_attn_forward(q, k, v, softmax_scale, causal): +def _flash_attn_forward(q, k, v, softmax_scale, causal, descale_q = 1.0, descale_k = 1.0, descale_v = 1.0): q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd( q, @@ -22,6 +22,9 @@ def _flash_attn_forward(q, k, v, softmax_scale, causal): v, None, softmax_scale, + descale_q, + descale_k, + descale_v, causal, ) return out, q, k, v, out_padded, softmax_lse, S_dmask