Skip to content

Commit 69d5e1d

Browse files
maleksan85Aleksandr Malyshev
and
Aleksandr Malyshev
authored
[BUGFIX] Restored handling of ROCM FA output as before adaptation of llama3.2 (#241)
* improved handling of output to be the same as before * after merge correction --------- Co-authored-by: Aleksandr Malyshev <[email protected]>
1 parent 16cedce commit 69d5e1d

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

vllm/attention/backends/rocm_flash_attn.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -610,18 +610,17 @@ def forward(
610610
assert attn_metadata.num_encoder_tokens is not None
611611
num_prefill_tokens = attn_metadata.num_encoder_tokens
612612

613+
output = torch.empty_like(query)
613614
# Query for decode. KV is not needed because it is already cached.
614615
decode_query = query[num_prefill_tokens:]
615-
616616
# QKV for prefill.
617617
query = query[:num_prefill_tokens]
618+
618619
if key is not None and value is not None:
619620
key = key[:num_prefill_tokens]
620621
value = value[:num_prefill_tokens]
621622

622623
if prefill_meta := attn_metadata.prefill_metadata:
623-
output = torch.empty_like(query)
624-
625624
# Prompt run.
626625
# normal attention and DECODER
627626
if attn_type == AttentionType.DECODER and (
@@ -738,7 +737,6 @@ def forward(
738737
if decode_meta := attn_metadata.decode_metadata:
739738
# Decoding run.
740739
# Whether to use rocm custom paged attention or not
741-
output = torch.empty_like(decode_query)
742740
num_seqs, num_heads, head_size = decode_query.shape
743741
block_size = value_cache.shape[3]
744742
gqa_ratio = num_heads // self.num_kv_heads

0 commit comments

Comments
 (0)