Skip to content

Commit 8d92864

Browse files
attafosuCopilot
andauthored
multimodal support for unified attn (#423)
- Enables Multimodal support for unified attention --------- Signed-off-by: attafosu <[email protected]> Signed-off-by: Thomas Atta-Fosu <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 8ef73e6 commit 8d92864

File tree

3 files changed

+132
-31
lines changed

3 files changed

+132
-31
lines changed

tests/full_tests/ci_gsm8k_tests.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,14 @@ run_qwen2_5_vl_test() {
218218
echo "✅ Test with multimodal-support with qwen2.5-vl-7b passed."
219219
}
220220

221+
# Multimodal-support + unified attention with qwen2.5-vl
222+
run_qwen2_5_vl_unified_attn_test() {
223+
echo "➡️ Testing Qwen2.5-VL-7B with unified attention..."
224+
VLLM_SKIP_WARMUP=true VLLM_UNIFIED_ATTN=True PT_HPU_LAZY_MODE=1 VLLM_USE_V1=1 \
225+
python -u "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/generation_mm.py" --model-card-path "${VLLM_GAUDI_PREFIX}/tests/full_tests/model_cards/qwen2.5-vl-7b.yaml"
226+
echo "✅ Test multimodal-support + unified attention with qwen2.5-vl-7b passed."
227+
}
228+
221229
# Spec decode with ngram
222230
run_spec_decode_ngram_test() {
223231
echo "➡️ Testing Spec-decode with ngram..."
@@ -292,6 +300,7 @@ launch_all_tests() {
292300
run_gsm8k_deepseek_test
293301
run_gsm8k_qwen3_30b_test
294302
run_qwen2_5_vl_test
303+
run_qwen2_5_vl_unified_attn_test
295304
run_spec_decode_ngram_test
296305
run_spec_decode_eagle3_test
297306
run_spec_decode_eagle3_num_spec_2_test

vllm_gaudi/ops/hpu_rotary_embedding.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,12 @@ def forward_oot(
657657
) -> tuple[torch.Tensor, torch.Tensor]:
658658
from habana_frameworks.torch.hpex.kernels import (RotaryPosEmbeddingMode, apply_rotary_pos_emb)
659659

660+
# NOTE (attafosu): positions is expected to be 2D tensor [3, seq_len],
661+
# But the unified_attention API sends it as 3D [3, seq_len, 1]
662+
# So we flatten it to 2D here
663+
if positions.ndim == 3:
664+
assert positions.shape[-1] == 1, "Expected last dimension to be 1 for 3d positions"
665+
positions = positions.squeeze(-1)
660666
num_tokens = positions.shape[-1]
661667
cos_sin = self.cos_sin_cache[positions]
662668
cos, sin = cos_sin.chunk(2, dim=-1)

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 117 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,6 +1440,41 @@ def _gather_mm_embeddings(
14401440

14411441
return mm_embeds, is_mm_embed
14421442

1443+
def _get_model_mm_inputs(
1444+
self,
1445+
token_ids: torch.Tensor,
1446+
total_num_scheduled_tokens: Optional[int],
1447+
scheduler_output: "SchedulerOutput",
1448+
req_ids: list[str],
1449+
) -> tuple[torch.Tensor | None, dict[str, Any] | None]:
1450+
inputs_embeds = None
1451+
model_mm_kwargs = None
1452+
if self.supports_mm_inputs:
1453+
# Run the multimodal encoder if any.
1454+
with self.profiler.record_event('internal', 'prepare_input_encoders'):
1455+
self._execute_mm_encoder(scheduler_output, req_ids)
1456+
1457+
mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output,
1458+
req_ids,
1459+
total_num_scheduled_tokens=total_num_scheduled_tokens)
1460+
# TODO: Only get embeddings for valid token_ids. Ignore token_ids[<pad_idxs>] # noqa
1461+
# This may require moving multimodal input preps into _prepare_inputs, # noqa
1462+
# to avoid padding issues.
1463+
htorch.core.mark_step()
1464+
inputs_embeds = self.model.embed_input_ids(
1465+
token_ids,
1466+
multimodal_embeddings=mm_embeds,
1467+
is_multimodal=is_mm_embed,
1468+
)
1469+
1470+
model_mm_kwargs = self._extract_mm_kwargs(scheduler_output)
1471+
model_mm_kwargs = MultiModalKwargs.as_kwargs(
1472+
model_mm_kwargs,
1473+
device=self.device,
1474+
)
1475+
1476+
return inputs_embeds, model_mm_kwargs
1477+
14431478
def get_model(self) -> torch.nn.Module:
14441479
if isinstance(self.model, HpuModelAdapter):
14451480
return self.model.model
@@ -1644,6 +1679,54 @@ def _align_and_pad_mrope_positions(self, req_ids: list[str], context_lens: list[
16441679
dst_start += target_len
16451680
return mrope_position_tensor
16461681

1682+
# modified from: vllm/v1/worker/gpu_model_runner.py:_calc_mrope_positions
1683+
def get_unified_mrope_position_ids(self, req_ids: list[str], num_computed_tokens: torch.tensor,
1684+
num_scheduled_tokens: torch.tensor, target_len: int,
1685+
padding_gen: int) -> torch.Tensor:
1686+
out_shape = (3, target_len)
1687+
mrope_position_tensor = torch.full(out_shape, padding_gen, dtype=torch.int32, device='cpu')
1688+
mrope_pos_ptr = 0
1689+
for index, req_id in enumerate(req_ids):
1690+
req = self.requests[req_id]
1691+
assert req.mrope_positions is not None
1692+
1693+
context_len = num_computed_tokens[index]
1694+
query_len = num_scheduled_tokens[index]
1695+
num_prompt_tokens = len(
1696+
req.prompt_token_ids) # The gpu runner uses either prompt_token_ids or prompt_embeds # noqa 501
1697+
1698+
if context_len + query_len > num_prompt_tokens:
1699+
prompt_part_len = max(0, num_prompt_tokens - context_len)
1700+
completion_part_len = max(0, query_len - prompt_part_len)
1701+
else:
1702+
prompt_part_len = query_len
1703+
completion_part_len = 0
1704+
1705+
assert query_len == prompt_part_len + completion_part_len
1706+
if prompt_part_len > 0:
1707+
# prompt's mrope_positions are pre-computed
1708+
dst_start = mrope_pos_ptr
1709+
dst_end = mrope_pos_ptr + prompt_part_len
1710+
src_start = context_len
1711+
src_end = context_len + prompt_part_len
1712+
mrope_position_tensor[:, dst_start:dst_end].copy_(req.mrope_positions[:, src_start:src_end],
1713+
non_blocking=True)
1714+
1715+
mrope_pos_ptr += prompt_part_len
1716+
if completion_part_len > 0:
1717+
# compute completion's mrope_positions on-the-fly
1718+
dst_start = mrope_pos_ptr
1719+
dst_end = mrope_pos_ptr + completion_part_len
1720+
pos_for_mrope = MRotaryEmbedding.get_next_input_positions(
1721+
mrope_position_delta=req.mrope_position_delta,
1722+
context_len=context_len + prompt_part_len,
1723+
seq_len=context_len + prompt_part_len + completion_part_len,
1724+
)
1725+
mrope_position_tensor[:, dst_start:dst_end] = torch.tensor(pos_for_mrope, dtype=torch.int32)
1726+
mrope_pos_ptr += completion_part_len
1727+
1728+
return mrope_position_tensor.to('hpu', non_blocking=True)
1729+
16471730
def _skip_bucketing(self, seq_lens, num_blocks):
16481731
return (len(seq_lens), 0, 0)
16491732

@@ -2857,10 +2940,19 @@ def prepare_unified_batch(self, scheduler_output):
28572940
self._prepare_input_ids(scheduler_output)
28582941
input_ids_hpu = self.input_ids_hpu
28592942

2860-
return create_unified_batch(self.input_batch.req_ids, all_token_ids, num_computed_tokens, num_scheduled_tokens,
2861-
num_prompt_tokens, block_table, self.block_size, self.dtype,
2862-
self.unified_attn_persistent_ctx, self.unified_bucketing_fn, self.get_dp_padding,
2863-
input_ids_hpu, num_decodes)
2943+
batch = create_unified_batch(self.input_batch.req_ids, all_token_ids, num_computed_tokens, num_scheduled_tokens,
2944+
num_prompt_tokens, block_table, self.block_size, self.dtype,
2945+
self.unified_attn_persistent_ctx, self.unified_bucketing_fn, self.get_dp_padding,
2946+
input_ids_hpu, num_decodes)
2947+
if self.uses_mrope:
2948+
batch.token_positions = self.get_unified_mrope_position_ids(
2949+
self.input_batch.req_ids,
2950+
num_computed_tokens,
2951+
num_scheduled_tokens,
2952+
target_len=batch.token_ids.size(0),
2953+
padding_gen=-1,
2954+
)
2955+
return batch
28642956

28652957
@torch.inference_mode()
28662958
def unified_execute_model(self,
@@ -2871,6 +2963,18 @@ def unified_execute_model(self,
28712963
with self.profiler.record_event('internal', 'prepare_unified_batch'):
28722964
batch = self.prepare_unified_batch(scheduler_output)
28732965
htorch.core.mark_step()
2966+
2967+
# Prepare multimodal inputs if any
2968+
inputs_embeds, model_mm_kwargs = self._get_model_mm_inputs(
2969+
batch.token_ids.unsqueeze(
2970+
0 # NOTE(attafosu): We unsqueeze at dim0 here to ensure the input tokens shape matches the expected batch-first format required by "model.embed_input_ids()" call in _get_model_mm_inputs and downstream model components. # noqa E501
2971+
),
2972+
batch.token_ids.shape[0],
2973+
scheduler_output,
2974+
self.input_batch.req_ids,
2975+
)
2976+
htorch.core.mark_step()
2977+
28742978
if self.is_driver_worker:
28752979
unified_attn_cfg = self._get_unified_config(batch.attn_metadata, batch.logits_indices)
28762980
(phase, qlen, num_shared_blocks, num_unique_blocks, num_logits) = unified_attn_cfg
@@ -2888,6 +2992,8 @@ def unified_execute_model(self,
28882992
kv_caches=self.kv_caches,
28892993
lora_logits_mask=None,
28902994
lora_mask=None,
2995+
inputs_embeds=inputs_embeds,
2996+
model_mm_kwargs=model_mm_kwargs,
28912997
warmup_mode=warmup_mode)
28922998
selected_req_ids = [batch.req_ids_cpu[idx] for idx in batch.logits_groups_cpu.tolist()]
28932999
htorch.core.mark_step()
@@ -3127,33 +3233,13 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu
31273233
for idx, (req_id, prompt_len, token_ids, position_ids, attn_metadata, logits_indices,
31283234
logits_requests) in enumerate(zip(*shallow_tuple(prefill_data))):
31293235

3130-
inputs_embeds = None
3131-
model_mm_kwargs = None
3132-
if self.supports_mm_inputs:
3133-
# Run the multimodal encoder if any.
3134-
with self.profiler.record_event('internal', 'prepare_input_encoders'):
3135-
self._execute_mm_encoder(scheduler_output, req_id)
3136-
htorch.core.mark_step()
3137-
3138-
mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output,
3139-
req_id,
3140-
total_num_scheduled_tokens=token_ids.shape[-1])
3141-
htorch.core.mark_step()
3142-
3143-
# TODO: Only get embeddings for valid token_ids. Ignore token_ids[<pad_idxs>] # noqa E501
3144-
# This may require moving multimodal input preps into _prepare_inputs, # noqa E501
3145-
# to avoid padding issues.
3146-
inputs_embeds = self.model.embed_input_ids(
3147-
token_ids,
3148-
multimodal_embeddings=mm_embeds,
3149-
is_multimodal=is_mm_embed,
3150-
)
3151-
3152-
model_mm_kwargs = self._extract_mm_kwargs(scheduler_output)
3153-
model_mm_kwargs = MultiModalKwargs.as_kwargs(
3154-
model_mm_kwargs,
3155-
device=self.device,
3156-
)
3236+
# Prepare multimodal inputs if any
3237+
inputs_embeds, model_mm_kwargs = self._get_model_mm_inputs(
3238+
token_ids,
3239+
token_ids.shape[-1],
3240+
scheduler_output,
3241+
req_id,
3242+
)
31573243

31583244
lora_mask, lora_logits_mask = self._configure_lora(token_ids, self.requests, req_id, True)
31593245

0 commit comments

Comments
 (0)