Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 261 additions & 0 deletions examples/tts/conf/magpietts/magpietts_multilingual_v2_lhotse.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
name: Magpie-TTS-ML

quadratic_duration: 20 # both training and validation datasets can apply same quadratic_duration.
# Dataset metadata for each manifest
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41
train_ds_meta: ???
val_ds_meta: ???

model:
use_lhotse: true
model_type: "decoder_ce" # single_encoder_sv_tts, decoder_context_tts or decoder_pretrain_synthesizer
use_text_conditioning_encoder: true # If true, distilbert will be used to encode context_text if provided.
text_conditioning_tokenizer_name: text_ce_tokenizer
context_duration_min: 5.0
context_duration_max: 5.0
load_cached_codes_if_available: true
prior_scaling_factor: 0.5
prior_end_step: 12000
prior_scaledown_start_step: 8000
indefinite_prior_prob: 0. # If > 0, then prior will be applied after prior_end_step with this probability.
alignment_loss_scale: 0.002
embedding_dim: 768
codecmodel_path: ???
cfg_unconditional_prob: 0.1
# Alignment encoder parameters, to binarize the prior
# This is used for attention-constrained training and inference
use_alignment_encoder: false

# Local transformer parameters for autoregressive codebook prediction within a frame
local_transformer_type: "autoregressive" # "none", "autoregressive", "maskgit"
# Below args are only relevant if use_local_transformer is true
local_transformer_loss_scale: 1.0
local_transformer_n_layers: 1
local_transformer_n_heads: 1
local_transformer_hidden_dim: 256

text_context_remapping_json: null # JSON file defining mapping of multiple text contexts to a single text context. Does not need to cover all text contexts.
text_context_remapping_prob: 0.0 # Probability of remapping the original text context to a remapped text context.

text_tokenizers: # Add more languages for multi-lingual TTS
english_phoneme:
_target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer
punct: true
apostrophe: true
pad_with_space: false
g2p:
_target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p
phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt"
heteronyms: "scripts/tts_dataset_files/heteronyms-052722"
phoneme_probability: 0.8
ignore_ambiguous_words: false
use_chars: true
use_stresses: true
spanish_phoneme:
_target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer
locale: es-ES
punct: true
apostrophe: true
pad_with_space: true
g2p:
_target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p
locale: es-ES
phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict"
phoneme_probability: 0.8
ignore_ambiguous_words: false
use_chars: true
use_stresses: true
german_phoneme:
_target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer
locale: de-DE
punct: true
apostrophe: true
pad_with_space: true
g2p:
_target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p
locale: 'de-DE'
phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict"
heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym"
phoneme_probability: 0.8
ignore_ambiguous_words: false
use_chars: true
use_stresses: true
grapheme_case: mixed
grapheme_prefix: '#'
mandarin_phoneme:
_target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.ChinesePhonemesTokenizer
punct: true
apostrophe: true
pad_with_space: true
g2p:
_target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p
phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt"
word_segmenter: "jieba"
phoneme_prefix: ""
phoneme_case: "lower"
tone_prefix: "#"
ascii_letter_prefix: ""
ascii_letter_case: "upper"
french_chartokenizer:
_target_: AutoTokenizer
pretrained_model: "google/byt5-small"
hindi_phoneme:
_target_: AutoTokenizer
pretrained_model: "google/byt5-small"
italian_phoneme:
_target_: AutoTokenizer
pretrained_model: "google/byt5-small"
vietnamese_phoneme:
_target_: AutoTokenizer
pretrained_model: "google/byt5-small"
text_ce_tokenizer:
_target_: AutoTokenizer
pretrained_model: "google/byt5-small"

train_ds:
use_lhotse: ${model.use_lhotse}
volume_norm: true

dataset:
min_duration: 0.2
min_context_speaker_similarity: 0.6
max_cer: 0.03
batch_duration : ??? # in seconds. Adjust based on your GPU memory.
quadratic_duration: ${quadratic_duration}
use_bucketing: true
num_buckets: 20
bucket_buffer_size: 20_000
shuffle_buffer_size: 20_000
num_cuts_for_bins_estimate: 20_000
shard_seed: "trng"
drop_last: true
shuffle: true
num_workers: 6
pin_memory: true

input_cfg:
- type: lhotse_shar
shar_path: ???
weight: 1.0
tags:
tokenizer_names: ["english_phoneme"]


validation_ds:
use_lhotse: ${model.use_lhotse}
volume_norm: true

dataset:
min_duration: 0.2
min_context_speaker_similarity: 0.6
max_cer: 0.03
batch_duration: ??? # recommend to use smaller batch_duration for validation dataset than training dataset.
quadratic_duration: ${quadratic_duration}
use_bucketing: false
force_finite: true
drop_last: false
shuffle: false
num_workers: 2
pin_memory: true

input_cfg:
- type: lhotse_shar
shar_path: ???
weight: 1.0
tags:
tokenizer_names: ["english_phoneme"]

encoder:
n_layers: 6
d_model: 768
d_ffn: 3072
sa_n_heads: 12
kernel_size: 3
p_dropout: 0.1
p_dropout_out: 0.0
has_xattn: false
is_causal: true
apply_norm_out: true
max_length_causal_mask: 2048
use_learnable_pos_emb: true

context_encoder: # Only used for multi_encoder_context_tts and decoder_ce
n_layers: 1
d_model: 768
d_ffn: 3072
sa_n_heads: 12
kernel_size: 3
p_dropout: 0.1
p_dropout_out: 0.0
has_xattn: false
is_causal: false
apply_norm_out: true
max_length_causal_mask: 2048
use_learnable_pos_emb: true

decoder:
n_layers: 12
d_model: 768
d_ffn: 3072
sa_n_heads: 12
kernel_size: 1
p_dropout: 0.1
p_dropout_out: 0.0
has_xattn: true
xa_d_head: 128
xa_d_memory: 768
xa_n_heads: 1
is_causal: true
apply_norm_to_cond: true
apply_norm_out: true
max_length_causal_mask: 2048
use_learnable_pos_emb: true
make_prior_window_strict: true

optim:
_target_: torch.optim.AdamW
lr: 2e-4

sched:
name: ExponentialLR
gamma: 0.998

trainer:
num_nodes: 1
devices: -1
accelerator: gpu
strategy: ddp_find_unused_parameters_true
precision: 32
max_steps: ???
accumulate_grad_batches: 1
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager
log_every_n_steps: 100
check_val_every_n_epoch: 1
limit_train_batches: 1_000
val_check_interval: 1_000
num_sanity_val_steps: 0
benchmark: false
use_distributed_sampler: false # required because Lhotse has its own handling
gradient_clip_val: 2.5

exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_wandb_logger: false
wandb_logger_kwargs:
entity: null
project: null
group: null
name: ${name}
resume: true # enable resume to ensure continuous training log metrics merged on the previous run id.
create_checkpoint_callback: true
checkpoint_callback_params:
monitor: val_loss
mode: min
save_top_k: 5
save_best_model: true
always_save_nemo: true
resume_if_exists: true
resume_ignore_no_checkpoint: true
4 changes: 3 additions & 1 deletion nemo/collections/tts/data/text_to_speech_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,9 @@ def __getitem__(self, index):
example['context_audio_codes_len'] = context_audio_codes_len
else:
# @shehzeenh: Added this condition so that a batch does not have a mix of context_audio and context_audio_codes
context_audio = torch.zeros(self.codec_model_samples_per_frame, dtype=torch.float32)
# @blisc: Added a +1. If we send in exactly 882 samples, then a conv layer complains about padding.
# Adding 883 works. This occurs when we use text context during inference.
context_audio = torch.zeros(self.codec_model_samples_per_frame + 1, dtype=torch.float32)
Comment on lines +553 to +555
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if there is a clean way to handle reflect/replicate padding when input is short. We could have each codec architecture define a minimum length it can handle, and then have pad_audio zero pad to at least that length: https://github.com/blisc/NeMo/blob/magpietts_2503/nemo/collections/tts/models/audio_codec.py#L453

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, we should do that.

context_audio_len = context_audio.shape[0]
example['context_audio'] = context_audio
example['context_audio_len'] = context_audio_len
Expand Down
31 changes: 22 additions & 9 deletions nemo/collections/tts/models/magpietts.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,13 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
if trainer is not None:
self.world_size = trainer.num_nodes * trainer.num_devices

# load codec
codec_model = AudioCodecModel.restore_from(cfg.get('codecmodel_path'), strict=False)

# load codec, disable loading of loss modules not needed during inference
codec_model_cfg = AudioCodecModel.restore_from(cfg.get('codecmodel_path'), return_config=True)
if "use_scl_loss" in codec_model_cfg:
codec_model_cfg.use_scl_loss = False
codec_model = AudioCodecModel.restore_from(
cfg.get('codecmodel_path'), strict=False, override_config_path=codec_model_cfg
)
self.sample_rate = codec_model.sample_rate
self.codec_model_samples_per_frame = codec_model.samples_per_frame
# del codec discriminator to free memory
Expand All @@ -108,16 +112,25 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
vector_quantizer = cfg.get('vector_quantizer')
if vector_quantizer is not None:
vector_quantizer = instantiate(vector_quantizer)
self.num_audio_codebooks = vector_quantizer.num_codebooks
self.codebook_size = vector_quantizer.codebook_size
num_audio_codebooks = vector_quantizer.num_codebooks
codebook_size = vector_quantizer.codebook_size
codec_converter = VectorQuantizerIndexConverter(
vector_quantizer_original=codec_model.vector_quantizer,
vector_quantizer_new=vector_quantizer,
)
data_num_audio_codebooks = codec_model.vector_quantizer.num_codebooks
else:
self.num_audio_codebooks = codec_model.num_codebooks
self.codebook_size = codec_model.codebook_size
num_audio_codebooks = codec_model.num_codebooks
data_num_audio_codebooks = num_audio_codebooks
codebook_size = codec_model.codebook_size
codec_converter = None
# The dataloader needs to know the number of codebooks that the context codes were stored in
# In the case where there are no context codes saved, and there is no context audio (in the text context path),
# We create a dummy context code tensor that is only [context_BOS, context_EOS] that is repeated for
# data_num_audio_codebooks
self.data_num_audio_codebooks = data_num_audio_codebooks
self.num_audio_codebooks = num_audio_codebooks
self.codebook_size = codebook_size

# Our codebooks start with actual audio codec tokens, followed by special tokens.
# The `forced_*` options are for backward compatibility for models trained with older code.
Expand Down Expand Up @@ -2648,7 +2661,7 @@ def get_dataset(self, dataset_cfg, dataset_type):
audio_eos_id=self.audio_eos_id,
context_audio_bos_id=self.context_audio_bos_id,
context_audio_eos_id=self.context_audio_eos_id,
num_audio_codebooks=self.num_audio_codebooks,
num_audio_codebooks=self.data_num_audio_codebooks,
codec_model_samples_per_frame=self.codec_model_samples_per_frame,
prior_scaling_factor=self.cfg.prior_scaling_factor,
load_cached_codes_if_available=self.cfg.load_cached_codes_if_available,
Expand Down Expand Up @@ -2678,7 +2691,7 @@ def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.D
audio_eos_id=self.audio_eos_id,
context_audio_bos_id=self.context_audio_bos_id,
context_audio_eos_id=self.context_audio_eos_id,
num_audio_codebooks=self.num_audio_codebooks,
num_audio_codebooks=self.data_num_audio_codebooks,
prior_scaling_factor=self.cfg.prior_scaling_factor,
load_cached_codes_if_available=self.cfg.load_cached_codes_if_available,
dataset_type=mode, # train or test used for setting phone prob to 1.0 in test dataset (worker_init_fn)
Expand Down
Loading