-
Notifications
You must be signed in to change notification settings - Fork 380
[Not for merge] Add Emilia Training Recipe for Llasa (cosyvoice2 token) #1887
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
yuekaizhang
wants to merge
8
commits into
k2-fsa:master
Choose a base branch
from
yuekaizhang:emilia
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 7 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
540430d
add token extraction
fa65870
add training codes
yuekaizhang 0f7ebb7
add llasa infer
yuekaizhang d2b473a
add eval seed tts
yuekaizhang 7623939
clean code
yuekaizhang bc6e113
remove run.sh
yuekaizhang c473192
update results
yuekaizhang 1653b76
update readme and requirements
yuekaizhang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| # Results | ||
| | LLM Model | Flow matching Model | Seed-TTS test_zh CER | Comment | | ||
| |---------------------------------------|----------|-----------|--------| | ||
| | pretrained cosyvoice2 llm | pretrained cosyvoice2 unet | 1.45% | See [paper](https://arxiv.org/abs/2412.10117)| | ||
| | pretrained cosyvoice2 llm | f5-tts-small (wenetspeech4tts) | 1.79% (16 steps) | See [PR](https://github.com/k2-fsa/icefall/pull/1880)| | ||
| | llasa_cosyvoice2_token llm (Emilia 50k hours ZH) | f5-tts-small (wenetspeech4tts) | 1.81% (16 steps) | | | ||
|
|
||
| # Introduction | ||
|
|
||
| [**Emilia**](https://huggingface.co/datasets/amphion/Emilia-Dataset) starts with over 101k | ||
| hours of speech across six languages, covering a wide range of speaking styles to enable more natural and spontaneous speech generation. | ||
|
|
||
| See https://arxiv.org/pdf/2407.05361. | ||
|
|
||
| > [!CAUTION] | ||
| > The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS). | ||
| > While these recipes has the potential to advance various fields such as accessibility, language education, and AI-driven solutions, it also carries certain ethical and legal responsibilities. | ||
| > | ||
| > By using this framework, you agree to the following: | ||
| > 1. Legal and Ethical Use: You shall not use this framework, or any models derived from it, for any unlawful or unethical purposes. This includes, but is not limited to: Creating voice clones without the explicit, informed consent of the individual whose voice is being cloned. Engaging in any form of identity theft, impersonation, or fraud using cloned voices. Violating any local, national, or international laws regarding privacy, intellectual property, or personal data. | ||
| > | ||
| > 2. Responsibility of Use: The users of this framework are solely responsible for ensuring that their use of voice cloning technologies complies with all applicable laws and ethical guidelines. We explicitly disclaim any liability for misuse of the technology. | ||
| > | ||
| > 3. Attribution and Use of Open-Source Components: This project is provided under the Apache 2.0 license. Users must adhere to the terms of this license and provide appropriate attribution when required. | ||
| > | ||
| > 4. No Warranty: This framework is provided “as-is,” without warranty of any kind, either express or implied. We do not guarantee that the use of this software will comply with legal requirements or that it will not infringe the rights of third parties. | ||
|
|
||
|
|
||
|
|
||
|
|
||
| # Llasa (cosyvoice2 token) | ||
|
|
||
| ./llasa_cosyvoice2_token contains the code for training qwen2.5-0.5b models to predict cosyvoice2 semantic tokens. | ||
|
|
||
| Generated samples and training logs of [Emilia](https://huggingface.co/datasets/amphion/Emilia-Dataset) 50k hours Chinese data can be found [here](https://huggingface.co/yuekai/llasa_cosyvoice2_token_qwen_0.5b/tree/main). | ||
|
|
||
| Preparation: | ||
|
|
||
| ``` | ||
| # extract cosyvoice2 semantic tokens | ||
| bash prepare.sh --stage 3 --stop_stage 4 | ||
|
|
||
| # Or you could use the prepared tokens. | ||
| huggingface-cli download yuekai/emilia_cosyvoice_v2_token --local-dir emilia_cosyvoice_v2_token | ||
| ``` | ||
|
|
||
| The training command is given below: | ||
|
|
||
| ``` | ||
| # docker: ghcr.io/swivid/f5-tts:main | ||
| # pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html | ||
| # pip install -r llasa_cosyvoice2_token/requirements.txt | ||
| # pip install -r icefall/egs/wenetspeech4tts/TTS/f5-tts/requirements.txt | ||
|
|
||
| WANDB_KEY=$your_wandb_key | ||
| wandb login ${WANDB_KEY} | ||
| huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir Qwen2.5-0.5B-Instruct | ||
| torchrun --nproc_per_node=8 train.py config.json | ||
| ``` | ||
|
|
||
| To inference with Icefall Emilia trained Chinese Llasa_cosyvoice2_token model, we need to use cosyvoice2 token flow matching [model](https://github.com/k2-fsa/icefall/pull/1880): | ||
| ``` | ||
| cd icefall/egs/wenetspeech4tts/TTS | ||
| huggingface-cli login | ||
| huggingface-cli download --local-dir ${exp_dir} yuekai/llasa_cosyvoice2_token_qwen_0.5b | ||
| huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x | ||
| vocoder=./bigvgan_v2_24khz_100band_256x | ||
| split=test_zh | ||
| llm_path=llasa_cosyvoice2_token_qwen_0.5b/checkpoint-800000 | ||
|
|
||
| huggingface-cli download --local-dir f5-tts-small-wenetspeech4tts-basic yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic | ||
| model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt | ||
| torchrun --nproc_per_node=2 \ | ||
| f5-tts/infer_dist.py \ | ||
| --output_dir $output_dir \ | ||
| --batch_size 1 \ | ||
| --num_workers 2 \ | ||
| --llm-model-name-or-path $llm_path \ | ||
| --flow-matching-model-path $model_path \ | ||
| --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \ | ||
| --use-cosyvoice-semantic-token True \ | ||
| --vocoder-dir $vocoder \ | ||
| --split-name $split -top-k 50 -top-p 0.95 -temperature 0.8 \ | ||
| --tokenizer-dir Qwen/Qwen2.5-0.5B-Instruct | ||
| # compute cer | ||
| huggingface-cli download yuekai/seed_tts_eval --local-dir seed_tts_eval --repo-type dataset | ||
| manifest=./seed_tts_eval/seedtts_testset/zh/meta.lst | ||
| bash local/compute_wer.sh $output_dir $manifest | ||
| ``` | ||
|
|
||
| # Credits | ||
| - [Llasa](https://arxiv.org/abs/2502.04128) | ||
| - [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) | ||
| - [S3Tokenizer](https://github.com/xingchensong/S3Tokenizer/tree/main) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| { | ||
| "llm_model_name_or_path": "./Qwen2.5-0.5B-Instruct", | ||
| "data_path": ["./emilia_cosyvoice_v2_token/cosy_v2_tokens_ZH.jsonl"], | ||
| "bf16": false, | ||
| "output_dir": "./exp_zh", | ||
| "num_train_epochs": 3, | ||
| "per_device_train_batch_size": 8, | ||
| "per_device_eval_batch_size": 8, | ||
| "gradient_accumulation_steps": 1, | ||
| "evaluation_strategy": "steps", | ||
| "eval_steps": 1000, | ||
| "save_strategy": "steps", | ||
| "save_steps": 5000, | ||
| "save_total_limit": 100, | ||
| "learning_rate": 0.00005, | ||
| "weight_decay": 0.01, | ||
| "adam_beta2": 0.95, | ||
| "warmup_ratio": 0.03, | ||
| "lr_scheduler_type": "cosine", | ||
| "logging_steps": 100, | ||
| "report_to": "wandb", | ||
| "model_max_length": 2048, | ||
| "gradient_checkpointing": false, | ||
| "dataloader_num_workers": 4, | ||
| "dataloader_prefetch_factor": 4, | ||
| "deepspeed": "ds_config_zero2.json" | ||
| } |
47 changes: 47 additions & 0 deletions
47
egs/emilia/TTS/llasa_cosyvoice2_token/ds_config_zero2.json
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| { | ||
| "fp16": { | ||
| "enabled": "auto", | ||
| "loss_scale": 0, | ||
| "loss_scale_window": 1000, | ||
| "initial_scale_power": 64, | ||
| "hysteresis": 2, | ||
| "min_loss_scale": 1 | ||
| }, | ||
| "bf16": { | ||
| "enabled": "auto" | ||
| }, | ||
| "optimizer": { | ||
| "type": "AdamW", | ||
| "params": { | ||
| "lr": "auto", | ||
| "betas": "auto", | ||
| "eps": "auto", | ||
| "weight_decay": "auto" | ||
| } | ||
| }, | ||
|
|
||
| "scheduler": { | ||
| "type": "WarmupCosineLR", | ||
| "params": { | ||
| "total_num_steps": "auto", | ||
| "warmup_min_ratio": 0.03, | ||
| "warmup_num_steps": "auto", | ||
| "cos_min_ratio": 0.1 | ||
| } | ||
| }, | ||
|
|
||
| "zero_optimization": { | ||
| "stage": 2, | ||
| "overlap_comm": false, | ||
| "contiguous_gradients": true, | ||
| "sub_group_size": 1e9, | ||
| "reduce_bucket_size": "auto" | ||
| }, | ||
|
|
||
| "gradient_accumulation_steps": "auto", | ||
| "gradient_clipping": 1.0, | ||
| "steps_per_print": 100, | ||
| "train_batch_size": "auto", | ||
| "train_micro_batch_size_per_gpu": "auto", | ||
| "wall_clock_breakdown": false | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| torch | ||
| transformers | ||
| wandb | ||
| datasets | ||
| accelerate>=0.26.0 | ||
| deepspeed | ||
| flash-attn | ||
| s3tokenizer |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,184 @@ | ||
| # Modified from https://github.com/zhenye234/LLaSA_training/blob/main/train_tts.py | ||
| """ Example Usage | ||
| WANDB_KEY=$your_wandb_key | ||
| wandb login ${WANDB_KEY} | ||
| huggingface-cli download yuekai/emilia_cosyvoice_v2_token --local-dir emilia_cosyvoice_v2_token | ||
| huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir Qwen2.5-0.5B-Instruct | ||
| torchrun --nproc_per_node=8 train.py config.json | ||
| """ | ||
| import json | ||
| import os | ||
| import random | ||
| import sys | ||
| from dataclasses import dataclass, field | ||
| from functools import partial | ||
| from typing import List, Optional | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| import torch.nn as nn | ||
| import transformers | ||
| import wandb | ||
| from datasets import load_dataset | ||
| from transformers import ( | ||
| AutoConfig, | ||
| AutoModelForCausalLM, | ||
| AutoTokenizer, | ||
| DataCollatorForSeq2Seq, | ||
| HfArgumentParser, | ||
| Trainer, | ||
| TrainingArguments, | ||
| ) | ||
| from transformers.trainer_pt_utils import LabelSmoother | ||
|
|
||
| IGNORE_TOKEN_ID = LabelSmoother.ignore_index | ||
| TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}" | ||
|
|
||
|
|
||
| @dataclass | ||
| class ModelArguments: | ||
| llm_model_name_or_path: Optional[str] = field( | ||
| default="meta-llama/Llama-3.2-1B-Instruct" | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class DataArguments: | ||
| data_path: List[str] = field( | ||
| default=None, | ||
| metadata={"help": "Root path(s) to the data. Can be single path or list."}, | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class CustomTrainingArguments(TrainingArguments): | ||
| optim: str = field(default="adamw_torch_fused") | ||
| model_max_length: int = field( | ||
| default=2048, | ||
| metadata={"help": "Maximum sequence length"}, | ||
| ) | ||
| logging_steps: int = field(default=100, metadata={"help": "Log every X updates"}) | ||
| report_to: Optional[str] = field( | ||
| default=None, | ||
| metadata={"help": "The integration to report the results and logs to."}, | ||
| ) | ||
| run_name: Optional[str] = field( | ||
| default=None, metadata={"help": "The name of the run for logging."} | ||
| ) | ||
| gradient_checkpointing: bool = field(default=False) | ||
| lr_scheduler_type: str = field( | ||
| default="cosine", metadata={"help": "The learning rate scheduler to use."} | ||
| ) | ||
| remove_unused_columns: bool = field(default=False) | ||
|
|
||
|
|
||
| def data_collator(batch, tokenizer, original_tokenizer_vocab_size, cut_off_len=2048): | ||
| speech_generation_start_index = tokenizer.convert_tokens_to_ids( | ||
| "<|SPEECH_GENERATION_START|>" | ||
| ) | ||
| assistant_index = tokenizer.convert_tokens_to_ids("assistant") | ||
| input_ids_list = [] | ||
| for i, item in enumerate(batch): | ||
| text, code = item["text"], item["code"] | ||
| message = [ | ||
| {"role": "user", "content": f"Convert the text to speech: {text}"}, | ||
| {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"}, | ||
| ] | ||
|
|
||
| input_ids = tokenizer.apply_chat_template( | ||
| message, | ||
| tokenize=True, | ||
| chat_template=TEMPLATE, | ||
| ) | ||
|
|
||
| code = [c + original_tokenizer_vocab_size for c in code] | ||
|
|
||
| idx = input_ids.index(speech_generation_start_index) | ||
| input_ids = input_ids[:idx] + code + input_ids[idx + 1 :] | ||
| if len(input_ids) < cut_off_len: | ||
| input_ids_list.append(input_ids) | ||
|
|
||
| max_len = max([len(input_ids) for input_ids in input_ids_list]) | ||
| input_ids_list = [ | ||
| input_ids + [tokenizer.pad_token_id] * (max_len - len(input_ids)) | ||
| for input_ids in input_ids_list | ||
| ] | ||
| input_ids = torch.tensor(input_ids_list, dtype=torch.int) | ||
| attention_mask = input_ids.ne(tokenizer.pad_token_id) | ||
|
|
||
| target_ids = input_ids.clone() | ||
| target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID | ||
| mask_indices = torch.where(input_ids == assistant_index) | ||
| for i in range(mask_indices[0].size(0)): | ||
| row = mask_indices[0][i] | ||
| col = mask_indices[1][i] | ||
| # + 2 to skip: 'assistant', '\n' | ||
| target_ids[row, : col + 2] = IGNORE_TOKEN_ID | ||
| return { | ||
| "input_ids": input_ids, | ||
| "attention_mask": attention_mask, | ||
| "labels": target_ids.to(dtype=torch.int64), | ||
| } | ||
|
|
||
|
|
||
| def main(): | ||
| parser = transformers.HfArgumentParser( | ||
| (ModelArguments, DataArguments, CustomTrainingArguments) | ||
| ) | ||
| assert len(sys.argv) == 2 and sys.argv[1].endswith(".json") | ||
| ( | ||
| model_args, | ||
| data_args, | ||
| training_args, | ||
| ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) | ||
|
|
||
| is_main_process = training_args.local_rank in [-1, 0] | ||
| if training_args.report_to == "wandb" and is_main_process: | ||
| wandb.init( | ||
| project="llm_tts", | ||
| config=training_args.to_sanitized_dict(), | ||
| name=training_args.run_name, | ||
| ) | ||
|
|
||
| model = AutoModelForCausalLM.from_pretrained( | ||
| model_args.llm_model_name_or_path, | ||
| torch_dtype=torch.float16, | ||
| attn_implementation="flash_attention_2", | ||
| ) | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained(model_args.llm_model_name_or_path) | ||
| original_tokenizer_vocab_size = len(tokenizer) | ||
| cosyvoice2_token_size = 6561 | ||
| new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [ | ||
| "<|SPEECH_GENERATION_START|>" | ||
| ] | ||
| num_added_tokens = tokenizer.add_tokens(new_tokens) | ||
|
|
||
| model.resize_token_embeddings(len(tokenizer)) | ||
| model.vocab_size = len(tokenizer) | ||
|
|
||
| dataset = load_dataset("json", data_files=data_args.data_path) | ||
| dataset = dataset["train"] | ||
| train_test_split = dataset.train_test_split(test_size=100, seed=42) | ||
| train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"] | ||
|
|
||
| trainer = Trainer( | ||
| model=model, | ||
| tokenizer=tokenizer, | ||
| args=training_args, | ||
| train_dataset=train_dataset, | ||
| eval_dataset=eval_dataset, | ||
| data_collator=lambda features: data_collator( | ||
| features, tokenizer, original_tokenizer_vocab_size | ||
| ), | ||
| ) | ||
|
|
||
| if is_main_process: | ||
| trainer.add_callback(transformers.integrations.WandbCallback()) | ||
|
|
||
| trainer.train(resume_from_checkpoint=None) | ||
| trainer.save_model(training_args.output_dir) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think the terms & conditions may have been taken from another framework & the name changed?
may be safest to just delete this . (Assuming we decide it makes sense to merge the PR overall, which we can discuss separately.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I copied from libritts recipe here https://github.com/k2-fsa/icefall/tree/master/egs/libritts/TTS#readme. Deleted now.