diff --git a/egs/emilia/TTS/README.md b/egs/emilia/TTS/README.md new file mode 100644 index 0000000000..d55ff10c30 --- /dev/null +++ b/egs/emilia/TTS/README.md @@ -0,0 +1,76 @@ +# Results +| LLM Model | Flow matching Model | Seed-TTS test_zh CER | Comment | +|---------------------------------------|----------|-----------|--------| +| pretrained cosyvoice2 llm | pretrained cosyvoice2 unet | 1.45% | See [paper](https://arxiv.org/abs/2412.10117)| +| pretrained cosyvoice2 llm | f5-tts-small (wenetspeech4tts) | 1.79% (16 steps) | See [PR](https://github.com/k2-fsa/icefall/pull/1880)| +| llasa_cosyvoice2_token llm (Emilia 50k hours ZH) | f5-tts-small (wenetspeech4tts) | 1.81% (16 steps) | | + +# Introduction + +[**Emilia**](https://huggingface.co/datasets/amphion/Emilia-Dataset) starts with over 101k +hours of speech across six languages, covering a wide range of speaking styles to enable more natural and spontaneous speech generation. + +See https://arxiv.org/pdf/2407.05361. + +# Llasa (cosyvoice2 token) + +./llasa_cosyvoice2_token contains the code for training qwen2.5-0.5b models to predict cosyvoice2 semantic tokens. + +Generated samples and training logs of [Emilia](https://huggingface.co/datasets/amphion/Emilia-Dataset) 50k hours Chinese data can be found [here](https://huggingface.co/yuekai/llasa_cosyvoice2_token_qwen_0.5b/tree/main). + +Preparation: + +``` +# extract cosyvoice2 semantic tokens +bash prepare.sh --stage 3 --stop_stage 4 + +# Or you could use the prepared tokens. +huggingface-cli download yuekai/emilia_cosyvoice_v2_token --local-dir emilia_cosyvoice_v2_token +``` + +The training command is given below: + +``` +# docker: ghcr.io/swivid/f5-tts:main +# pip install -r llasa_cosyvoice2_token/requirements.txt + +WANDB_KEY=$your_wandb_key +wandb login ${WANDB_KEY} +huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir Qwen2.5-0.5B-Instruct +torchrun --nproc_per_node=8 train.py config.json +``` + +To inference with Icefall Emilia trained Chinese Llasa_cosyvoice2_token model, we need to use cosyvoice2 token flow matching [model](https://github.com/k2-fsa/icefall/pull/1880): +``` +cd icefall/egs/wenetspeech4tts/TTS +huggingface-cli login +huggingface-cli download --local-dir ${exp_dir} yuekai/llasa_cosyvoice2_token_qwen_0.5b +huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x +vocoder=./bigvgan_v2_24khz_100band_256x +split=test_zh +llm_path=llasa_cosyvoice2_token_qwen_0.5b/checkpoint-800000 + +huggingface-cli download --local-dir f5-tts-small-wenetspeech4tts-basic yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic +model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt +torchrun --nproc_per_node=2 \ + f5-tts/infer_dist.py \ + --output_dir $output_dir \ + --batch_size 1 \ + --num_workers 2 \ + --llm-model-name-or-path $llm_path \ + --flow-matching-model-path $model_path \ + --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \ + --use-cosyvoice-semantic-token True \ + --vocoder-dir $vocoder \ + --split-name $split -top-k 50 -top-p 0.95 -temperature 0.8 \ + --tokenizer-dir Qwen/Qwen2.5-0.5B-Instruct +# compute cer +huggingface-cli download yuekai/seed_tts_eval --local-dir seed_tts_eval --repo-type dataset +manifest=./seed_tts_eval/seedtts_testset/zh/meta.lst +bash local/compute_wer.sh $output_dir $manifest +``` + +# Credits +- [Llasa](https://arxiv.org/abs/2502.04128) +- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) +- [S3Tokenizer](https://github.com/xingchensong/S3Tokenizer/tree/main) diff --git a/egs/emilia/TTS/llasa_cosyvoice2_token/config.json b/egs/emilia/TTS/llasa_cosyvoice2_token/config.json new file mode 100644 index 0000000000..858edae84d --- /dev/null +++ b/egs/emilia/TTS/llasa_cosyvoice2_token/config.json @@ -0,0 +1,27 @@ +{ + "llm_model_name_or_path": "./Qwen2.5-0.5B-Instruct", + "data_path": ["./emilia_cosyvoice_v2_token/cosy_v2_tokens_ZH.jsonl"], + "bf16": false, + "output_dir": "./exp_zh", + "num_train_epochs": 3, + "per_device_train_batch_size": 8, + "per_device_eval_batch_size": 8, + "gradient_accumulation_steps": 1, + "evaluation_strategy": "steps", + "eval_steps": 1000, + "save_strategy": "steps", + "save_steps": 5000, + "save_total_limit": 100, + "learning_rate": 0.00005, + "weight_decay": 0.01, + "adam_beta2": 0.95, + "warmup_ratio": 0.03, + "lr_scheduler_type": "cosine", + "logging_steps": 100, + "report_to": "wandb", + "model_max_length": 2048, + "gradient_checkpointing": false, + "dataloader_num_workers": 4, + "dataloader_prefetch_factor": 4, + "deepspeed": "ds_config_zero2.json" +} diff --git a/egs/emilia/TTS/llasa_cosyvoice2_token/ds_config_zero2.json b/egs/emilia/TTS/llasa_cosyvoice2_token/ds_config_zero2.json new file mode 100644 index 0000000000..b0b139598f --- /dev/null +++ b/egs/emilia/TTS/llasa_cosyvoice2_token/ds_config_zero2.json @@ -0,0 +1,47 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 64, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + + "scheduler": { + "type": "WarmupCosineLR", + "params": { + "total_num_steps": "auto", + "warmup_min_ratio": 0.03, + "warmup_num_steps": "auto", + "cos_min_ratio": 0.1 + } + }, + + "zero_optimization": { + "stage": 2, + "overlap_comm": false, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto" + }, + + "gradient_accumulation_steps": "auto", + "gradient_clipping": 1.0, + "steps_per_print": 100, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/egs/emilia/TTS/llasa_cosyvoice2_token/requirements.txt b/egs/emilia/TTS/llasa_cosyvoice2_token/requirements.txt new file mode 100644 index 0000000000..11574c1909 --- /dev/null +++ b/egs/emilia/TTS/llasa_cosyvoice2_token/requirements.txt @@ -0,0 +1,8 @@ +torch +transformers +wandb +datasets +accelerate>=0.26.0 +deepspeed +flash-attn +s3tokenizer diff --git a/egs/emilia/TTS/llasa_cosyvoice2_token/train.py b/egs/emilia/TTS/llasa_cosyvoice2_token/train.py new file mode 100644 index 0000000000..e3c6fcae61 --- /dev/null +++ b/egs/emilia/TTS/llasa_cosyvoice2_token/train.py @@ -0,0 +1,184 @@ +# Modified from https://github.com/zhenye234/LLaSA_training/blob/main/train_tts.py +""" Example Usage +WANDB_KEY=$your_wandb_key +wandb login ${WANDB_KEY} +huggingface-cli download yuekai/emilia_cosyvoice_v2_token --local-dir emilia_cosyvoice_v2_token +huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir Qwen2.5-0.5B-Instruct +torchrun --nproc_per_node=8 train.py config.json +""" +import json +import os +import random +import sys +from dataclasses import dataclass, field +from functools import partial +from typing import List, Optional + +import numpy as np +import torch +import torch.nn as nn +import transformers +import wandb +from datasets import load_dataset +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + DataCollatorForSeq2Seq, + HfArgumentParser, + Trainer, + TrainingArguments, +) +from transformers.trainer_pt_utils import LabelSmoother + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index +TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" + + +@dataclass +class ModelArguments: + llm_model_name_or_path: Optional[str] = field( + default="meta-llama/Llama-3.2-1B-Instruct" + ) + + +@dataclass +class DataArguments: + data_path: List[str] = field( + default=None, + metadata={"help": "Root path(s) to the data. Can be single path or list."}, + ) + + +@dataclass +class CustomTrainingArguments(TrainingArguments): + optim: str = field(default="adamw_torch_fused") + model_max_length: int = field( + default=2048, + metadata={"help": "Maximum sequence length"}, + ) + logging_steps: int = field(default=100, metadata={"help": "Log every X updates"}) + report_to: Optional[str] = field( + default=None, + metadata={"help": "The integration to report the results and logs to."}, + ) + run_name: Optional[str] = field( + default=None, metadata={"help": "The name of the run for logging."} + ) + gradient_checkpointing: bool = field(default=False) + lr_scheduler_type: str = field( + default="cosine", metadata={"help": "The learning rate scheduler to use."} + ) + remove_unused_columns: bool = field(default=False) + + +def data_collator(batch, tokenizer, original_tokenizer_vocab_size, cut_off_len=2048): + speech_generation_start_index = tokenizer.convert_tokens_to_ids( + "<|SPEECH_GENERATION_START|>" + ) + assistant_index = tokenizer.convert_tokens_to_ids("assistant") + input_ids_list = [] + for i, item in enumerate(batch): + text, code = item["text"], item["code"] + message = [ + {"role": "user", "content": f"Convert the text to speech: {text}"}, + {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"}, + ] + + input_ids = tokenizer.apply_chat_template( + message, + tokenize=True, + chat_template=TEMPLATE, + ) + + code = [c + original_tokenizer_vocab_size for c in code] + + idx = input_ids.index(speech_generation_start_index) + input_ids = input_ids[:idx] + code + input_ids[idx + 1 :] + if len(input_ids) < cut_off_len: + input_ids_list.append(input_ids) + + max_len = max([len(input_ids) for input_ids in input_ids_list]) + input_ids_list = [ + input_ids + [tokenizer.pad_token_id] * (max_len - len(input_ids)) + for input_ids in input_ids_list + ] + input_ids = torch.tensor(input_ids_list, dtype=torch.int) + attention_mask = input_ids.ne(tokenizer.pad_token_id) + + target_ids = input_ids.clone() + target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID + mask_indices = torch.where(input_ids == assistant_index) + for i in range(mask_indices[0].size(0)): + row = mask_indices[0][i] + col = mask_indices[1][i] + # + 2 to skip: 'assistant', '\n' + target_ids[row, : col + 2] = IGNORE_TOKEN_ID + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": target_ids.to(dtype=torch.int64), + } + + +def main(): + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, CustomTrainingArguments) + ) + assert len(sys.argv) == 2 and sys.argv[1].endswith(".json") + ( + model_args, + data_args, + training_args, + ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + + is_main_process = training_args.local_rank in [-1, 0] + if training_args.report_to == "wandb" and is_main_process: + wandb.init( + project="llm_tts", + config=training_args.to_sanitized_dict(), + name=training_args.run_name, + ) + + model = AutoModelForCausalLM.from_pretrained( + model_args.llm_model_name_or_path, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + ) + + tokenizer = AutoTokenizer.from_pretrained(model_args.llm_model_name_or_path) + original_tokenizer_vocab_size = len(tokenizer) + cosyvoice2_token_size = 6561 + new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [ + "<|SPEECH_GENERATION_START|>" + ] + num_added_tokens = tokenizer.add_tokens(new_tokens) + + model.resize_token_embeddings(len(tokenizer)) + model.vocab_size = len(tokenizer) + + dataset = load_dataset("json", data_files=data_args.data_path) + dataset = dataset["train"] + train_test_split = dataset.train_test_split(test_size=100, seed=42) + train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"] + + trainer = Trainer( + model=model, + tokenizer=tokenizer, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + data_collator=lambda features: data_collator( + features, tokenizer, original_tokenizer_vocab_size + ), + ) + + if is_main_process: + trainer.add_callback(transformers.integrations.WandbCallback()) + + trainer.train(resume_from_checkpoint=None) + trainer.save_model(training_args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/egs/emilia/TTS/local/extract_cosyvoice2_token.py b/egs/emilia/TTS/local/extract_cosyvoice2_token.py new file mode 100644 index 0000000000..2a6d1d3805 --- /dev/null +++ b/egs/emilia/TTS/local/extract_cosyvoice2_token.py @@ -0,0 +1,200 @@ +# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song) +# 2025 (authors: Yuekai Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Example Usage +torchrun --nproc_per_node=8 --nnodes=1 \ + --rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \ + local/extract_cosyvoice2_token.py --data_dir $data_dir \ + --jsonl_file $jsonl_file_basename \ + --device "cuda" \ + --output_dir $output_dir \ + --batch_size 32 \ + --num_workers 2 \ + --model "speech_tokenizer_v2_25hz" + +""" + +import argparse +import json +import os +from pathlib import Path + +import s3tokenizer +import torch +import torch.distributed as dist +from lhotse.serialization import load_jsonl +from torch.utils.data import DataLoader, Dataset, DistributedSampler +from tqdm import tqdm + + +class AudioDataset(Dataset): + def __init__(self, data_dir, jsonl_file): + self.data = [] + # convert data_dir to Path object + self.data_dir = Path(data_dir) + # jsonl_files = self.data_dir.glob("*.jsonl") + jsonl_files = [self.data_dir / jsonl_file] + for jsonl_file in jsonl_files: + for item in tqdm( + # Note: People's Speech manifest.json is really a JSONL. + load_jsonl(jsonl_file), + desc=f"Processing {jsonl_file}", + ): + self.data.append(item) + break + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + file_path = self.data_dir / self.data[idx]["wav"] + audio = s3tokenizer.load_audio(file_path) + if audio.shape[0] / 16000 > 30: + print( + f"do not support extract speech token for audio longer than 30s, file_path: {file_path}" # noqa + ) + mel = torch.zeros(128, 0) + else: + mel = s3tokenizer.log_mel_spectrogram(audio) + return self.data[idx], mel + + +def collate_fn(batch): + keys = [item[0] for item in batch] + mels = [item[1] for item in batch] + mels, mels_lens = s3tokenizer.padding(mels) + return keys, mels, mels_lens + + +def init_distributed(): + world_size = int(os.environ.get("WORLD_SIZE", 1)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + rank = int(os.environ.get("RANK", 0)) + print( + "Inference on multiple gpus, this gpu {}".format(local_rank) + + ", rank {}, world_size {}".format(rank, world_size) + ) + torch.cuda.set_device(local_rank) + dist.init_process_group("nccl") + return world_size, local_rank, rank + + +def get_args(): + parser = argparse.ArgumentParser(description="extract speech code") + parser.add_argument( + "--model", + required=True, + type=str, + choices=[ + "speech_tokenizer_v1", + "speech_tokenizer_v1_25hz", + "speech_tokenizer_v2_25hz", + ], + help="model version", + ) + parser.add_argument( + "--data_dir", + required=True, + type=str, + help="each line contains `wav_name wav_path`", + ) + parser.add_argument( + "--jsonl_file", + required=True, + type=str, + help="each line contains `wav_name wav_path`", + ) + parser.add_argument( + "--device", + required=True, + type=str, + choices=["cuda", "cpu"], + help="device for inference", + ) + parser.add_argument( + "--output_dir", required=True, type=str, help="dir to save result" + ) + parser.add_argument( + "--batch_size", + required=True, + type=int, + help="batch size (per-device) for inference", + ) + parser.add_argument( + "--num_workers", type=int, default=4, help="workers for dataloader" + ) + parser.add_argument( + "--prefetch", type=int, default=5, help="prefetch for dataloader" + ) + args = parser.parse_args() + return args + + +def main(): + args = get_args() + os.makedirs(args.output_dir, exist_ok=True) + + if args.device == "cuda": + assert torch.cuda.is_available() + world_size, local_rank, rank = init_distributed() + else: + world_size, local_rank, rank = 1, 0, 0 + + device = torch.device(args.device) + model = s3tokenizer.load_model(args.model).to(device) + dataset = AudioDataset(args.data_dir, args.jsonl_file) + + if args.device == "cuda": + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[local_rank] + ) + sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) + else: + sampler = None + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + sampler=sampler, + shuffle=False, + num_workers=args.num_workers, + prefetch_factor=args.prefetch, + collate_fn=collate_fn, + ) + + total_steps = len(dataset) + + if rank == 0: + progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs") + + writer = open(f"{args.output_dir}/part_{rank + 1}_of_{world_size}", "w") + for keys, mels, mels_lens in dataloader: + codes, codes_lens = model(mels.to(device), mels_lens.to(device)) + for i, k in enumerate(keys): + code = codes[i, : codes_lens[i].item()].tolist() + k["code"] = code + writer.write(json.dumps(k, ensure_ascii=False) + "\n") + if rank == 0: + progress_bar.update(world_size * len(keys)) + + if rank == 0: + progress_bar.close() + writer.close() + if args.device == "cuda": + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/egs/emilia/TTS/prepare.sh b/egs/emilia/TTS/prepare.sh new file mode 100755 index 0000000000..8abcfaf612 --- /dev/null +++ b/egs/emilia/TTS/prepare.sh @@ -0,0 +1,115 @@ +#!/usr/bin/env bash + +set -eou pipefail + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +stage=3 +stop_stage=4 + +# Please download the OpenDataLab format from HuggingFace, you can specify the revision argument to fc71e07e8572f5f3be1dbd02ed3172a4d298f152, which is the old format. +# https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07e8572f5f3be1dbd02ed3172a4d298f152 +dl_dir=$PWD/download + +prefix="emilia" +# zh, en, ja, ko, de, fr +lang_set=("de" "en" "zh" "ja" "ko" "fr") +. shared/parse_options.sh || exit 1 + + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "dl_dir: $dl_dir" + log "Stage 0: Download data" + # Extract the downloaded data: + cat $dl_dir/raw/EN/EN_B00008.tar.gz.* > $dl_dir/raw/EN/EN_B00008.tar.gz + for lang in "${lang_set[@]}"; do + lang_upper=$(echo "${lang}" | tr '[:lower:]' '[:upper:]') + folder=$dl_dir/raw/${lang_upper} + for file in $folder/*.tar.gz; do + echo "Processing ${file}" + tar -xzvf $file -C $folder + done + done +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare emilia manifest (used by ./f5-tts)" + # We assume that you have downloaded the Emilia corpus + # to $dl_dir/emilia + mkdir -p data/manifests + for lang in "${lang_set[@]}"; do + echo "Processing ${lang}" + if [ ! -e data/manifests/.emilia.${lang}.done ]; then + lhotse prepare emilia $dl_dir data/manifests --num-jobs 30 --lang "${lang}" + touch data/manifests/.emilia.${lang}.done + fi + done +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Generate fbank (used by ./f5-tts)" + mkdir -p data/fbank + for lang in "${lang_set[@]}"; do + echo "Processing ${lang}" + if [ ! -e data/fbank/.emilia.${lang}.done ]; then + ./local/compute_mel_feat.py --dataset-parts $lang --split 100 --prefix ${prefix} + touch data/fbank/.emilia.${lang}.done + fi + done +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Extract cosyvoice2 FSQ token (used by ./llaasa_cosyvoice2_token)" + for lang in "${lang_set[@]}"; do + lang_upper=$(echo "${lang}" | tr '[:lower:]' '[:upper:]') + data_dir=$dl_dir/raw/${lang_upper} + # for all jsonl files in data_dir + for jsonl_file in $data_dir/*.jsonl; do + # get the file basename + jsonl_file_basename=$(basename $jsonl_file) + echo "Processing $jsonl_file" + output_dir="./cosy_v2_tokens_${lang_upper}/${jsonl_file_basename%.jsonl}" + echo "output_dir: $output_dir" + # skip if the output_dir exists + if [ -e $output_dir ]; then + echo "Output directory $output_dir already exists, skipping" + continue + fi + mkdir -p $output_dir + torchrun --nproc_per_node=8 --nnodes=1 \ + --rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \ + local/extract_cosyvoice2_token.py --data_dir $data_dir \ + --jsonl_file $jsonl_file_basename \ + --device "cuda" \ + --output_dir $output_dir \ + --batch_size 32 \ + --num_workers 2 \ + --model "speech_tokenizer_v2_25hz" # or "speech_tokenizer_v1_25hz + done + done +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Merge cosyvoice2 FSQ token (used by ./llaasa_cosyvoice2_token)" + for lang in "${lang_set[@]}"; do + lang_upper=$(echo "${lang}" | tr '[:lower:]' '[:upper:]') + cosy_token_dir="./cosy_v2_tokens_${lang_upper}" + for dir in $cosy_token_dir/*; do + echo "Processing $dir" + dir_basename=$(basename $dir) + echo "dir_basename: $dir_basename" + cat $dir/part* > $dir/${dir_basename}.jsonl + done + cat $cosy_token_dir/${lang_upper}*/*.jsonl > $cosy_token_dir/cosy_v2_tokens_${lang_upper}.jsonl + done +fi diff --git a/egs/emilia/TTS/shared b/egs/emilia/TTS/shared new file mode 120000 index 0000000000..4c5e91438c --- /dev/null +++ b/egs/emilia/TTS/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/README.md b/egs/wenetspeech4tts/TTS/README.md index 8329ae9484..f1c57d853d 100644 --- a/egs/wenetspeech4tts/TTS/README.md +++ b/egs/wenetspeech4tts/TTS/README.md @@ -9,20 +9,6 @@ [**WenetSpeech4TTS**](https://huggingface.co/datasets/Wenetspeech4TTS/WenetSpeech4TTS) is a multi-domain **Mandarin** corpus derived from the open-sourced [WenetSpeech](https://arxiv.org/abs/2110.03370) dataset. -> [!CAUTION] -> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS). -> While these recipes has the potential to advance various fields such as accessibility, language education, and AI-driven solutions, it also carries certain ethical and legal responsibilities. -> -> By using this framework, you agree to the following: -> 1. Legal and Ethical Use: You shall not use this framework, or any models derived from it, for any unlawful or unethical purposes. This includes, but is not limited to: Creating voice clones without the explicit, informed consent of the individual whose voice is being cloned. Engaging in any form of identity theft, impersonation, or fraud using cloned voices. Violating any local, national, or international laws regarding privacy, intellectual property, or personal data. -> -> 2. Responsibility of Use: The users of this framework are solely responsible for ensuring that their use of voice cloning technologies complies with all applicable laws and ethical guidelines. We explicitly disclaim any liability for misuse of the technology. -> -> 3. Attribution and Use of Open-Source Components: This project is provided under the Apache 2.0 license. Users must adhere to the terms of this license and provide appropriate attribution when required. -> -> 4. No Warranty: This framework is provided “as-is,” without warranty of any kind, either express or implied. We do not guarantee that the use of this software will comply with legal requirements or that it will not infringe the rights of third parties. - - # [VALL-E](https://arxiv.org/abs/2301.02111) ./valle contains the code for training VALL-E TTS model. @@ -186,3 +172,5 @@ bash local/compute_wer.sh $output_dir $manifest - [VALL-E](https://github.com/lifeiteng/vall-e) - [F5-TTS](https://github.com/SWivid/F5-TTS) - [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) +- [S3Tokenizer](https://github.com/xingchensong/S3Tokenizer/tree/main) +- [Spark-TTS](https://github.com/SparkAudio/Spark-TTS) diff --git a/egs/wenetspeech4tts/TTS/f5-tts/infer.py b/egs/wenetspeech4tts/TTS/f5-tts/infer.py index 6964a43be6..b90657d0e2 100644 --- a/egs/wenetspeech4tts/TTS/f5-tts/infer.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/infer.py @@ -108,13 +108,6 @@ def get_parser(): help="Interpolate semantic token to match mel frames for CosyVoice", ) - parser.add_argument( - "--use-cosyvoice-semantic-token", - type=str2bool, - default=False, - help="Whether to use cosyvoice semantic token to replace text token.", - ) - parser.add_argument( "--split-name", type=str, diff --git a/egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py b/egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py new file mode 100644 index 0000000000..636720f032 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/infer_dist.py @@ -0,0 +1,373 @@ +# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song) +# 2025 (authors: Yuekai Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py +""" Example Usage +split=test_zh +llm_path=f5-tts/exp_zh/checkpoint-805000 +huggingface-cli download --local-dir f5-tts-small-wenetspeech4tts-basic yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic +model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt +huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir ./bigvgan_v2_24khz_100band_256x +vocoder=./bigvgan_v2_24khz_100band_256x +torchrun --nproc_per_node=2 \ + f5-tts/infer_dist.py \ + --output_dir $output_dir \ + --batch_size 1 \ + --num_workers 2 \ + --llm-model-name-or-path $llm_path \ + --flow-matching-model-path $model_path \ + --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \ + --use-cosyvoice-semantic-token True \ + --vocoder-dir $vocoder \ + --split-name $split -top-k 50 -top-p 0.95 -temperature 0.8 \ + --tokenizer-dir Qwen/Qwen2.5-0.5B-Instruct +""" + +import argparse +import json +import os +from pathlib import Path + +import s3tokenizer +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torchaudio +from bigvganinference import BigVGANInference +from datasets import load_dataset +from lhotse.serialization import load_jsonl +from llm_tts import LLMTTS +from model.modules import MelSpec +from torch.utils.data import DataLoader, Dataset, DistributedSampler +from tqdm import tqdm +from train import ( + add_model_arguments, + get_model, + get_tokenizer, + interpolate_tokens, + load_F5_TTS_pretrained_checkpoint, +) + +from icefall.checkpoint import load_checkpoint + +TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" + + +def get_args(): + parser = argparse.ArgumentParser(description="extract speech code") + parser.add_argument( + "--s3-tokenizer-name", + required=False, + type=str, + choices=[ + "speech_tokenizer_v1", + "speech_tokenizer_v1_25hz", + "speech_tokenizer_v2_25hz", + ], + help="model version", + ) + parser.add_argument( + "--split-name", + type=str, + default="wenetspeech4tts", + choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"], + help="huggingface dataset split name", + ) + parser.add_argument( + "--output-dir", required=True, type=str, help="dir to save result" + ) + parser.add_argument( + "--batch-size", + required=True, + type=int, + help="batch size (per-device) for inference", + ) + parser.add_argument( + "--num-workers", type=int, default=4, help="workers for dataloader" + ) + parser.add_argument( + "--prefetch", type=int, default=5, help="prefetch for dataloader" + ) + parser.add_argument( + "--llm-model-name-or-path", + required=True, + type=str, + help="model version", + ) + parser.add_argument( + "--tokenizer-dir", + required=True, + type=str, + help="tokenizer dir", + ) + parser.add_argument( + "--vocoder-dir", + required=True, + type=str, + help="vocoder dir", + ) + parser.add_argument( + "--flow-matching-model-path", + required=True, + type=str, + help="flow matching model path", + ) + parser.add_argument( + "--top-k", + type=int, + default=50, + help="top k for sampling", + ) + parser.add_argument( + "--top-p", + type=float, + default=0.95, + help="top p for sampling", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.8, + help="temperature for sampling", + ) + add_model_arguments(parser) + args = parser.parse_args() + return args + + +def padded_mel_batch(ref_mels): + max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax() + padded_ref_mels = [] + for mel in ref_mels: + padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0) + padded_ref_mels.append(padded_ref_mel) + padded_ref_mels = torch.stack(padded_ref_mels) + padded_ref_mels = padded_ref_mels.permute(0, 2, 1) + return padded_ref_mels + + +def data_collator(batch, tokenizer, mel_spectrogram): + speech_generation_start_index = tokenizer.convert_tokens_to_ids( + "<|SPEECH_GENERATION_START|>" + ) + assistant_index = tokenizer.convert_tokens_to_ids("assistant") + target_sample_rate = 24000 + hop_length = 256 + target_rms = 0.1 + input_ids_list, ref_mel_list, ref_mel_len_list = [], [], [] + for i, item in enumerate(batch): + prompt_text, target_text, prompt_audio_codes = ( + item["prompt_text"], + item["target_text"], + item["prompt_audio_cosy2_tokens"], + ) + message = [ + { + "role": "user", + "content": f"Convert the text to speech: {prompt_text + target_text}", + }, + {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"}, + ] + + input_ids = tokenizer.apply_chat_template( + message, + tokenize=True, + chat_template=TEMPLATE, + ) + + prompt_audio_codes = [c + 151665 for c in prompt_audio_codes] + + idx = input_ids.index(speech_generation_start_index) + input_ids = input_ids[:idx] + prompt_audio_codes + input_ids_list.append(input_ids) + + # get flow matching model's prompt mel spectrogram + ref_audio_org, ref_sr = ( + item["prompt_audio"]["array"], + item["prompt_audio"]["sampling_rate"], + ) + ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float() + ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org))) + if ref_rms < target_rms: + ref_audio_org = ref_audio_org * target_rms / ref_rms + + if ref_sr != target_sample_rate: + resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) + ref_audio = resampler(ref_audio_org) + else: + ref_audio = ref_audio_org + + # Duration in mel frame length + ref_mel_len = ref_audio.shape[-1] // hop_length + # to mel spectrogram + ref_mel = mel_spectrogram(ref_audio) + ref_mel = ref_mel.squeeze(0) + + ref_mel_list.append(ref_mel) + ref_mel_len_list.append(ref_mel_len) + + max_len = max([len(input_ids) for input_ids in input_ids_list]) + input_ids_list = [ + [tokenizer.pad_token_id] * (max_len - len(input_ids)) + input_ids + for input_ids in input_ids_list + ] + input_ids = torch.tensor(input_ids_list, dtype=torch.int64) + attention_mask = input_ids.ne(tokenizer.pad_token_id).long() + ids = [item["id"] for item in batch] + + ref_mel_batch = padded_mel_batch(ref_mel_list) + ref_mel_len_batch = torch.LongTensor(ref_mel_len_list) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "ids": ids, + "ref_mel_batch": ref_mel_batch, + "ref_mel_len_batch": ref_mel_len_batch, + } + + +def init_distributed(): + world_size = int(os.environ.get("WORLD_SIZE", 1)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + rank = int(os.environ.get("RANK", 0)) + print( + "Inference on multiple gpus, this gpu {}".format(local_rank) + + ", rank {}, world_size {}".format(rank, world_size) + ) + torch.cuda.set_device(local_rank) + dist.init_process_group("nccl") + return world_size, local_rank, rank + + +def main(): + args = get_args() + os.makedirs(args.output_dir, exist_ok=True) + + assert torch.cuda.is_available() + world_size, local_rank, rank = init_distributed() + device = torch.device(f"cuda:{local_rank}") + model = LLMTTS( + model_dir=args.llm_model_name_or_path, + tokenizer_dir=args.tokenizer_dir, + s3_tokenizer_name=args.s3_tokenizer_name, + device=device, + ) + + vocoder = BigVGANInference.from_pretrained(args.vocoder_dir, use_cuda_kernel=False) + vocoder = vocoder.eval().to(device) + + flow_matching_model = get_model(args).eval().to(device) + _ = load_checkpoint( + args.flow_matching_model_path, + model=flow_matching_model, + ) + + dataset = load_dataset( + "yuekai/seed_tts_cosy2", + split=args.split_name, + trust_remote_code=True, + ) + + sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) + + mel_spectrogram = MelSpec( + n_fft=1024, + hop_length=256, + win_length=1024, + n_mel_channels=100, + target_sample_rate=24000, + mel_spec_type="bigvgan", + ) + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + sampler=sampler, + shuffle=False, + num_workers=args.num_workers, + prefetch_factor=args.prefetch, + collate_fn=lambda x: data_collator(x, model.tokenizer, mel_spectrogram), + ) + + total_steps = len(dataset) + + if rank == 0: + progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs") + + for batch in dataloader: + generate_codes = model.inference_batch( + batch["input_ids"], + batch["attention_mask"], + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + ) + flow_matching_input_tokens, total_mel_lens = [], [] + for i, code in enumerate(generate_codes): + flow_matching_input_token = interpolate_tokens(code) + total_mel_len = len(flow_matching_input_token) + flow_matching_input_tokens.append(flow_matching_input_token) + total_mel_lens.append(total_mel_len) + total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device) + ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch[ + "ref_mel_len_batch" + ].to(device) + + max_len = max([len(tokens) for tokens in flow_matching_input_tokens]) + # pad tokens to the same length + for i, tokens in enumerate(flow_matching_input_tokens): + flow_matching_input_tokens[i] = torch.tensor( + tokens + [-1] * (max_len - len(tokens)), dtype=torch.long + ) + flow_matching_input_tokens = torch.stack(flow_matching_input_tokens).to(device) + generated, _ = flow_matching_model.sample( + cond=ref_mels, + text=flow_matching_input_tokens, + duration=total_mel_lens, + lens=ref_mel_lens, + steps=16, + cfg_strength=2.0, + sway_sampling_coef=-1, + no_ref_audio=False, + seed=0, + ) + + for i, gen in enumerate(generated): + gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) + gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32) + + generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu() + target_rms = 0.1 + target_sample_rate = 24_000 + # if ref_rms_list[i] < target_rms: + # generated_wave = generated_wave * ref_rms_list[i] / target_rms + utt = batch["ids"][i] + torchaudio.save( + f"{args.output_dir}/{utt}.wav", + generated_wave, + target_sample_rate, + ) + + if rank == 0: + progress_bar.update(world_size * len(batch["ids"])) + + if rank == 0: + progress_bar.close() + + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py b/egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py new file mode 100644 index 0000000000..1d0fdc5c8d --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/llm_tts.py @@ -0,0 +1,110 @@ +# Copyright (c) 2025 SparkAudio +# 2025 Xinsheng Wang (w.xinshawn@gmail.com) +# 2025 Yuekai Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from https://github.com/SparkAudio/Spark-TTS/blob/main/cli/SparkTTS.py + +import re +from pathlib import Path +from typing import List + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + + +class LLMTTS: + """ + LLM-TTS for text-to-speech generation. + """ + + def __init__( + self, + model_dir: Path, + tokenizer_dir: Path, + s3_tokenizer_name: str, + device: torch.device, + ): + """ + Initializes the LLMTTS model with the provided configurations and device. + + Args: + model_dir (Path): Directory containing the model and config files. + tokenizer_dir (Path): Directory containing the tokenizer files. + s3_tokenizer_name (str): Name of the tokenizer file on S3. + device (torch.device): Device to run the model on. + """ + self.device = device + + self.model = AutoModelForCausalLM.from_pretrained( + model_dir, + torch_dtype=torch.float16, + device_map=device, + attn_implementation="flash_attention_2", + ) + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) + self.original_vocab_size = len(tokenizer) + self.cosyvoice2_token_vocab_size = 6561 + new_tokens = [f"<|s_{i}|>" for i in range(self.cosyvoice2_token_vocab_size)] + [ + "<|SPEECH_GENERATION_START|>" + ] + num_added_tokens = tokenizer.add_tokens(new_tokens) + tokenizer.padding_side = "left" + self.tokenizer = tokenizer + self.assistant_index = tokenizer.convert_tokens_to_ids("assistant") + + @torch.no_grad() + def inference_batch( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + temperature: float = 0.8, + top_k: float = 50, + top_p: float = 0.95, + max_new_tokens: int = 1024, + ) -> torch.Tensor: + """ + Performs inference to generate speech from text, incorporating prompt audio and/or text. + + Args: + input_ids (torch.Tensor): Input IDs for the model. + attention_mask (torch.Tensor): Attention mask for the model. + temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8. + top_k (float, optional): Top-k sampling parameter. Default is 50. + top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95. + max_new_tokens (int, optional): Maximum number of tokens to generate. Default is 1024. + + Returns: + torch.Tensor: Generated waveform as a tensor. + """ + generated_ids = self.model.generate( + input_ids=input_ids.to(self.device), + attention_mask=attention_mask.to(self.device), + max_new_tokens=max_new_tokens, + do_sample=True, + top_k=top_k, + top_p=top_p, + temperature=temperature, + ) + results = [] + generated_ids = generated_ids.cpu().tolist() + for i in range(len(generated_ids)): + assistant_index = generated_ids[i].index(self.assistant_index) + padding_index = len(generated_ids[i]) + # WAR: harding coding assistant_index + 2, for the current template Assistant: \n + result = generated_ids[i][assistant_index + 2 :] + result = [token - self.original_vocab_size for token in result] + result = [token for token in result if token >= 0] + results.append(result) + return results diff --git a/egs/wenetspeech4tts/TTS/f5-tts/train.py b/egs/wenetspeech4tts/TTS/f5-tts/train.py index 5333b3f277..343d0c65ca 100755 --- a/egs/wenetspeech4tts/TTS/f5-tts/train.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/train.py @@ -118,6 +118,13 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Number of Decoder layers.", ) + parser.add_argument( + "--use-cosyvoice-semantic-token", + type=str2bool, + default=False, + help="Whether to use cosyvoice semantic token to replace text token.", + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -313,13 +320,6 @@ def get_parser(): help="perform OOM check on dataloader batches before starting training.", ) - parser.add_argument( - "--use-cosyvoice-semantic-token", - type=str2bool, - default=False, - help="Whether to use cosyvoice semantic token to replace text token.", - ) - add_model_arguments(parser) return parser