diff --git a/README.md b/README.md index 5e3cfd5a..873e8d6a 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ - [x] 2025/08 - - [x] Thanks to the contribution from NVIDIA Yuekai Zhang, add triton trtllm runtime support + - [x] Thanks to the contribution from NVIDIA Yuekai Zhang, add triton trtllm runtime support and cosyvoice2 grpo training support - [x] 2025/07 diff --git a/examples/grpo/cosyvoice2/Dockerfile b/examples/grpo/cosyvoice2/Dockerfile new file mode 100644 index 00000000..17d80ed5 --- /dev/null +++ b/examples/grpo/cosyvoice2/Dockerfile @@ -0,0 +1,6 @@ +FROM verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2 +COPY requirements.txt /myworkspace/requirements.txt +RUN pip install -r /myworkspace/requirements.txt +RUN pip install -U nvidia-pytriton +RUN git clone https://github.com/yuekaizhang/verl.git /myworkspace/verl -b thread && cd /myworkspace/verl && pip install --no-deps -e . +RUN git clone https://github.com/yuekaizhang/PytritonSenseVoice.git /myworkspace/PytritonSenseVoice && cd /myworkspace/PytritonSenseVoice && pip install -e . \ No newline at end of file diff --git a/examples/grpo/cosyvoice2/README.md b/examples/grpo/cosyvoice2/README.md new file mode 100644 index 00000000..8783aa16 --- /dev/null +++ b/examples/grpo/cosyvoice2/README.md @@ -0,0 +1,125 @@ +# CosyVoice2 LLM Reinforcement Learning Recipe + +This recipe demonstrates how to fine-tune the **CosyVoice2** large language model with reinforcement learning algorithms—specifically **GRPO**—using the [veRL](https://github.com/volcengine/verl) framework. Our experiments show that applying GRPO reduces the character error rate (CER) on the CosyVoice3 `zero_shot_zh` set from 4.08% to 3.36%. + +## Table of Contents + +- [Environment Setup](#environment-setup) +- [Data Preparation](#data-preparation) +- [Reward Function & ASR Server](#reward-function--asr-server) +- [Training](#training) +- [Evaluation](#evaluation) +- [Export Model](#export-model) +- [Results](#results) +- [Acknowledgement](#acknowledgement) + +## Environment Setup +We recommend using the pre-built Docker image below. Alternatively, you can manually install the dependencies following the Dockerfile. +```bash +docker pull soar97/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2 +``` +If Docker is not available, you can refer to `run.sh` `stage -2` to install the dependencies locally. + +## Data Preparation + +`prepare_data.py` expects a JSON/JSONL file with at least the following schema: + +```jsonc +{ + "text": "An example sentence to be synthesized." +} +``` +You can download the JSONL files from the metadata directory of the [SparkAudio/voxbox](https://huggingface.co/datasets/SparkAudio/voxbox/tree/main/metadata) dataset on Hugging Face. + +Stage `0` converts raw JSONL files into the parquet format expected by veRL: + +```bash +bash run.sh 0 0 +``` +Create two JSONL files—`train.jsonl` and `test.jsonl`. +The script will then generate two Parquet files: + +``` +data/parquet_tiny/train.parquet +data/parquet_tiny/test.parquet +``` + +Each sample is automatically wrapped into a CosyVoice2-style prompt so that the LLM learns to output CosyVoice2 speech tokens. + + +## Reward Function & ASR Server + +To compute rewards, we run a lightweight server that: + +1. Converts generated speech tokens back to a 16 kHz waveform with the **CosyVoice2** pretrained U-Net model. +2. Transcribes the waveform with **SenseVoice** ASR. +3. Calculates the pinyin-level error rate relative to the ground-truth text and maps it to a score between 0 and 1. + +Start the server (stage `1`) in a dedicated terminal or on a separate GPU: + +```bash +bash run.sh 1 1 +# Triton server listens on ports 8000/8001/8002 +``` + +The custom reward implementation is located in [`reward_tts.py`](./reward_tts.py) and calls the server to obtain the reward score. + +## Training + +Run stage `2` to start GRPO training: + +```bash +bash run.sh 2 2 +``` + +Key CLI arguments passed to `verl.trainer.main_ppo`: + +* `algorithm.adv_estimator=grpo` – use GRPO instead of PPO. +* `data.train_files=data/parquet_aishell3/train.parquet` and `data.val_files=data/parquet_aishell3/test.parquet` +* `custom_reward_function.path=reward_tts.py` – custom reward function described above. + +Adjust `CUDA_VISIBLE_DEVICES`, batch sizes, and other hyperparameters to match your hardware. +> [!TIP] +> Note: the lm_head bias is disabled during training to make the model compatible with VLLM and Transformers' Qwen model. + +## Evaluation + +After training is complete, collect the sharded FSDP weights and export a Hugging Face-style checkpoint (stage `3`): + +```bash +bash run.sh 3 3 # merges weights into $llm_path/merged_hf_model +``` + +You can then evaluate the model on the CosyVoice3 zero-shot Chinese test set (stage `4`): + +```bash +bash run.sh 4 4 +``` + +This command launches distributed inference via `infer_dataset.py` and computes WER with `scripts/compute_wer.sh`. + +> [!TIP] +> The script also supports the Seed-TTS test set by setting `dataset=test_zh`. + +## Export Model + +To use the RL-trained model with the official CosyVoice repository: + +```bash +bash run.sh 5 5 +``` + +The script converts the Hugging Face checkpoint back into the format expected by the CosyVoice repository. +> [!TIP] +> However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format. + +## Results + +| Model | Seed-TTS `test_zh` CER | CosyVoice3 `zero_shot_zh` CER | Comment | +|-------|------------------------|------------------------------|---------| +| CosyVoice2 LLM (official) | 1.45% | 4.08% | See the [paper](https://arxiv.org/abs/2412.10117) | +| CosyVoice2 LLM + GRPO | 1.37% | **3.36%** | See the [decoding results](yuekai/official-cosyvoice-llm-grpo-aishell3), Hugging Face-format model | + +## Acknowledgement + +This work was inspired by the implementation in [ch-tts-llasa-rl-grpo](https://github.com/channel-io/ch-tts-llasa-rl-grpo). diff --git a/examples/grpo/cosyvoice2/huggingface_to_pretrained.py b/examples/grpo/cosyvoice2/huggingface_to_pretrained.py new file mode 100644 index 00000000..ca49fc3b --- /dev/null +++ b/examples/grpo/cosyvoice2/huggingface_to_pretrained.py @@ -0,0 +1,71 @@ + +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +python3 hf2pretrained.py --hf-cosyvoice2-llm-path /workspace/rl-exp/checkpoint-400 --output-path /workspace/CosyVoice2-0.5B/llm-new.pt +""" +from argparse import ArgumentParser +import torch +from safetensors import safe_open +from transformers import AutoTokenizer + + +def get_args(): + parser = ArgumentParser() + + parser.add_argument( + "--hf-cosyvoice2-llm-path", + type=str, + default=None, + help="The RL trained CosyVoice2 model path in HuggingFace format", + ) + parser.add_argument( + "--output-path", + type=str, + default="./llm.pt", + help="The path to save the llm.pt", + ) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = get_args() + + tokenizer = AutoTokenizer.from_pretrained(args.hf_cosyvoice2_llm_path) + speech_start_idx = tokenizer.convert_tokens_to_ids("<|s_0|>") + cosyvoice2_token_size = 6561 + 3 + llm_embedding_vocab_size = 2 + + hf_tensors = {} + with safe_open(f"{args.hf_cosyvoice2_llm_path}/model.safetensors", framework="pt", device="cpu") as f: + for k in f.keys(): + if k.startswith("lm_head.bias"): + # RL trained model disable bias for lm_head + continue + new_k = "llm.model." + k + hf_tensors[new_k] = f.get_tensor(k) + if k.startswith("lm_head"): + hf_tensors["llm_decoder.weight"] = f.get_tensor(k)[speech_start_idx:speech_start_idx + cosyvoice2_token_size] + hf_tensors["llm_decoder.bias"] = torch.zeros_like(hf_tensors["llm_decoder.weight"][:, 0]) + if k.startswith("model.embed_tokens"): + hf_tensors["speech_embedding.weight"] = f.get_tensor(k)[speech_start_idx:speech_start_idx + cosyvoice2_token_size] + hf_tensors["llm_embedding.weight"] = f.get_tensor(k)[speech_start_idx + cosyvoice2_token_size:speech_start_idx + cosyvoice2_token_size + llm_embedding_vocab_size] + + # use tie_word_embeddings=True + hf_tensors["llm.model.model.embed_tokens.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"][:151936] + hf_tensors["llm.model.lm_head.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"] + + torch.save(hf_tensors, args.output_path) diff --git a/examples/grpo/cosyvoice2/infer_dataset.py b/examples/grpo/cosyvoice2/infer_dataset.py new file mode 100644 index 00000000..4dcbc967 --- /dev/null +++ b/examples/grpo/cosyvoice2/infer_dataset.py @@ -0,0 +1,397 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Example Usage +dataset=zero_shot_zh +output_dir=./outputs_rl_aishell3_step${step}_${dataset}_jit_trt_fp16_reward_tts + +token2wav_path=/workspace/CosyVoice2-0.5B +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ +torchrun --nproc_per_node=8 \ + infer_dataset.py \ + --output-dir $output_dir \ + --llm-model-name-or-path $llm_path/merged_hf_model \ + --token2wav-path $token2wav_path \ + --split-name ${dataset} || exit 1 +""" + +import argparse +import json +import os +import sys +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torchaudio +from cosyvoice.cli.cosyvoice import CosyVoice2 +from cosyvoice.utils.file_utils import load_wav +from datasets import load_dataset +from transformers import AutoTokenizer, AutoModelForCausalLM +from torch.utils.data import DataLoader, Dataset, DistributedSampler +from tqdm import tqdm +import soundfile as sf +import s3tokenizer +from functools import partial + +sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") +try: + torch.multiprocessing.set_start_method("spawn") +except RuntimeError: + pass + + +TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}" + + +def audio_decode_cosyvoice2( + audio_tokens, prompt_text, prompt_speech_16k, codec_decoder +): + """ + Generate audio from tokens with optional tone and prompt embedding. + """ + model_inputs_dict = codec_decoder.frontend.frontend_zero_shot( + "empty", prompt_text, prompt_speech_16k, 24000 + ) + tts_mel, _ = codec_decoder.model.flow.inference( + token=audio_tokens.to(codec_decoder.model.device), + token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to( + codec_decoder.model.device + ), + prompt_token=model_inputs_dict["flow_prompt_speech_token"].to( + codec_decoder.model.device + ), + prompt_token_len=torch.tensor( + [model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32 + ).to(codec_decoder.model.device), + prompt_feat=model_inputs_dict["prompt_speech_feat"].to( + codec_decoder.model.device + ), + prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to( + codec_decoder.model.device + ), + embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device), + finalize=True, + ) + + audio_hat, _ = codec_decoder.model.hift.inference( + speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0) + ) + + return audio_hat + + +def extract_speech_ids(speech_tokens_str): + """Extract speech IDs from token strings like <|s_23456|>""" + speech_ids = [] + for token_str in speech_tokens_str: + if token_str.startswith('<|s_') and token_str.endswith('|>'): + num_str = token_str[4:-2] + num = int(num_str) + speech_ids.append(num) + else: + print(f"Unexpected token: {token_str}") + return speech_ids + + +def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens): + """Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>""" + speech_id_str = "" + for token in cosy2_tokens: + speech_id_str += f"<|s_{token}|>" + return speech_id_str + + +def get_args(): + parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2") + parser.add_argument( + "--split-name", + type=str, + default="wenetspeech4tts", + help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2", + ) + parser.add_argument( + "--output-dir", required=True, type=str, help="dir to save result" + ) + parser.add_argument( + "--batch-size", + default=1, + type=int, + help="batch size (per-device) for inference", + ) + parser.add_argument( + "--num-workers", type=int, default=1, help="workers for dataloader" + ) + parser.add_argument( + "--prefetch", type=int, default=5, help="prefetch for dataloader" + ) + parser.add_argument( + "--llm-model-name-or-path", + required=True, + type=str, + help="LLM model path (includes both model and tokenizer)", + ) + parser.add_argument( + "--token2wav-path", + required=True, + type=str, + help="CosyVoice2 token2wav model path", + ) + parser.add_argument( + "--prompt-text", + type=str, + default=None, + help="The prompt text for CosyVoice2", + ) + parser.add_argument( + "--prompt-speech-path", + type=str, + default=None, + help="The path to the prompt speech for CosyVoice2", + ) + parser.add_argument( + "--top-p", + type=float, + default=0.95, + help="top p for sampling", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.8, + help="temperature for sampling", + ) + parser.add_argument( + "--top-k", + type=int, + default=50, + help="top k for sampling", + ) + args = parser.parse_args() + return args + + +def data_collator(batch, tokenizer, s3_tokenizer): + """Simplified data collator for batch_size=1 processing""" + target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio + device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu") + input_ids_list, prompt_audio_list, prompt_text_list = [], [], [] + mels, prompt_audio_cosy2tokens_list = [], [] + for item in batch: + prompt_text, target_text = ( + item["prompt_text"], + item["target_text"], + ) + prompt_text_list.append(prompt_text) + # Combine prompt and target text + full_text = prompt_text + target_text + + # get prompt audio for CosyVoice2 (convert to 16kHz) + ref_audio_org, ref_sr = ( + item["prompt_audio"]["array"], + item["prompt_audio"]["sampling_rate"], + ) + ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0) + # ref_audio_org = ref_audio_org.mean(dim=0, keepdim=True) + print(ref_audio_org.shape) + + if ref_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) + ref_audio = resampler(ref_audio_org) + else: + ref_audio = ref_audio_org + + prompt_audio_list.append(ref_audio) + + if "prompt_audio_cosy2_tokens" in item: + prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"] + prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens) + else: + # convert to float first + mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0))) + + if len(mels) > 0: + mels, mels_lens = s3tokenizer.padding(mels) + codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device)) + for i in range(len(codes)): + prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()]) + for prompt_audio_cosy2tokens in prompt_audio_cosy2tokens_list: + prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens) + # Create chat template for LLM generation + chat = [ + {"role": "user", "content": full_text}, + {"role": "assistant", "content": prompt_audio_cosy2_id_str} + ] + if 'system' in tokenizer.chat_template: + tokenizer.chat_template = TEMPLATE + input_ids = tokenizer.apply_chat_template( + chat, + tokenize=True, + return_tensors='pt', + continue_final_message=True + ) + input_ids_list.append(input_ids.squeeze(0)) + + # For batch_size=1, no need to pad + if len(input_ids_list) == 1: + input_ids = input_ids_list[0].unsqueeze(0) + else: + # Handle batch > 1 if needed + max_len = max([len(input_ids) for input_ids in input_ids_list]) + input_ids_list = [ + torch.cat([torch.full((max_len - len(input_ids),), tokenizer.pad_token_id), input_ids]) + for input_ids in input_ids_list + ] + input_ids = torch.stack(input_ids_list) + + ids = [item["id"] for item in batch] + + return { + "input_ids": input_ids, + "ids": ids, + "prompt_text": prompt_text_list, + "prompt_audio_list": prompt_audio_list, + } + + +def init_distributed(): + world_size = int(os.environ.get("WORLD_SIZE", 1)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + rank = int(os.environ.get("RANK", 0)) + print( + "Inference on multiple gpus, this gpu {}".format(local_rank) + + ", rank {}, world_size {}".format(rank, world_size) + ) + torch.cuda.set_device(local_rank) + dist.init_process_group("nccl") + return world_size, local_rank, rank + + +def main(): + args = get_args() + os.makedirs(args.output_dir, exist_ok=True) + + assert torch.cuda.is_available() + world_size, local_rank, rank = init_distributed() + device = torch.device(f"cuda:{local_rank}") + + # Load LLM model and tokenizer directly + tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path) + model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path) + model.eval() + model.to(device) + + cosyvoice_codec = CosyVoice2( + args.token2wav_path, load_jit=True, load_trt=True, fp16=True + ) + if args.prompt_speech_path: + prompt_speech_16k = load_wav(args.prompt_speech_path, 16000) + else: + prompt_speech_16k = None + s3_tokenizer = s3tokenizer.load_model("speech_tokenizer_v2_25hz").to(device) if 'zero' in args.split_name else None + dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2" + dataset = load_dataset( + dataset_name, + split=args.split_name, + trust_remote_code=True, + ) + + sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + sampler=sampler, + shuffle=False, + num_workers=args.num_workers, + prefetch_factor=args.prefetch, + collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer), + ) + + total_steps = len(dataset) + + if rank == 0: + progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs") + + for batch in dataloader: + with torch.no_grad(): + input_ids = batch["input_ids"].to(device) + + # Generate speech tokens using LLM + outputs = model.generate( + input_ids, + max_new_tokens=2048, # Max length for generation + do_sample=True, + top_p=args.top_p, + temperature=args.temperature, + top_k=args.top_k, + ) + + # Process each sample in the batch + for i in range(len(batch["ids"])): + # Extract generated tokens (excluding input) + input_length = input_ids[i].shape[0] + generated_ids = outputs[i][input_length:-1] # Remove last token if needed + speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + + # Extract speech IDs from token strings like <|s_23456|> + speech_ids = extract_speech_ids(speech_tokens_str) + + if len(speech_ids) == 0: + print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping") + continue + + # Convert to tensor for CosyVoice2 + audio_tokens = torch.tensor(speech_ids, dtype=torch.long, device=device).unsqueeze(0) + + if args.prompt_text is not None: + current_prompt_text = args.prompt_text + current_prompt_audio = prompt_speech_16k + else: + current_prompt_text = batch["prompt_text"][i] + current_prompt_audio = batch["prompt_audio_list"][i] + + if current_prompt_audio is not None: + # Generate audio using CosyVoice2 + audio_hat = audio_decode_cosyvoice2( + audio_tokens, + current_prompt_text, + current_prompt_audio, + cosyvoice_codec, + ) + + # Convert to numpy and save + generated_wave = audio_hat.squeeze(0).cpu().numpy() + target_sample_rate = 24000 + + utt = batch["ids"][i] + sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate) + + print(f"Generated audio for sample {utt} with {len(speech_ids)} tokens") + else: + print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping") + + if rank == 0: + progress_bar.update(world_size * len(batch["ids"])) + + if rank == 0: + progress_bar.close() + + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/examples/grpo/cosyvoice2/prepare_data.py b/examples/grpo/cosyvoice2/prepare_data.py new file mode 100644 index 00000000..46c3c097 --- /dev/null +++ b/examples/grpo/cosyvoice2/prepare_data.py @@ -0,0 +1,86 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Preprocess the Text to Speech dataset to parquet format +""" + +import argparse +import os +import re + +import datasets + +from verl.utils.hdfs_io import copy, makedirs + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--train_file", required=True, help="Path to training JSON/JSONL file") + parser.add_argument("--test_file", required=True, help="Path to test JSON/JSONL file") + parser.add_argument("--local_dir", default=None, required=True) + parser.add_argument("--hdfs_dir", default=None) + + args = parser.parse_args() + + # Load datasets from local JSON files + train_dataset = datasets.load_dataset("json", data_files=args.train_file)['train'] + test_dataset = datasets.load_dataset("json", data_files=args.test_file)['train'] + + # add a row to each data item that represents a unique id + def make_map_fn(split): + def process_fn(example, idx): + text = example.pop("text") + + # use cosyvoice2 official huggingface compatible checkpoint template + question = text + answer = "" + + data = { + "data_source": f"{args.train_file}_{args.test_file}", # Use file names as data source + "prompt": [ + { + "role": "user", + "content": question, + }, + { + "role": "assistant", + "content": answer, + }, + ], + "ability": "text-to-speech", + "reward_model": {"style": "rule", "ground_truth": text}, + "extra_info": { + "split": split, + "index": idx, + "text": text, + }, + } + return data + + return process_fn + + train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) + test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True) + + local_dir = args.local_dir + hdfs_dir = args.hdfs_dir + + print(train_dataset) + print(test_dataset) + train_dataset.to_parquet(os.path.join(local_dir, "train.parquet")) + test_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) + + if hdfs_dir is not None: + makedirs(hdfs_dir) + + copy(src=local_dir, dst=hdfs_dir) diff --git a/examples/grpo/cosyvoice2/pretrained_to_huggingface.py b/examples/grpo/cosyvoice2/pretrained_to_huggingface.py new file mode 100644 index 00000000..161a11fc --- /dev/null +++ b/examples/grpo/cosyvoice2/pretrained_to_huggingface.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 + +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: Instruct TTS + python3 infer.py \ + --token2wav-path /workspace/CosyVoice2-0.5B \ + --prompt-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \ + --prompt-speech-path ./assets/prompt_audio.wav \ + --model-path ./transformers_cosyvoice2_llm \ + --input-text "用四川话说<|endofprompt|>扁担长,板凳宽,扁担绑在板凳上。吃葡萄不吐葡萄皮,不吃葡萄倒吐葡萄皮。" +""" +from cosyvoice.cli.cosyvoice import CosyVoice2 +import sys +from argparse import ArgumentParser +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch + +sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") + + +def get_args(): + parser = ArgumentParser() + + parser.add_argument( + "--pretrained-cosyvoice2-path", + type=str, + default="/workspace/CosyVoice2-0.5B", + help="Token2Wav path, default to %(default)r", + ) + parser.add_argument( + "--save-path", + type=str, + default='./transformers_cosyvoice2_llm', + help="The path to save the model", + ) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = get_args() + cosy2_model = CosyVoice2( + args.pretrained_cosyvoice2_path, load_jit=False, load_trt=False, fp16=False + ) + + llm = cosy2_model.model.llm.llm.model + + speech_embedding = cosy2_model.model.llm.speech_embedding + llm_decoder = cosy2_model.model.llm.llm_decoder + llm_embedding = cosy2_model.model.llm.llm_embedding + + tokenizer = AutoTokenizer.from_pretrained(f"{args.pretrained_cosyvoice2_path}/CosyVoice-BlankEN") + special_tokens = { + 'eos_token': '<|endoftext|>', + 'pad_token': '<|endoftext|>', + 'additional_special_tokens': [ + '<|im_start|>', '<|im_end|>', '<|endofprompt|>', + '[breath]', '', '', '[noise]', + '[laughter]', '[cough]', '[clucking]', '[accent]', + '[quick_breath]', + "", "", + "[hissing]", "[sigh]", "[vocalized-noise]", + "[lipsmack]", "[mn]" + ] + } + tokenizer.add_special_tokens(special_tokens) + + original_tokenizer_vocab_size = len(tokenizer) + cosyvoice2_token_size = 6561 + new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [ + "<|eos1|>", "<|eos2|>", "<|eos3|>", "<|sos|>", "<|task_id|>" + ] + num_added_tokens = tokenizer.add_tokens(new_tokens) + + llm.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=128) + vocab_size = llm.get_input_embeddings().weight.shape[0] + + feature_size = speech_embedding.embedding_dim + new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=vocab_size, bias=True) + + with torch.no_grad(): + # set the weight and bias of the new lm_head to 0 + new_lm_head.weight.data.zero_() + # make bias value -inf + new_lm_head.bias.data.fill_(-float('inf')) + new_lm_head.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.weight + new_lm_head.bias[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.bias + + llm.lm_head = new_lm_head + input_embeddings = llm.get_input_embeddings() + + with torch.no_grad(): + input_embeddings.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = speech_embedding.weight + input_embeddings.weight[original_tokenizer_vocab_size + cosyvoice2_token_size + 3:original_tokenizer_vocab_size + cosyvoice2_token_size + 3 + 2] = llm_embedding.weight + + eos_token_ids = [original_tokenizer_vocab_size + cosyvoice2_token_size, + original_tokenizer_vocab_size + cosyvoice2_token_size + 1, + original_tokenizer_vocab_size + cosyvoice2_token_size + 2] + llm.generation_config.eos_token_id = eos_token_ids + llm.generation_config.temperature = 1.0 + llm.generation_config.top_p = 0.8 + llm.generation_config.top_k = 25 + + llm.config.eos_token_id = original_tokenizer_vocab_size + cosyvoice2_token_size + llm.config.vocab_size = vocab_size + llm.config.tie_word_embeddings = False + llm.config.use_bias = True + llm.to(torch.bfloat16) + llm.save_pretrained(args.save_path) + + TEMPLATE = ( + "{%- for message in messages %}" + "{%- if message['role'] == 'user' %}" + "{{- '<|sos|>' + message['content'] + '<|task_id|>' }}" + "{%- elif message['role'] == 'assistant' %}" + "{{- message['content']}}" + "{%- endif %}" + "{%- endfor %}" + ) + tokenizer.chat_template = TEMPLATE + tokenizer.save_pretrained(args.save_path) diff --git a/examples/grpo/cosyvoice2/requirements.txt b/examples/grpo/cosyvoice2/requirements.txt new file mode 100644 index 00000000..50f4edd7 --- /dev/null +++ b/examples/grpo/cosyvoice2/requirements.txt @@ -0,0 +1,31 @@ +conformer==0.3.2 +diffusers==0.29.0 +gdown==5.1.0 +gradio +hydra-core==1.3.2 +HyperPyYAML==1.2.2 +inflect==7.3.1 +librosa==0.10.2 +lightning==2.2.4 +matplotlib==3.7.5 +modelscope==1.15.0 +networkx==3.1 +omegaconf==2.3.0 +onnx==1.16.0 +onnxruntime-gpu==1.18.0 +protobuf==4.25 +pydantic==2.7.0 +pyworld==0.3.4 +rich==13.7.1 +soundfile==0.12.1 +tensorboard==2.14.0 +wget==3.2 +WeTextProcessing==1.0.3 +s3tokenizer +tensorrt +sherpa_onnx +jiwer +zhon +numpy==1.25.2 +pypinyin +openai-whisper \ No newline at end of file diff --git a/examples/grpo/cosyvoice2/reward_tts.py b/examples/grpo/cosyvoice2/reward_tts.py new file mode 100644 index 00000000..4c407611 --- /dev/null +++ b/examples/grpo/cosyvoice2/reward_tts.py @@ -0,0 +1,233 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Reward calculation for CosyVoice2-0.5B. +""" + +from __future__ import annotations + +import re +import json +import time +import argparse +from typing import List + +import numpy as np +import requests + + +REWARD_SERVER_URL = "http://localhost:8000/v2/models/token2wav_asr/infer" + + +def _parse_ids(token_str: str) -> List[int]: + return [int(t) for t in re.findall(r"<\|s_(\d+)\|>", token_str)] + + +def _remote_reward(tokens: List[int], ground_truth: str, timeout: float = 200.0) -> float: + """Send token IDs and ground-truth text to the Triton server and get reward.""" + + tokens_arr = np.array(tokens, dtype=np.int32).reshape(1, -1) + lens_arr = np.array([[tokens_arr.shape[1]]], dtype=np.int32) + + gt_arr = np.array([ground_truth.encode("utf-8")], dtype=object) + + payload = { + "inputs": [ + { + "name": "TOKENS", + "shape": list(tokens_arr.shape), + "datatype": "INT32", + "data": tokens_arr.tolist(), + }, + { + "name": "TOKEN_LENS", + "shape": list(lens_arr.shape), + "datatype": "INT32", + "data": lens_arr.tolist(), + }, + { + "name": "GT_TEXT", + "shape": [1, 1], + "datatype": "BYTES", + "data": [ground_truth], + }, + ] + } + rsp = requests.post( + REWARD_SERVER_URL, + headers={"Content-Type": "application/json"}, + json=payload, + timeout=timeout, + verify=False, + params={"request_id": "0"}, + ) + rsp.raise_for_status() + result = rsp.json() + + try: + # Reward is returned as the first output + return float(result["outputs"][0]["data"][0]) + except (KeyError, IndexError, TypeError): + return 0.0 + + +def compute_score( + data_source: str, + solution_str: str, + ground_truth: str, + extra_info: dict | None = None, + *, + debug_dump: bool = False, +) -> float: + """Return reward in [0, 1] using the Triton ASR service. + + The reward is based on the pinyin-level WER between the ASR transcript + produced from *solution_str* and the provided *ground_truth* text. + """ + + # Decode token IDs + ids = _parse_ids(solution_str) + + # Query remote server for reward + try: + reward = _remote_reward(ids, ground_truth) + except Exception as e: + reward = 0.0 + + if debug_dump: + print( + f"\033[92m[{data_source}] Remote reward: {reward:.4f}\033[0m" + ) + + return reward + + +# CLI quick test +if __name__ == "__main__": + import sys + + def get_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Test TTS CER scoring with data from JSONL file", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--input", "-i", + type=str, + default="data/emilia_zh-cosy-tiny-test.jsonl", + help="Path to input JSONL file" + ) + + parser.add_argument( + "--max-samples", "-n", + type=int, + default=None, + help="Maximum number of samples to process (default: all)" + ) + + parser.add_argument( + "--no-interactive", + action="store_true", + help="Run in non-interactive mode (process all samples without prompts)" + ) + + parser.add_argument( + "--debug", + action="store_true", + help="Enable debug mode" + ) + + return parser.parse_args() + + def load_jsonl(file_path: str): + """Load data from jsonl file.""" + data = [] + with open(file_path, 'r', encoding='utf-8') as f: + for line in f: + data.append(json.loads(line.strip())) + return data + + def code_to_solution_str(code_list: List[int]) -> str: + """Convert code list to solution string format.""" + return ''.join([f"<|s_{code}|>" for code in code_list]) + + # Parse command line arguments + args = get_args() + + try: + # Load data from jsonl file + print(f"Loading data from: {args.input}") + data_list = load_jsonl(args.input) + print(f"Loaded {len(data_list)} samples") + + # Limit samples if specified + if args.max_samples is not None: + data_list = data_list[:args.max_samples] + print(f"Processing first {len(data_list)} samples (limited by --max-samples)") + + # Process each sample + begin_time = time.time() + for i, sample in enumerate(data_list): + print(f"\n--- Sample {i+1}/{len(data_list)} ---") + print(f"Index: {sample.get('index', 'unknown')}") + print(f"Text: {sample['text']}") + + # Extract required fields + code_list = sample['code'] + ground_truth = sample['text'] + data_source = sample.get('index', f'sample_{i}') # Use index as data_source + + # Convert code list to solution string + solution_str = code_to_solution_str(code_list) + print(f"Solution tokens: {len(code_list)} tokens") + if args.debug: + print(f"Solution string: {solution_str}") + else: + print(f"Solution string preview: {solution_str[:100]}..." if len(solution_str) > 100 else f"Solution string: {solution_str}") + + # Call compute_score function + try: + score = compute_score( + data_source=data_source, + solution_str=solution_str, + ground_truth=ground_truth, + extra_info=None, + debug_dump=args.debug + ) + print(f"Final Score: {score:.4f}") + except Exception as e: + print(f"Error computing score: {e}") + + # Ask user if they want to continue (for interactive mode) + if not args.no_interactive and i < len(data_list) - 1: + try: + response = input("\nPress Enter to continue or 'q' to quit: ").strip().lower() + if response == 'q': + break + except KeyboardInterrupt: + print("\nStopped by user") + break + + print(f"\nProcessed {min(i+1, len(data_list))} samples") + end_time = time.time() + print(f"Time taken: {end_time - begin_time} seconds") + except FileNotFoundError: + print(f"Error: File not found - {args.input}") + print("Please check the file path or use --input to specify correct path") + print("Run with --help for usage information") + except Exception as e: + print(f"Error: {e}") diff --git a/examples/grpo/cosyvoice2/run.sh b/examples/grpo/cosyvoice2/run.sh new file mode 100644 index 00000000..ce97ab31 --- /dev/null +++ b/examples/grpo/cosyvoice2/run.sh @@ -0,0 +1,159 @@ +#!/usr/bin/env bash + +set -eou pipefail + +stage=-1 +stop_stage=4 + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +export PYTHONPATH=/workspace/CosyVoice +model_scope_model_path=./CosyVoice2-0.5B +sft_model_path=./transformers_cosyvoice2_llm + +if [ $stage -le -2 ] && [ $stop_stage -ge -2 ]; then + log "stage -2: install dependencies locally if pre-built docker image is not available" + conda create -n cosyvoice2 python=3.10 -y + conda activate cosyvoice2 + # install verl + git clone https://github.com/yuekaizhang/verl.git -b thread + cd verl + USE_MEGATRON=0 bash scripts/install_vllm_sglang_mcore.sh + pip install --no-deps -e . + cd - + # install requirements + pip install -r requirements.txt + pip install -U nvidia-pytriton + git clone https://github.com/yuekaizhang/PytritonSenseVoice.git && cd PytritonSenseVoice && pip install -e . +fi + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "stage -1: download official CosyVoice2-0.5B LLM model and convert to huggingface compatible checkpoint" + modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path + python3 pretrained_to_huggingface.py \ + --pretrained-cosyvoice2-path $model_scope_model_path \ + --save-path $sft_model_path + + # Or, you could use the following command to download the huggingface compatible checkpoint + # huggingface-cli download --local-dir $sft_model_path yuekai/cosyvoice2_llm + + # Note: we remove the lm_head's bias to make it compatible with the Qwen2.5-0.5B model in Transformers. +fi + +data_dir=data/parquet_aishell3 +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "stage 0: prepare data into verl format" + mkdir -p $data_dir + wget -O data/aishell-3.jsonl https://huggingface.co/datasets/SparkAudio/voxbox/resolve/main/metadata/aishell-3.jsonl + # total 88035 samples + head -n 80000 data/aishell-3.jsonl > data/train.jsonl + tail -n 100 data/aishell-3.jsonl > data/test.jsonl + python prepare_data.py \ + --train_file data/train.jsonl \ + --test_file data/test.jsonl \ + --local_dir $data_dir +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "stage 1: start token2wav asr server for reward function" + python3 token2wav_asr_server.py --number-of-devices 8 +fi + +exp_name=official_llm_aishell3_grpo +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "stage 2: grpo train" + export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + export MKL_SERVICE_FORCE_INTEL=TRUE + n_gpus_per_node=8 + micro_batch_size=4 + train_batch_size=32 + python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$data_dir/train.parquet \ + data.val_files=$data_dir/test.parquet \ + data.train_batch_size=$train_batch_size \ + data.max_prompt_length=1024 \ + data.max_response_length=512 \ + data.truncation='error' \ + actor_rollout_ref.model.use_remove_padding=False \ + actor_rollout_ref.model.path=$sft_model_path \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$micro_batch_size \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$micro_batch_size \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.do_sample=true \ + actor_rollout_ref.rollout.temperature=0.8 \ + actor_rollout_ref.rollout.top_p=0.95 \ + actor_rollout_ref.rollout.top_k=25 \ + actor_rollout_ref.rollout.n=4 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=true \ + actor_rollout_ref.rollout.val_kwargs.temperature=0.8 \ + actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \ + actor_rollout_ref.rollout.val_kwargs.top_k=25 \ + reward_model.reward_manager=prime \ + custom_reward_function.path=reward_tts.py \ + custom_reward_function.name=compute_score \ + trainer.project_name='cosyvoice2_grpo' \ + trainer.experiment_name=$exp_name \ + trainer.logger=['console','wandb'] \ + trainer.n_gpus_per_node=$n_gpus_per_node \ + trainer.nnodes=1 \ + trainer.save_freq=100 \ + trainer.test_freq=100 \ + trainer.resume_mode='auto' \ + trainer.total_epochs=1 \ + trainer.val_before_train=False +fi + +steps=(100 200 300 400 500) +for step in ${steps[@]}; do +llm_path=./checkpoints/cosyvoice2_grpo/$exp_name/global_step_${step} +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "stage 3: merge the model" + python -m verl.model_merger merge \ + --backend fsdp \ + --local_dir $llm_path/actor \ + --target_dir $llm_path/merged_hf_model || exit 1 +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "stage 4: Test the model" + dataset=zero_shot_zh # from CosyVoice3 test set + # dataset=test_zh # from seed_tts test set + output_dir=./outputs_${exp_name}_${step}_${dataset} + + token2wav_path=/workspace/CosyVoice2-0.5B + model_path=$llm_path/merged_hf_model + + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ + torchrun --nproc_per_node=8 \ + infer_dataset.py \ + --output-dir $output_dir \ + --llm-model-name-or-path $model_path \ + --token2wav-path $token2wav_path \ + --split-name ${dataset} || exit 1 + + bash scripts/compute_wer.sh $output_dir ${dataset} +fi +done + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "stage 5: Convert the RL trained model to CosyVoice repo format" + python3 huggingface_to_pretrained.py \ + --hf-cosyvoice2-llm-path $llm_path/merged_hf_model \ + --output-path /workspace/CosyVoice2-0.5B/llm-new.pt + # You need to manually move the llm-new.pt to overwrite /workspace/CosyVoice2-0.5B/llm.pt + # However, we found that the RL trained model accuracy would slightly drop after this conversion. + # Please be careful or use the huggingface format inference code. +fi \ No newline at end of file diff --git a/examples/grpo/cosyvoice2/scripts/compute_wer.sh b/examples/grpo/cosyvoice2/scripts/compute_wer.sh new file mode 100644 index 00000000..43a6872b --- /dev/null +++ b/examples/grpo/cosyvoice2/scripts/compute_wer.sh @@ -0,0 +1,33 @@ +wav_dir=$1 +wav_files=$(ls $wav_dir/*.wav) +# if wav_files is empty, then exit +if [ -z "$wav_files" ]; then + exit 1 +fi +split_name=$2 +model_path=models/sherpa-onnx-paraformer-zh-2023-09-14 + +if [ ! -d $model_path ]; then + pip install sherpa-onnx + wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 + mkdir models + tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C models +fi + +python3 scripts/offline-decode-files.py \ + --tokens=$model_path/tokens.txt \ + --paraformer=$model_path/model.int8.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + --sample-rate=24000 \ + --log-dir $wav_dir \ + --feature-dim=80 \ + --split-name $split_name \ + --name sherpa_onnx \ + $wav_files + +# python3 scripts/paraformer-pytriton-client.py \ +# --log-dir $wav_dir \ +# --split-name $split_name \ +# $wav_files \ No newline at end of file diff --git a/examples/grpo/cosyvoice2/scripts/offline-decode-files.py b/examples/grpo/cosyvoice2/scripts/offline-decode-files.py new file mode 100644 index 00000000..847d4343 --- /dev/null +++ b/examples/grpo/cosyvoice2/scripts/offline-decode-files.py @@ -0,0 +1,756 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2023 by manyeyes +# Copyright (c) 2023 Xiaomi Corporation + +""" +This file demonstrates how to use sherpa-onnx Python API to transcribe +file(s) with a non-streaming model. + +(1) For paraformer + + ./python-api-examples/offline-decode-files.py \ + --tokens=/path/to/tokens.txt \ + --paraformer=/path/to/paraformer.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + --sample-rate=16000 \ + --feature-dim=80 \ + /path/to/0.wav \ + /path/to/1.wav + +(2) For transducer models from icefall + + ./python-api-examples/offline-decode-files.py \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + --sample-rate=16000 \ + --feature-dim=80 \ + /path/to/0.wav \ + /path/to/1.wav + +(3) For CTC models from NeMo + +python3 ./python-api-examples/offline-decode-files.py \ + --tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \ + --nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \ + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \ + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav + +(4) For Whisper models + +python3 ./python-api-examples/offline-decode-files.py \ + --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ + --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ + --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \ + --whisper-task=transcribe \ + --num-threads=1 \ + ./sherpa-onnx-whisper-base.en/test_wavs/0.wav \ + ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \ + ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav + +(5) For CTC models from WeNet + +python3 ./python-api-examples/offline-decode-files.py \ + --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \ + --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav + +(6) For tdnn models of the yesno recipe from icefall + +python3 ./python-api-examples/offline-decode-files.py \ + --sample-rate=8000 \ + --feature-dim=23 \ + --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \ + --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/index.html +to install sherpa-onnx and to download non-streaming pre-trained models +used in this file. +""" +import argparse +import time +import wave +from pathlib import Path +from typing import List, Tuple, Dict, Iterable, TextIO, Union + +import numpy as np +import sherpa_onnx +import soundfile as sf +from datasets import load_dataset +import logging +from collections import defaultdict +import kaldialign +from zhon.hanzi import punctuation +import string +punctuation_all = punctuation + string.punctuation +Pathlike = Union[str, Path] + + +def remove_punctuation(text: str) -> str: + for x in punctuation_all: + if x == '\'': + continue + text = text.replace(x, '') + return text + + +def store_transcripts( + filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False +) -> None: + """Save predicted results and reference transcripts to a file. + + Args: + filename: + File to save the results to. + texts: + An iterable of tuples. The first element is the cur_id, the second is + the reference transcript and the third element is the predicted result. + If it is a multi-talker ASR system, the ref and hyp may also be lists of + strings. + Returns: + Return None. + """ + with open(filename, "w", encoding="utf8") as f: + for cut_id, ref, hyp in texts: + if char_level: + ref = list("".join(ref)) + hyp = list("".join(hyp)) + print(f"{cut_id}:\tref={ref}", file=f) + print(f"{cut_id}:\thyp={hyp}", file=f) + + +def write_error_stats( + f: TextIO, + test_set_name: str, + results: List[Tuple[str, str]], + enable_log: bool = True, + compute_CER: bool = False, + sclite_mode: bool = False, +) -> float: + """Write statistics based on predicted results and reference transcripts. + + It will write the following to the given file: + + - WER + - number of insertions, deletions, substitutions, corrects and total + reference words. For example:: + + Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 + reference words (2337 correct) + + - The difference between the reference transcript and predicted result. + An instance is given below:: + + THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES + + The above example shows that the reference word is `EDISON`, + but it is predicted to `ADDISON` (a substitution error). + + Another example is:: + + FOR THE FIRST DAY (SIR->*) I THINK + + The reference word `SIR` is missing in the predicted + results (a deletion error). + results: + An iterable of tuples. The first element is the cut_id, the second is + the reference transcript and the third element is the predicted result. + enable_log: + If True, also print detailed WER to the console. + Otherwise, it is written only to the given file. + Returns: + Return None. + """ + subs: Dict[Tuple[str, str], int] = defaultdict(int) + ins: Dict[str, int] = defaultdict(int) + dels: Dict[str, int] = defaultdict(int) + + # `words` stores counts per word, as follows: + # corr, ref_sub, hyp_sub, ins, dels + words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) + num_corr = 0 + ERR = "*" + + if compute_CER: + for i, res in enumerate(results): + cut_id, ref, hyp = res + ref = list("".join(ref)) + hyp = list("".join(hyp)) + results[i] = (cut_id, ref, hyp) + + for cut_id, ref, hyp in results: + ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode) + for ref_word, hyp_word in ali: + if ref_word == ERR: + ins[hyp_word] += 1 + words[hyp_word][3] += 1 + elif hyp_word == ERR: + dels[ref_word] += 1 + words[ref_word][4] += 1 + elif hyp_word != ref_word: + subs[(ref_word, hyp_word)] += 1 + words[ref_word][1] += 1 + words[hyp_word][2] += 1 + else: + words[ref_word][0] += 1 + num_corr += 1 + ref_len = sum([len(r) for _, r, _ in results]) + sub_errs = sum(subs.values()) + ins_errs = sum(ins.values()) + del_errs = sum(dels.values()) + tot_errs = sub_errs + ins_errs + del_errs + tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) + + if enable_log: + logging.info( + f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " + f"[{tot_errs} / {ref_len}, {ins_errs} ins, " + f"{del_errs} del, {sub_errs} sub ]" + ) + + print(f"%WER = {tot_err_rate}", file=f) + print( + f"Errors: {ins_errs} insertions, {del_errs} deletions, " + f"{sub_errs} substitutions, over {ref_len} reference " + f"words ({num_corr} correct)", + file=f, + ) + print( + "Search below for sections starting with PER-UTT DETAILS:, " + "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:", + file=f, + ) + + print("", file=f) + print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) + for cut_id, ref, hyp in results: + ali = kaldialign.align(ref, hyp, ERR) + combine_successive_errors = True + if combine_successive_errors: + ali = [[[x], [y]] for x, y in ali] + for i in range(len(ali) - 1): + if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]: + ali[i + 1][0] = ali[i][0] + ali[i + 1][0] + ali[i + 1][1] = ali[i][1] + ali[i + 1][1] + ali[i] = [[], []] + ali = [ + [ + list(filter(lambda a: a != ERR, x)), + list(filter(lambda a: a != ERR, y)), + ] + for x, y in ali + ] + ali = list(filter(lambda x: x != [[], []], ali)) + ali = [ + [ + ERR if x == [] else " ".join(x), + ERR if y == [] else " ".join(y), + ] + for x, y in ali + ] + + print( + f"{cut_id}:\t" + + " ".join( + ( + ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})" + for ref_word, hyp_word in ali + ) + ), + file=f, + ) + + print("", file=f) + print("SUBSTITUTIONS: count ref -> hyp", file=f) + + for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True): + print(f"{count} {ref} -> {hyp}", file=f) + + print("", file=f) + print("DELETIONS: count ref", file=f) + for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True): + print(f"{count} {ref}", file=f) + + print("", file=f) + print("INSERTIONS: count hyp", file=f) + for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True): + print(f"{count} {hyp}", file=f) + + print("", file=f) + print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f) + for _, word, counts in sorted( + [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True + ): + (corr, ref_sub, hyp_sub, ins, dels) = counts + tot_errs = ref_sub + hyp_sub + ins + dels + ref_count = corr + ref_sub + dels + hyp_count = corr + hyp_sub + ins + + print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) + return float(tot_err_rate) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tokens", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--hotwords-file", + type=str, + default="", + help=""" + The file containing hotwords, one words/phrases per line, like + HELLO WORLD + 你好世界 + """, + ) + + parser.add_argument( + "--hotwords-score", + type=float, + default=1.5, + help=""" + The hotword score of each token for biasing word/phrase. Used only if + --hotwords-file is given. + """, + ) + + parser.add_argument( + "--modeling-unit", + type=str, + default="", + help=""" + The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe. + Used only when hotwords-file is given. + """, + ) + + parser.add_argument( + "--bpe-vocab", + type=str, + default="", + help=""" + The path to the bpe vocabulary, the bpe vocabulary is generated by + sentencepiece, you can also export the bpe vocabulary through a bpe model + by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given + and modeling-unit is bpe or cjkchar+bpe. + """, + ) + + parser.add_argument( + "--encoder", + default="", + type=str, + help="Path to the encoder model", + ) + + parser.add_argument( + "--decoder", + default="", + type=str, + help="Path to the decoder model", + ) + + parser.add_argument( + "--joiner", + default="", + type=str, + help="Path to the joiner model", + ) + + parser.add_argument( + "--paraformer", + default="", + type=str, + help="Path to the model.onnx from Paraformer", + ) + + parser.add_argument( + "--nemo-ctc", + default="", + type=str, + help="Path to the model.onnx from NeMo CTC", + ) + + parser.add_argument( + "--wenet-ctc", + default="", + type=str, + help="Path to the model.onnx from WeNet CTC", + ) + + parser.add_argument( + "--tdnn-model", + default="", + type=str, + help="Path to the model.onnx for the tdnn model of the yesno recipe", + ) + + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--whisper-encoder", + default="", + type=str, + help="Path to whisper encoder model", + ) + + parser.add_argument( + "--whisper-decoder", + default="", + type=str, + help="Path to whisper decoder model", + ) + + parser.add_argument( + "--whisper-language", + default="", + type=str, + help="""It specifies the spoken language in the input audio file. + Example values: en, fr, de, zh, jp. + Available languages for multilingual models can be found at + https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 + If not specified, we infer the language from the input audio file. + """, + ) + + parser.add_argument( + "--whisper-task", + default="transcribe", + choices=["transcribe", "translate"], + type=str, + help="""For multilingual models, if you specify translate, the output + will be in English. + """, + ) + + parser.add_argument( + "--whisper-tail-paddings", + default=-1, + type=int, + help="""Number of tail padding frames. + We have removed the 30-second constraint from whisper, so you need to + choose the amount of tail padding frames by yourself. + Use -1 to use a default value for tail padding. + """, + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + parser.add_argument( + "--debug", + type=bool, + default=False, + help="True to show debug messages", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="""Sample rate of the feature extractor. Must match the one + expected by the model. Note: The input sound files can have a + different sample rate from this argument.""", + ) + + parser.add_argument( + "--feature-dim", + type=int, + default=80, + help="Feature dimension. Must match the one expected by the model", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to decode. Each file must be of WAVE" + "format with a single channel, and each sample has 16-bit, " + "i.e., int16_t. " + "The sample rate of the file can be arbitrary and does not need to " + "be 16 kHz", + ) + + parser.add_argument( + "--name", + type=str, + default="", + help="The directory containing the input sound files to decode", + ) + + parser.add_argument( + "--log-dir", + type=str, + default="", + help="The directory containing the input sound files to decode", + ) + + parser.add_argument( + "--label", + type=str, + default=None, + help="wav_base_name label", + ) + + # Dataset related arguments for loading labels when label file is not provided + parser.add_argument( + "--dataset-name", + type=str, + default="yuekai/seed_tts_cosy2", + help="Huggingface dataset name for loading labels", + ) + + parser.add_argument( + "--split-name", + type=str, + default="wenetspeech4tts", + help="Dataset split name for loading labels", + ) + + return parser.parse_args() + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and can be of type + 32-bit floating point PCM. Its sample rate does not need to be 24kHz. + + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, + which are normalized to the range [-1, 1]. + - Sample rate of the wave file. + """ + + samples, sample_rate = sf.read(wave_filename, dtype="float32") + assert ( + samples.ndim == 1 + ), f"Expected single channel, but got {samples.ndim} channels." + + samples_float32 = samples.astype(np.float32) + + return samples_float32, sample_rate + + +def normalize_text_alimeeting(text: str) -> str: + """ + Text normalization similar to M2MeT challenge baseline. + See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl + """ + import re + text = text.replace('\u00A0', '') # test_hard + text = text.replace(" ", "") + text = text.replace("", "") + text = text.replace("<%>", "") + text = text.replace("<->", "") + text = text.replace("<$>", "") + text = text.replace("<#>", "") + text = text.replace("<_>", "") + text = text.replace("", "") + text = text.replace("`", "") + text = text.replace("&", "") + text = text.replace(",", "") + if re.search("[a-zA-Z]", text): + text = text.upper() + text = text.replace("A", "A") + text = text.replace("a", "A") + text = text.replace("b", "B") + text = text.replace("c", "C") + text = text.replace("k", "K") + text = text.replace("t", "T") + text = text.replace(",", "") + text = text.replace("丶", "") + text = text.replace("。", "") + text = text.replace("、", "") + text = text.replace("?", "") + text = remove_punctuation(text) + return text + + +def main(): + args = get_args() + assert_file_exists(args.tokens) + assert args.num_threads > 0, args.num_threads + + assert len(args.nemo_ctc) == 0, args.nemo_ctc + assert len(args.wenet_ctc) == 0, args.wenet_ctc + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.tdnn_model) == 0, args.tdnn_model + + assert_file_exists(args.paraformer) + + recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( + paraformer=args.paraformer, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + debug=args.debug, + ) + + print("Started!") + start_time = time.time() + + streams, results = [], [] + total_duration = 0 + + for i, wave_filename in enumerate(args.sound_files): + assert_file_exists(wave_filename) + samples, sample_rate = read_wave(wave_filename) + duration = len(samples) / sample_rate + total_duration += duration + s = recognizer.create_stream() + s.accept_waveform(sample_rate, samples) + + streams.append(s) + if i % 10 == 0: + recognizer.decode_streams(streams) + results += [s.result.text for s in streams] + streams = [] + print(f"Processed {i} files") + # process the last batch + if streams: + recognizer.decode_streams(streams) + results += [s.result.text for s in streams] + end_time = time.time() + print("Done!") + + results_dict = {} + for wave_filename, result in zip(args.sound_files, results): + print(f"{wave_filename}\n{result}") + print("-" * 10) + wave_basename = Path(wave_filename).stem + results_dict[wave_basename] = result + + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + print(f"num_threads: {args.num_threads}") + print(f"decoding_method: {args.decoding_method}") + print(f"Wave duration: {total_duration:.3f} s") + print(f"Elapsed time: {elapsed_seconds:.3f} s") + print( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + # Load labels either from file or from dataset + labels_dict = {} + + if args.label: + # Load labels from file (original functionality) + print(f"Loading labels from file: {args.label}") + with open(args.label, "r") as f: + for line in f: + # fields = line.strip().split(" ") + # fields = [item for item in fields if item] + # assert len(fields) == 4 + # prompt_text, prompt_audio, text, audio_path = fields + + fields = line.strip().split("|") + fields = [item for item in fields if item] + assert len(fields) == 4 + audio_path, prompt_text, prompt_audio, text = fields + labels_dict[Path(audio_path).stem] = normalize_text_alimeeting(text) + else: + # Load labels from dataset (new functionality) + print(f"Loading labels from dataset: {args.dataset_name}, split: {args.split_name}") + if 'zero' in args.split_name: + dataset_name = "yuekai/CV3-Eval" + else: + dataset_name = "yuekai/seed_tts_cosy2" + dataset = load_dataset( + dataset_name, + split=args.split_name, + trust_remote_code=True, + ) + + for item in dataset: + audio_id = item["id"] + labels_dict[audio_id] = normalize_text_alimeeting(item["target_text"]) + + print(f"Loaded {len(labels_dict)} labels from dataset") + + # Perform evaluation if labels are available + if labels_dict: + + final_results = [] + for key, value in results_dict.items(): + if key in labels_dict: + final_results.append((key, labels_dict[key], value)) + else: + print(f"Warning: No label found for {key}, skipping...") + + if final_results: + store_transcripts( + filename=f"{args.log_dir}/recogs-{args.name}.txt", texts=final_results + ) + with open(f"{args.log_dir}/errs-{args.name}.txt", "w") as f: + write_error_stats(f, "test-set", final_results, enable_log=True) + + with open(f"{args.log_dir}/errs-{args.name}.txt", "r") as f: + print(f.readline()) # WER + print(f.readline()) # Detailed errors + else: + print("No matching labels found for evaluation") + else: + print("No labels available for evaluation") + + +if __name__ == "__main__": + main() diff --git a/examples/grpo/cosyvoice2/token2wav_asr_server.py b/examples/grpo/cosyvoice2/token2wav_asr_server.py new file mode 100644 index 00000000..8a6cb6e1 --- /dev/null +++ b/examples/grpo/cosyvoice2/token2wav_asr_server.py @@ -0,0 +1,346 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pytriton server for token2wav conversion and ASR""" + +from datasets import load_dataset +from cosyvoice.cli.cosyvoice import CosyVoice2 +from omnisense.models import OmniSenseVoiceSmall +from pytriton.proxy.types import Request +from pytriton.triton import Triton, TritonConfig +from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor +from pytriton.decorators import batch +import argparse +import io +import logging +from typing import Any, List +import numpy as np +import torch +from scipy.signal import resample +import sys +import random +import re +from jiwer import wer +from pypinyin import lazy_pinyin, Style +from tn.chinese.normalizer import Normalizer as ZhNormalizer + +# Chinese text normalizer (cached globally) +zh_tn_model = ZhNormalizer( + cache_dir="./cache", + remove_erhua=False, + remove_interjections=False, + remove_puncts=True, + overwrite_cache=True, +) + + +sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") + +logger = logging.getLogger("token2wav_asr_server") + + +class _ASR_Server: + """Wraps a single OmniSenseVoiceSmall model instance for Triton.""" + + def __init__(self, device_id: int): + self._model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id) + + @batch + def __call__(self, WAV: np.ndarray, WAV_LENS: np.ndarray, LANGUAGE: np.ndarray, TEXT_NORM: np.ndarray): + """ + WAV: np.ndarray, WAV_LENS: np.ndarray + LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used + See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu + """ + logger.debug("WAV: %s, WAV_LENS: %s, shapes: %s %s", type(WAV), type(WAV_LENS), WAV.shape, WAV_LENS.shape) + wavs = [WAV[i, :WAV_LENS[i, 0]] for i in range(len(WAV))] + + results = self._model.transcribe_single_batch( + wavs, + language="zh", + textnorm="woitn", + ) + texts = [result.text for result in results] + transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8") + return {"TRANSCRIPTS": transcripts} + + +def audio_decode_cosyvoice2( + audio_tokens, prompt_text, prompt_speech_16k, codec_decoder +): + """ + Generate audio from tokens with optional tone and prompt embedding. + """ + model_inputs_dict = codec_decoder.frontend.frontend_zero_shot( + "empty", prompt_text, prompt_speech_16k, 24000 + ) + tts_mel, _ = codec_decoder.model.flow.inference( + token=audio_tokens.to(codec_decoder.model.device), + token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to( + codec_decoder.model.device + ), + prompt_token=model_inputs_dict["flow_prompt_speech_token"].to( + codec_decoder.model.device + ), + prompt_token_len=torch.tensor( + [model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32 + ).to(codec_decoder.model.device), + prompt_feat=model_inputs_dict["prompt_speech_feat"].to( + codec_decoder.model.device + ), + prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to( + codec_decoder.model.device + ), + embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device), + finalize=True, + ) + + audio_hat, _ = codec_decoder.model.hift.inference( + speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0) + ) + + return audio_hat + + +def get_random_prompt_from_dataset(dataset): + """ + Get random prompt text and speech from the pre-loaded dataset. + Returns (prompt_text, prompt_speech_16k) + """ + random_idx = random.randint(0, len(dataset) - 1) + sample = dataset[random_idx] + + # Extract audio data + audio_data = sample["audio"] + audio_array = audio_data["array"] + sample_rate = audio_data["sampling_rate"] + + # Convert audio to 16kHz if needed + if sample_rate != 16000: + num_samples = int(len(audio_array) * (16000 / sample_rate)) + audio_array = resample(audio_array, num_samples) + + # Convert to torch tensor + prompt_speech_16k = torch.from_numpy(audio_array).float().unsqueeze(0) + prompt_text = sample["text"] + # remove space in prompt_text + prompt_text = prompt_text.replace(" ", "") + return prompt_text, prompt_speech_16k + + +class _Token2Wav_ASR: + """Wraps a single OmniSenseVoiceSmall model instance for Triton.""" + + def __init__(self, device_id: int): + self.asr_model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id) + self.dataset = load_dataset("yuekai/aishell", "test", trust_remote_code=True)["test"] + + # Make sure the CosyVoice2 decoder lives on the same GPU as the ASR model + # CosyVoice2 internally uses generic "cuda" device, so we first switch the + # current CUDA context to the desired card before the object is created. + # Afterwards, all parameters loaded with the generic "cuda" device will + # reside on this GPU. We keep the selected id in `self.device_id` and + # will set the context again for every forward call to avoid race + # conditions when several instances are used in the same process. + + self.device_id = device_id + + # Construct the TTS codec decoder under the correct CUDA device context + with torch.cuda.device(self.device_id): + self.codec_decoder = CosyVoice2( + "/workspace/CosyVoice2-0.5B", load_jit=True, load_trt=True, fp16=True + ) + + @batch + def __call__(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray): + """ + WAV: np.ndarray, WAV_LENS: np.ndarray + LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used + See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu + """ + # Ensure the default CUDA device is set correctly for this invocation + torch.cuda.set_device(self.device_id) + + if self.device_id == 0: + print(f"device_id: {self.device_id}, TOKENS: {TOKENS.shape}, TOKEN_LENS: {TOKEN_LENS.shape}") + + tokens_list = [TOKENS[i, :TOKEN_LENS[i, 0]] for i in range(len(TOKENS))] + + # Decode ground-truth text strings (BYTES → str) + if GT_TEXT.ndim == 2: + gt_texts = [GT_TEXT[i, 0].decode("utf-8") for i in range(len(GT_TEXT))] + else: + gt_texts = [GT_TEXT[i].decode("utf-8") for i in range(len(GT_TEXT))] + + wavs = [] + for tokens in tokens_list: + prompt_text, prompt_speech_16k = get_random_prompt_from_dataset(self.dataset) + audio_tokens = torch.tensor(tokens, dtype=torch.long, device=self.asr_model.device).unsqueeze(0) + audio_hat = audio_decode_cosyvoice2( + audio_tokens, + prompt_text, + prompt_speech_16k, + self.codec_decoder, + ) + # resample to 16000 using soundfile + audio_hat = audio_hat.squeeze(0).float().cpu() + audio_hat = audio_hat.numpy() + num_samples = int(len(audio_hat) * (16000 / 24000)) + audio_hat = resample(audio_hat, num_samples) + wavs.append(audio_hat) + + results = self.asr_model.transcribe_single_batch( + wavs, + language="zh", + textnorm="woitn", + ) + texts = [result.text for result in results] + + # ---------------- Reward computation ---------------- + rewards = [] + for gt_text, hyp_text in zip(gt_texts, texts): + gt_norm = zh_tn_model.normalize(gt_text).lower() + hyp_norm = zh_tn_model.normalize(hyp_text).lower() + + gt_pinyin = lazy_pinyin( + gt_norm, + style=Style.TONE3, + tone_sandhi=True, + neutral_tone_with_five=True, + ) + hyp_pinyin = lazy_pinyin( + hyp_norm, + style=Style.TONE3, + tone_sandhi=True, + neutral_tone_with_five=True, + ) + + c = float(wer(" ".join(gt_pinyin), " ".join(hyp_pinyin))) + reward_val = 1.0 - np.tanh(3.0 * c) + reward_val = max(0.0, min(1.0, reward_val)) + rewards.append(reward_val) + print(f"gt_text: {gt_text}, hyp_text: {hyp_text}, reward_val: {reward_val}") + + transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8") + rewards_arr = np.array(rewards, dtype=np.float32).reshape(-1, 1) + + return {"REWARDS": rewards_arr, "TRANSCRIPTS": transcripts} + + +def _infer_function_factory(device_ids: List[int], model_name: str): + """Creates a list of inference functions, one for each requested device ID.""" + infer_funcs = [] + for device_id in device_ids: + if model_name == "sensevoice": + infer_funcs.append(_ASR_Server(device_id=device_id)) + else: + infer_funcs.append(_Token2Wav_ASR(device_id=device_id)) + return infer_funcs + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--max-batch-size", + type=int, + default=32, + help="Batch size of request.", + required=False, + ) + parser.add_argument( + "--verbose", + action="store_true", + default=False, + ) + parser.add_argument( + "--number-of-instances-per-device", + type=int, + default=1, + help="Number of model instances to load.", + required=False, + ) + parser.add_argument( + "--number-of-devices", + type=int, + default=8, + help="Number of devices to use.", + ) + parser.add_argument( + "--model-name", + type=str, + default="token2wav_asr", + choices=["token2wav_asr", "sensevoice"], + help="Model name.", + ) + + args = parser.parse_args() + + log_level = logging.DEBUG if args.verbose else logging.INFO + logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s") + + triton_config = TritonConfig( + http_port=8000, + grpc_port=8001, + metrics_port=8002, + ) + + device_ids = [i for i in range(args.number_of_devices)] + device_ids = device_ids * args.number_of_instances_per_device + + with Triton(config=triton_config) as triton: + logger.info("Loading SenseVoice model on device ids: %s", device_ids) + if args.model_name == "sensevoice": + triton.bind( + model_name="sensevoice", + infer_func=_infer_function_factory(device_ids, args.model_name), + inputs=[ + Tensor(name="WAV", dtype=np.float32, shape=(-1,)), + Tensor(name="WAV_LENS", dtype=np.int32, shape=(-1,)), + Tensor(name="LANGUAGE", dtype=np.int32, shape=(-1,)), + Tensor(name="TEXT_NORM", dtype=np.int32, shape=(-1,)), + ], + outputs=[ + Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)), + ], + config=ModelConfig( + max_batch_size=args.max_batch_size, + batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms + ), + strict=True, + ) + else: + triton.bind( + model_name="token2wav_asr", + infer_func=_infer_function_factory(device_ids, args.model_name), + inputs=[ + Tensor(name="TOKENS", dtype=np.int32, shape=(-1,)), + Tensor(name="TOKEN_LENS", dtype=np.int32, shape=(-1,)), + Tensor(name="GT_TEXT", dtype=bytes, shape=(-1,)), + ], + outputs=[ + Tensor(name="REWARDS", dtype=np.float32, shape=(-1,)), + Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)), + ], + config=ModelConfig( + max_batch_size=args.max_batch_size, + batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms + ), + strict=True, + ) + logger.info("Serving inference") + triton.serve() + + +if __name__ == "__main__": + main()