|
| 1 | +# CosyVoice2 LLM Reinforcement Learning Recipe |
| 2 | + |
| 3 | +This recipe demonstrates how to fine-tune the **CosyVoice2** large language model with reinforcement learning algorithms—specifically **GRPO**—using the [veRL](https://github.com/volcengine/verl) framework. Our experiments show that applying GRPO reduces the character error rate (CER) on the CosyVoice3 `zero_shot_zh` set from 4.08% to 3.36%. |
| 4 | + |
| 5 | +## Table of Contents |
| 6 | + |
| 7 | +- [Environment Setup](#environment-setup) |
| 8 | +- [Data Preparation](#data-preparation) |
| 9 | +- [Reward Function & ASR Server](#reward-function--asr-server) |
| 10 | +- [Training](#training) |
| 11 | +- [Evaluation](#evaluation) |
| 12 | +- [Export Model](#export-model) |
| 13 | +- [Results](#results) |
| 14 | +- [Acknowledgement](#acknowledgement) |
| 15 | + |
| 16 | +## Environment Setup |
| 17 | +We recommend using the pre-built Docker image below. Alternatively, you can manually install the dependencies following the Dockerfile. |
| 18 | +```bash |
| 19 | +docker pull soar97/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2 |
| 20 | +``` |
| 21 | +If Docker is not available, you can refer to `run.sh` `stage -2` to install the dependencies locally. |
| 22 | + |
| 23 | +## Data Preparation |
| 24 | + |
| 25 | +`prepare_data.py` expects a JSON/JSONL file with at least the following schema: |
| 26 | + |
| 27 | +```jsonc |
| 28 | +{ |
| 29 | + "text": "An example sentence to be synthesized." |
| 30 | +} |
| 31 | +``` |
| 32 | +You can download the JSONL files from the metadata directory of the [SparkAudio/voxbox](https://huggingface.co/datasets/SparkAudio/voxbox/tree/main/metadata) dataset on Hugging Face. |
| 33 | + |
| 34 | +Stage `0` converts raw JSONL files into the parquet format expected by veRL: |
| 35 | + |
| 36 | +```bash |
| 37 | +bash run.sh 0 0 |
| 38 | +``` |
| 39 | +Create two JSONL files—`train.jsonl` and `test.jsonl`. |
| 40 | +The script will then generate two Parquet files: |
| 41 | + |
| 42 | +``` |
| 43 | +data/parquet_tiny/train.parquet |
| 44 | +data/parquet_tiny/test.parquet |
| 45 | +``` |
| 46 | + |
| 47 | +Each sample is automatically wrapped into a CosyVoice2-style prompt so that the LLM learns to output CosyVoice2 speech tokens. |
| 48 | + |
| 49 | + |
| 50 | +## Reward Function & ASR Server |
| 51 | + |
| 52 | +To compute rewards, we run a lightweight server that: |
| 53 | + |
| 54 | +1. Converts generated speech tokens back to a 16 kHz waveform with the **CosyVoice2** pretrained U-Net model. |
| 55 | +2. Transcribes the waveform with **SenseVoice** ASR. |
| 56 | +3. Calculates the pinyin-level error rate relative to the ground-truth text and maps it to a score between 0 and 1. |
| 57 | + |
| 58 | +Start the server (stage `1`) in a dedicated terminal or on a separate GPU: |
| 59 | + |
| 60 | +```bash |
| 61 | +bash run.sh 1 1 |
| 62 | +# Triton server listens on ports 8000/8001/8002 |
| 63 | +``` |
| 64 | + |
| 65 | +The custom reward implementation is located in [`reward_tts.py`](./reward_tts.py) and calls the server to obtain the reward score. |
| 66 | + |
| 67 | +## Training |
| 68 | + |
| 69 | +Run stage `2` to start GRPO training: |
| 70 | + |
| 71 | +```bash |
| 72 | +bash run.sh 2 2 |
| 73 | +``` |
| 74 | + |
| 75 | +Key CLI arguments passed to `verl.trainer.main_ppo`: |
| 76 | + |
| 77 | +* `algorithm.adv_estimator=grpo` – use GRPO instead of PPO. |
| 78 | +* `data.train_files=data/parquet_aishell3/train.parquet` and `data.val_files=data/parquet_aishell3/test.parquet` |
| 79 | +* `custom_reward_function.path=reward_tts.py` – custom reward function described above. |
| 80 | + |
| 81 | +Adjust `CUDA_VISIBLE_DEVICES`, batch sizes, and other hyperparameters to match your hardware. |
| 82 | +> [!TIP] |
| 83 | +> Note: the lm_head bias is disabled during training to make the model compatible with VLLM and Transformers' Qwen model. |
| 84 | +
|
| 85 | +## Evaluation |
| 86 | + |
| 87 | +After training is complete, collect the sharded FSDP weights and export a Hugging Face-style checkpoint (stage `3`): |
| 88 | + |
| 89 | +```bash |
| 90 | +bash run.sh 3 3 # merges weights into $llm_path/merged_hf_model |
| 91 | +``` |
| 92 | + |
| 93 | +You can then evaluate the model on the CosyVoice3 zero-shot Chinese test set (stage `4`): |
| 94 | + |
| 95 | +```bash |
| 96 | +bash run.sh 4 4 |
| 97 | +``` |
| 98 | + |
| 99 | +This command launches distributed inference via `infer_dataset.py` and computes WER with `scripts/compute_wer.sh`. |
| 100 | + |
| 101 | +> [!TIP] |
| 102 | +> The script also supports the Seed-TTS test set by setting `dataset=test_zh`. |
| 103 | +
|
| 104 | +## Export Model |
| 105 | + |
| 106 | +To use the RL-trained model with the official CosyVoice repository: |
| 107 | + |
| 108 | +```bash |
| 109 | +bash run.sh 5 5 |
| 110 | +``` |
| 111 | + |
| 112 | +The script converts the Hugging Face checkpoint back into the format expected by the CosyVoice repository. |
| 113 | +> [!TIP] |
| 114 | +> However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format. |
| 115 | +
|
| 116 | +## Results |
| 117 | + |
| 118 | +| Model | Seed-TTS `test_zh` CER | CosyVoice3 `zero_shot_zh` CER | Comment | |
| 119 | +|-------|------------------------|------------------------------|---------| |
| 120 | +| CosyVoice2 LLM (official) | 1.45% | 4.08% | See the [paper](https://arxiv.org/abs/2412.10117) | |
| 121 | +| CosyVoice2 LLM + GRPO | 1.37% | **3.36%** | See the [decoding results](yuekai/official-cosyvoice-llm-grpo-aishell3), Hugging Face-format model | |
| 122 | + |
| 123 | +## Acknowledgement |
| 124 | + |
| 125 | +This work was inspired by the implementation in [ch-tts-llasa-rl-grpo](https://github.com/channel-io/ch-tts-llasa-rl-grpo). |
0 commit comments