-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement kv cache #74
Conversation
61f3b16
to
75fb8cd
Compare
Other than the one comment about adding comments to _update_cos_sin_tables, this looks good to me! Would it be possible to document the speedup we get with KV-cache, compared to say HF LLaMa2? Doesn't need to be super thorough, just a quick gut check that we have roughly the right speedup. |
oh and finally, could you paste the outputs of pytest so we can merge? (I believe dataloading tests currently fail due to a missing assets/ folder that we're working on adding, but the rest should pass) |
There were bugs but most didn't come from my changes:
I added a test for kv cache but it is slow so I also added a slow marker in pyproject.toml, slow tests can be marked with |
With the attached script for a 1b model I get the following results:
(the test was done on 1 GPU Nvidia A6000) |
I added a test that checks the generated sequences. It is not very stable: |
fb5bdfa
to
4c4f73c
Compare
ce7c6bc
to
c5a72bc
Compare
Ok so now check are passing but they do not include the two tests that I added (they are marked both as slow and gpu). One of them involves downloading the 1B model and using it to generate short sequences. |
dc3d9bd
to
4cec766
Compare
I had to force a minimum version of multiprocess which depends on a newer version of dill than what apache-beam wants. I could not find a common ground so I remove apache-beam from the requirements and had to add ipython. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me, except as discussed just revert the formatting changes from tests/. We can add those formatting changes in a separate PR.
99994e4
to
ec80c71
Compare
Looks good to me except for the two things we discussed offline. I was also wondering if we could look at the coverage of tests but that could be done in a separate PR. |
nice! For future reference and documentation, could you comment what the two things you discussed offline are, @ruixin31? |
* Allow parsing parameters from a config file * Address nits
…' has no attribute '_Condition'.
…e xformers version that we need
So what Rui suggested is to use xformers new masking xops.fmha.attn_bias.LowerTriangularFromBottomRightMask() I added comments about that in the code but reverted to using custom masks to make installation faster (we don't use beam search now so it doesn't justify the extra step in installation) |
Explanation:
KV cache is an inference trick to avoid re-computing the whole sequence when only one token was added. It stores keys and queries of each attention layer in a list for later use.
If
use_cache
is set toTrue
HuggingFace generator automatically handles this and will only input the last token along with the KV cache when generating a sequence.The past keys and values are extracted from the cache and concatenated to the next token in each attention layer. Only one query attends this sequence to generate the next token.
Positional embeddings need to be offset because it can no longer rely on the position of the query in the sequence (there is only 1 query).
Tests:
LLaMA2 and a custom trained model have been tested qualitatively with the
generation.py
script with this change (with and without the flag --use-cache). Jeopardy with LLaMA2 was also ran to ensure no regression.Changes:
use_cache
(returns None if not used) thus it now returns 3 output values instead of 2past_key_values
anduse_cache
are taken by the model and passed all the way to the attention computationtwoan extra offset argument to offset the queries and keys in the time sequence. This is useful to offset the new query and key to the end of the kv cache sequence