Commit 69d5e1d 1 parent 16cedce commit 69d5e1d Copy full SHA for 69d5e1d
File tree 1 file changed +2
-4
lines changed
1 file changed +2
-4
lines changed Original file line number Diff line number Diff line change @@ -610,18 +610,17 @@ def forward(
610
610
assert attn_metadata .num_encoder_tokens is not None
611
611
num_prefill_tokens = attn_metadata .num_encoder_tokens
612
612
613
+ output = torch .empty_like (query )
613
614
# Query for decode. KV is not needed because it is already cached.
614
615
decode_query = query [num_prefill_tokens :]
615
-
616
616
# QKV for prefill.
617
617
query = query [:num_prefill_tokens ]
618
+
618
619
if key is not None and value is not None :
619
620
key = key [:num_prefill_tokens ]
620
621
value = value [:num_prefill_tokens ]
621
622
622
623
if prefill_meta := attn_metadata .prefill_metadata :
623
- output = torch .empty_like (query )
624
-
625
624
# Prompt run.
626
625
# normal attention and DECODER
627
626
if attn_type == AttentionType .DECODER and (
@@ -738,7 +737,6 @@ def forward(
738
737
if decode_meta := attn_metadata .decode_metadata :
739
738
# Decoding run.
740
739
# Whether to use rocm custom paged attention or not
741
- output = torch .empty_like (decode_query )
742
740
num_seqs , num_heads , head_size = decode_query .shape
743
741
block_size = value_cache .shape [3 ]
744
742
gqa_ratio = num_heads // self .num_kv_heads
You can’t perform that action at this time.
0 commit comments