Skip to content

[Core] Cast multimodal input in hf processor #18862

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions vllm/inputs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union

import torch
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from typing_extensions import TypeVar

from vllm.jsontree import JSONTree, json_map_leaves
from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import resolve_mm_processor_kwargs
Expand All @@ -21,6 +24,8 @@
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)

logger = init_logger(__name__)


@dataclass(frozen=True)
class InputContext:
Expand Down Expand Up @@ -134,7 +139,7 @@ def call_hf_processor(
hf_processor: ProcessorMixin,
data: Mapping[str, object],
kwargs: Mapping[str, object] = {},
) -> BatchFeature:
) -> Union[BatchFeature, JSONTree]:
"""
Call `hf_processor` on the prompt `data`
(text, image, audio...) with configurable options `kwargs`.
Expand All @@ -154,8 +159,25 @@ def call_hf_processor(
allow_var_kwargs=True,
)

def maybe_cast_dtype(x):
# This mimics the behavior of transformers.BatchFeature
if isinstance(x, torch.Tensor) and x.is_floating_point():
return x.to(dtype=self.model_config.dtype)
return x

try:
return hf_processor(**data, **merged_kwargs, return_tensors="pt")
output = hf_processor(**data, **merged_kwargs, return_tensors="pt")
# this emulates output.to(dtype=self.model_config.dtype)
cast_output = json_map_leaves(maybe_cast_dtype, output)
if isinstance(output, BatchFeature):
return BatchFeature(cast_output)

logger.warning_once(
f"{type(hf_processor).__name__} did not return `BatchFeature`. "
"Make sure to match the behaviour of `ProcessorMixin` when "
"implementing custom processors.")
return cast_output

except Exception as exc:
msg = (f"Failed to apply {type(hf_processor).__name__} "
f"on data={data} with kwargs={merged_kwargs}")
Expand Down
8 changes: 1 addition & 7 deletions vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,17 +747,11 @@ def as_kwargs(
batched_inputs: BatchedTensorInputs,
*,
device: torch.types.Device,
dtype: Optional[torch.dtype] = None,
) -> BatchedTensorInputs:
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)

def maybe_cast_dtype(x: torch.Tensor):
# This mimics the behavior of transformers.BatchFeature
return x.to(dtype=dtype) if x.is_floating_point() else x

json_mapped = json_map_leaves(
# NOTE: Cast the dtype before sending it to device
lambda x: maybe_cast_dtype(x).to(device=device, non_blocking=True),
lambda x: x.to(device=device, non_blocking=True),
json_inputs,
)

Expand Down
1 change: 0 additions & 1 deletion vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,6 @@ def execute_model(
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
dtype=self.model_runner.model_config.dtype,
device=self.device,
),
**model_execute_kwargs,
Expand Down
2 changes: 0 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,6 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(
batched_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)

Expand Down Expand Up @@ -1943,7 +1942,6 @@ def profile_run(self) -> None:
[dummy_mm_kwargs] * max_num_mm_items)
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)

Expand Down
2 changes: 0 additions & 2 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,6 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(
batched_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)

Expand Down Expand Up @@ -1560,7 +1559,6 @@ def _get_mm_dummy_batch(self, modality: str,
batch_size)
return MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)

Expand Down
1 change: 0 additions & 1 deletion vllm/worker/cpu_enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,6 @@ def execute_model(
model_input.encoder_input_positions,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
"intermediate_tensors":
Expand Down
1 change: 0 additions & 1 deletion vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,6 @@ def execute_model(
if model_input.multi_modal_kwargs is not None:
multimodal_kwargs = MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs,
dtype=self.model_config.dtype,
device=self.device,
)
execute_model_kwargs = {}
Expand Down
1 change: 0 additions & 1 deletion vllm/worker/cpu_pooling_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def execute_model(
model_input.input_positions,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
**cross_enc_kwargs,
Expand Down
1 change: 0 additions & 1 deletion vllm/worker/enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ def execute_model(
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
dtype=self.model_config.dtype,
device=self.device,
),
**seqlen_agnostic_kwargs,
Expand Down
1 change: 0 additions & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1848,7 +1848,6 @@ def execute_model(
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
dtype=self.model_config.dtype,
device=self.device,
),
**seqlen_agnostic_kwargs,
Expand Down
1 change: 0 additions & 1 deletion vllm/worker/multi_step_neuron_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def execute_model(
input_block_ids=model_input.input_block_ids,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
)
Expand Down
1 change: 0 additions & 1 deletion vllm/worker/multi_step_neuronx_distributed_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def execute_model(
sampling_params=sampling_params,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
)
Expand Down
2 changes: 0 additions & 2 deletions vllm/worker/neuron_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,6 @@ def execute_model(
adapter_ids=model_input.adapter_ids,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
)
Expand All @@ -408,7 +407,6 @@ def execute_model(
input_block_ids=model_input.input_block_ids,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
)
Expand Down
1 change: 0 additions & 1 deletion vllm/worker/pooling_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def execute_model(
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
dtype=self.model_config.dtype,
device=self.device,
),
**cross_enc_kwargs,
Expand Down
1 change: 0 additions & 1 deletion vllm/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,6 @@ def execute_model(
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
)
Expand Down