@@ -105,6 +105,7 @@ def extract_speech_ids(speech_tokens_str):
105105 print (f"Unexpected token: { token_str } " )
106106 return speech_ids
107107
108+
108109def convert_cosy2_tokens_to_speech_id_str (cosy2_tokens ):
109110 """Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
110111 speech_id_str = ""
@@ -182,14 +183,13 @@ def get_args():
182183 return args
183184
184185
185-
186186def data_collator (batch , tokenizer , s3_tokenizer ):
187187 """Simplified data collator for batch_size=1 processing"""
188188 target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
189189 device = s3_tokenizer .device if s3_tokenizer is not None else torch .device ("cpu" )
190190 input_ids_list , prompt_audio_list , prompt_text_list = [], [], []
191191 mels , prompt_audio_cosy2tokens_list = [], []
192- for i , item in enumerate ( batch ) :
192+ for item in batch :
193193 prompt_text , target_text = (
194194 item ["prompt_text" ],
195195 item ["target_text" ],
@@ -227,7 +227,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
227227 codes , codes_lens = s3_tokenizer .quantize (mels .to (device ), mels_lens .to (device ))
228228 for i in range (len (codes )):
229229 prompt_audio_cosy2tokens_list .append (codes [i , :codes_lens [i ].item ()])
230- for i , prompt_audio_cosy2tokens in enumerate ( prompt_audio_cosy2tokens_list ) :
230+ for prompt_audio_cosy2tokens in prompt_audio_cosy2tokens_list :
231231 prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str (prompt_audio_cosy2tokens )
232232 # Create chat template for LLM generation
233233 chat = [
@@ -244,7 +244,6 @@ def data_collator(batch, tokenizer, s3_tokenizer):
244244 )
245245 input_ids_list .append (input_ids .squeeze (0 ))
246246
247-
248247 # For batch_size=1, no need to pad
249248 if len (input_ids_list ) == 1 :
250249 input_ids = input_ids_list [0 ].unsqueeze (0 )
@@ -256,7 +255,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
256255 for input_ids in input_ids_list
257256 ]
258257 input_ids = torch .stack (input_ids_list )
259-
258+
260259 ids = [item ["id" ] for item in batch ]
261260
262261 return {
@@ -287,7 +286,7 @@ def main():
287286 assert torch .cuda .is_available ()
288287 world_size , local_rank , rank = init_distributed ()
289288 device = torch .device (f"cuda:{ local_rank } " )
290-
289+
291290 # Load LLM model and tokenizer directly
292291 tokenizer = AutoTokenizer .from_pretrained (args .llm_model_name_or_path )
293292 model = AutoModelForCausalLM .from_pretrained (args .llm_model_name_or_path )
@@ -329,7 +328,7 @@ def main():
329328 for batch in dataloader :
330329 with torch .no_grad ():
331330 input_ids = batch ["input_ids" ].to (device )
332-
331+
333332 # Generate speech tokens using LLM
334333 outputs = model .generate (
335334 input_ids ,
@@ -339,31 +338,31 @@ def main():
339338 temperature = args .temperature ,
340339 top_k = args .top_k ,
341340 )
342-
341+
343342 # Process each sample in the batch
344343 for i in range (len (batch ["ids" ])):
345344 # Extract generated tokens (excluding input)
346345 input_length = input_ids [i ].shape [0 ]
347346 generated_ids = outputs [i ][input_length :- 1 ] # Remove last token if needed
348347 speech_tokens_str = tokenizer .batch_decode (generated_ids , skip_special_tokens = True )
349-
348+
350349 # Extract speech IDs from token strings like <|s_23456|>
351350 speech_ids = extract_speech_ids (speech_tokens_str )
352-
351+
353352 if len (speech_ids ) == 0 :
354353 print (f"Warning: No speech tokens generated for sample { batch ['ids' ][i ]} , skipping" )
355354 continue
356-
355+
357356 # Convert to tensor for CosyVoice2
358357 audio_tokens = torch .tensor (speech_ids , dtype = torch .long , device = device ).unsqueeze (0 )
359-
358+
360359 if args .prompt_text is not None :
361360 current_prompt_text = args .prompt_text
362361 current_prompt_audio = prompt_speech_16k
363362 else :
364363 current_prompt_text = batch ["prompt_text" ][i ]
365364 current_prompt_audio = batch ["prompt_audio_list" ][i ]
366-
365+
367366 if current_prompt_audio is not None :
368367 # Generate audio using CosyVoice2
369368 audio_hat = audio_decode_cosyvoice2 (
@@ -372,18 +371,17 @@ def main():
372371 current_prompt_audio ,
373372 cosyvoice_codec ,
374373 )
375-
374+
376375 # Convert to numpy and save
377376 generated_wave = audio_hat .squeeze (0 ).cpu ().numpy ()
378377 target_sample_rate = 24000
379-
378+
380379 utt = batch ["ids" ][i ]
381380 sf .write (f"{ args .output_dir } /{ utt } .wav" , generated_wave , target_sample_rate )
382381
383382 print (f"Generated audio for sample { utt } with { len (speech_ids )} tokens" )
384383 else :
385384 print (f"Warning: No prompt audio available for sample { batch ['ids' ][i ]} , skipping" )
386-
387385
388386 if rank == 0 :
389387 progress_bar .update (world_size * len (batch ["ids" ]))
0 commit comments