Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement kv cache #74

Merged
merged 9 commits into from
Dec 13, 2023
Merged

Implement kv cache #74

merged 9 commits into from
Dec 13, 2023

Conversation

jmercat
Copy link
Collaborator

@jmercat jmercat commented Nov 8, 2023

Explanation:

KV cache is an inference trick to avoid re-computing the whole sequence when only one token was added. It stores keys and queries of each attention layer in a list for later use.
If use_cache is set to True HuggingFace generator automatically handles this and will only input the last token along with the KV cache when generating a sequence.
The past keys and values are extracted from the cache and concatenated to the next token in each attention layer. Only one query attends this sequence to generate the next token.
Positional embeddings need to be offset because it can no longer rely on the position of the query in the sequence (there is only 1 query).

Tests:

LLaMA2 and a custom trained model have been tested qualitatively with the generation.py script with this change (with and without the flag --use-cache). Jeopardy with LLaMA2 was also ran to ensure no regression.

Changes:

  • The model outputs past_key_values regardless of the use_cache (returns None if not used) thus it now returns 3 output values instead of 2
  • Optional inputs past_key_values and use_cache are taken by the model and passed all the way to the attention computation
  • Positional embedding takes two an extra offset argument to offset the queries and keys in the time sequence. This is useful to offset the new query and key to the end of the kv cache sequence
  • head rotary is not compatible with offsets but will run anyway since it doesn't need to offset anything

@jmercat jmercat force-pushed the kv_cache branch 5 times, most recently from 61f3b16 to 75fb8cd Compare November 10, 2023 00:53
open_lm/model.py Outdated Show resolved Hide resolved
open_lm/positional_embedding/rotary.py Outdated Show resolved Hide resolved
open_lm/train.py Show resolved Hide resolved
@achalddave
Copy link
Collaborator

Other than the one comment about adding comments to _update_cos_sin_tables, this looks good to me! Would it be possible to document the speedup we get with KV-cache, compared to say HF LLaMa2? Doesn't need to be super thorough, just a quick gut check that we have roughly the right speedup.

@achalddave achalddave self-assigned this Nov 10, 2023
@achalddave
Copy link
Collaborator

oh and finally, could you paste the outputs of pytest so we can merge? (I believe dataloading tests currently fail due to a missing assets/ folder that we're working on adding, but the rest should pass)

@jmercat
Copy link
Collaborator Author

jmercat commented Nov 13, 2023

oh and finally, could you paste the outputs of pytest so we can merge? (I believe dataloading tests currently fail due to a missing assets/ folder that we're working on adding, but the rest should pass)

image

There were bugs but most didn't come from my changes:

  • I added ignore_parse_errors to MockDataArgs.
  • In train.py l.144 sample chunk I updated to the new args input
  • in test accumulation I updated the SimpleModel to output 3 items (this was due to my changes)

I added a test for kv cache but it is slow so I also added a slow marker in pyproject.toml, slow tests can be marked with @pytest.mark.slow
Tests marked as slow can be skipped with pytest . -m "not slow" this might help with #59 (comment)

@jmercat
Copy link
Collaborator Author

jmercat commented Nov 13, 2023

image

@jmercat
Copy link
Collaborator Author

jmercat commented Nov 13, 2023

With the attached script for a 1b model I get the following results:


Context length | Generation length | Time with cache | Time without cache | Time gain (%)
--------------------------------------------------------------------------------------------
512            | 512               | 18.1522          | 30.4054             | 40.30          
512            | 1024              | 20.8708          | 80.5402             | 74.09          
512            | 1536              | 29.9303          | 152.6310            | 80.39          
1024           | 512               | 15.6641          | 50.1081             | 68.74          
1024           | 1024              | 25.5443          | 123.8644            | 79.38          
1536           | 512               | 19.4633          | 73.8159             | 73.63          
--------------------------------------------------------------------------------------------

time_generate.txt

(the test was done on 1 GPU Nvidia A6000)

@jmercat
Copy link
Collaborator Author

jmercat commented Nov 15, 2023

I added a test that checks the generated sequences. It is not very stable:
If we generate long sequences (sometimes as little as 64 tokens), they start to diverge between cached and not cached but not between two non-cached generations. I considered that this was not a bug but the result of compounding errors so I set the test to only compare the first 32 generated tokens. This ends-up passing.

@jmercat jmercat force-pushed the kv_cache branch 2 times, most recently from fb5bdfa to 4c4f73c Compare November 15, 2023 05:23
@jmercat jmercat changed the title Implement kv cache similar to HF Implement kv cache Nov 28, 2023
@jmercat
Copy link
Collaborator Author

jmercat commented Nov 28, 2023

I've added some support for kv_cache + an input sequence (prior it only worked with a unique new input).
However, it assumes the input sequence to come directly after the kv_cache in the sequence and does not support input position indices.

I've made 2 separate tests: one for speed on a randomly initialized model. One for concistency that loads the 1B pre-trained OpenLM weights and test that the generated results are the same with kv_cache and without.

I didn't produce any test for beam search...

Two tests were not passing but I don't think it's from my changes:
image

@jmercat jmercat force-pushed the kv_cache branch 3 times, most recently from ce7c6bc to c5a72bc Compare November 29, 2023 00:40
@jmercat
Copy link
Collaborator Author

jmercat commented Nov 29, 2023

Ok so now check are passing but they do not include the two tests that I added (they are marked both as slow and gpu). One of them involves downloading the 1B model and using it to generate short sequences.

@jmercat jmercat force-pushed the kv_cache branch 6 times, most recently from dc3d9bd to 4cec766 Compare December 9, 2023 01:00
@jmercat
Copy link
Collaborator Author

jmercat commented Dec 9, 2023

I had to force a minimum version of multiprocess which depends on a newer version of dill than what apache-beam wants. I could not find a common ground so I remove apache-beam from the requirements and had to add ipython.

@achalddave achalddave changed the base branch from main to package-imports December 9, 2023 01:16
@achalddave achalddave changed the base branch from package-imports to main December 9, 2023 01:16
Copy link
Collaborator

@achalddave achalddave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, except as discussed just revert the formatting changes from tests/. We can add those formatting changes in a separate PR.

@jmercat jmercat force-pushed the kv_cache branch 4 times, most recently from 99994e4 to ec80c71 Compare December 9, 2023 01:47
@ruixin31
Copy link
Contributor

Looks good to me except for the two things we discussed offline. I was also wondering if we could look at the coverage of tests but that could be done in a separate PR.

@achalddave
Copy link
Collaborator

nice! For future reference and documentation, could you comment what the two things you discussed offline are, @ruixin31?

@jmercat
Copy link
Collaborator Author

jmercat commented Dec 13, 2023

nice! For future reference and documentation, could you comment what the two things you discussed offline are, @ruixin31?

So what Rui suggested is to use xformers new masking xops.fmha.attn_bias.LowerTriangularFromBottomRightMask()
It works nicely and should be easier to read and faster when using beam search. Sadly it is not compatible with llm_foundry because their dependencies don't match (pip install -r requirements.txt would fail but it could be installed in 2 steps successfully...)

I added comments about that in the code but reverted to using custom masks to make installation faster (we don't use beam search now so it doesn't justify the extra step in installation)

@achalddave achalddave merged commit e016855 into mlfoundations:main Dec 13, 2023
2 checks passed
@jmercat jmercat deleted the kv_cache branch December 14, 2023 19:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants