-
Notifications
You must be signed in to change notification settings - Fork 838
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FlashAttention support? #221
Comments
Hey! Flash attention is orthogonal to QLoRA, meaning that you can combine the two. In fact, we had implemented it for LLaMA at some point but didn't end up keeping it. More generally, if your base model uses flash attention you can use it with QLoRA. I would look for models that implement flash attention or implement it for your favorite base model and then finetune with QLoRA. |
Thanks for replying! Great to learn that there are no inherent issues preventing to combine FlashAttention with QLoRA. With the latest FlashAttention2 promising even further performance improvements, and given that the memory pressure of standard attention rapidly increases with context size, having built-in support even just for Llama (and now Llama2, which has a native 4k context size) would great. While using QLoRA for finetuning is pretty straightforward, adapting it to take advantage of FlashAttention is not so obvious, considering its low-level nature—there is a so-called monkey patch available for LLaMA but I've not personally been successful in applying it to the QLoRA code. So, I think that if there is already existing code demonstrating how to use it in practice with QLoRA, it would certainly benefit many if it could be uploaded on the repository. |
That's a good point! I agree we should look into this. If someone wants to contribute an example in the meantime we would appreciate the help. |
Perhaps some of the code from Axolotl could be used. It's a trainer which employs QLoRA and different attention mechanisms, including FlashAttention. I haven't been able to make FlashAttention work with it yet, but with xformers-attention (another supported method) I could train Llama-13B with 4096 tokens-long sequences within less than 16GB of VRAM (at a batch size of 1), which is almost unbelievable. Training speed did not appear to increase, on the other hand. It's likely that FlashAttention would yield similar or better benefits and make 30B-class LLMs trainable on a 24 GB GPU with long sequences. |
We would also be very happy to see flash attention 2 support to be added to this tool |
|
I gave this a shot, to implement flash attention in the same way that fastchat and axolotl do. It seems not to work. I was wondering if anyone more familiar with cuda could understand what is going wrong? the error message is :
|
Same boat here, I try testing by trying both versions of flash attention individually using monkey patching code of FastChat https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py but got stuck on similar error which @ehartford reported |
jon durbin has a fork here that implements flash attention. I'll see if I can adapt the code into a PR |
Yup I tried that already it’s throws same error for Llama2 70b. It’s same monkey patch code from FastChat which I tried to integrate. |
here is full stack trace:
|
I got around that problem by moving the patch code to an earlier point before loading the model. But I hit another error:
Then it seems, Flash Attention does not support not support 4-bit, only 16-bit. |
You can also add this part if torch.cuda.get_device_capability()[0] >= 8:
from utils.llama_patch import replace_attn_with_flash_attn
replace_attn_with_flash_attn() |
I had both these issues. Solved with working training when running
before loading the model and
before the training loop. check link below to get the llama_patch file used. Credits to philschmid. |
This might be more of a general question, but is it possible to use FlashAttention with QLoRA in order to further decrease memory requirements when finetuning?
I would guess that in principle it could be done, but has anybody actually attempted implementing it?
The text was updated successfully, but these errors were encountered: