@@ -175,6 +175,24 @@ def patch_warmup_shapes(warmup_shapes: Union[list[tuple[int, int, int]],
175175 ',' .join (str (val ) for val in warmup_new_tokens ))
176176
177177
178+ def extract_output (req_output ):
179+ """Extract text, token_ids, tokens, and logprobs from request output."""
180+
181+ result = {}
182+ result ['text' ] = req_output .outputs [0 ].text
183+
184+ # TODO: Workaround for V1, if request does not fit in a warmup shape
185+ # token_ids may be filled with -1.
186+ token_ids = [t for t in req_output .outputs [0 ].token_ids if t >= 0 ]
187+ result ['token_ids' ] = tuple (token_ids )
188+ result ['tokens' ] = tuple (req_output .outputs [0 ].logprobs [i ][t ].decoded_token
189+ for i , t in enumerate (token_ids ))
190+ result ['logprobs' ] = tuple (req_output .outputs [0 ].logprobs [i ][t ].logprob
191+ for i , t in enumerate (token_ids ))
192+
193+ return result
194+
195+
178196# vLLM / Spyre
179197def generate_spyre_vllm_output (
180198 model : str ,
@@ -226,20 +244,7 @@ def generate_spyre_vllm_output(
226244 results = []
227245
228246 for req_output in vllm_outputs :
229- result = {}
230- result ['text' ] = req_output .outputs [0 ].text
231- # TODO: Workaround for V1, if request does not fit in a warmup shape
232- # token_ids may be filled with -1.
233- token_ids = [t for t in req_output .outputs [0 ].token_ids if t >= 0 ]
234- result ['token_ids' ] = tuple (token_ids )
235- result ['tokens' ] = tuple ([
236- req_output .outputs [0 ].logprobs [i ][t ].decoded_token
237- for i , t in enumerate (result ['token_ids' ])
238- ])
239- result ['logprobs' ] = tuple ([
240- req_output .outputs [0 ].logprobs [i ][t ].logprob
241- for i , t in enumerate (result ['token_ids' ])
242- ])
247+ result = extract_output (req_output )
243248 results .append (result )
244249
245250 force_engine_shutdown (vllm_model )
@@ -554,26 +559,50 @@ def _default_test_models(isEmbeddings=False):
554559 return params
555560
556561
557- def create_text_prompt (model : str , min_tokens : int , max_tokens : int ) -> str :
562+ def create_text_prompt (model : str , min_token_length : int ,
563+ max_token_length : int ) -> str :
558564 """Create a text prompt for the specified model that will tokenize to within
559565 the specified token length range."""
560566 tokenizer = AutoTokenizer .from_pretrained (model )
561567 pepper = "🌶️"
562568 pepper_tokens = len (tokenizer .encode (pepper , add_special_tokens = False ))
563569
564570 # Find a good starting number of peppers
565- prompt = pepper * (min_tokens // pepper_tokens + 1 )
571+ prompt = pepper * (min_token_length // pepper_tokens + 1 )
566572
567573 # And add more until we're over the minimum token length
568- while len (tokenizer .encode (prompt )) <= min_tokens :
574+ while len (tokenizer .encode (prompt )) <= min_token_length :
569575 prompt += pepper
570576
571577 # Make sure this prompt is within the specified range
572- assert min_tokens < len (tokenizer .encode (prompt )) < max_tokens
578+ assert min_token_length < len (tokenizer .encode (prompt )) < max_token_length
573579
574580 return prompt
575581
576582
583+ def create_seq_prompt (model : str , token_length : int ) -> str :
584+ """Create a repeating sequential number prompt for the specified
585+ model that will tokenize to exactly the specified token length."""
586+
587+ tokenizer = AutoTokenizer .from_pretrained (model )
588+
589+ # 20-token pattern
590+ pattern = "0 1 2 3 4 5 6 7 8 9 "
591+
592+ # Repeat to token_length
593+ repeat_count = (token_length // 20 ) + 1
594+ text_prompt = pattern * repeat_count
595+
596+ # Tokenize and slice
597+ tokens = tokenizer .encode (text_prompt )[:token_length ]
598+
599+ # Assert exact token length
600+ assert len (tokens ) == token_length , \
601+ f"Token length mismatch: { len (tokens )} != { token_length } "
602+
603+ return tokenizer .decode (tokens )
604+
605+
577606def create_random_request (
578607 request_id : int ,
579608 num_tokens : int ,
0 commit comments