@@ -1440,6 +1440,41 @@ def _gather_mm_embeddings(
14401440
14411441 return mm_embeds , is_mm_embed
14421442
1443+ def _get_model_mm_inputs (
1444+ self ,
1445+ token_ids : torch .Tensor ,
1446+ total_num_scheduled_tokens : Optional [int ],
1447+ scheduler_output : "SchedulerOutput" ,
1448+ req_ids : list [str ],
1449+ ) -> tuple [torch .Tensor | None , dict [str , Any ] | None ]:
1450+ inputs_embeds = None
1451+ model_mm_kwargs = None
1452+ if self .supports_mm_inputs :
1453+ # Run the multimodal encoder if any.
1454+ with self .profiler .record_event ('internal' , 'prepare_input_encoders' ):
1455+ self ._execute_mm_encoder (scheduler_output , req_ids )
1456+
1457+ mm_embeds , is_mm_embed = self ._gather_mm_embeddings (scheduler_output ,
1458+ req_ids ,
1459+ total_num_scheduled_tokens = total_num_scheduled_tokens )
1460+ # TODO: Only get embeddings for valid token_ids. Ignore token_ids[<pad_idxs>] # noqa
1461+ # This may require moving multimodal input preps into _prepare_inputs, # noqa
1462+ # to avoid padding issues.
1463+ htorch .core .mark_step ()
1464+ inputs_embeds = self .model .embed_input_ids (
1465+ token_ids ,
1466+ multimodal_embeddings = mm_embeds ,
1467+ is_multimodal = is_mm_embed ,
1468+ )
1469+
1470+ model_mm_kwargs = self ._extract_mm_kwargs (scheduler_output )
1471+ model_mm_kwargs = MultiModalKwargs .as_kwargs (
1472+ model_mm_kwargs ,
1473+ device = self .device ,
1474+ )
1475+
1476+ return inputs_embeds , model_mm_kwargs
1477+
14431478 def get_model (self ) -> torch .nn .Module :
14441479 if isinstance (self .model , HpuModelAdapter ):
14451480 return self .model .model
@@ -1644,6 +1679,54 @@ def _align_and_pad_mrope_positions(self, req_ids: list[str], context_lens: list[
16441679 dst_start += target_len
16451680 return mrope_position_tensor
16461681
1682+ # modified from: vllm/v1/worker/gpu_model_runner.py:_calc_mrope_positions
1683+ def get_unified_mrope_position_ids (self , req_ids : list [str ], num_computed_tokens : torch .tensor ,
1684+ num_scheduled_tokens : torch .tensor , target_len : int ,
1685+ padding_gen : int ) -> torch .Tensor :
1686+ out_shape = (3 , target_len )
1687+ mrope_position_tensor = torch .full (out_shape , padding_gen , dtype = torch .int32 , device = 'cpu' )
1688+ mrope_pos_ptr = 0
1689+ for index , req_id in enumerate (req_ids ):
1690+ req = self .requests [req_id ]
1691+ assert req .mrope_positions is not None
1692+
1693+ context_len = num_computed_tokens [index ]
1694+ query_len = num_scheduled_tokens [index ]
1695+ num_prompt_tokens = len (
1696+ req .prompt_token_ids ) # The gpu runner uses either prompt_token_ids or prompt_embeds # noqa 501
1697+
1698+ if context_len + query_len > num_prompt_tokens :
1699+ prompt_part_len = max (0 , num_prompt_tokens - context_len )
1700+ completion_part_len = max (0 , query_len - prompt_part_len )
1701+ else :
1702+ prompt_part_len = query_len
1703+ completion_part_len = 0
1704+
1705+ assert query_len == prompt_part_len + completion_part_len
1706+ if prompt_part_len > 0 :
1707+ # prompt's mrope_positions are pre-computed
1708+ dst_start = mrope_pos_ptr
1709+ dst_end = mrope_pos_ptr + prompt_part_len
1710+ src_start = context_len
1711+ src_end = context_len + prompt_part_len
1712+ mrope_position_tensor [:, dst_start :dst_end ].copy_ (req .mrope_positions [:, src_start :src_end ],
1713+ non_blocking = True )
1714+
1715+ mrope_pos_ptr += prompt_part_len
1716+ if completion_part_len > 0 :
1717+ # compute completion's mrope_positions on-the-fly
1718+ dst_start = mrope_pos_ptr
1719+ dst_end = mrope_pos_ptr + completion_part_len
1720+ pos_for_mrope = MRotaryEmbedding .get_next_input_positions (
1721+ mrope_position_delta = req .mrope_position_delta ,
1722+ context_len = context_len + prompt_part_len ,
1723+ seq_len = context_len + prompt_part_len + completion_part_len ,
1724+ )
1725+ mrope_position_tensor [:, dst_start :dst_end ] = torch .tensor (pos_for_mrope , dtype = torch .int32 )
1726+ mrope_pos_ptr += completion_part_len
1727+
1728+ return mrope_position_tensor .to ('hpu' , non_blocking = True )
1729+
16471730 def _skip_bucketing (self , seq_lens , num_blocks ):
16481731 return (len (seq_lens ), 0 , 0 )
16491732
@@ -2857,10 +2940,19 @@ def prepare_unified_batch(self, scheduler_output):
28572940 self ._prepare_input_ids (scheduler_output )
28582941 input_ids_hpu = self .input_ids_hpu
28592942
2860- return create_unified_batch (self .input_batch .req_ids , all_token_ids , num_computed_tokens , num_scheduled_tokens ,
2861- num_prompt_tokens , block_table , self .block_size , self .dtype ,
2862- self .unified_attn_persistent_ctx , self .unified_bucketing_fn , self .get_dp_padding ,
2863- input_ids_hpu , num_decodes )
2943+ batch = create_unified_batch (self .input_batch .req_ids , all_token_ids , num_computed_tokens , num_scheduled_tokens ,
2944+ num_prompt_tokens , block_table , self .block_size , self .dtype ,
2945+ self .unified_attn_persistent_ctx , self .unified_bucketing_fn , self .get_dp_padding ,
2946+ input_ids_hpu , num_decodes )
2947+ if self .uses_mrope :
2948+ batch .token_positions = self .get_unified_mrope_position_ids (
2949+ self .input_batch .req_ids ,
2950+ num_computed_tokens ,
2951+ num_scheduled_tokens ,
2952+ target_len = batch .token_ids .size (0 ),
2953+ padding_gen = - 1 ,
2954+ )
2955+ return batch
28642956
28652957 @torch .inference_mode ()
28662958 def unified_execute_model (self ,
@@ -2871,6 +2963,18 @@ def unified_execute_model(self,
28712963 with self .profiler .record_event ('internal' , 'prepare_unified_batch' ):
28722964 batch = self .prepare_unified_batch (scheduler_output )
28732965 htorch .core .mark_step ()
2966+
2967+ # Prepare multimodal inputs if any
2968+ inputs_embeds , model_mm_kwargs = self ._get_model_mm_inputs (
2969+ batch .token_ids .unsqueeze (
2970+ 0 # NOTE(attafosu): We unsqueeze at dim0 here to ensure the input tokens shape matches the expected batch-first format required by "model.embed_input_ids()" call in _get_model_mm_inputs and downstream model components. # noqa E501
2971+ ),
2972+ batch .token_ids .shape [0 ],
2973+ scheduler_output ,
2974+ self .input_batch .req_ids ,
2975+ )
2976+ htorch .core .mark_step ()
2977+
28742978 if self .is_driver_worker :
28752979 unified_attn_cfg = self ._get_unified_config (batch .attn_metadata , batch .logits_indices )
28762980 (phase , qlen , num_shared_blocks , num_unique_blocks , num_logits ) = unified_attn_cfg
@@ -2888,6 +2992,8 @@ def unified_execute_model(self,
28882992 kv_caches = self .kv_caches ,
28892993 lora_logits_mask = None ,
28902994 lora_mask = None ,
2995+ inputs_embeds = inputs_embeds ,
2996+ model_mm_kwargs = model_mm_kwargs ,
28912997 warmup_mode = warmup_mode )
28922998 selected_req_ids = [batch .req_ids_cpu [idx ] for idx in batch .logits_groups_cpu .tolist ()]
28932999 htorch .core .mark_step ()
@@ -3127,33 +3233,13 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu
31273233 for idx , (req_id , prompt_len , token_ids , position_ids , attn_metadata , logits_indices ,
31283234 logits_requests ) in enumerate (zip (* shallow_tuple (prefill_data ))):
31293235
3130- inputs_embeds = None
3131- model_mm_kwargs = None
3132- if self .supports_mm_inputs :
3133- # Run the multimodal encoder if any.
3134- with self .profiler .record_event ('internal' , 'prepare_input_encoders' ):
3135- self ._execute_mm_encoder (scheduler_output , req_id )
3136- htorch .core .mark_step ()
3137-
3138- mm_embeds , is_mm_embed = self ._gather_mm_embeddings (scheduler_output ,
3139- req_id ,
3140- total_num_scheduled_tokens = token_ids .shape [- 1 ])
3141- htorch .core .mark_step ()
3142-
3143- # TODO: Only get embeddings for valid token_ids. Ignore token_ids[<pad_idxs>] # noqa E501
3144- # This may require moving multimodal input preps into _prepare_inputs, # noqa E501
3145- # to avoid padding issues.
3146- inputs_embeds = self .model .embed_input_ids (
3147- token_ids ,
3148- multimodal_embeddings = mm_embeds ,
3149- is_multimodal = is_mm_embed ,
3150- )
3151-
3152- model_mm_kwargs = self ._extract_mm_kwargs (scheduler_output )
3153- model_mm_kwargs = MultiModalKwargs .as_kwargs (
3154- model_mm_kwargs ,
3155- device = self .device ,
3156- )
3236+ # Prepare multimodal inputs if any
3237+ inputs_embeds , model_mm_kwargs = self ._get_model_mm_inputs (
3238+ token_ids ,
3239+ token_ids .shape [- 1 ],
3240+ scheduler_output ,
3241+ req_id ,
3242+ )
31573243
31583244 lora_mask , lora_logits_mask = self ._configure_lora (token_ids , self .requests , req_id , True )
31593245
0 commit comments