Skip to content

Commit 62d0826

Browse files
committed
fix lint
1 parent d1c354e commit 62d0826

File tree

7 files changed

+71
-68
lines changed

7 files changed

+71
-68
lines changed

examples/grpo/cosyvoice2/huggingface_to_pretrained.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from safetensors import safe_open
2222
from transformers import AutoTokenizer
2323

24+
2425
def get_args():
2526
parser = ArgumentParser()
2627

@@ -39,6 +40,7 @@ def get_args():
3940
args = parser.parse_args()
4041
return args
4142

43+
4244
if __name__ == "__main__":
4345
args = get_args()
4446

@@ -67,4 +69,3 @@ def get_args():
6769
hf_tensors["llm.model.lm_head.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"]
6870

6971
torch.save(hf_tensors, args.output_path)
70-

examples/grpo/cosyvoice2/infer_dataset.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def extract_speech_ids(speech_tokens_str):
105105
print(f"Unexpected token: {token_str}")
106106
return speech_ids
107107

108+
108109
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
109110
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
110111
speech_id_str = ""
@@ -182,14 +183,13 @@ def get_args():
182183
return args
183184

184185

185-
186186
def data_collator(batch, tokenizer, s3_tokenizer):
187187
"""Simplified data collator for batch_size=1 processing"""
188188
target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
189189
device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
190190
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
191191
mels, prompt_audio_cosy2tokens_list = [], []
192-
for i, item in enumerate(batch):
192+
for item in batch:
193193
prompt_text, target_text = (
194194
item["prompt_text"],
195195
item["target_text"],
@@ -227,7 +227,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
227227
codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
228228
for i in range(len(codes)):
229229
prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
230-
for i, prompt_audio_cosy2tokens in enumerate(prompt_audio_cosy2tokens_list):
230+
for prompt_audio_cosy2tokens in prompt_audio_cosy2tokens_list:
231231
prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
232232
# Create chat template for LLM generation
233233
chat = [
@@ -244,7 +244,6 @@ def data_collator(batch, tokenizer, s3_tokenizer):
244244
)
245245
input_ids_list.append(input_ids.squeeze(0))
246246

247-
248247
# For batch_size=1, no need to pad
249248
if len(input_ids_list) == 1:
250249
input_ids = input_ids_list[0].unsqueeze(0)
@@ -256,7 +255,7 @@ def data_collator(batch, tokenizer, s3_tokenizer):
256255
for input_ids in input_ids_list
257256
]
258257
input_ids = torch.stack(input_ids_list)
259-
258+
260259
ids = [item["id"] for item in batch]
261260

262261
return {
@@ -287,7 +286,7 @@ def main():
287286
assert torch.cuda.is_available()
288287
world_size, local_rank, rank = init_distributed()
289288
device = torch.device(f"cuda:{local_rank}")
290-
289+
291290
# Load LLM model and tokenizer directly
292291
tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
293292
model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
@@ -329,7 +328,7 @@ def main():
329328
for batch in dataloader:
330329
with torch.no_grad():
331330
input_ids = batch["input_ids"].to(device)
332-
331+
333332
# Generate speech tokens using LLM
334333
outputs = model.generate(
335334
input_ids,
@@ -339,31 +338,31 @@ def main():
339338
temperature=args.temperature,
340339
top_k=args.top_k,
341340
)
342-
341+
343342
# Process each sample in the batch
344343
for i in range(len(batch["ids"])):
345344
# Extract generated tokens (excluding input)
346345
input_length = input_ids[i].shape[0]
347346
generated_ids = outputs[i][input_length:-1] # Remove last token if needed
348347
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
349-
348+
350349
# Extract speech IDs from token strings like <|s_23456|>
351350
speech_ids = extract_speech_ids(speech_tokens_str)
352-
351+
353352
if len(speech_ids) == 0:
354353
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
355354
continue
356-
355+
357356
# Convert to tensor for CosyVoice2
358357
audio_tokens = torch.tensor(speech_ids, dtype=torch.long, device=device).unsqueeze(0)
359-
358+
360359
if args.prompt_text is not None:
361360
current_prompt_text = args.prompt_text
362361
current_prompt_audio = prompt_speech_16k
363362
else:
364363
current_prompt_text = batch["prompt_text"][i]
365364
current_prompt_audio = batch["prompt_audio_list"][i]
366-
365+
367366
if current_prompt_audio is not None:
368367
# Generate audio using CosyVoice2
369368
audio_hat = audio_decode_cosyvoice2(
@@ -372,18 +371,17 @@ def main():
372371
current_prompt_audio,
373372
cosyvoice_codec,
374373
)
375-
374+
376375
# Convert to numpy and save
377376
generated_wave = audio_hat.squeeze(0).cpu().numpy()
378377
target_sample_rate = 24000
379-
378+
380379
utt = batch["ids"][i]
381380
sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
382381

383382
print(f"Generated audio for sample {utt} with {len(speech_ids)} tokens")
384383
else:
385384
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
386-
387385

388386
if rank == 0:
389387
progress_bar.update(world_size * len(batch["ids"]))

examples/grpo/cosyvoice2/prepare_data.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323

2424
from verl.utils.hdfs_io import copy, makedirs
2525

26-
27-
2826
if __name__ == "__main__":
2927
parser = argparse.ArgumentParser()
3028
parser.add_argument("--train_file", required=True, help="Path to training JSON/JSONL file")

examples/grpo/cosyvoice2/pretrained_to_huggingface.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
3232

3333

34-
3534
def get_args():
3635
parser = ArgumentParser()
3736

@@ -96,17 +95,20 @@ def get_args():
9695
# set the weight and bias of the new lm_head to 0
9796
new_lm_head.weight.data.zero_()
9897
new_lm_head.bias.data.zero_()
99-
new_lm_head.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size+cosyvoice2_token_size+3] = llm_decoder.weight
100-
new_lm_head.bias[original_tokenizer_vocab_size:original_tokenizer_vocab_size+cosyvoice2_token_size+3] = llm_decoder.bias
98+
new_lm_head.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.weight
99+
new_lm_head.bias[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.bias
101100

102101
llm.lm_head = new_lm_head
103102
input_embeddings = llm.get_input_embeddings()
104103

105104
with torch.no_grad():
106-
input_embeddings.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size+cosyvoice2_token_size+3] = speech_embedding.weight
107-
input_embeddings.weight[original_tokenizer_vocab_size+cosyvoice2_token_size+3:original_tokenizer_vocab_size+cosyvoice2_token_size+3+2] = llm_embedding.weight
105+
input_embeddings.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = speech_embedding.weight
106+
input_embeddings.weight[original_tokenizer_vocab_size + cosyvoice2_token_size + 3:original_tokenizer_vocab_size + cosyvoice2_token_size + 3 + 2] = llm_embedding.weight
108107

109-
eos_token_ids = [original_tokenizer_vocab_size + cosyvoice2_token_size, original_tokenizer_vocab_size + cosyvoice2_token_size + 1, original_tokenizer_vocab_size + cosyvoice2_token_size + 2]
108+
eos_token_ids = [original_tokenizer_vocab_size + cosyvoice2_token_size,
109+
original_tokenizer_vocab_size + cosyvoice2_token_size + 1,
110+
original_tokenizer_vocab_size + cosyvoice2_token_size + 2,
111+
original_tokenizer_vocab_size + cosyvoice2_token_size + 3]
110112
llm.generation_config.eos_token_id = eos_token_ids
111113
llm.generation_config.temperature = 1.0
112114
llm.generation_config.top_p = 0.8
@@ -121,4 +123,4 @@ def get_args():
121123

122124
TEMPLATE = "{%- for message in messages %}{%- if message['role'] == 'user' %}{{- '<|sos|>' + message['content'] + '<|task_id|>' }}{%- elif message['role'] == 'assistant' %}{{- message['content']}}{%- endif %}{%- endfor %}"
123125
tokenizer.chat_template = TEMPLATE
124-
tokenizer.save_pretrained(args.save_path)
126+
tokenizer.save_pretrained(args.save_path)

examples/grpo/cosyvoice2/reward_tts.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818

1919
from __future__ import annotations
2020

21-
import os, re, warnings, json, time, argparse
21+
import re
22+
import json
23+
import time
24+
import argparse
2225
from typing import List
2326

2427
import numpy as np
@@ -31,6 +34,7 @@
3134
def _parse_ids(token_str: str) -> List[int]:
3235
return [int(t) for t in re.findall(r"<\|s_(\d+)\|>", token_str)]
3336

37+
3438
def _remote_reward(tokens: List[int], ground_truth: str, timeout: float = 200.0) -> float:
3539
"""Send token IDs and ground-truth text to the Triton server and get reward."""
3640

@@ -100,7 +104,6 @@ def compute_score(
100104
try:
101105
reward = _remote_reward(ids, ground_truth)
102106
except Exception as e:
103-
warnings.warn(f"Remote reward server error: {e}; returning 0.0")
104107
reward = 0.0
105108

106109
if debug_dump:
@@ -110,92 +113,92 @@ def compute_score(
110113

111114
return reward
112115

116+
113117
# CLI quick test
114118
if __name__ == "__main__":
115119
import sys
116-
120+
117121
def get_args():
118122
"""Parse command line arguments."""
119123
parser = argparse.ArgumentParser(
120124
description="Test TTS CER scoring with data from JSONL file",
121125
formatter_class=argparse.ArgumentDefaultsHelpFormatter
122126
)
123-
127+
124128
parser.add_argument(
125129
"--input", "-i",
126130
type=str,
127131
default="data/emilia_zh-cosy-tiny-test.jsonl",
128132
help="Path to input JSONL file"
129133
)
130-
134+
131135
parser.add_argument(
132136
"--max-samples", "-n",
133137
type=int,
134138
default=None,
135139
help="Maximum number of samples to process (default: all)"
136140
)
137-
141+
138142
parser.add_argument(
139143
"--no-interactive",
140144
action="store_true",
141145
help="Run in non-interactive mode (process all samples without prompts)"
142146
)
143-
144-
147+
145148
parser.add_argument(
146149
"--debug",
147150
action="store_true",
148151
help="Enable debug mode"
149152
)
150-
153+
151154
return parser.parse_args()
152-
155+
153156
def load_jsonl(file_path: str):
154157
"""Load data from jsonl file."""
155158
data = []
156159
with open(file_path, 'r', encoding='utf-8') as f:
157160
for line in f:
158161
data.append(json.loads(line.strip()))
159162
return data
160-
163+
161164
def code_to_solution_str(code_list: List[int]) -> str:
162165
"""Convert code list to solution string format."""
163166
return ''.join([f"<|s_{code}|>" for code in code_list])
164-
167+
165168
# Parse command line arguments
166169
args = get_args()
167-
170+
168171
try:
169172
# Load data from jsonl file
170173
print(f"Loading data from: {args.input}")
171174
data_list = load_jsonl(args.input)
172175
print(f"Loaded {len(data_list)} samples")
173-
176+
174177
# Limit samples if specified
175178
if args.max_samples is not None:
176179
data_list = data_list[:args.max_samples]
177180
print(f"Processing first {len(data_list)} samples (limited by --max-samples)")
178-
181+
179182
# Process each sample
180183
begin_time = time.time()
181184
for i, sample in enumerate(data_list):
182185
print(f"\n--- Sample {i+1}/{len(data_list)} ---")
183186
print(f"Index: {sample.get('index', 'unknown')}")
184187
print(f"Text: {sample['text']}")
185-
188+
186189
# Extract required fields
187190
code_list = sample['code']
188191
ground_truth = sample['text']
189192
data_source = sample.get('index', f'sample_{i}') # Use index as data_source
190-
193+
191194
# Convert code list to solution string
192195
solution_str = code_to_solution_str(code_list)
193196
print(f"Solution tokens: {len(code_list)} tokens")
194197
if args.debug:
195198
print(f"Solution string: {solution_str}")
196199
else:
197200
print(f"Solution string preview: {solution_str[:100]}..." if len(solution_str) > 100 else f"Solution string: {solution_str}")
198-
201+
199202
# Call compute_score function
200203
try:
201204
score = compute_score(
@@ -208,7 +211,7 @@ def code_to_solution_str(code_list: List[int]) -> str:
208211
print(f"Final Score: {score:.4f}")
209212
except Exception as e:
210213
print(f"Error computing score: {e}")
211-
214+
212215
# Ask user if they want to continue (for interactive mode)
213216
if not args.no_interactive and i < len(data_list) - 1:
214217
try:
@@ -218,7 +221,7 @@ def code_to_solution_str(code_list: List[int]) -> str:
218221
except KeyboardInterrupt:
219222
print("\nStopped by user")
220223
break
221-
224+
222225
print(f"\nProcessed {min(i+1, len(data_list))} samples")
223226
end_time = time.time()
224227
print(f"Time taken: {end_time - begin_time} seconds")

0 commit comments

Comments
 (0)