Skip to content

Commit a25576c

Browse files
committed
Support processors that do not return BatchFeature
Signed-off-by: Lukas Geiger <[email protected]>
1 parent f4abcf2 commit a25576c

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

vllm/inputs/registry.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
# SPDX-License-Identifier: Apache-2.0
22
from collections.abc import Mapping
33
from dataclasses import dataclass
4-
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
4+
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union, cast
55

66
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
77
from typing_extensions import TypeVar
88

9+
from vllm.jsontree import JSONTree, json_map_leaves
10+
from vllm.logger import init_logger
911
from vllm.transformers_utils.processor import cached_processor_from_config
1012
from vllm.transformers_utils.tokenizer import AnyTokenizer
1113
from vllm.utils import resolve_mm_processor_kwargs
1214

1315
if TYPE_CHECKING:
16+
import torch
17+
1418
from vllm.config import ModelConfig
1519
from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict,
1620
MultiModalRegistry)
@@ -20,6 +24,8 @@
2024
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
2125
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
2226

27+
logger = init_logger(__name__)
28+
2329

2430
@dataclass(frozen=True)
2531
class InputContext:
@@ -133,7 +139,7 @@ def call_hf_processor(
133139
hf_processor: ProcessorMixin,
134140
data: Mapping[str, object],
135141
kwargs: Mapping[str, object] = {},
136-
) -> BatchFeature:
142+
) -> Union[BatchFeature, JSONTree["torch.Tensor"]]:
137143
"""
138144
Call `hf_processor` on the prompt `data`
139145
(text, image, audio...) with configurable options `kwargs`.
@@ -155,7 +161,20 @@ def call_hf_processor(
155161

156162
try:
157163
output = hf_processor(**data, **merged_kwargs, return_tensors="pt")
158-
return output.to(dtype=self.model_config.dtype)
164+
if isinstance(output, BatchFeature):
165+
return output.to(dtype=self.model_config.dtype)
166+
167+
def maybe_cast_dtype(x: torch.Tensor):
168+
# This mimics the behavior of transformers.BatchFeature
169+
dtype = self.model_config.dtype
170+
return x.to(dtype=dtype) if x.is_floating_point() else x
171+
172+
logger.warning_once(
173+
f"{type(hf_processor).__name__} did not return `BatchFeature`. "
174+
"Make sure to match the behaviour of `ProcessorMixin` when "
175+
"implementing custom processors.")
176+
output = cast(JSONTree["torch.Tensor"], output)
177+
return json_map_leaves(maybe_cast_dtype, output)
159178
except Exception as exc:
160179
msg = (f"Failed to apply {type(hf_processor).__name__} "
161180
f"on data={data} with kwargs={merged_kwargs}")

0 commit comments

Comments
 (0)