-
Notifications
You must be signed in to change notification settings - Fork 26
[Continuous batching] Initial cb test #52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 4 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
d3f42ed
initial cb test
nikolaospapandreou d6b7afa
make tkv, active_pages optional in SpyreCausalLM class for the V0 tests
nikolaospapandreou 9d4c961
sync with dev branch, new classes for static and continuous batching
nikolaospapandreou ac1369a
format
nikolaospapandreou 8cd2318
remove manual testing and fix formatting
yannicks1 6cf29aa
remove tkv2fms
yannicks1 a6942a3
remove unnecessary class variables
yannicks1 dbc24e3
tidy up class variables
yannicks1 fb43f8c
simplify code: req_ids2idx and active_pages will be reset in prepare …
yannicks1 04af67b
renaming variable
yannicks1 1135210
removing batch padding in prefil stage
yannicks1 a184b0b
indices always list of Trues since no padding or removed sequences...
yannicks1 98bf15a
fix active/free page handling
yannicks1 e1dd52b
avoiding unnecessary tensor construction
yannicks1 c54fcee
fix sorting indifference token/position_ids vs masks
yannicks1 47ed1e7
refactoring not requiring req_ids2idx
yannicks1 cbb5980
removing unsused class variables, simplifying code
yannicks1 717f05f
use VLLM_SPYRE_MAX_BATCH_SIZE to control (decoding) batch size on AIU…
yannicks1 80850ca
removing unnecessary helper functions for schedule and add_request
yannicks1 fd670af
removing unused argument
yannicks1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import os | ||
import time | ||
|
||
from vllm import LLM, SamplingParams | ||
|
||
max_tokens1 = 10 | ||
max_tokens2 = 5 | ||
max_tokens3 = 7 | ||
max_tokens = max([max_tokens1, max_tokens2, max_tokens3]) | ||
max_num_seqs = 2 | ||
|
||
os.environ["VLLM_SPYRE_WARMUP_PROMPT_LENS"] = '64' | ||
os.environ["VLLM_SPYRE_WARMUP_NEW_TOKENS"] = str(max_tokens) | ||
os.environ['VLLM_SPYRE_WARMUP_BATCH_SIZES'] = '4' | ||
|
||
# defining here to be able to run/debug directly from VSC (not via terminal) | ||
os.environ['VLLM_SPYRE_DYNAMO_BACKEND'] = 'eager' | ||
os.environ['VLLM_SPYRE_USE_CB'] = '1' | ||
os.environ['VLLM_USE_V1'] = '1' | ||
|
||
# Sample prompts. | ||
template = ( | ||
"Below is an instruction that describes a task. Write a response that " | ||
"appropriately completes the request. Be polite in your response to the " | ||
"user.\n\n### Instruction:\n{}\n\n### Response:") | ||
|
||
prompt1 = template.format( | ||
"Provide a list of instructions for preparing chicken soup for a family " | ||
"of four.") | ||
|
||
prompt2 = template.format("Provide instructions for preparing chicken soup.") | ||
|
||
prompt3 = template.format( | ||
"Provide a list of instructions for preparing chicken soup for a family.") | ||
|
||
prompts = [ | ||
prompt1, | ||
prompt2, | ||
prompt3, | ||
] | ||
|
||
# Create a sampling params object. | ||
sampling_params1 = SamplingParams(max_tokens=max_tokens1, | ||
temperature=0.0, | ||
ignore_eos=True) | ||
|
||
sampling_params2 = SamplingParams(max_tokens=max_tokens2, | ||
temperature=0.0, | ||
ignore_eos=True) | ||
|
||
sampling_params3 = SamplingParams(max_tokens=max_tokens3, | ||
temperature=0.0, | ||
ignore_eos=True) | ||
|
||
sampling_params = [ | ||
sampling_params1, | ||
sampling_params2, | ||
sampling_params3, | ||
] | ||
|
||
# Create an LLM. | ||
llm = LLM(model="/models/llama-194m", | ||
tokenizer="/models/llama-194m", | ||
max_model_len=2048, | ||
block_size=2048, | ||
max_num_seqs=max_num_seqs) | ||
|
||
# Generate texts from the prompts. The output is a list of RequestOutput objects | ||
# that contain the prompt, generated text, and other information. | ||
print("=============== GENERATE") | ||
t0 = time.time() | ||
outputs = llm.generate(prompts, sampling_params) | ||
print("Time elaspsed for %d tokens is %.2f sec" % | ||
(len(outputs[0].outputs[0].token_ids), time.time() - t0)) | ||
print("===============") | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
print("===============") | ||
for output in outputs: | ||
print(output.outputs[0]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,21 +68,20 @@ def __init__( | |
max_decode_length, | ||
) | ||
|
||
# horizontal offset in physical KV cache memory block | ||
self.tkv: int = 0 | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
masks: torch.Tensor, | ||
is_prompt: bool, | ||
tkv: Optional[int] = None, | ||
active_pages: Optional[list[int]] = None, | ||
) -> torch.Tensor: | ||
|
||
if is_prompt: | ||
self.tkv = 0 | ||
if not envs_spyre.VLLM_SPYRE_USE_CB: | ||
self.model.past_key_value_states = None | ||
self.tkv = tkv | ||
|
||
|
||
if is_prompt and not envs_spyre.VLLM_SPYRE_USE_CB: | ||
self.model.past_key_value_states = None | ||
|
||
extra_kwargs = {} | ||
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND != "sendnn_decoder": | ||
|
@@ -128,7 +127,7 @@ def forward( | |
self.model.sample_mask = matrix.unsqueeze(0) | ||
|
||
# prefil of batch size 1 | ||
logits, self.tkv = self.model( | ||
logits = self.model( | ||
self.model.sample_token_id, | ||
position_ids=self.model.sample_position, | ||
mask=self.model.sample_mask, | ||
|
@@ -153,16 +152,18 @@ def forward( | |
masks[0, :, :] = self.model.sample_mask | ||
|
||
# normal prefil or decoding step | ||
logits, self.tkv = self.model( | ||
logits = self.model( | ||
input_ids, | ||
position_ids=positions, | ||
mask=masks, | ||
use_cache=True, | ||
only_last_token=True, | ||
tkv=self.tkv, | ||
active_pages=[i for i in range(input_ids.shape[0])], | ||
#active_pages=[i for i in range(input_ids.shape[0])], | ||
|
||
active_pages=active_pages, | ||
**extra_kwargs, | ||
) | ||
|
||
if TESTING_CB and self.tkv >= (6 + 64): | ||
# update sample_token_id, sample_position and sample_mask | ||
self.model.update_sample_inputs(logits=logits[0, :]) | ||
|
@@ -430,7 +431,7 @@ def forward( | |
page, :, :tkv, :] = key_value_states[layer][1][ | ||
idx, :, :, :] # [1, 8, L, 128] | ||
|
||
return logits, tkv + 1 | ||
return logits | ||
|
||
def update_sample_inputs( | ||
self, | ||
|
@@ -494,4 +495,4 @@ def forward( | |
for tensor in layer: | ||
torch._dynamo.mark_dynamic(tensor, 2) | ||
|
||
return logits, tkv + 1 | ||
return logits |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.