-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
qLoRA support #398
Comments
The related issue in qLoRA |
My full stack trace:
|
We're not planning to write custom kernel for 4-bit. Can you just cast the input (q, k, v) to fp16/bf16, call FlashAttention, then convert the output to whichever dtype? |
Understood thank you |
hi, do you check the result of flash_attention 2 and qlora 4bit? I use flash_attn_func function to calculate the attention of GQA ,and i cast the input (q, k, v) to bf16. Durning the training process, the loss is healthy but the qlora adaper model file sames error. When I use the qlora adaper model file, the generatation result is error. query_states, key_states, value_states = query_states.to(torch.bfloat16), key_states.to(torch.bfloat16), value_states.to(torch.bfloat16) |
I tried adding Flash Attention into qLoRA, I receive the following error:
RuntimeError: FlashAttention only support fp16 and bf16 data type
Is it possible to add support for 4-bit qLoRA?
The text was updated successfully, but these errors were encountered: