1111
1212from vllm_spyre .v1 .core .scheduler import ContinuousBatchingSpyreScheduler
1313
14+ DISABLE_ASSERTS = False # used for debugging
15+
1416
1517def augment_checked_steps (
1618 checked_steps : list [dict [str , Any ]]) -> deque [dict [str , Any ]]:
@@ -105,11 +107,13 @@ def check_scheduler_inference_steps(
105107 generated_prompts .append (request .prompt_token_ids )
106108
107109 # Setup the engine
108- engine_args = EngineArgs (model = model ,
109- tokenizer = model ,
110- max_model_len = max_model_len ,
111- max_num_seqs = max_num_seqs ,
112- num_gpu_blocks_override = available_blocks )
110+ engine_args = EngineArgs (
111+ model = model ,
112+ tokenizer = model ,
113+ max_model_len = max_model_len ,
114+ max_num_seqs = max_num_seqs ,
115+ num_gpu_blocks_override = available_blocks ,
116+ )
113117 vllm_config = engine_args .create_engine_config ()
114118 executor_class = Executor .get_class (vllm_config )
115119 engine_core = EngineCore (vllm_config = vllm_config ,
@@ -139,17 +143,18 @@ def check_scheduler_inference_steps(
139143 r .request_id for r in request_outputs if r .finished
140144 ]
141145
142- assert (scheduler .tkv == step_ref ["tkv" ]
143- ), f"Step { step } , tkv: { scheduler .tkv } "
144- assert waiting == step_ref [
145- "waiting" ], f"Step { step } , waiting: { waiting } "
146- assert running == step_ref [
147- "running" ], f"Step { step } , running: { running } "
148- assert (out_reqs_ids == step_ref ["request_outputs" ]
149- ), f"Step { step } , request outputs: { out_reqs_ids } "
146+ assert DISABLE_ASSERTS or (scheduler .tkv == step_ref ["tkv" ]
147+ ), f"Step { step } , tkv: { scheduler .tkv } "
148+ assert (DISABLE_ASSERTS or waiting
149+ == step_ref ["waiting" ]), f"Step { step } , waiting: { waiting } "
150+ assert (DISABLE_ASSERTS or running
151+ == step_ref ["running" ]), f"Step { step } , running: { running } "
152+ assert DISABLE_ASSERTS or (
153+ out_reqs_ids == step_ref ["request_outputs" ]
154+ ), f"Step { step } , request outputs: { out_reqs_ids } "
150155
151156 ref_finished_reqs = step_ref .get ("finished_requests" , [])
152- assert (
157+ assert DISABLE_ASSERTS or (
153158 out_reqs_finished == ref_finished_reqs
154159 ), f"Step { step } , finished request output: { out_reqs_finished } "
155160
@@ -166,27 +171,31 @@ def check_scheduler_inference_steps(
166171 [len (blocks ) for blocks in req_ids2blocks .values ()])
167172
168173 if step > 0 :
169- assert (
174+ assert DISABLE_ASSERTS or (
170175 n_reserved_blocks == step_ref ["n_reserved_blocks" ]
171176 ), f"Step { step } , n_reserved_blocks: { n_reserved_blocks } "
172- assert (n_used_blocks == step_ref ["n_used_blocks" ]
173- ), f"Step { step } , n_used_blocks: { n_used_blocks } "
177+ assert DISABLE_ASSERTS or (
178+ n_used_blocks == step_ref ["n_used_blocks" ]
179+ ), f"Step { step } , n_used_blocks: { n_used_blocks } "
174180
175- assert len (req_ids2blocks ) == len (req_ids2reserved_blocks )
181+ assert DISABLE_ASSERTS or len (req_ids2blocks ) == len (
182+ req_ids2reserved_blocks )
176183 for req_id in req_ids2blocks :
177184 # current number of used blocks should be less than reserved
178- assert len (
179- req_ids2blocks [ req_id ]) <= req_ids2reserved_blocks [req_id ]
185+ assert ( DISABLE_ASSERTS or len (req_ids2blocks [ req_id ])
186+ <= req_ids2reserved_blocks [req_id ])
180187 # update requested/reserved blocks to check in last step
181- # Note: overwrite and not max because of reduce_left_padding()
188+ # Note: overwrite and not max
189+ # because of reduce_left_padding()
182190 requested_blocks [req_id ] = len (req_ids2blocks [req_id ])
183191 reserved_blocks [req_id ] = req_ids2reserved_blocks [req_id ]
184192
185193 # last step: check that sequences used all their reserved blocks
186194 # Note: no early stopping, all sequences produce max_num_tokens
187195 if len (checked_steps ) == 0 :
188196 for req_id in requested_blocks :
189- assert requested_blocks [req_id ] == reserved_blocks [req_id ]
197+ assert (DISABLE_ASSERTS
198+ or requested_blocks [req_id ] == reserved_blocks [req_id ])
190199
191200 # Perform next step
192201 step_output = engine_core .step ()
@@ -197,15 +206,17 @@ def check_scheduler_inference_steps(
197206 for output in request_outputs :
198207 new_token_ids = output .new_token_ids
199208 new_logprobs = output .new_logprobs .logprobs
200- assert len (new_token_ids ) == 1 and len (new_logprobs ) == 1
209+ assert DISABLE_ASSERTS or len (new_token_ids ) == 1 and len (
210+ new_logprobs ) == 1
201211
202212 collected_outputs [output .request_id ]["token_ids" ].append (
203213 new_token_ids [0 ])
204214 collected_outputs [output .request_id ]["logprobs" ].append (
205215 new_logprobs [0 ][0 ])
206216
207217 output_keys = sorted (int (k ) for k in collected_outputs )
208- assert output_keys [0 ] == 0 and output_keys [- 1 ] == len (output_keys ) - 1
218+ assert (DISABLE_ASSERTS
219+ or output_keys [0 ] == 0 and output_keys [- 1 ] == len (output_keys ) - 1 )
209220
210221 # convert dict of dicts to ordered list and make values immutable
211222 collected_outputs_new = []
@@ -216,4 +227,6 @@ def check_scheduler_inference_steps(
216227 output [k ] = tuple (list_values )
217228 collected_outputs_new .append (output )
218229
230+ # good practice?
231+ engine_core .shutdown ()
219232 return collected_outputs_new , generated_prompts
0 commit comments