2929from vllm_spyre .platform import SpyrePlatform
3030# yapf conflicts with ruff for this block
3131# yapf: disable
32- from vllm_spyre .v1 .worker .spyre_input_batch import (BaseInputBatch ,
33- BaseRequestState ,
34- PoolingInputBatch ,
35- PoolingRequestState ,
36- SamplingInputBatch ,
37- SamplingRequestState )
32+ from vllm_spyre .v1 .worker .spyre_input_batch import (
33+ BaseInputBatch , BaseRequestState , PoolingInputBatch , PoolingRequestState ,
34+ SamplingInputBatch , SamplingRequestState , get_builtin_logits_processors )
3835
3936# yapf: enable
4037if TYPE_CHECKING :
@@ -306,12 +303,17 @@ def load_model(self, prompt_lens: Iterable[int],
306303 )
307304
308305 def build_input_batch (self ) -> SamplingInputBatch :
306+ # Define logits processors.
307+ # TODO(Max): logits processor list should be extensible via engine
308+ # constructor argument; for now the list is fixed to builtin processors
309+ logits_processors = get_builtin_logits_processors (self .vllm_config )
309310 return SamplingInputBatch (
310311 max_num_reqs = self .scheduler_config .max_num_seqs ,
311312 max_model_len = self .model_config .max_model_len ,
312313 device = self .device ,
313314 pin_memory = self .pin_memory ,
314315 vocab_size = self .model_config .get_vocab_size (),
316+ logitsprocs = logits_processors ,
315317 )
316318
317319 @property
@@ -810,8 +812,7 @@ def __init__(
810812 max_model_len = vllm_config .model_config .max_model_len ,
811813 device = self .device ,
812814 pin_memory = self .pin_memory ,
813- vocab_size = vllm_config .model_config .get_vocab_size (),
814- )
815+ vocab_size = vllm_config .model_config .get_vocab_size ())
815816
816817 def pre_warmup (self ) -> None :
817818 # Set the number of kv cache blocks to the minimal value of 2 which is
@@ -1351,9 +1352,18 @@ def build_input_batch(self) -> PoolingInputBatch:
13511352 def load_model (self , prompt_lens : Iterable [int ],
13521353 num_decode_tokens : Iterable [int ]) -> None :
13531354
1354- if self .model_config .task == "embed" :
1355+ task = self .model_config .task
1356+ if task is None :
1357+ # Task is being deprecated upstream because the models
1358+ # support several tasks at once. But for now, here we need
1359+ # to know the task to load the model with
1360+ # AutoModelForSequenceClassification
1361+ task = self .model_config ._get_default_pooling_task (
1362+ self .model_config .architectures )
1363+
1364+ if task == "embed" :
13551365 self .model = AutoModel .from_pretrained (self .model_config .model )
1356- elif self . model_config . task == "classify" :
1366+ elif task == "classify" :
13571367 class_model = AutoModelForSequenceClassification .from_pretrained (
13581368 self .model_config .model )
13591369 if hasattr (class_model , "bert" ):
@@ -1368,7 +1378,7 @@ def load_model(self, prompt_lens: Iterable[int],
13681378 "Bert or Roberta for sequence classification" )
13691379 self .classifier = class_model .classifier
13701380 else :
1371- raise ValueError (f"Unsupported task { self . model_config . task } " )
1381+ raise ValueError (f"Unsupported task { task } " )
13721382
13731383 model_class_name = type (self .model ).__name__
13741384 self .is_roberta = "roberta" in model_class_name .lower ()
@@ -1393,7 +1403,7 @@ def load_model(self, prompt_lens: Iterable[int],
13931403 dynamic = False ,
13941404 backend = envs_spyre .VLLM_SPYRE_DYNAMO_BACKEND )
13951405
1396- if self . model_config . task == "classify" :
1406+ if task == "classify" :
13971407 tokenizer = AutoTokenizer .from_pretrained (self .model_config .model )
13981408 output = tokenizer (text = "foo" , text_pair = "bar" )
13991409 self .use_token_type_ids = "token_type_ids" in output
@@ -1404,13 +1414,13 @@ def load_model(self, prompt_lens: Iterable[int],
14041414 if hasattr (Pooler , "from_config_with_defaults" ):
14051415 # TODO: remove this when we no longer support
14061416 # vllm version v0.9.2
1407- if self . model_config . task == "embed" :
1417+ if task == "embed" :
14081418 self .pooler = Pooler .from_config_with_defaults (
14091419 pooler_config ,
14101420 pooling_type = PoolingType .CLS ,
14111421 normalize = True ,
14121422 softmax = False )
1413- elif self . model_config . task == "classify" :
1423+ elif task == "classify" :
14141424 self .pooler = ClassifierPooler (config = self .model_config ,
14151425 pooler = self ._pooler ,
14161426 classifier = self .classifier )
@@ -1428,10 +1438,10 @@ def load_model(self, prompt_lens: Iterable[int],
14281438 if 'default_pooling_type' in annotations :
14291439 extra_args ['default_pooling_type' ] = PoolingType .CLS
14301440
1431- if self . model_config . task == "embed" :
1441+ if task == "embed" :
14321442 self .pooler = Pooler .for_embed (pooler_config = pooler_config ,
14331443 ** extra_args )
1434- elif self . model_config . task == "classify" :
1444+ elif task == "classify" :
14351445 self .pooler = ClassifierPooler (
14361446 pooling = self ._pooler ,
14371447 classifier = self .classifier ,
0 commit comments