@@ -1133,6 +1133,26 @@ def _preprocess_inputs(self, inputs: Dict[str, Any]):
11331133 self .previous_kv_lens_offsets_cuda [:num_gen_requests ])
11341134 return inputs
11351135
1136+ def _postprocess_inputs (self , inputs : Dict [str , Any ]):
1137+ """
1138+ Postprocess to make sure model forward doesn't change the inputs.
1139+ It is only used in cuda graph capture, because other cases will prepare
1140+ new inputs before the model forward.
1141+ """
1142+ if self .enable_spec_decode and not self ._disable_overlap_scheduler :
1143+ if inputs ['attn_metadata' ].kv_cache_manager is not None :
1144+ num_seqs = inputs ['attn_metadata' ].num_seqs
1145+ num_ctx_requests = inputs ['attn_metadata' ].num_contexts
1146+ num_gen_requests = inputs ['attn_metadata' ].num_generations
1147+ num_ctx_tokens = inputs ['attn_metadata' ].num_ctx_tokens
1148+ previous_batch_tokens = inputs ['input_ids' ].shape [
1149+ 0 ] - num_ctx_tokens
1150+ inputs ['position_ids' ][0 , num_ctx_tokens :] -= (
1151+ self .previous_pos_id_offsets_cuda [:previous_batch_tokens ])
1152+ inputs ['attn_metadata' ].kv_lens_cuda [
1153+ num_ctx_requests :num_seqs ] -= (
1154+ self .previous_kv_lens_offsets_cuda [:num_gen_requests ])
1155+
11361156 def _get_all_rank_num_tokens (self , attn_metadata : AttentionMetadata ):
11371157 if self .enable_attention_dp :
11381158 return list (self .dist .tp_allgather (attn_metadata .num_tokens ))
@@ -2206,8 +2226,12 @@ def capture_forward_fn(inputs: Dict[str, Any]):
22062226 gather_ids = gather_ids ,
22072227 gather_context_logits = gather_context_logits )
22082228
2229+ def capture_postprocess_fn (inputs : Dict [str , Any ]):
2230+ self ._postprocess_inputs (inputs )
2231+
22092232 self .cuda_graph_runner .capture (batch_size ,
2210- capture_forward_fn , inputs )
2233+ capture_forward_fn , inputs ,
2234+ capture_postprocess_fn )
22112235
22122236 # here we don't need to use context since cuda graph capture didn't run kernel.
22132237 # maybe we need a cleaner way to do this.
0 commit comments