1
1
# SPDX-License-Identifier: Apache-2.0
2
2
from collections .abc import Mapping
3
3
from dataclasses import dataclass
4
- from typing import TYPE_CHECKING , Any , NamedTuple , Optional , Union
4
+ from typing import TYPE_CHECKING , Any , NamedTuple , Optional , Union , cast
5
5
6
6
from transformers import BatchFeature , PretrainedConfig , ProcessorMixin
7
7
from typing_extensions import TypeVar
8
8
9
+ from vllm .jsontree import JSONTree , json_map_leaves
10
+ from vllm .logger import init_logger
9
11
from vllm .transformers_utils .processor import cached_processor_from_config
10
12
from vllm .transformers_utils .tokenizer import AnyTokenizer
11
13
from vllm .utils import resolve_mm_processor_kwargs
12
14
13
15
if TYPE_CHECKING :
16
+ import torch
17
+
14
18
from vllm .config import ModelConfig
15
19
from vllm .multimodal import (MultiModalDataDict , MultiModalPlaceholderDict ,
16
20
MultiModalRegistry )
20
24
_C = TypeVar ("_C" , bound = PretrainedConfig , default = PretrainedConfig )
21
25
_P = TypeVar ("_P" , bound = ProcessorMixin , default = ProcessorMixin )
22
26
27
+ logger = init_logger (__name__ )
28
+
23
29
24
30
@dataclass (frozen = True )
25
31
class InputContext :
@@ -133,7 +139,7 @@ def call_hf_processor(
133
139
hf_processor : ProcessorMixin ,
134
140
data : Mapping [str , object ],
135
141
kwargs : Mapping [str , object ] = {},
136
- ) -> BatchFeature :
142
+ ) -> Union [ BatchFeature , JSONTree [ "torch.Tensor" ]] :
137
143
"""
138
144
Call `hf_processor` on the prompt `data`
139
145
(text, image, audio...) with configurable options `kwargs`.
@@ -155,7 +161,20 @@ def call_hf_processor(
155
161
156
162
try :
157
163
output = hf_processor (** data , ** merged_kwargs , return_tensors = "pt" )
158
- return output .to (dtype = self .model_config .dtype )
164
+ if isinstance (output , BatchFeature ):
165
+ return output .to (dtype = self .model_config .dtype )
166
+
167
+ def maybe_cast_dtype (x : torch .Tensor ):
168
+ # This mimics the behavior of transformers.BatchFeature
169
+ dtype = self .model_config .dtype
170
+ return x .to (dtype = dtype ) if x .is_floating_point () else x
171
+
172
+ logger .warning_once (
173
+ f"{ type (hf_processor ).__name__ } did not return `BatchFeature`. "
174
+ "Make sure to match the behaviour of `ProcessorMixin` when "
175
+ "implementing custom processors." )
176
+ output = cast (JSONTree ["torch.Tensor" ], output )
177
+ return json_map_leaves (maybe_cast_dtype , output )
159
178
except Exception as exc :
160
179
msg = (f"Failed to apply { type (hf_processor ).__name__ } "
161
180
f"on data={ data } with kwargs={ merged_kwargs } " )
0 commit comments