Skip to content

Commit 0b357ba

Browse files
authored
Merge pull request #1583 from FunAudioLLM/dev/lyuxiang.lx
Dev/lyuxiang.lx
2 parents 0867ebc + cf61501 commit 0b357ba

File tree

13 files changed

+2379
-1
lines changed

13 files changed

+2379
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
- [x] 2025/08
3333

34-
- [x] Thanks to the contribution from NVIDIA Yuekai Zhang, add triton trtllm runtime support
34+
- [x] Thanks to the contribution from NVIDIA Yuekai Zhang, add triton trtllm runtime support and cosyvoice2 grpo training support
3535

3636
- [x] 2025/07
3737

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
FROM verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2
2+
COPY requirements.txt /myworkspace/requirements.txt
3+
RUN pip install -r /myworkspace/requirements.txt
4+
RUN pip install -U nvidia-pytriton
5+
RUN git clone https://github.com/yuekaizhang/verl.git /myworkspace/verl -b thread && cd /myworkspace/verl && pip install --no-deps -e .
6+
RUN git clone https://github.com/yuekaizhang/PytritonSenseVoice.git /myworkspace/PytritonSenseVoice && cd /myworkspace/PytritonSenseVoice && pip install -e .

examples/grpo/cosyvoice2/README.md

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# CosyVoice2 LLM Reinforcement Learning Recipe
2+
3+
This recipe demonstrates how to fine-tune the **CosyVoice2** large language model with reinforcement learning algorithms—specifically **GRPO**—using the [veRL](https://github.com/volcengine/verl) framework. Our experiments show that applying GRPO reduces the character error rate (CER) on the CosyVoice3 `zero_shot_zh` set from 4.08% to 3.36%.
4+
5+
## Table of Contents
6+
7+
- [Environment Setup](#environment-setup)
8+
- [Data Preparation](#data-preparation)
9+
- [Reward Function & ASR Server](#reward-function--asr-server)
10+
- [Training](#training)
11+
- [Evaluation](#evaluation)
12+
- [Export Model](#export-model)
13+
- [Results](#results)
14+
- [Acknowledgement](#acknowledgement)
15+
16+
## Environment Setup
17+
We recommend using the pre-built Docker image below. Alternatively, you can manually install the dependencies following the Dockerfile.
18+
```bash
19+
docker pull soar97/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2
20+
```
21+
If Docker is not available, you can refer to `run.sh` `stage -2` to install the dependencies locally.
22+
23+
## Data Preparation
24+
25+
`prepare_data.py` expects a JSON/JSONL file with at least the following schema:
26+
27+
```jsonc
28+
{
29+
"text": "An example sentence to be synthesized."
30+
}
31+
```
32+
You can download the JSONL files from the metadata directory of the [SparkAudio/voxbox](https://huggingface.co/datasets/SparkAudio/voxbox/tree/main/metadata) dataset on Hugging Face.
33+
34+
Stage `0` converts raw JSONL files into the parquet format expected by veRL:
35+
36+
```bash
37+
bash run.sh 0 0
38+
```
39+
Create two JSONL files—`train.jsonl` and `test.jsonl`.
40+
The script will then generate two Parquet files:
41+
42+
```
43+
data/parquet_tiny/train.parquet
44+
data/parquet_tiny/test.parquet
45+
```
46+
47+
Each sample is automatically wrapped into a CosyVoice2-style prompt so that the LLM learns to output CosyVoice2 speech tokens.
48+
49+
50+
## Reward Function & ASR Server
51+
52+
To compute rewards, we run a lightweight server that:
53+
54+
1. Converts generated speech tokens back to a 16 kHz waveform with the **CosyVoice2** pretrained U-Net model.
55+
2. Transcribes the waveform with **SenseVoice** ASR.
56+
3. Calculates the pinyin-level error rate relative to the ground-truth text and maps it to a score between 0 and 1.
57+
58+
Start the server (stage `1`) in a dedicated terminal or on a separate GPU:
59+
60+
```bash
61+
bash run.sh 1 1
62+
# Triton server listens on ports 8000/8001/8002
63+
```
64+
65+
The custom reward implementation is located in [`reward_tts.py`](./reward_tts.py) and calls the server to obtain the reward score.
66+
67+
## Training
68+
69+
Run stage `2` to start GRPO training:
70+
71+
```bash
72+
bash run.sh 2 2
73+
```
74+
75+
Key CLI arguments passed to `verl.trainer.main_ppo`:
76+
77+
* `algorithm.adv_estimator=grpo` – use GRPO instead of PPO.
78+
* `data.train_files=data/parquet_aishell3/train.parquet` and `data.val_files=data/parquet_aishell3/test.parquet`
79+
* `custom_reward_function.path=reward_tts.py` – custom reward function described above.
80+
81+
Adjust `CUDA_VISIBLE_DEVICES`, batch sizes, and other hyperparameters to match your hardware.
82+
> [!TIP]
83+
> Note: the lm_head bias is disabled during training to make the model compatible with VLLM and Transformers' Qwen model.
84+
85+
## Evaluation
86+
87+
After training is complete, collect the sharded FSDP weights and export a Hugging Face-style checkpoint (stage `3`):
88+
89+
```bash
90+
bash run.sh 3 3 # merges weights into $llm_path/merged_hf_model
91+
```
92+
93+
You can then evaluate the model on the CosyVoice3 zero-shot Chinese test set (stage `4`):
94+
95+
```bash
96+
bash run.sh 4 4
97+
```
98+
99+
This command launches distributed inference via `infer_dataset.py` and computes WER with `scripts/compute_wer.sh`.
100+
101+
> [!TIP]
102+
> The script also supports the Seed-TTS test set by setting `dataset=test_zh`.
103+
104+
## Export Model
105+
106+
To use the RL-trained model with the official CosyVoice repository:
107+
108+
```bash
109+
bash run.sh 5 5
110+
```
111+
112+
The script converts the Hugging Face checkpoint back into the format expected by the CosyVoice repository.
113+
> [!TIP]
114+
> However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format.
115+
116+
## Results
117+
118+
| Model | Seed-TTS `test_zh` CER | CosyVoice3 `zero_shot_zh` CER | Comment |
119+
|-------|------------------------|------------------------------|---------|
120+
| CosyVoice2 LLM (official) | 1.45% | 4.08% | See the [paper](https://arxiv.org/abs/2412.10117) |
121+
| CosyVoice2 LLM + GRPO | 1.37% | **3.36%** | See the [decoding results](yuekai/official-cosyvoice-llm-grpo-aishell3), Hugging Face-format model |
122+
123+
## Acknowledgement
124+
125+
This work was inspired by the implementation in [ch-tts-llasa-rl-grpo](https://github.com/channel-io/ch-tts-llasa-rl-grpo).
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
2+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
"""
17+
python3 hf2pretrained.py --hf-cosyvoice2-llm-path /workspace/rl-exp/checkpoint-400 --output-path /workspace/CosyVoice2-0.5B/llm-new.pt
18+
"""
19+
from argparse import ArgumentParser
20+
import torch
21+
from safetensors import safe_open
22+
from transformers import AutoTokenizer
23+
24+
25+
def get_args():
26+
parser = ArgumentParser()
27+
28+
parser.add_argument(
29+
"--hf-cosyvoice2-llm-path",
30+
type=str,
31+
default=None,
32+
help="The RL trained CosyVoice2 model path in HuggingFace format",
33+
)
34+
parser.add_argument(
35+
"--output-path",
36+
type=str,
37+
default="./llm.pt",
38+
help="The path to save the llm.pt",
39+
)
40+
args = parser.parse_args()
41+
return args
42+
43+
44+
if __name__ == "__main__":
45+
args = get_args()
46+
47+
tokenizer = AutoTokenizer.from_pretrained(args.hf_cosyvoice2_llm_path)
48+
speech_start_idx = tokenizer.convert_tokens_to_ids("<|s_0|>")
49+
cosyvoice2_token_size = 6561 + 3
50+
llm_embedding_vocab_size = 2
51+
52+
hf_tensors = {}
53+
with safe_open(f"{args.hf_cosyvoice2_llm_path}/model.safetensors", framework="pt", device="cpu") as f:
54+
for k in f.keys():
55+
if k.startswith("lm_head.bias"):
56+
# RL trained model disable bias for lm_head
57+
continue
58+
new_k = "llm.model." + k
59+
hf_tensors[new_k] = f.get_tensor(k)
60+
if k.startswith("lm_head"):
61+
hf_tensors["llm_decoder.weight"] = f.get_tensor(k)[speech_start_idx:speech_start_idx + cosyvoice2_token_size]
62+
hf_tensors["llm_decoder.bias"] = torch.zeros_like(hf_tensors["llm_decoder.weight"][:, 0])
63+
if k.startswith("model.embed_tokens"):
64+
hf_tensors["speech_embedding.weight"] = f.get_tensor(k)[speech_start_idx:speech_start_idx + cosyvoice2_token_size]
65+
hf_tensors["llm_embedding.weight"] = f.get_tensor(k)[speech_start_idx + cosyvoice2_token_size:speech_start_idx + cosyvoice2_token_size + llm_embedding_vocab_size]
66+
67+
# use tie_word_embeddings=True
68+
hf_tensors["llm.model.model.embed_tokens.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"][:151936]
69+
hf_tensors["llm.model.lm_head.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"]
70+
71+
torch.save(hf_tensors, args.output_path)

0 commit comments

Comments
 (0)