Skip to content

Conversation

yannicks1
Copy link
Collaborator

@yannicks1 yannicks1 commented Jun 3, 2025

[CB] get number of blocks from compiler mock implementation

This is a first draft how the message passing from the Spyre compiler to vLLM could work on vLLM side.

The process consists of the following steps:

  • For warmup vLLM reserves a required minimum of 4 pages/blocks (num_blocks=4).
  • The num_blocks (4) dimension is marked as dynamic (torch._dynamo.mark_dynamic()) for warmup forward calls only.
  • The Spyre compiler calculates the maximum number of pages/blocks it can accomodate and write this to a .json file.
  • torch_sendnn reads the value from the .json file and can return it via a function
  • vLLM calls the above mentioned torch_sendnn function to get the number of available blocks/pages and sets num_blocks=N
  • For actual inference vLLM will then adjust the list of free blocks/pages and the KV cache size accordingly (num_blocks=N).

Copy link

github-actions bot commented Jun 3, 2025

👋 Hi! Thank you for contributing to vLLM support on Spyre.
Just a reminder: Make sure that your code passes all the linting checks, otherwise your PR won't be able to be merged. To do so, first install the linting requirements, then run format.sh and commit the changes. This can be done with uv directly:

uv sync --frozen --group lint --active --inexact

Or this can be done with pip:

uv pip compile --group lint > requirements-lint.txt
pip install -r requirements-lint.txt
bash format.sh

Now you are good to go 🚀

@sducouedic
Copy link
Collaborator

sducouedic commented Jun 4, 2025

I realised all the changes in this PR only applies to sendnn_decoder backend and we still need the old logic for the other backends.

Maybe everything works by renaming function get_num_blocks_from_compiler_mock to get_num_blocks and in that function either return max_batch_size * max_model_len // self.BLOCK_SIZE if cpu backend or call the torch_sendnn function if torch_sendnn backend

@sducouedic
Copy link
Collaborator

sducouedic commented Jun 4, 2025

I guess we will probably need to adapt max_model_len or BLOCK_SIZE to be consistent with the number of blocks? To be confirmed with the compiler team.

Rephrasing and update after meeting today: in the previous implementation, the number of blocks were set based on max_model_len, max_batch_size and BLOCK_SIZE and the input request were thoroughly checked against those values. Probably there are new checks to be done in order to avoid obscure errors (e.g. check max_model_len input).

yannicks1 and others added 4 commits June 4, 2025 12:43
@yannicks1
Copy link
Collaborator Author

Rephrasing and update after meeting today: in the previous implementation, the number of blocks were set based on max_model_len, max_batch_size and BLOCK_SIZE and the input request were thoroughly checked against those values. Probably there are new checks to be done in order to avoid obscure errors (e.g. check max_model_len input).

Not entirely sure what you mean here. This code does not change anything, except of setting n_blocks = 4 for the warmup. This should always be enough, since we only do 2 prompts (block size is fixed to 64, Spyre constraint). After warmup it is set to what it was before.

@yannicks1
Copy link
Collaborator Author

thanks for the great feedback @sducouedic , I addressed all:)

@sducouedic
Copy link
Collaborator

sducouedic commented Jun 4, 2025

Not entirely sure what you mean here. This code does not change anything, except of setting n_blocks = 4 for the warmup. This should always be enough, since we only do 2 prompts (block size is fixed to 64, Spyre constraint). After warmup it is set to what it was before.

The idea is that there is a dependency between the max_model_len, max_batch_size, and the number of blocks. Depending on the values of max_model_len and max_num_seqs set by the user, we might not have enough blocks to serve a batch of sequences. This issue couldn't arise before because the number of blocks were set based on these values, but it is not the case anymore. Somehow we need to enforce that max_model_len and max_num_seqs can be served. Tom suggested to do a check and raise an error if that is not the case.

@sducouedic
Copy link
Collaborator

But you're right as long as your temporary function is used that error won't happen, the comment is about when we will start using the function of torch_sendnn

@yannicks1
Copy link
Collaborator Author

I think I get you now: you mean to check something like
num_blocks_spyre >= max_batch_size * max_model_len // block_size, right?
We can already incorporate this for the future, yes!

@sducouedic
Copy link
Collaborator

I think I get you now: you mean to check something like num_blocks_spyre >= max_batch_size * max_model_len // block_size, right? We can already incorporate this for the future, yes!

Yes correct

@yannicks1
Copy link
Collaborator Author

@sducouedic I have incorporated your suggested check already here

self.model_runner.vllm_config.scheduler_config.max_model_len
block_size = self.model_runner.BLOCK_SIZE # type: ignore[union-attr]

min_req_num_blocks = max_batch_size * max_model_len // block_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upstream vllm has this check as:

min_req_num_blocks = max_model_len // block_size

I think it's more correct to only ensure you have enough blocks to run a single full-size request, because ideally you want to be able to set a high batch size to be able to run many smaller requests on a long-context model. For example say a model has a context length of 1m tokens, and with your current hardware you can only deploy it in a way where you have enough kv-cache available for 1m tokens. You wouldn't want to be forced to set the max batch size to only 1, because in practice very few requests will use the full context length. Ideally, you'd be able to set the max batch size much higher like 256, to still run many smaller requests in parallel.

That requires scheduling with preemption though to kick out request(s) from the batch when you run out of kv-cache blocks. I think we'd need to hook up our kv cache management with the scheduler and implement preemption in the model runner to make that happen. I haven't looked at how hard that would be to do yet

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Joe, this makes a lot of sense. changed it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: we need more than the above min_req_num_blocks for one of the test cases...


min_req_num_blocks = max_batch_size * max_model_len // block_size

if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == 'sendnn_decoder':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sendnn_decoder no longer exists since #186

self.model_runner.vllm_config.scheduler_config.max_model_len
block_size = self.model_runner.BLOCK_SIZE # type: ignore[union-attr]

min_req_num_blocks = max_model_len // block_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this always what the min_req_num_blocks will be? Is there a case where we will not handle full max_model_len requests?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Joe's comment here says that this is how it is handled upstream. More logic in the scheduler will follow to handle e.g. running out of blocks...

# TODO: replace num_blocks_spyre by calling a function in
# torch_sendnn which returns the value set by the Spyre compiler
num_blocks_spyre = max_batch_size * min_req_num_blocks
assert num_blocks_spyre >= min_req_num_blocks, (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If min_req_num_blocks stays same as above, we may hit a case where we can handle only a portion of the max_model_len. Should we fail in this case, or just let the user know we may not be able to support up to the full max_model_len for each request?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Joe's comment covers this as well. Preemption will hopefully soon be implemented (or reused from upstream)...

return num_blocks_spyre
else: # dynamo backend 'eager'
# TODO: how do we get a meaningful value for CPU here
num_blocks_cpu = max_batch_size * min_req_num_blocks
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least with cuda, I believe you would profile a run and see what the peak memory usage was. Then use some percentage of whatever else was left over to determine the number of blocks that could fit.

https://github.com/vllm-project/vllm/blob/ace5cdaff0cf021ff02ddbe39ea814f2ed2e56b7/vllm/worker/worker.py#L232

There is also a method of this for cpu that calculates based on this:

num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes //
                             cache_block_size)

https://github.com/vllm-project/vllm/blob/ace5cdaff0cf021ff02ddbe39ea814f2ed2e56b7/vllm/worker/cpu_worker.py#L239

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JRosenkranz as far as I can tell self.cache_config.cpu_kvcache_space_bytes is set by the user here:
https://github.com/vllm-project/vllm/blob/7e8d97dd3f0aaf05265f947997310ca3827d3c06/vllm/platforms/cpu.py#L128

I also think it is not critical to have a super meaningful value here obtained with profiling as this is merely the CPU path to test/validate the Spyre plugin code, and not an actual CPU worker.

Thanks for the review!

@yannicks1
Copy link
Collaborator Author

Do you guys think we can merge this PR now and insert the torch_sendnn function once it is available?

Note: The PR in the current form does not change any behavior, but as it is quite some refactoring, it couldn't hurt to merge instead of waiting for the one line change inserting the function...

@JRosenkranz @tdoublep @joerunde @sducouedic @nikolaospapandreou

@yannicks1 yannicks1 marked this pull request as ready for review June 13, 2025 11:12
@sducouedic
Copy link
Collaborator

I agree this should be merged

@JRosenkranz
Copy link
Collaborator

Do you guys think we can merge this PR now and insert the torch_sendnn function once it is available?

Note: The PR in the current form does not change any behavior, but as it is quite some refactoring, it couldn't hurt to merge instead of waiting for the one line change inserting the function...

@JRosenkranz @tdoublep @joerunde @sducouedic @nikolaospapandreou

Yes, this looks good to me to be merged

@yannicks1 yannicks1 merged commit a2c68c3 into main Jun 13, 2025
20 checks passed
@yannicks1 yannicks1 deleted the ysc-mock-read-n-pages branch June 13, 2025 14:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants