@@ -152,12 +152,11 @@ def get_output(self) -> ModelRunnerOutput:
152152 # Release the device tensor once the copy has completed
153153 self ._async_copy_ready_event .synchronize ()
154154
155- sampled_token_ids_np = self ._sampled_token_ids_cpu .numpy ()
156- valid_sampled_token_ids = [sampled_token_ids_np [i ] for i in range (len (sampled_token_ids_np ))]
155+ valid_sampled_token_ids = self ._sampled_token_ids_cpu .tolist ()
157156 del self ._sampled_token_ids
158157 for i in self ._invalid_req_indices :
159158 if i < len (valid_sampled_token_ids ):
160- valid_sampled_token_ids [i ] = np . array ([], dtype = np . int32 )
159+ valid_sampled_token_ids [i ]. clear ( )
161160
162161 output = self ._model_runner_output
163162 output .sampled_token_ids [:len (valid_sampled_token_ids )] = valid_sampled_token_ids
@@ -2903,7 +2902,7 @@ def unified_execute_model(self,
29032902 self .input_batch .req_id_to_index .copy ()
29042903
29052904 with self .profiler .record_event ('internal' , 'unified_postprocess' ):
2906- sampled_token_ids : list [np . ndarray ] = [np . array ([], dtype = np . int32 ) for _ in batch .req_ids_cpu ]
2905+ sampled_token_ids : list [list [ int ]] = [[] for _ in batch .req_ids_cpu ]
29072906 if self .use_async_scheduling :
29082907 sampled_token_ids_hpu = sampler_output .sampled_token_ids .view (- 1 , 1 )
29092908 self .input_batch .prev_sampled_token_ids = sampled_token_ids_hpu .flatten ()
@@ -2916,11 +2915,8 @@ def unified_execute_model(self,
29162915 }
29172916 else :
29182917 sampled_token_ids_cpu = sampler_output .sampled_token_ids .cpu ()
2919-
2920- sampled_token_ids_np = sampled_token_ids_cpu .numpy ()
2921- for req_id , tokens_array in zip (selected_req_ids , sampled_token_ids_np ):
2922- idx = self .input_batch .req_id_to_index [req_id ]
2923- sampled_token_ids [idx ] = tokens_array
2918+ for req_id , tokens in zip (selected_req_ids , sampled_token_ids_cpu .tolist ()):
2919+ sampled_token_ids [self .input_batch .req_id_to_index [req_id ]].extend (tokens )
29242920
29252921 #TODO: add support for multi-token output
29262922 assert sampled_token_ids_cpu .size (1 ) == 1 , 'Currently only single token output is supported!'
@@ -2941,7 +2937,7 @@ def unified_execute_model(self,
29412937 num_tokens = len (token_ids )
29422938 self .input_batch .token_ids_cpu [i , seq_len :seq_len + num_tokens ] = token_ids
29432939 self .input_batch .num_tokens [i ] += len (token_ids )
2944- req_state .output_token_ids .extend (token_ids . tolist () )
2940+ req_state .output_token_ids .extend (token_ids )
29452941
29462942 if self .use_async_scheduling :
29472943 model_runner_output = ModelRunnerOutput (
@@ -3366,9 +3362,7 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu
33663362 self .input_batch .req_id_to_index .copy ()
33673363
33683364 max_req_index = max (self .input_batch .req_id_to_index .values ())
3369- postprocessed_sampled_token_ids : list [np .ndarray ] = [
3370- np .array ([], dtype = np .int32 ) for _ in range (max_req_index + 1 )
3371- ]
3365+ postprocessed_sampled_token_ids : list [list [int ]] = [[] for _ in range (max_req_index + 1 )]
33723366 if self .use_async_scheduling :
33733367 self .input_batch .prev_sampled_token_ids = sampled_token_ids .flatten ()
33743368 # self.input_batch.prev_sampled_token_ids_invalid_indices
@@ -3392,10 +3386,9 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu
33923386 else :
33933387 decode_sampled_token_ids = [tensor .cpu ()[:num_decodes ] for tensor in decode_sampled_token_ids ]
33943388 if decode_sampled_token_ids + prefill_sampled_token_ids :
3395- sampled_token_ids_tensor = torch .cat (decode_sampled_token_ids + prefill_sampled_token_ids )
3396- sampled_token_ids_np = sampled_token_ids_tensor .cpu ().numpy ().flatten ()
3389+ sampled_token_ids_list = torch .cat (decode_sampled_token_ids + prefill_sampled_token_ids ).tolist ()
33973390 else :
3398- sampled_token_ids_np = np . array ([], dtype = np . int32 )
3391+ sampled_token_ids_list = []
33993392 sampled_token_requests = \
34003393 decode_sampled_requests + prefill_sampled_requests
34013394 max_req_index = max (self .input_batch .req_id_to_index .values ())
@@ -3405,10 +3398,9 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu
34053398 for i , req_id in enumerate (sampled_token_requests ):
34063399 num_tokens = spec_decode_num_tokens [
34073400 i ] if spec_decode_num_tokens is not None and i < num_decodes else 1
3408- req_idx = self .input_batch .req_id_to_index [req_id ]
3409- postprocessed_sampled_token_ids [req_idx ] = np .array (sampled_token_ids_np [start_idx :start_idx +
3410- num_tokens ],
3411- dtype = np .int32 )
3401+ postprocessed_sampled_token_ids [
3402+ self .input_batch .req_id_to_index [req_id ]] += sampled_token_ids_list [start_idx :start_idx +
3403+ num_tokens ]
34123404 start_idx += num_tokens
34133405
34143406 ################## RETURN ##################
@@ -3431,7 +3423,7 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu
34313423 # the sampled tokens back, because there's no direct communication
34323424 # between the first-stage worker and the last-stage worker.
34333425 for req_idx , sampled_ids in enumerate (postprocessed_sampled_token_ids [:num_reqs ]):
3434- if sampled_ids is None :
3426+ if not sampled_ids :
34353427 continue
34363428
34373429 start_idx = self .input_batch .num_tokens_no_spec [req_idx ]
0 commit comments