@@ -62,9 +62,10 @@ def find_closest_graph_batch_size(self, batch_size):
6262 else :
6363 return None
6464
65- def _capture_decode (self , decode_func , input_ids : torch . Tensor , infer_state : InferStateInfo ):
65+ def _capture_decode (self , decode_func , infer_state : InferStateInfo ):
6666 dist_group : CustomProcessGroup = infer_state .dist_group
6767 graph_obj = torch .cuda .CUDAGraph ()
68+ input_ids = infer_state .input_ids
6869 batch_size = input_ids .shape [0 ]
6970 infer_state .max_len_in_batch = self .graph_max_len_in_batch
7071 infer_state .total_token_num = self .graph_max_len_in_batch * batch_size
@@ -78,27 +79,26 @@ def _capture_decode(self, decode_func, input_ids: torch.Tensor, infer_state: Inf
7879 # 中的 tensor。
7980 for _ in range (1 ):
8081 torch .cuda .synchronize ()
81- decode_func (input_ids , copy .copy (infer_state ))
82+ decode_func (copy .copy (infer_state ))
8283 torch .cuda .synchronize ()
8384
8485 with lightllm_capture_graph (dist_group ):
8586 with torch .cuda .graph (graph_obj , pool = self .mempool ):
86- model_output = decode_func (input_ids , infer_state )
87- self .graph [batch_size ] = (graph_obj , input_ids , infer_state , model_output )
87+ model_output = decode_func (infer_state )
88+ self .graph [batch_size ] = (graph_obj , infer_state , model_output )
8889 graph_obj .replay ()
8990 return model_output
9091
9192 def _capture_decode_overlap (
9293 self ,
9394 decode_func ,
94- input_ids : torch .Tensor ,
9595 infer_state : InferStateInfo ,
96- input_ids1 : torch .Tensor ,
9796 infer_state1 : InferStateInfo ,
9897 ):
9998 dist_group : CustomProcessGroup = infer_state .dist_group
10099 dist_group1 = infer_state1 .dist_group
101100 graph_obj = torch .cuda .CUDAGraph ()
101+ input_ids = infer_state .input_ids
102102 batch_size = input_ids .shape [0 ]
103103 infer_state .max_len_in_batch = self .graph_max_len_in_batch
104104 infer_state .total_token_num = self .graph_max_len_in_batch * batch_size
@@ -107,17 +107,15 @@ def _capture_decode_overlap(
107107 # warmup
108108 for _ in range (1 ):
109109 torch .cuda .synchronize ()
110- decode_func (input_ids , copy .copy (infer_state ), input_ids1 , copy .copy (infer_state1 ))
110+ decode_func (copy .copy (infer_state ), copy .copy (infer_state1 ))
111111 torch .cuda .synchronize ()
112112 with lightllm_capture_graph (dist_group1 ):
113113 with lightllm_capture_graph (dist_group ):
114114 with torch .cuda .graph (graph_obj , pool = self .mempool ):
115- model_output , model_output1 = decode_func (input_ids , infer_state , input_ids1 , infer_state1 )
115+ model_output , model_output1 = decode_func (infer_state , infer_state1 )
116116 self .graph [batch_size ] = (
117117 graph_obj ,
118- input_ids ,
119118 infer_state ,
120- input_ids1 ,
121119 infer_state1 ,
122120 model_output ,
123121 model_output1 ,
@@ -128,59 +126,50 @@ def _capture_decode_overlap(
128126 def capture_decode (
129127 self ,
130128 decode_func ,
131- input_ids : torch .Tensor ,
132129 infer_state : InferStateInfo ,
133- input_ids1 : Optional [torch .Tensor ] = None ,
134- infer_state1 : Optional [torch .Tensor ] = None ,
130+ infer_state1 : Optional [InferStateInfo ] = None ,
135131 ):
136132 """
137133 Capture the cuda graph for the decoding stage.
138134 input_ids1 and infer_state1 is used for the overlap.
139135 """
140136 if self .enable_decode_microbatch_overlap :
141- return self ._capture_decode_overlap (decode_func , input_ids , infer_state , input_ids1 , infer_state1 )
137+ return self ._capture_decode_overlap (decode_func , infer_state , infer_state1 )
142138 else :
143- assert input_ids1 is None and infer_state1 is None
144- return self ._capture_decode (decode_func , input_ids , infer_state )
139+ assert infer_state1 is None
140+ return self ._capture_decode (decode_func , infer_state )
145141
146- def _replay (self , input_ids : torch .Tensor , infer_state : InferStateInfo ):
147- batch_size = input_ids .shape [0 ]
148- graph_obj , graph_input_ids , graph_infer_state , graph_output = self .graph [batch_size ]
149- graph_input_ids .copy_ (input_ids )
142+ def _replay (self , infer_state : InferStateInfo ):
143+ batch_size = infer_state .input_ids .shape [0 ]
144+ graph_obj , graph_infer_state , graph_output = self .graph [batch_size ]
150145 graph_infer_state .copy_for_cuda_graph (infer_state )
151146 graph_obj .replay ()
152147 return graph_output
153148
154149 def _replay_overlap (
155150 self ,
156- input_ids : torch .Tensor ,
157151 infer_state : InferStateInfo ,
158- input_ids1 : torch .Tensor ,
159152 infer_state1 : InferStateInfo ,
160153 ):
161- batch_size = input_ids .shape [0 ]
154+ batch_size = infer_state . input_ids .shape [0 ]
162155 (
163156 graph_obj ,
164- graph_input_ids ,
165157 graph_infer_state ,
166- graph_input_ids1 ,
167158 graph_infer_state1 ,
168159 graph_model_output ,
169160 graph_model_output1 ,
170161 ) = self .graph [batch_size ]
171- graph_input_ids .copy_ (input_ids )
172162 graph_infer_state .copy_for_cuda_graph (infer_state )
173- graph_input_ids1 .copy_ (input_ids1 )
174163 graph_infer_state1 .copy_for_cuda_graph (infer_state1 )
175164 graph_obj .replay ()
176165 return graph_model_output , graph_model_output1
177166
178- def replay (self , input_ids , infer_state , input_ids1 = None , infer_state1 = None ):
167+ def replay (self , infer_state , infer_state1 = None ):
179168 if self .enable_decode_microbatch_overlap :
180- return self ._replay_overlap (input_ids , infer_state , input_ids1 , infer_state1 )
169+ return self ._replay_overlap (infer_state , infer_state1 )
181170 else :
182- assert input_ids1 is None and infer_state1 is None
183- return self ._replay (input_ids , infer_state )
171+ assert infer_state1 is None
172+ return self ._replay (infer_state )
184173
185174 @torch .no_grad ()
186175 def warmup (self , model ):
0 commit comments