-
Notifications
You must be signed in to change notification settings - Fork 3.2k
New ML yaml + changes to allow for Spectral Codec training with text context #14894
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
blisc
merged 11 commits into
NVIDIA-NeMo:magpietts_2508
from
blisc:magpietts_2508_jasondev0
Oct 21, 2025
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
cebb6ee
Add new config
blisc d1f6e49
update wandb configs
blisc 2b01cb1
update config
blisc 7d2b988
add separate tokenizer for text condition
blisc 3480b46
update codec loading
blisc 9131c64
merge latest changes
blisc a436004
add it tokenizer
blisc 5c895e1
fix attempt 1
blisc f977815
add an additional +1 for dataset
blisc f4c7181
Merge branch 'magpietts_2508' into magpietts_2508_jasondev0
blisc c91a808
Apply isort and black reformatting
blisc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
261 changes: 261 additions & 0 deletions
261
examples/tts/conf/magpietts/magpietts_multilingual_v2_lhotse.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,261 @@ | ||
| name: Magpie-TTS-ML | ||
|
|
||
| quadratic_duration: 20 # both training and validation datasets can apply same quadratic_duration. | ||
| # Dataset metadata for each manifest | ||
| # https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41 | ||
| train_ds_meta: ??? | ||
| val_ds_meta: ??? | ||
|
|
||
| model: | ||
| use_lhotse: true | ||
| model_type: "decoder_ce" # single_encoder_sv_tts, decoder_context_tts or decoder_pretrain_synthesizer | ||
| use_text_conditioning_encoder: true # If true, distilbert will be used to encode context_text if provided. | ||
| text_conditioning_tokenizer_name: text_ce_tokenizer | ||
| context_duration_min: 5.0 | ||
| context_duration_max: 5.0 | ||
| load_cached_codes_if_available: true | ||
| prior_scaling_factor: 0.5 | ||
| prior_end_step: 12000 | ||
| prior_scaledown_start_step: 8000 | ||
| indefinite_prior_prob: 0. # If > 0, then prior will be applied after prior_end_step with this probability. | ||
| alignment_loss_scale: 0.002 | ||
| embedding_dim: 768 | ||
| codecmodel_path: ??? | ||
| cfg_unconditional_prob: 0.1 | ||
| # Alignment encoder parameters, to binarize the prior | ||
| # This is used for attention-constrained training and inference | ||
| use_alignment_encoder: false | ||
|
|
||
| # Local transformer parameters for autoregressive codebook prediction within a frame | ||
| local_transformer_type: "autoregressive" # "none", "autoregressive", "maskgit" | ||
| # Below args are only relevant if use_local_transformer is true | ||
| local_transformer_loss_scale: 1.0 | ||
| local_transformer_n_layers: 1 | ||
| local_transformer_n_heads: 1 | ||
| local_transformer_hidden_dim: 256 | ||
|
|
||
| text_context_remapping_json: null # JSON file defining mapping of multiple text contexts to a single text context. Does not need to cover all text contexts. | ||
| text_context_remapping_prob: 0.0 # Probability of remapping the original text context to a remapped text context. | ||
|
|
||
| text_tokenizers: # Add more languages for multi-lingual TTS | ||
| english_phoneme: | ||
| _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer | ||
| punct: true | ||
| apostrophe: true | ||
| pad_with_space: false | ||
| g2p: | ||
| _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p | ||
| phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" | ||
| heteronyms: "scripts/tts_dataset_files/heteronyms-052722" | ||
| phoneme_probability: 0.8 | ||
| ignore_ambiguous_words: false | ||
| use_chars: true | ||
| use_stresses: true | ||
| spanish_phoneme: | ||
| _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer | ||
| locale: es-ES | ||
| punct: true | ||
| apostrophe: true | ||
| pad_with_space: true | ||
| g2p: | ||
| _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p | ||
| locale: es-ES | ||
| phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict" | ||
| phoneme_probability: 0.8 | ||
| ignore_ambiguous_words: false | ||
| use_chars: true | ||
| use_stresses: true | ||
| german_phoneme: | ||
| _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer | ||
| locale: de-DE | ||
| punct: true | ||
| apostrophe: true | ||
| pad_with_space: true | ||
| g2p: | ||
| _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p | ||
| locale: 'de-DE' | ||
| phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict" | ||
| heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym" | ||
| phoneme_probability: 0.8 | ||
| ignore_ambiguous_words: false | ||
| use_chars: true | ||
| use_stresses: true | ||
| grapheme_case: mixed | ||
| grapheme_prefix: '#' | ||
| mandarin_phoneme: | ||
| _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.ChinesePhonemesTokenizer | ||
| punct: true | ||
| apostrophe: true | ||
| pad_with_space: true | ||
| g2p: | ||
| _target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p | ||
| phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt" | ||
| word_segmenter: "jieba" | ||
| phoneme_prefix: "" | ||
| phoneme_case: "lower" | ||
| tone_prefix: "#" | ||
| ascii_letter_prefix: "" | ||
| ascii_letter_case: "upper" | ||
| french_chartokenizer: | ||
| _target_: AutoTokenizer | ||
| pretrained_model: "google/byt5-small" | ||
| hindi_phoneme: | ||
| _target_: AutoTokenizer | ||
| pretrained_model: "google/byt5-small" | ||
| italian_phoneme: | ||
| _target_: AutoTokenizer | ||
| pretrained_model: "google/byt5-small" | ||
| vietnamese_phoneme: | ||
| _target_: AutoTokenizer | ||
| pretrained_model: "google/byt5-small" | ||
| text_ce_tokenizer: | ||
| _target_: AutoTokenizer | ||
| pretrained_model: "google/byt5-small" | ||
|
|
||
| train_ds: | ||
| use_lhotse: ${model.use_lhotse} | ||
| volume_norm: true | ||
|
|
||
| dataset: | ||
| min_duration: 0.2 | ||
| min_context_speaker_similarity: 0.6 | ||
| max_cer: 0.03 | ||
| batch_duration : ??? # in seconds. Adjust based on your GPU memory. | ||
| quadratic_duration: ${quadratic_duration} | ||
| use_bucketing: true | ||
| num_buckets: 20 | ||
| bucket_buffer_size: 20_000 | ||
| shuffle_buffer_size: 20_000 | ||
| num_cuts_for_bins_estimate: 20_000 | ||
| shard_seed: "trng" | ||
| drop_last: true | ||
| shuffle: true | ||
| num_workers: 6 | ||
| pin_memory: true | ||
|
|
||
| input_cfg: | ||
| - type: lhotse_shar | ||
| shar_path: ??? | ||
| weight: 1.0 | ||
| tags: | ||
| tokenizer_names: ["english_phoneme"] | ||
|
|
||
|
|
||
| validation_ds: | ||
| use_lhotse: ${model.use_lhotse} | ||
| volume_norm: true | ||
|
|
||
| dataset: | ||
| min_duration: 0.2 | ||
| min_context_speaker_similarity: 0.6 | ||
| max_cer: 0.03 | ||
| batch_duration: ??? # recommend to use smaller batch_duration for validation dataset than training dataset. | ||
| quadratic_duration: ${quadratic_duration} | ||
| use_bucketing: false | ||
| force_finite: true | ||
| drop_last: false | ||
| shuffle: false | ||
| num_workers: 2 | ||
| pin_memory: true | ||
|
|
||
| input_cfg: | ||
| - type: lhotse_shar | ||
| shar_path: ??? | ||
| weight: 1.0 | ||
| tags: | ||
| tokenizer_names: ["english_phoneme"] | ||
|
|
||
| encoder: | ||
| n_layers: 6 | ||
| d_model: 768 | ||
| d_ffn: 3072 | ||
| sa_n_heads: 12 | ||
| kernel_size: 3 | ||
| p_dropout: 0.1 | ||
| p_dropout_out: 0.0 | ||
| has_xattn: false | ||
| is_causal: true | ||
| apply_norm_out: true | ||
| max_length_causal_mask: 2048 | ||
| use_learnable_pos_emb: true | ||
|
|
||
| context_encoder: # Only used for multi_encoder_context_tts and decoder_ce | ||
| n_layers: 1 | ||
| d_model: 768 | ||
| d_ffn: 3072 | ||
| sa_n_heads: 12 | ||
| kernel_size: 3 | ||
| p_dropout: 0.1 | ||
| p_dropout_out: 0.0 | ||
| has_xattn: false | ||
| is_causal: false | ||
| apply_norm_out: true | ||
| max_length_causal_mask: 2048 | ||
| use_learnable_pos_emb: true | ||
|
|
||
| decoder: | ||
| n_layers: 12 | ||
| d_model: 768 | ||
| d_ffn: 3072 | ||
| sa_n_heads: 12 | ||
| kernel_size: 1 | ||
| p_dropout: 0.1 | ||
| p_dropout_out: 0.0 | ||
| has_xattn: true | ||
| xa_d_head: 128 | ||
| xa_d_memory: 768 | ||
| xa_n_heads: 1 | ||
| is_causal: true | ||
| apply_norm_to_cond: true | ||
| apply_norm_out: true | ||
| max_length_causal_mask: 2048 | ||
| use_learnable_pos_emb: true | ||
| make_prior_window_strict: true | ||
|
|
||
| optim: | ||
| _target_: torch.optim.AdamW | ||
| lr: 2e-4 | ||
|
|
||
| sched: | ||
| name: ExponentialLR | ||
| gamma: 0.998 | ||
|
|
||
| trainer: | ||
| num_nodes: 1 | ||
| devices: -1 | ||
| accelerator: gpu | ||
| strategy: ddp_find_unused_parameters_true | ||
| precision: 32 | ||
| max_steps: ??? | ||
| accumulate_grad_batches: 1 | ||
| enable_checkpointing: False # Provided by exp_manager | ||
| logger: false # Provided by exp_manager | ||
| log_every_n_steps: 100 | ||
| check_val_every_n_epoch: 1 | ||
| limit_train_batches: 1_000 | ||
| val_check_interval: 1_000 | ||
| num_sanity_val_steps: 0 | ||
| benchmark: false | ||
| use_distributed_sampler: false # required because Lhotse has its own handling | ||
| gradient_clip_val: 2.5 | ||
|
|
||
| exp_manager: | ||
| exp_dir: null | ||
| name: ${name} | ||
| create_tensorboard_logger: true | ||
| create_wandb_logger: false | ||
| wandb_logger_kwargs: | ||
| entity: null | ||
| project: null | ||
| group: null | ||
| name: ${name} | ||
| resume: true # enable resume to ensure continuous training log metrics merged on the previous run id. | ||
| create_checkpoint_callback: true | ||
| checkpoint_callback_params: | ||
| monitor: val_loss | ||
| mode: min | ||
| save_top_k: 5 | ||
| save_best_model: true | ||
| always_save_nemo: true | ||
| resume_if_exists: true | ||
| resume_ignore_no_checkpoint: true |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if there is a clean way to handle reflect/replicate padding when input is short. We could have each codec architecture define a minimum length it can handle, and then have
pad_audiozero pad to at least that length: https://github.com/blisc/NeMo/blob/magpietts_2503/nemo/collections/tts/models/audio_codec.py#L453There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, we should do that.