Skip to content

Commit 22be3f4

Browse files
authored
New ML yaml + changes to allow for Spectral Codec training with text context (#14894)
* Add new config Signed-off-by: Jason <[email protected]> * update wandb configs Signed-off-by: Jason <[email protected]> * update config Signed-off-by: Jason <[email protected]> * add separate tokenizer for text condition Signed-off-by: Jason <[email protected]> * update codec loading Signed-off-by: Jason <[email protected]> * add it tokenizer Signed-off-by: Jason <[email protected]> * fix attempt 1 Signed-off-by: Jason <[email protected]> * add an additional +1 for dataset Signed-off-by: Jason <[email protected]> * Apply isort and black reformatting Signed-off-by: blisc <[email protected]> --------- Signed-off-by: Jason <[email protected]> Signed-off-by: blisc <[email protected]> Co-authored-by: blisc <[email protected]>
1 parent 066d622 commit 22be3f4

File tree

3 files changed

+286
-10
lines changed

3 files changed

+286
-10
lines changed
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
name: Magpie-TTS-ML
2+
3+
quadratic_duration: 20 # both training and validation datasets can apply same quadratic_duration.
4+
# Dataset metadata for each manifest
5+
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41
6+
train_ds_meta: ???
7+
val_ds_meta: ???
8+
9+
model:
10+
use_lhotse: true
11+
model_type: "decoder_ce" # single_encoder_sv_tts, decoder_context_tts or decoder_pretrain_synthesizer
12+
use_text_conditioning_encoder: true # If true, distilbert will be used to encode context_text if provided.
13+
text_conditioning_tokenizer_name: text_ce_tokenizer
14+
context_duration_min: 5.0
15+
context_duration_max: 5.0
16+
load_cached_codes_if_available: true
17+
prior_scaling_factor: 0.5
18+
prior_end_step: 12000
19+
prior_scaledown_start_step: 8000
20+
indefinite_prior_prob: 0. # If > 0, then prior will be applied after prior_end_step with this probability.
21+
alignment_loss_scale: 0.002
22+
embedding_dim: 768
23+
codecmodel_path: ???
24+
cfg_unconditional_prob: 0.1
25+
# Alignment encoder parameters, to binarize the prior
26+
# This is used for attention-constrained training and inference
27+
use_alignment_encoder: false
28+
29+
# Local transformer parameters for autoregressive codebook prediction within a frame
30+
local_transformer_type: "autoregressive" # "none", "autoregressive", "maskgit"
31+
# Below args are only relevant if use_local_transformer is true
32+
local_transformer_loss_scale: 1.0
33+
local_transformer_n_layers: 1
34+
local_transformer_n_heads: 1
35+
local_transformer_hidden_dim: 256
36+
37+
text_context_remapping_json: null # JSON file defining mapping of multiple text contexts to a single text context. Does not need to cover all text contexts.
38+
text_context_remapping_prob: 0.0 # Probability of remapping the original text context to a remapped text context.
39+
40+
text_tokenizers: # Add more languages for multi-lingual TTS
41+
english_phoneme:
42+
_target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer
43+
punct: true
44+
apostrophe: true
45+
pad_with_space: false
46+
g2p:
47+
_target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p
48+
phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt"
49+
heteronyms: "scripts/tts_dataset_files/heteronyms-052722"
50+
phoneme_probability: 0.8
51+
ignore_ambiguous_words: false
52+
use_chars: true
53+
use_stresses: true
54+
spanish_phoneme:
55+
_target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer
56+
locale: es-ES
57+
punct: true
58+
apostrophe: true
59+
pad_with_space: true
60+
g2p:
61+
_target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p
62+
locale: es-ES
63+
phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict"
64+
phoneme_probability: 0.8
65+
ignore_ambiguous_words: false
66+
use_chars: true
67+
use_stresses: true
68+
german_phoneme:
69+
_target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer
70+
locale: de-DE
71+
punct: true
72+
apostrophe: true
73+
pad_with_space: true
74+
g2p:
75+
_target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p
76+
locale: 'de-DE'
77+
phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict"
78+
heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym"
79+
phoneme_probability: 0.8
80+
ignore_ambiguous_words: false
81+
use_chars: true
82+
use_stresses: true
83+
grapheme_case: mixed
84+
grapheme_prefix: '#'
85+
mandarin_phoneme:
86+
_target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.ChinesePhonemesTokenizer
87+
punct: true
88+
apostrophe: true
89+
pad_with_space: true
90+
g2p:
91+
_target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p
92+
phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt"
93+
word_segmenter: "jieba"
94+
phoneme_prefix: ""
95+
phoneme_case: "lower"
96+
tone_prefix: "#"
97+
ascii_letter_prefix: ""
98+
ascii_letter_case: "upper"
99+
french_chartokenizer:
100+
_target_: AutoTokenizer
101+
pretrained_model: "google/byt5-small"
102+
hindi_phoneme:
103+
_target_: AutoTokenizer
104+
pretrained_model: "google/byt5-small"
105+
italian_phoneme:
106+
_target_: AutoTokenizer
107+
pretrained_model: "google/byt5-small"
108+
vietnamese_phoneme:
109+
_target_: AutoTokenizer
110+
pretrained_model: "google/byt5-small"
111+
text_ce_tokenizer:
112+
_target_: AutoTokenizer
113+
pretrained_model: "google/byt5-small"
114+
115+
train_ds:
116+
use_lhotse: ${model.use_lhotse}
117+
volume_norm: true
118+
119+
dataset:
120+
min_duration: 0.2
121+
min_context_speaker_similarity: 0.6
122+
max_cer: 0.03
123+
batch_duration : ??? # in seconds. Adjust based on your GPU memory.
124+
quadratic_duration: ${quadratic_duration}
125+
use_bucketing: true
126+
num_buckets: 20
127+
bucket_buffer_size: 20_000
128+
shuffle_buffer_size: 20_000
129+
num_cuts_for_bins_estimate: 20_000
130+
shard_seed: "trng"
131+
drop_last: true
132+
shuffle: true
133+
num_workers: 6
134+
pin_memory: true
135+
136+
input_cfg:
137+
- type: lhotse_shar
138+
shar_path: ???
139+
weight: 1.0
140+
tags:
141+
tokenizer_names: ["english_phoneme"]
142+
143+
144+
validation_ds:
145+
use_lhotse: ${model.use_lhotse}
146+
volume_norm: true
147+
148+
dataset:
149+
min_duration: 0.2
150+
min_context_speaker_similarity: 0.6
151+
max_cer: 0.03
152+
batch_duration: ??? # recommend to use smaller batch_duration for validation dataset than training dataset.
153+
quadratic_duration: ${quadratic_duration}
154+
use_bucketing: false
155+
force_finite: true
156+
drop_last: false
157+
shuffle: false
158+
num_workers: 2
159+
pin_memory: true
160+
161+
input_cfg:
162+
- type: lhotse_shar
163+
shar_path: ???
164+
weight: 1.0
165+
tags:
166+
tokenizer_names: ["english_phoneme"]
167+
168+
encoder:
169+
n_layers: 6
170+
d_model: 768
171+
d_ffn: 3072
172+
sa_n_heads: 12
173+
kernel_size: 3
174+
p_dropout: 0.1
175+
p_dropout_out: 0.0
176+
has_xattn: false
177+
is_causal: true
178+
apply_norm_out: true
179+
max_length_causal_mask: 2048
180+
use_learnable_pos_emb: true
181+
182+
context_encoder: # Only used for multi_encoder_context_tts and decoder_ce
183+
n_layers: 1
184+
d_model: 768
185+
d_ffn: 3072
186+
sa_n_heads: 12
187+
kernel_size: 3
188+
p_dropout: 0.1
189+
p_dropout_out: 0.0
190+
has_xattn: false
191+
is_causal: false
192+
apply_norm_out: true
193+
max_length_causal_mask: 2048
194+
use_learnable_pos_emb: true
195+
196+
decoder:
197+
n_layers: 12
198+
d_model: 768
199+
d_ffn: 3072
200+
sa_n_heads: 12
201+
kernel_size: 1
202+
p_dropout: 0.1
203+
p_dropout_out: 0.0
204+
has_xattn: true
205+
xa_d_head: 128
206+
xa_d_memory: 768
207+
xa_n_heads: 1
208+
is_causal: true
209+
apply_norm_to_cond: true
210+
apply_norm_out: true
211+
max_length_causal_mask: 2048
212+
use_learnable_pos_emb: true
213+
make_prior_window_strict: true
214+
215+
optim:
216+
_target_: torch.optim.AdamW
217+
lr: 2e-4
218+
219+
sched:
220+
name: ExponentialLR
221+
gamma: 0.998
222+
223+
trainer:
224+
num_nodes: 1
225+
devices: -1
226+
accelerator: gpu
227+
strategy: ddp_find_unused_parameters_true
228+
precision: 32
229+
max_steps: ???
230+
accumulate_grad_batches: 1
231+
enable_checkpointing: False # Provided by exp_manager
232+
logger: false # Provided by exp_manager
233+
log_every_n_steps: 100
234+
check_val_every_n_epoch: 1
235+
limit_train_batches: 1_000
236+
val_check_interval: 1_000
237+
num_sanity_val_steps: 0
238+
benchmark: false
239+
use_distributed_sampler: false # required because Lhotse has its own handling
240+
gradient_clip_val: 2.5
241+
242+
exp_manager:
243+
exp_dir: null
244+
name: ${name}
245+
create_tensorboard_logger: true
246+
create_wandb_logger: false
247+
wandb_logger_kwargs:
248+
entity: null
249+
project: null
250+
group: null
251+
name: ${name}
252+
resume: true # enable resume to ensure continuous training log metrics merged on the previous run id.
253+
create_checkpoint_callback: true
254+
checkpoint_callback_params:
255+
monitor: val_loss
256+
mode: min
257+
save_top_k: 5
258+
save_best_model: true
259+
always_save_nemo: true
260+
resume_if_exists: true
261+
resume_ignore_no_checkpoint: true

nemo/collections/tts/data/text_to_speech_dataset.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,9 @@ def __getitem__(self, index):
550550
example['context_audio_codes_len'] = context_audio_codes_len
551551
else:
552552
# @shehzeenh: Added this condition so that a batch does not have a mix of context_audio and context_audio_codes
553-
context_audio = torch.zeros(self.codec_model_samples_per_frame, dtype=torch.float32)
553+
# @blisc: Added a +1. If we send in exactly 882 samples, then a conv layer complains about padding.
554+
# Adding 883 works. This occurs when we use text context during inference.
555+
context_audio = torch.zeros(self.codec_model_samples_per_frame + 1, dtype=torch.float32)
554556
context_audio_len = context_audio.shape[0]
555557
example['context_audio'] = context_audio
556558
example['context_audio_len'] = context_audio_len

nemo/collections/tts/models/magpietts.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,13 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
9494
if trainer is not None:
9595
self.world_size = trainer.num_nodes * trainer.num_devices
9696

97-
# load codec
98-
codec_model = AudioCodecModel.restore_from(cfg.get('codecmodel_path'), strict=False)
99-
97+
# load codec, disable loading of loss modules not needed during inference
98+
codec_model_cfg = AudioCodecModel.restore_from(cfg.get('codecmodel_path'), return_config=True)
99+
if "use_scl_loss" in codec_model_cfg:
100+
codec_model_cfg.use_scl_loss = False
101+
codec_model = AudioCodecModel.restore_from(
102+
cfg.get('codecmodel_path'), strict=False, override_config_path=codec_model_cfg
103+
)
100104
self.sample_rate = codec_model.sample_rate
101105
self.codec_model_samples_per_frame = codec_model.samples_per_frame
102106
# del codec discriminator to free memory
@@ -108,16 +112,25 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
108112
vector_quantizer = cfg.get('vector_quantizer')
109113
if vector_quantizer is not None:
110114
vector_quantizer = instantiate(vector_quantizer)
111-
self.num_audio_codebooks = vector_quantizer.num_codebooks
112-
self.codebook_size = vector_quantizer.codebook_size
115+
num_audio_codebooks = vector_quantizer.num_codebooks
116+
codebook_size = vector_quantizer.codebook_size
113117
codec_converter = VectorQuantizerIndexConverter(
114118
vector_quantizer_original=codec_model.vector_quantizer,
115119
vector_quantizer_new=vector_quantizer,
116120
)
121+
data_num_audio_codebooks = codec_model.vector_quantizer.num_codebooks
117122
else:
118-
self.num_audio_codebooks = codec_model.num_codebooks
119-
self.codebook_size = codec_model.codebook_size
123+
num_audio_codebooks = codec_model.num_codebooks
124+
data_num_audio_codebooks = num_audio_codebooks
125+
codebook_size = codec_model.codebook_size
120126
codec_converter = None
127+
# The dataloader needs to know the number of codebooks that the context codes were stored in
128+
# In the case where there are no context codes saved, and there is no context audio (in the text context path),
129+
# We create a dummy context code tensor that is only [context_BOS, context_EOS] that is repeated for
130+
# data_num_audio_codebooks
131+
self.data_num_audio_codebooks = data_num_audio_codebooks
132+
self.num_audio_codebooks = num_audio_codebooks
133+
self.codebook_size = codebook_size
121134

122135
# Our codebooks start with actual audio codec tokens, followed by special tokens.
123136
# The `forced_*` options are for backward compatibility for models trained with older code.
@@ -2648,7 +2661,7 @@ def get_dataset(self, dataset_cfg, dataset_type):
26482661
audio_eos_id=self.audio_eos_id,
26492662
context_audio_bos_id=self.context_audio_bos_id,
26502663
context_audio_eos_id=self.context_audio_eos_id,
2651-
num_audio_codebooks=self.num_audio_codebooks,
2664+
num_audio_codebooks=self.data_num_audio_codebooks,
26522665
codec_model_samples_per_frame=self.codec_model_samples_per_frame,
26532666
prior_scaling_factor=self.cfg.prior_scaling_factor,
26542667
load_cached_codes_if_available=self.cfg.load_cached_codes_if_available,
@@ -2678,7 +2691,7 @@ def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.D
26782691
audio_eos_id=self.audio_eos_id,
26792692
context_audio_bos_id=self.context_audio_bos_id,
26802693
context_audio_eos_id=self.context_audio_eos_id,
2681-
num_audio_codebooks=self.num_audio_codebooks,
2694+
num_audio_codebooks=self.data_num_audio_codebooks,
26822695
prior_scaling_factor=self.cfg.prior_scaling_factor,
26832696
load_cached_codes_if_available=self.cfg.load_cached_codes_if_available,
26842697
dataset_type=mode, # train or test used for setting phone prob to 1.0 in test dataset (worker_init_fn)

0 commit comments

Comments
 (0)