Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Discard KV cache for last token before reusing prompt cache for prompt + suffix #79

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 31 additions & 22 deletions performance_optimization/prompt_reuse.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# This example showcases re-using a prompt for all your generation.

# For this to work correctly, please install transformers from source with the following command:
# pip install git+https://github.com/huggingface/transformers
import os, torch, copy
import torch, copy
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
device = "cuda"
device = "cuda" if torch.cuda.is_available() else "cpu"
ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct"

INITIAL_PROMPT = "From now on, you are going to answer all my questions with historical details. Make sure to always add a bit of french here and there, for style."
Expand All @@ -15,26 +13,37 @@
tokenizer = AutoTokenizer.from_pretrained(ckpt)

prompt_cache = DynamicCache()
inputs = tokenizer(INITIAL_PROMPT, return_tensors="pt").to("cuda")
inputs = tokenizer(INITIAL_PROMPT, return_tensors="pt").to(device)
with torch.no_grad():
prompt_cache = model(**inputs, past_key_values = prompt_cache).past_key_values
prompt_cache_fixed = copy.deepcopy(prompt_cache)
prompt_cache_fixed.key_cache = [x[:, :, :-1] for x in prompt_cache.key_cache]
prompt_cache_fixed.value_cache = [x[:, :, :-1] for x in prompt_cache.value_cache]


prompt = "Why are french people obsessed with french?"
new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
past_key_values = copy.deepcopy(prompt_cache)
outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20)
response = tokenizer.batch_decode(outputs)[0]
print(response)
"""

"""

prompt = "What is the best city to swim in?"
new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
outputs = model.generate(**new_inputs, past_key_values=copy.deepcopy(prompt_cache),max_new_tokens=20)
response = tokenizer.batch_decode(outputs)[0]
print(response)
"""

"""
new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to(device)

outputs_baseline = model.generate(**new_inputs, max_new_tokens=20, do_sample=False)
response_baseline = tokenizer.batch_decode(outputs_baseline)[0]

outputs_fixed = model.generate(**new_inputs, past_key_values=copy.deepcopy(prompt_cache_fixed), max_new_tokens=20, do_sample=False)
response_fixed = tokenizer.batch_decode(outputs_fixed)[0]

outputs_unfixed = model.generate(**new_inputs, past_key_values=copy.deepcopy(prompt_cache), max_new_tokens=20, do_sample=False)
response_unfixed = tokenizer.batch_decode(outputs_unfixed)[0]

print()
print("Baseline:")
print(response_baseline)
print()
print("Fixed:")
print(response_fixed)
print()
print("Unfixed:")
print(response_unfixed)
print()

# The fixed version should be the same as the baseline, while the unfixed version should be different.
print("Fixed matches baseline:", response_fixed == response_baseline)
print("Unfixed matches baseline:", response_unfixed == response_baseline)