Skip to content

Commit

Permalink
add descaling factors to _flash_attn_forward
Browse files Browse the repository at this point in the history
  • Loading branch information
jayhshah committed Aug 23, 2024
1 parent c0ec763 commit e08b6d4
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion hopper/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@
def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x

def _flash_attn_forward(q, k, v, softmax_scale, causal):
def _flash_attn_forward(q, k, v, softmax_scale, causal, descale_q = 1.0, descale_k = 1.0, descale_v = 1.0):
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd(
q,
k,
v,
None,
softmax_scale,
descale_q,
descale_k,
descale_v,
causal,
)
return out, q, k, v, out_padded, softmax_lse, S_dmask
Expand Down

0 comments on commit e08b6d4

Please sign in to comment.