Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][Model] Add V1 support for Qwen2-VL #11668

Closed
wants to merge 4 commits into from

Conversation

imkero
Copy link
Contributor

@imkero imkero commented Jan 1, 2025

What's changed:

  1. Allow using a function to determine dynamic dimensions of a tensor while torch.compile (M-RoPE uses a 2d position tensor which differs from common RoPE, and they share same impl in Qwen2 LM's forward fn)
  2. Modify dummy data retrival in profile_run for Qwen2-VL launch
  3. Add M-RoPE support to V1 gpu_model_runner
  4. Add support of encoder output in tuple (embeddings: torch.Tensor, modality: str) in gpu_model_runner for Qwen2-VL
  5. Use token_id instead of token str of image_token and video_token in Qwen2-VL's preprocessing for better performance

This PR should make Qwen2-VL works in V1 with chunked prefill and prefix caching enabled.

Copy link

github-actions bot commented Jan 1, 2025

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@@ -791,6 +791,7 @@ def _parse_video_data(


class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
_placeholder_map: Optional[dict[str, list[int]]] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should initialize this in the init method to avoid confusing it with a static class variable.

Copy link
Member

@DarkLight1337 DarkLight1337 Jan 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apart from this, the processor-related changes in the model file LGTM.

@ywang96
Copy link
Member

ywang96 commented Jan 1, 2025

Hello @imkero! Much appreciated that you made this PR!

The reason why I haven't spent too much on Qwen2-VL is that I want to see if there's a way to move MRope inside model file for Qwen2-VL since it is so specific to this model.

You would also need to change the implementation of _process_image_input and _process_video_input for this model to make it work properly on V1 (the returned embeddings need to be a NestedTensor, with the first dimension matching the total number of multimodal data items involved in the batch for fine-grained scheduling).

Feel free to take changes from here into this PR.

Comment on lines +829 to +838
if not self._placeholder_map:
# NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has
# image_token and video_token registered
encode_fn = hf_processor.tokenizer.encode
self._placeholder_map = {
"image": encode_fn(hf_processor.image_token),
"video": encode_fn(hf_processor.video_token),
}
placeholder = self._placeholder_map

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we can set this at initialization time.

Comment on lines +579 to +582
encoder_outputs.append((
encoder_output[0]
[start_idx:end_idx], # embedding tensor
encoder_output[1], # modality
Copy link
Member

@ywang96 ywang96 Jan 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thought is we don't necessarily need to have the modality key here.

We can leverage the fact that any two mm_position's from any modalities cannot possibily have overlaps, and now that

def merge_multimodal_embeddings(
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings: NestedTensors,
placeholder_token_id: Union[int, List[int]],
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``.
``placeholder_token_id`` can be a list of token ids (e.g, token ids
of img_start, img_break, and img_end tokens) when needed: This means
the order of these tokens in the ``input_ids`` MUST MATCH the order of
their embeddings in ``multimodal_embeddings`` since we need to
slice-merge instead of individually scattering.

can apply the embedding replacement based on a list of token ids (so we can simply have [self.config.image_token_id, self.config.video_token_id] here)

Therefore, all we need to do should be just sorting mm_position's and their correpsonding mm_inputs in the following code(which also needs to be modified to support video modality for Qwen2VL in this PR)

vllm/vllm/v1/request.py

Lines 51 to 59 in 11d8a09

# Multi-modal input metadata.
mm_positions = self.inputs.multi_modal_placeholders
if mm_positions:
# FIXME(woosuk): Support other modalities.
self.mm_positions = mm_positions.get("image", [])
else:
self.mm_positions = []
# Output of the mm input mapper (e.g., image tensors).
self.mm_inputs: List[MultiModalKwargs] = []

WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a second thought - let me actually work on this design for llava-onevision too

@ywang96
Copy link
Member

ywang96 commented Jan 10, 2025

Hello @imkero! Please feel free to take a look at the updated code in https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava_onevision.py for dealing with multiple modalities.

In particular, I think you can pretty much adopt the same code below to Qwen-2VL without changing the interface for model runner and encoder cache. Let me know if you need any help and I'm happy to work on this PR as well if you don't have the bandwidth!

def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {}
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if input_key == "pixel_values" and "images" not in modalities:
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if input_key == "pixel_values_videos" and "videos" not in modalities: # noqa E501
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)
return modalities

def get_multimodal_embeddings(
self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return None
# The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video).
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
multimodal_embeddings += tuple(vision_embeddings)
if modality == "videos":
video_input = modalities["videos"]
video_embeddings = self._process_video_pixels(video_input)
multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings

inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_index, self.config.video_token_index])

@imkero
Copy link
Contributor Author

imkero commented Jan 10, 2025

@ywang96 Sorry for the late response. I'll continue working on this PR soon.

Copy link

mergify bot commented Jan 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @imkero.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 11, 2025
@baifanxxx
Copy link

baifanxxx commented Jan 13, 2025

Hi,
Thank you for your contribution. Qwen2-VL is a promising MLLM and needs v1 support. When I run your PR, I found some errors.

ERROR 01-13 09:10:24 core.py:200]   File "/media/SSD6/personal_files/bf/vllm/vllm/v1/worker/gpu_worker.py", line 134, in determine_num_available_blocks
ERROR 01-13 09:10:24 core.py:200]     self.model_runner.profile_run()
ERROR 01-13 09:10:24 core.py:200]   File "/media/SSD6/personal_files/bf/vllm/vllm/v1/worker/gpu_model_runner.py", line 830, in profile_run
ERROR 01-13 09:10:24 core.py:200]     assert len(dummy_encoder_outputs) == max_num_mm_items, (
ERROR 01-13 09:10:24 core.py:200] AssertionError: Expected dimension 0 of encoder outputs to match the number of multimodal data items: 13, got len(dummy_encoder_outputs)=1 instead. This is most likely due to the 'get_multimodal_embeddings' method of the model not implemented correctly

It seems the dummy data in profile running is not correct. Then, I print some values.

self.max_num_encoder_input_tokens 16384
self.encoder_cache_size 16384
max_tokens_per_mm_item 1225
max_num_mm_items 13
batched_dummy_mm_inputs['pixel_values']  torch.Size([13, 4900, 1176])
dummy_encoder_outputs[0][0] torch.Size([15925, 3584])

Could you help me? I would appreciate it and hope that Qwen2-VL will be supported by v1 in time.

Thank you

Best regards

@ywang96
Copy link
Member

ywang96 commented Jan 13, 2025

Hi, Thank you for your contribution. Qwen2-VL is a promising MLLM and needs v1 support. When I run your PR, I found some errors.

ERROR 01-13 09:10:24 core.py:200]   File "/media/SSD6/personal_files/bf/vllm/vllm/v1/worker/gpu_worker.py", line 134, in determine_num_available_blocks
ERROR 01-13 09:10:24 core.py:200]     self.model_runner.profile_run()
ERROR 01-13 09:10:24 core.py:200]   File "/media/SSD6/personal_files/bf/vllm/vllm/v1/worker/gpu_model_runner.py", line 830, in profile_run
ERROR 01-13 09:10:24 core.py:200]     assert len(dummy_encoder_outputs) == max_num_mm_items, (
ERROR 01-13 09:10:24 core.py:200] AssertionError: Expected dimension 0 of encoder outputs to match the number of multimodal data items: 13, got len(dummy_encoder_outputs)=1 instead. This is most likely due to the 'get_multimodal_embeddings' method of the model not implemented correctly

It seems the dummy data in profile running is not correct. Then, I print some values.

self.max_num_encoder_input_tokens 16384
self.encoder_cache_size 16384
max_tokens_per_mm_item 1225
max_num_mm_items 13
batched_dummy_mm_inputs['pixel_values']  torch.Size([13, 4900, 1176])
dummy_encoder_outputs[0][0] torch.Size([15925, 3584])

Could you help me? I would appreciate it and hope that Qwen2-VL will be supported by v1 in time.

Thank you

Best regards

@baifanxxx I'll start taking a look at this PR tomorrow. @imkero has already done a great job of adding MRoPE in v1 with torch compile support, so it shouldn't take us too long to get this PR into a functional stage!

@mergify mergify bot removed the needs-rebase label Jan 13, 2025
dynamic_arg_dims={
"input_ids": 0,
# dim 1 for mrope in shape (3, seq_len), else dim 0 in shape (seq_len, )
"positions": lambda tensor: tensor.ndim - 1,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does -1 work here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The value here will be passthrough
to pytorch's impl torch._dynamo.mark_dynamic(tensor, dim), and it seems to assume that dim is a non-negative integer.

https://github.com/pytorch/pytorch/blob/95b41d2aa43c606d65e127d4825c08baf9fcacd9/torch/_dynamo/decorators.py#L464

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can do the conversion here:

for k, dims in dynamic_arg_dims.items():

iterate over the dims , and conver -1 to tensor.ndim - 1

@Zhiy-Zhang
Copy link

Zhiy-Zhang commented Jan 14, 2025

hi, when I running qwen2-vl(2b) use this pr, I get error:

ERROR 01-14 21:21:53 core.py:197]   File "/home/user/vllm/v1/worker/gpu_worker.py", line 134, in determine_num_available_blocks
ERROR 01-14 21:21:53 core.py:197]     self.model_runner.profile_run()
ERROR 01-14 21:21:53 core.py:197]   File "/home/user/vllm/v1/worker/gpu_model_runner.py", line 880, in profile_run
ERROR 01-14 21:21:53 core.py:197]     assert max_num_mm_items_encoder_budget > 0, (
ERROR 01-14 21:21:53 core.py:197] AssertionError: Encoder cache budget=16384 is too small to support the maximum possible size of multimodal embeddings=31850.
ERROR 01-14 21:21:53 core.py:197] 
CRITICAL 01-14 21:21:53 core_client.py:146] Got fatal signal from worker processes, shutting down. See stack trace above for root cause issue.

Is there a problem with the profile_run?

@imkero imkero marked this pull request as draft January 14, 2025 14:37
@imkero
Copy link
Contributor Author

imkero commented Jan 14, 2025

@baifanxxx @Zhiy-Zhang
I think the current dummy data / profile_run approach works not so well with Qwen2-VL. This should be fixed before we bring V1 support to Qwen2-VL finally.

Actually I commented this assertion (assert len(dummy_encoder_outputs) == max_num_mm_items) out to allow the program move forward while developing in this PR currently. @baifanxxx

Also I modified the value of encoder_budget, and Qwen2-VL's image processor's max_pixels while developing. @Zhiy-Zhang

@Zhiy-Zhang
Copy link

@baifanxxx @Zhiy-Zhang I think the current dummy data / profile_run approach works not so well with Qwen2-VL. This should be fixed before we bring V1 support to Qwen2-VL finally.

Actually I commented this assertion (assert len(dummy_encoder_outputs) == max_num_mm_items) out to allow the program move forward while developing in this PR currently. @baifanxxx

Also I modified the value of encoder_budget, and Qwen2-VL's image processor's max_pixels while developing. @Zhiy-Zhang

Thank you very much for your reply. Has this change(“modified the value of encoder_budget, and Qwen2-VL's image processor's max_pixels while developing”) already been merged into the modified branch?

@imkero
Copy link
Contributor Author

imkero commented Jan 18, 2025

Continued in #12128. Thanks @ywang96

@imkero imkero closed this Jan 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants