From 7757218a6ac8488c4b2c62acd1652ed27c91ad43 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 14 Oct 2024 11:29:48 +0800 Subject: [PATCH 01/27] copy files from Matcha-TTS --- egs/ljspeech/TTS/matcha_tts/__init__.py | 0 egs/ljspeech/TTS/matcha_tts/hifigan/LICENSE | 21 + egs/ljspeech/TTS/matcha_tts/hifigan/README.md | 101 ++++ .../TTS/matcha_tts/hifigan/__init__.py | 0 egs/ljspeech/TTS/matcha_tts/hifigan/config.py | 28 ++ .../TTS/matcha_tts/hifigan/denoiser.py | 64 +++ egs/ljspeech/TTS/matcha_tts/hifigan/env.py | 17 + .../TTS/matcha_tts/hifigan/meldataset.py | 217 +++++++++ egs/ljspeech/TTS/matcha_tts/hifigan/models.py | 368 +++++++++++++++ egs/ljspeech/TTS/matcha_tts/hifigan/xutils.py | 60 +++ .../TTS/matcha_tts/models/__init__.py | 0 .../matcha_tts/models/baselightningmodule.py | 210 +++++++++ .../matcha_tts/models/components/__init__.py | 0 .../matcha_tts/models/components/decoder.py | 443 ++++++++++++++++++ .../models/components/flow_matching.py | 132 ++++++ .../models/components/text_encoder.py | 410 ++++++++++++++++ .../models/components/transformer.py | 316 +++++++++++++ .../TTS/matcha_tts/models/matcha_tts.py | 244 ++++++++++ egs/ljspeech/TTS/matcha_tts/text/__init__.py | 60 +++ egs/ljspeech/TTS/matcha_tts/text/cleaners.py | 129 +++++ egs/ljspeech/TTS/matcha_tts/text/numbers.py | 71 +++ egs/ljspeech/TTS/matcha_tts/text/symbols.py | 17 + egs/ljspeech/TTS/matcha_tts/train.py | 122 +++++ egs/ljspeech/TTS/matcha_tts/utils/__init__.py | 5 + egs/ljspeech/TTS/matcha_tts/utils/audio.py | 82 ++++ .../utils/generate_data_statistics.py | 113 +++++ .../utils/get_durations_from_trained_model.py | 195 ++++++++ .../TTS/matcha_tts/utils/instantiators.py | 56 +++ .../TTS/matcha_tts/utils/logging_utils.py | 53 +++ egs/ljspeech/TTS/matcha_tts/utils/model.py | 90 ++++ .../utils/monotonic_align/__init__.py | 22 + .../matcha_tts/utils/monotonic_align/core.pyx | 47 ++ .../matcha_tts/utils/monotonic_align/setup.py | 9 + egs/ljspeech/TTS/matcha_tts/utils/pylogger.py | 21 + .../TTS/matcha_tts/utils/rich_utils.py | 101 ++++ egs/ljspeech/TTS/matcha_tts/utils/utils.py | 259 ++++++++++ 36 files changed, 4083 insertions(+) create mode 100644 egs/ljspeech/TTS/matcha_tts/__init__.py create mode 100644 egs/ljspeech/TTS/matcha_tts/hifigan/LICENSE create mode 100644 egs/ljspeech/TTS/matcha_tts/hifigan/README.md create mode 100644 egs/ljspeech/TTS/matcha_tts/hifigan/__init__.py create mode 100644 egs/ljspeech/TTS/matcha_tts/hifigan/config.py create mode 100644 egs/ljspeech/TTS/matcha_tts/hifigan/denoiser.py create mode 100644 egs/ljspeech/TTS/matcha_tts/hifigan/env.py create mode 100644 egs/ljspeech/TTS/matcha_tts/hifigan/meldataset.py create mode 100644 egs/ljspeech/TTS/matcha_tts/hifigan/models.py create mode 100644 egs/ljspeech/TTS/matcha_tts/hifigan/xutils.py create mode 100644 egs/ljspeech/TTS/matcha_tts/models/__init__.py create mode 100644 egs/ljspeech/TTS/matcha_tts/models/baselightningmodule.py create mode 100644 egs/ljspeech/TTS/matcha_tts/models/components/__init__.py create mode 100644 egs/ljspeech/TTS/matcha_tts/models/components/decoder.py create mode 100644 egs/ljspeech/TTS/matcha_tts/models/components/flow_matching.py create mode 100644 egs/ljspeech/TTS/matcha_tts/models/components/text_encoder.py create mode 100644 egs/ljspeech/TTS/matcha_tts/models/components/transformer.py create mode 100644 egs/ljspeech/TTS/matcha_tts/models/matcha_tts.py create mode 100644 egs/ljspeech/TTS/matcha_tts/text/__init__.py create mode 100644 egs/ljspeech/TTS/matcha_tts/text/cleaners.py create mode 100644 egs/ljspeech/TTS/matcha_tts/text/numbers.py create mode 100644 egs/ljspeech/TTS/matcha_tts/text/symbols.py create mode 100644 egs/ljspeech/TTS/matcha_tts/train.py create mode 100644 egs/ljspeech/TTS/matcha_tts/utils/__init__.py create mode 100644 egs/ljspeech/TTS/matcha_tts/utils/audio.py create mode 100644 egs/ljspeech/TTS/matcha_tts/utils/generate_data_statistics.py create mode 100644 egs/ljspeech/TTS/matcha_tts/utils/get_durations_from_trained_model.py create mode 100644 egs/ljspeech/TTS/matcha_tts/utils/instantiators.py create mode 100644 egs/ljspeech/TTS/matcha_tts/utils/logging_utils.py create mode 100644 egs/ljspeech/TTS/matcha_tts/utils/model.py create mode 100644 egs/ljspeech/TTS/matcha_tts/utils/monotonic_align/__init__.py create mode 100644 egs/ljspeech/TTS/matcha_tts/utils/monotonic_align/core.pyx create mode 100644 egs/ljspeech/TTS/matcha_tts/utils/monotonic_align/setup.py create mode 100644 egs/ljspeech/TTS/matcha_tts/utils/pylogger.py create mode 100644 egs/ljspeech/TTS/matcha_tts/utils/rich_utils.py create mode 100644 egs/ljspeech/TTS/matcha_tts/utils/utils.py diff --git a/egs/ljspeech/TTS/matcha_tts/__init__.py b/egs/ljspeech/TTS/matcha_tts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/egs/ljspeech/TTS/matcha_tts/hifigan/LICENSE b/egs/ljspeech/TTS/matcha_tts/hifigan/LICENSE new file mode 100644 index 0000000000..91751daed8 --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/hifigan/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Jungil Kong + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/egs/ljspeech/TTS/matcha_tts/hifigan/README.md b/egs/ljspeech/TTS/matcha_tts/hifigan/README.md new file mode 100644 index 0000000000..5db2585045 --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/hifigan/README.md @@ -0,0 +1,101 @@ +# HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis + +### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae + +In our [paper](https://arxiv.org/abs/2010.05646), +we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.
+We provide our implementation and pretrained models as open source in this repository. + +**Abstract :** +Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms. +Although such methods improve the sampling efficiency and memory usage, +their sample quality has not yet reached that of autoregressive and flow-based generative models. +In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis. +As speech audio consists of sinusoidal signals with various periods, +we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality. +A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method +demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than +real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen +speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times +faster than real-time on CPU with comparable quality to an autoregressive counterpart. + +Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples. + +## Pre-requisites + +1. Python >= 3.6 +2. Clone this repository. +3. Install python requirements. Please refer [requirements.txt](requirements.txt) +4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/). + And move all wav files to `LJSpeech-1.1/wavs` + +## Training + +``` +python train.py --config config_v1.json +``` + +To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.
+Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.
+You can change the path by adding `--checkpoint_path` option. + +Validation loss during training with V1 generator.
+![validation loss](./validation_loss.png) + +## Pretrained Model + +You can also use pretrained models we provide.
+[Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)
+Details of each folder are as in follows: + +| Folder Name | Generator | Dataset | Fine-Tuned | +| ------------ | --------- | --------- | ------------------------------------------------------ | +| LJ_V1 | V1 | LJSpeech | No | +| LJ_V2 | V2 | LJSpeech | No | +| LJ_V3 | V3 | LJSpeech | No | +| LJ_FT_T2_V1 | V1 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | +| LJ_FT_T2_V2 | V2 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | +| LJ_FT_T2_V3 | V3 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | +| VCTK_V1 | V1 | VCTK | No | +| VCTK_V2 | V2 | VCTK | No | +| VCTK_V3 | V3 | VCTK | No | +| UNIVERSAL_V1 | V1 | Universal | No | + +We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets. + +## Fine-Tuning + +1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.
+ The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.
+ Example: + ` Audio File : LJ001-0001.wav +Mel-Spectrogram File : LJ001-0001.npy` +2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.
+3. Run the following command. + ``` + python train.py --fine_tuning True --config config_v1.json + ``` + For other command line options, please refer to the training section. + +## Inference from wav file + +1. Make `test_files` directory and copy wav files into the directory. +2. Run the following command. + ` python inference.py --checkpoint_file [generator checkpoint file path]` + Generated wav files are saved in `generated_files` by default.
+ You can change the path by adding `--output_dir` option. + +## Inference for end-to-end speech synthesis + +1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.
+ You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2), + [Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth. +2. Run the following command. + ` python inference_e2e.py --checkpoint_file [generator checkpoint file path]` + Generated wav files are saved in `generated_files_from_mel` by default.
+ You can change the path by adding `--output_dir` option. + +## Acknowledgements + +We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips) +and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this. diff --git a/egs/ljspeech/TTS/matcha_tts/hifigan/__init__.py b/egs/ljspeech/TTS/matcha_tts/hifigan/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/egs/ljspeech/TTS/matcha_tts/hifigan/config.py b/egs/ljspeech/TTS/matcha_tts/hifigan/config.py new file mode 100644 index 0000000000..b3abea9e15 --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/hifigan/config.py @@ -0,0 +1,28 @@ +v1 = { + "resblock": "1", + "num_gpus": 0, + "batch_size": 16, + "learning_rate": 0.0004, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + "upsample_rates": [8, 8, 2, 2], + "upsample_kernel_sizes": [16, 16, 4, 4], + "upsample_initial_channel": 512, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "resblock_initial_channel": 256, + "segment_size": 8192, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + "sampling_rate": 22050, + "fmin": 0, + "fmax": 8000, + "fmax_loss": None, + "num_workers": 4, + "dist_config": {"dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1}, +} diff --git a/egs/ljspeech/TTS/matcha_tts/hifigan/denoiser.py b/egs/ljspeech/TTS/matcha_tts/hifigan/denoiser.py new file mode 100644 index 0000000000..9fd33312a0 --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/hifigan/denoiser.py @@ -0,0 +1,64 @@ +# Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py + +"""Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio.""" +import torch + + +class Denoiser(torch.nn.Module): + """Removes model bias from audio produced with waveglow""" + + def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"): + super().__init__() + self.filter_length = filter_length + self.hop_length = int(filter_length / n_overlap) + self.win_length = win_length + + dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device + self.device = device + if mode == "zeros": + mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device) + elif mode == "normal": + mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device) + else: + raise Exception(f"Mode {mode} if not supported") + + def stft_fn(audio, n_fft, hop_length, win_length, window): + spec = torch.stft( + audio, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + return_complex=True, + ) + spec = torch.view_as_real(spec) + return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0]) + + self.stft = lambda x: stft_fn( + audio=x, + n_fft=self.filter_length, + hop_length=self.hop_length, + win_length=self.win_length, + window=torch.hann_window(self.win_length, device=device), + ) + self.istft = lambda x, y: torch.istft( + torch.complex(x * torch.cos(y), x * torch.sin(y)), + n_fft=self.filter_length, + hop_length=self.hop_length, + win_length=self.win_length, + window=torch.hann_window(self.win_length, device=device), + ) + + with torch.no_grad(): + bias_audio = vocoder(mel_input).float().squeeze(0) + bias_spec, _ = self.stft(bias_audio) + + self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None]) + + @torch.inference_mode() + def forward(self, audio, strength=0.0005): + audio_spec, audio_angles = self.stft(audio) + audio_spec_denoised = audio_spec - self.bias_spec.to(audio.device) * strength + audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) + audio_denoised = self.istft(audio_spec_denoised, audio_angles) + return audio_denoised diff --git a/egs/ljspeech/TTS/matcha_tts/hifigan/env.py b/egs/ljspeech/TTS/matcha_tts/hifigan/env.py new file mode 100644 index 0000000000..9ea4f948a3 --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/hifigan/env.py @@ -0,0 +1,17 @@ +""" from https://github.com/jik876/hifi-gan """ + +import os +import shutil + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__dict__ = self + + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) diff --git a/egs/ljspeech/TTS/matcha_tts/hifigan/meldataset.py b/egs/ljspeech/TTS/matcha_tts/hifigan/meldataset.py new file mode 100644 index 0000000000..8b43ea7965 --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/hifigan/meldataset.py @@ -0,0 +1,217 @@ +""" from https://github.com/jik876/hifi-gan """ + +import math +import os +import random + +import numpy as np +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn +from librosa.util import normalize +from scipy.io.wavfile import read + +MAX_WAV_VALUE = 32768.0 + + +def load_wav(full_path): + sampling_rate, data = read(full_path) + return data, sampling_rate + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window # pylint: disable=global-statement + if fmax not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec + + +def get_dataset_filelist(a): + with open(a.input_training_file, encoding="utf-8") as fi: + training_files = [ + os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 + ] + + with open(a.input_validation_file, encoding="utf-8") as fi: + validation_files = [ + os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 + ] + return training_files, validation_files + + +class MelDataset(torch.utils.data.Dataset): + def __init__( + self, + training_files, + segment_size, + n_fft, + num_mels, + hop_size, + win_size, + sampling_rate, + fmin, + fmax, + split=True, + shuffle=True, + n_cache_reuse=1, + device=None, + fmax_loss=None, + fine_tuning=False, + base_mels_path=None, + ): + self.audio_files = training_files + random.seed(1234) + if shuffle: + random.shuffle(self.audio_files) + self.segment_size = segment_size + self.sampling_rate = sampling_rate + self.split = split + self.n_fft = n_fft + self.num_mels = num_mels + self.hop_size = hop_size + self.win_size = win_size + self.fmin = fmin + self.fmax = fmax + self.fmax_loss = fmax_loss + self.cached_wav = None + self.n_cache_reuse = n_cache_reuse + self._cache_ref_count = 0 + self.device = device + self.fine_tuning = fine_tuning + self.base_mels_path = base_mels_path + + def __getitem__(self, index): + filename = self.audio_files[index] + if self._cache_ref_count == 0: + audio, sampling_rate = load_wav(filename) + audio = audio / MAX_WAV_VALUE + if not self.fine_tuning: + audio = normalize(audio) * 0.95 + self.cached_wav = audio + if sampling_rate != self.sampling_rate: + raise ValueError(f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR") + self._cache_ref_count = self.n_cache_reuse + else: + audio = self.cached_wav + self._cache_ref_count -= 1 + + audio = torch.FloatTensor(audio) + audio = audio.unsqueeze(0) + + if not self.fine_tuning: + if self.split: + if audio.size(1) >= self.segment_size: + max_audio_start = audio.size(1) - self.segment_size + audio_start = random.randint(0, max_audio_start) + audio = audio[:, audio_start : audio_start + self.segment_size] + else: + audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") + + mel = mel_spectrogram( + audio, + self.n_fft, + self.num_mels, + self.sampling_rate, + self.hop_size, + self.win_size, + self.fmin, + self.fmax, + center=False, + ) + else: + mel = np.load(os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + ".npy")) + mel = torch.from_numpy(mel) + + if len(mel.shape) < 3: + mel = mel.unsqueeze(0) + + if self.split: + frames_per_seg = math.ceil(self.segment_size / self.hop_size) + + if audio.size(1) >= self.segment_size: + mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) + mel = mel[:, :, mel_start : mel_start + frames_per_seg] + audio = audio[:, mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size] + else: + mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant") + audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") + + mel_loss = mel_spectrogram( + audio, + self.n_fft, + self.num_mels, + self.sampling_rate, + self.hop_size, + self.win_size, + self.fmin, + self.fmax_loss, + center=False, + ) + + return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) + + def __len__(self): + return len(self.audio_files) diff --git a/egs/ljspeech/TTS/matcha_tts/hifigan/models.py b/egs/ljspeech/TTS/matcha_tts/hifigan/models.py new file mode 100644 index 0000000000..d209d9a4e9 --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/hifigan/models.py @@ -0,0 +1,368 @@ +""" from https://github.com/jik876/hifi-gan """ + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +from .xutils import get_padding, init_weights + +LRELU_SLOPE = 0.1 + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super().__init__() + self.h = h + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super().__init__() + self.h = h + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Generator(torch.nn.Module): + def __init__(self, h): + super().__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if h.resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super().__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ] + ) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self): + super().__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorP(2), + DiscriminatorP(3), + DiscriminatorP(5), + DiscriminatorP(7), + DiscriminatorP(11), + ] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for _, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super().__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self): + super().__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ] + ) + self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg**2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses diff --git a/egs/ljspeech/TTS/matcha_tts/hifigan/xutils.py b/egs/ljspeech/TTS/matcha_tts/hifigan/xutils.py new file mode 100644 index 0000000000..eefadcb7a1 --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/hifigan/xutils.py @@ -0,0 +1,60 @@ +""" from https://github.com/jik876/hifi-gan """ + +import glob +import os + +import matplotlib +import torch +from torch.nn.utils import weight_norm + +matplotlib.use("Agg") +import matplotlib.pylab as plt + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print(f"Loading '{filepath}'") + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + print(f"Saving checkpoint to {filepath}") + torch.save(obj, filepath) + print("Complete.") + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + "????????") + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return None + return sorted(cp_list)[-1] diff --git a/egs/ljspeech/TTS/matcha_tts/models/__init__.py b/egs/ljspeech/TTS/matcha_tts/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/egs/ljspeech/TTS/matcha_tts/models/baselightningmodule.py b/egs/ljspeech/TTS/matcha_tts/models/baselightningmodule.py new file mode 100644 index 0000000000..f8abe7b44f --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/models/baselightningmodule.py @@ -0,0 +1,210 @@ +""" +This is a base lightning module that can be used to train a model. +The benefit of this abstraction is that all the logic outside of model definition can be reused for different models. +""" +import inspect +from abc import ABC +from typing import Any, Dict + +import torch +from lightning import LightningModule +from lightning.pytorch.utilities import grad_norm + +from matcha import utils +from matcha.utils.utils import plot_tensor + +log = utils.get_pylogger(__name__) + + +class BaseLightningClass(LightningModule, ABC): + def update_data_statistics(self, data_statistics): + if data_statistics is None: + data_statistics = { + "mel_mean": 0.0, + "mel_std": 1.0, + } + + self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"])) + self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"])) + + def configure_optimizers(self) -> Any: + optimizer = self.hparams.optimizer(params=self.parameters()) + if self.hparams.scheduler not in (None, {}): + scheduler_args = {} + # Manage last epoch for exponential schedulers + if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters: + if hasattr(self, "ckpt_loaded_epoch"): + current_epoch = self.ckpt_loaded_epoch - 1 + else: + current_epoch = -1 + + scheduler_args.update({"optimizer": optimizer}) + scheduler = self.hparams.scheduler.scheduler(**scheduler_args) + scheduler.last_epoch = current_epoch + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": self.hparams.scheduler.lightning_args.interval, + "frequency": self.hparams.scheduler.lightning_args.frequency, + "name": "learning_rate", + }, + } + + return {"optimizer": optimizer} + + def get_losses(self, batch): + x, x_lengths = batch["x"], batch["x_lengths"] + y, y_lengths = batch["y"], batch["y_lengths"] + spks = batch["spks"] + + dur_loss, prior_loss, diff_loss, *_ = self( + x=x, + x_lengths=x_lengths, + y=y, + y_lengths=y_lengths, + spks=spks, + out_size=self.out_size, + durations=batch["durations"], + ) + return { + "dur_loss": dur_loss, + "prior_loss": prior_loss, + "diff_loss": diff_loss, + } + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init + + def training_step(self, batch: Any, batch_idx: int): + loss_dict = self.get_losses(batch) + self.log( + "step", + float(self.global_step), + on_step=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + + self.log( + "sub_loss/train_dur_loss", + loss_dict["dur_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + self.log( + "sub_loss/train_prior_loss", + loss_dict["prior_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + self.log( + "sub_loss/train_diff_loss", + loss_dict["diff_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + + total_loss = sum(loss_dict.values()) + self.log( + "loss/train", + total_loss, + on_step=True, + on_epoch=True, + logger=True, + prog_bar=True, + sync_dist=True, + ) + + return {"loss": total_loss, "log": loss_dict} + + def validation_step(self, batch: Any, batch_idx: int): + loss_dict = self.get_losses(batch) + self.log( + "sub_loss/val_dur_loss", + loss_dict["dur_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + self.log( + "sub_loss/val_prior_loss", + loss_dict["prior_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + self.log( + "sub_loss/val_diff_loss", + loss_dict["diff_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + + total_loss = sum(loss_dict.values()) + self.log( + "loss/val", + total_loss, + on_step=True, + on_epoch=True, + logger=True, + prog_bar=True, + sync_dist=True, + ) + + return total_loss + + def on_validation_end(self) -> None: + if self.trainer.is_global_zero: + one_batch = next(iter(self.trainer.val_dataloaders)) + if self.current_epoch == 0: + log.debug("Plotting original samples") + for i in range(2): + y = one_batch["y"][i].unsqueeze(0).to(self.device) + self.logger.experiment.add_image( + f"original/{i}", + plot_tensor(y.squeeze().cpu()), + self.current_epoch, + dataformats="HWC", + ) + + log.debug("Synthesising...") + for i in range(2): + x = one_batch["x"][i].unsqueeze(0).to(self.device) + x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device) + spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None + output = self.synthesise(x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks) + y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"] + attn = output["attn"] + self.logger.experiment.add_image( + f"generated_enc/{i}", + plot_tensor(y_enc.squeeze().cpu()), + self.current_epoch, + dataformats="HWC", + ) + self.logger.experiment.add_image( + f"generated_dec/{i}", + plot_tensor(y_dec.squeeze().cpu()), + self.current_epoch, + dataformats="HWC", + ) + self.logger.experiment.add_image( + f"alignment/{i}", + plot_tensor(attn.squeeze().cpu()), + self.current_epoch, + dataformats="HWC", + ) + + def on_before_optimizer_step(self, optimizer): + self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()}) diff --git a/egs/ljspeech/TTS/matcha_tts/models/components/__init__.py b/egs/ljspeech/TTS/matcha_tts/models/components/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/egs/ljspeech/TTS/matcha_tts/models/components/decoder.py b/egs/ljspeech/TTS/matcha_tts/models/components/decoder.py new file mode 100644 index 0000000000..1137cd7008 --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/models/components/decoder.py @@ -0,0 +1,443 @@ +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from conformer import ConformerBlock +from diffusers.models.activations import get_activation +from einops import pack, rearrange, repeat + +from matcha.models.components.transformer import BasicTransformerBlock + + +class SinusoidalPosEmb(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even" + + def forward(self, x, scale=1000): + if x.ndim < 1: + x = x.unsqueeze(0) + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class Block1D(torch.nn.Module): + def __init__(self, dim, dim_out, groups=8): + super().__init__() + self.block = torch.nn.Sequential( + torch.nn.Conv1d(dim, dim_out, 3, padding=1), + torch.nn.GroupNorm(groups, dim_out), + nn.Mish(), + ) + + def forward(self, x, mask): + output = self.block(x * mask) + return output * mask + + +class ResnetBlock1D(torch.nn.Module): + def __init__(self, dim, dim_out, time_emb_dim, groups=8): + super().__init__() + self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)) + + self.block1 = Block1D(dim, dim_out, groups=groups) + self.block2 = Block1D(dim_out, dim_out, groups=groups) + + self.res_conv = torch.nn.Conv1d(dim, dim_out, 1) + + def forward(self, x, mask, time_emb): + h = self.block1(x, mask) + h += self.mlp(time_emb).unsqueeze(-1) + h = self.block2(h, mask) + output = h + self.res_conv(x * mask) + return output + + +class Downsample1D(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Upsample1D(nn.Module): + """A 1D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + """ + + def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + self.conv = None + if use_conv_transpose: + self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, inputs): + assert inputs.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(inputs) + + outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") + + if self.use_conv: + outputs = self.conv(outputs) + + return outputs + + +class ConformerWrapper(ConformerBlock): + def __init__( # pylint: disable=useless-super-delegation + self, + *, + dim, + dim_head=64, + heads=8, + ff_mult=4, + conv_expansion_factor=2, + conv_kernel_size=31, + attn_dropout=0, + ff_dropout=0, + conv_dropout=0, + conv_causal=False, + ): + super().__init__( + dim=dim, + dim_head=dim_head, + heads=heads, + ff_mult=ff_mult, + conv_expansion_factor=conv_expansion_factor, + conv_kernel_size=conv_kernel_size, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + conv_dropout=conv_dropout, + conv_causal=conv_causal, + ) + + def forward( + self, + hidden_states, + attention_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + ): + return super().forward(x=hidden_states, mask=attention_mask.bool()) + + +class Decoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + channels=(256, 256), + dropout=0.05, + attention_head_dim=64, + n_blocks=1, + num_mid_blocks=2, + num_heads=4, + act_fn="snake", + down_block_type="transformer", + mid_block_type="transformer", + up_block_type="transformer", + ): + super().__init__() + channels = tuple(channels) + self.in_channels = in_channels + self.out_channels = out_channels + + self.time_embeddings = SinusoidalPosEmb(in_channels) + time_embed_dim = channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=time_embed_dim, + act_fn="silu", + ) + + self.down_blocks = nn.ModuleList([]) + self.mid_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + output_channel = in_channels + for i in range(len(channels)): # pylint: disable=consider-using-enumerate + input_channel = output_channel + output_channel = channels[i] + is_last = i == len(channels) - 1 + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + transformer_blocks = nn.ModuleList( + [ + self.get_block( + down_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + downsample = ( + Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + + self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) + + for i in range(num_mid_blocks): + input_channel = channels[-1] + out_channels = channels[-1] + + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + + transformer_blocks = nn.ModuleList( + [ + self.get_block( + mid_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + + self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) + + channels = channels[::-1] + (channels[0],) + for i in range(len(channels) - 1): + input_channel = channels[i] + output_channel = channels[i + 1] + is_last = i == len(channels) - 2 + + resnet = ResnetBlock1D( + dim=2 * input_channel, + dim_out=output_channel, + time_emb_dim=time_embed_dim, + ) + transformer_blocks = nn.ModuleList( + [ + self.get_block( + up_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + upsample = ( + Upsample1D(output_channel, use_conv_transpose=True) + if not is_last + else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + + self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) + + self.final_block = Block1D(channels[-1], channels[-1]) + self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) + + self.initialize_weights() + # nn.init.normal_(self.final_proj.weight) + + @staticmethod + def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn): + if block_type == "conformer": + block = ConformerWrapper( + dim=dim, + dim_head=attention_head_dim, + heads=num_heads, + ff_mult=1, + conv_expansion_factor=2, + ff_dropout=dropout, + attn_dropout=dropout, + conv_dropout=dropout, + conv_kernel_size=31, + ) + elif block_type == "transformer": + block = BasicTransformerBlock( + dim=dim, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + else: + raise ValueError(f"Unknown block type {block_type}") + + return block + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.GroupNorm): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x, mask, mu, t, spks=None, cond=None): + """Forward pass of the UNet1DConditional model. + + Args: + x (torch.Tensor): shape (batch_size, in_channels, time) + mask (_type_): shape (batch_size, 1, time) + t (_type_): shape (batch_size) + spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. + cond (_type_, optional): placeholder for future use. Defaults to None. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + _type_: _description_ + """ + + t = self.time_embeddings(t) + t = self.time_mlp(t) + + x = pack([x, mu], "b * t")[0] + + if spks is not None: + spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) + x = pack([x, spks], "b * t")[0] + + hiddens = [] + masks = [mask] + for resnet, transformer_blocks, downsample in self.down_blocks: + mask_down = masks[-1] + x = resnet(x, mask_down, t) + x = rearrange(x, "b c t -> b t c") + mask_down = rearrange(mask_down, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=mask_down, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t") + mask_down = rearrange(mask_down, "b t -> b 1 t") + hiddens.append(x) # Save hidden states for skip connections + x = downsample(x * mask_down) + masks.append(mask_down[:, :, ::2]) + + masks = masks[:-1] + mask_mid = masks[-1] + + for resnet, transformer_blocks in self.mid_blocks: + x = resnet(x, mask_mid, t) + x = rearrange(x, "b c t -> b t c") + mask_mid = rearrange(mask_mid, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=mask_mid, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t") + mask_mid = rearrange(mask_mid, "b t -> b 1 t") + + for resnet, transformer_blocks, upsample in self.up_blocks: + mask_up = masks.pop() + x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t) + x = rearrange(x, "b c t -> b t c") + mask_up = rearrange(mask_up, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=mask_up, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t") + mask_up = rearrange(mask_up, "b t -> b 1 t") + x = upsample(x * mask_up) + + x = self.final_block(x, mask_up) + output = self.final_proj(x * mask_up) + + return output * mask diff --git a/egs/ljspeech/TTS/matcha_tts/models/components/flow_matching.py b/egs/ljspeech/TTS/matcha_tts/models/components/flow_matching.py new file mode 100644 index 0000000000..5cad7431ef --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/models/components/flow_matching.py @@ -0,0 +1,132 @@ +from abc import ABC + +import torch +import torch.nn.functional as F + +from matcha.models.components.decoder import Decoder +from matcha.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +class BASECFM(torch.nn.Module, ABC): + def __init__( + self, + n_feats, + cfm_params, + n_spks=1, + spk_emb_dim=128, + ): + super().__init__() + self.n_feats = n_feats + self.n_spks = n_spks + self.spk_emb_dim = spk_emb_dim + self.solver = cfm_params.solver + if hasattr(cfm_params, "sigma_min"): + self.sigma_min = cfm_params.sigma_min + else: + self.sigma_min = 1e-4 + + self.estimator = None + + @torch.inference_mode() + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): + """Forward diffusion + + Args: + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + n_timesteps (int): number of diffusion steps + temperature (float, optional): temperature for scaling noise. Defaults to 1.0. + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + + Returns: + sample: generated mel-spectrogram + shape: (batch_size, n_feats, mel_timesteps) + """ + z = torch.randn_like(mu) * temperature + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) + return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) + + def solve_euler(self, x, t_span, mu, mask, spks, cond): + """ + Fixed euler solver for ODEs. + Args: + x (torch.Tensor): random noise + t_span (torch.Tensor): n_timesteps interpolated + shape: (n_timesteps + 1,) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + """ + t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + + # I am storing this because I can later plot it by putting a debugger here and saving it to a file + # Or in future might add like a return_all_steps flag + sol = [] + + for step in range(1, len(t_span)): + dphi_dt = self.estimator(x, mask, mu, t, spks, cond) + + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + + return sol[-1] + + def compute_loss(self, x1, mask, mu, spks=None, cond=None): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): target mask + shape: (batch_size, 1, mel_timesteps) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + spks (torch.Tensor, optional): speaker embedding. Defaults to None. + shape: (batch_size, spk_emb_dim) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_feats, mel_timesteps) + """ + b, _, t = mu.shape + + # random timestep + t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) + # sample noise p(x_0) + z = torch.randn_like(x1) + + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + + loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / ( + torch.sum(mask) * u.shape[1] + ) + return loss, y + + +class CFM(BASECFM): + def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64): + super().__init__( + n_feats=in_channels, + cfm_params=cfm_params, + n_spks=n_spks, + spk_emb_dim=spk_emb_dim, + ) + + in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0) + # Just change the architecture of the estimator here + self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params) diff --git a/egs/ljspeech/TTS/matcha_tts/models/components/text_encoder.py b/egs/ljspeech/TTS/matcha_tts/models/components/text_encoder.py new file mode 100644 index 0000000000..a388d05d63 --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/models/components/text_encoder.py @@ -0,0 +1,410 @@ +""" from https://github.com/jaywalnut310/glow-tts """ + +import math + +import torch +import torch.nn as nn +from einops import rearrange + +import matcha.utils as utils +from matcha.utils.model import sequence_mask + +log = utils.get_pylogger(__name__) + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-4): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = torch.nn.Parameter(torch.ones(channels)) + self.beta = torch.nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + n_dims = len(x.shape) + mean = torch.mean(x, 1, keepdim=True) + variance = torch.mean((x - mean) ** 2, 1, keepdim=True) + + x = (x - mean) * torch.rsqrt(variance + self.eps) + + shape = [1, -1] + [1] * (n_dims - 2) + x = x * self.gamma.view(*shape) + self.beta.view(*shape) + return x + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.conv_layers = torch.nn.ModuleList() + self.norm_layers = torch.nn.ModuleList() + self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append( + torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.p_dropout = p_dropout + + self.drop = torch.nn.Dropout(p_dropout) + self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_1 = LayerNorm(filter_channels) + self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_2 = LayerNorm(filter_channels) + self.proj = torch.nn.Conv1d(filter_channels, 1, 1) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class RotaryPositionalEmbeddings(nn.Module): + """ + ## RoPE module + + Rotary encoding transforms pairs of features by rotating in the 2D plane. + That is, it organizes the $d$ features as $\frac{d}{2}$ pairs. + Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it + by an angle depending on the position of the token. + """ + + def __init__(self, d: int, base: int = 10_000): + r""" + * `d` is the number of features $d$ + * `base` is the constant used for calculating $\Theta$ + """ + super().__init__() + + self.base = base + self.d = int(d) + self.cos_cached = None + self.sin_cached = None + + def _build_cache(self, x: torch.Tensor): + r""" + Cache $\cos$ and $\sin$ values + """ + # Return if cache is already built + if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]: + return + + # Get sequence length + seq_len = x.shape[0] + + # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.einsum("n,d->nd", seq_idx, theta) + + # Concatenate so that for row $m$ we have + # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$ + idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) + + # Cache them + self.cos_cached = idx_theta2.cos()[:, None, None, :] + self.sin_cached = idx_theta2.sin()[:, None, None, :] + + def _neg_half(self, x: torch.Tensor): + # $\frac{d}{2}$ + d_2 = self.d // 2 + + # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ + return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) + + def forward(self, x: torch.Tensor): + """ + * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]` + """ + # Cache $\cos$ and $\sin$ values + x = rearrange(x, "b h t d -> t b h d") + + self._build_cache(x) + + # Split the features, we can choose to apply rotary embeddings only to a partial set of features. + x_rope, x_pass = x[..., : self.d], x[..., self.d :] + + # Calculate + # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ + neg_half_x = self._neg_half(x_rope) + + x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]]) + + return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d") + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels, + out_channels, + n_heads, + heads_share=True, + p_dropout=0.0, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.heads_share = heads_share + self.proximal_bias = proximal_bias + self.p_dropout = p_dropout + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = torch.nn.Conv1d(channels, channels, 1) + self.conv_k = torch.nn.Conv1d(channels, channels, 1) + self.conv_v = torch.nn.Conv1d(channels, channels, 1) + + # from https://nn.labml.ai/transformers/rope/index.html + self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) + self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) + + self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) + self.drop = torch.nn.Dropout(p_dropout) + + torch.nn.init.xavier_uniform_(self.conv_q.weight) + torch.nn.init.xavier_uniform_(self.conv_k.weight) + if proximal_init: + self.conv_k.weight.data.copy_(self.conv_q.weight.data) + self.conv_k.bias.data.copy_(self.conv_q.bias.data) + torch.nn.init.xavier_uniform_(self.conv_v.weight) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads) + key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads) + value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads) + + query = self.query_rotary_pe(query) + key = self.key_rotary_pe(key) + + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) + + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + p_attn = torch.nn.functional.softmax(scores, dim=-1) + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) + return output, p_attn + + @staticmethod + def _attention_bias_proximal(length): + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2) + self.drop = torch.nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + return x * x_mask + + +class Encoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + **kwargs, + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.drop = torch.nn.Dropout(p_dropout) + self.attn_layers = torch.nn.ModuleList() + self.norm_layers_1 = torch.nn.ModuleList() + self.ffn_layers = torch.nn.ModuleList() + self.norm_layers_2 = torch.nn.ModuleList() + for _ in range(self.n_layers): + self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + for i in range(self.n_layers): + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class TextEncoder(nn.Module): + def __init__( + self, + encoder_type, + encoder_params, + duration_predictor_params, + n_vocab, + n_spks=1, + spk_emb_dim=128, + ): + super().__init__() + self.encoder_type = encoder_type + self.n_vocab = n_vocab + self.n_feats = encoder_params.n_feats + self.n_channels = encoder_params.n_channels + self.spk_emb_dim = spk_emb_dim + self.n_spks = n_spks + + self.emb = torch.nn.Embedding(n_vocab, self.n_channels) + torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5) + + if encoder_params.prenet: + self.prenet = ConvReluNorm( + self.n_channels, + self.n_channels, + self.n_channels, + kernel_size=5, + n_layers=3, + p_dropout=0.5, + ) + else: + self.prenet = lambda x, x_mask: x + + self.encoder = Encoder( + encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0), + encoder_params.filter_channels, + encoder_params.n_heads, + encoder_params.n_layers, + encoder_params.kernel_size, + encoder_params.p_dropout, + ) + + self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1) + self.proj_w = DurationPredictor( + self.n_channels + (spk_emb_dim if n_spks > 1 else 0), + duration_predictor_params.filter_channels_dp, + duration_predictor_params.kernel_size, + duration_predictor_params.p_dropout, + ) + + def forward(self, x, x_lengths, spks=None): + """Run forward pass to the transformer based encoder and duration predictor + + Args: + x (torch.Tensor): text input + shape: (batch_size, max_text_length) + x_lengths (torch.Tensor): text input lengths + shape: (batch_size,) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size,) + + Returns: + mu (torch.Tensor): average output of the encoder + shape: (batch_size, n_feats, max_text_length) + logw (torch.Tensor): log duration predicted by the duration predictor + shape: (batch_size, 1, max_text_length) + x_mask (torch.Tensor): mask for the text input + shape: (batch_size, 1, max_text_length) + """ + x = self.emb(x) * math.sqrt(self.n_channels) + x = torch.transpose(x, 1, -1) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + x = self.prenet(x, x_mask) + if self.n_spks > 1: + x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1) + x = self.encoder(x, x_mask) + mu = self.proj_m(x) * x_mask + + x_dp = torch.detach(x) + logw = self.proj_w(x_dp, x_mask) + + return mu, logw, x_mask diff --git a/egs/ljspeech/TTS/matcha_tts/models/components/transformer.py b/egs/ljspeech/TTS/matcha_tts/models/components/transformer.py new file mode 100644 index 0000000000..dd1afa3aff --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/models/components/transformer.py @@ -0,0 +1,316 @@ +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn +from diffusers.models.attention import ( + GEGLU, + GELU, + AdaLayerNorm, + AdaLayerNormZero, + ApproximateGELU, +) +from diffusers.models.attention_processor import Attention +from diffusers.models.lora import LoRACompatibleLinear +from diffusers.utils.torch_utils import maybe_allow_in_graph + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + """ + + def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): + """ + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + """ + super().__init__() + self.in_features = out_features if isinstance(out_features, list) else [out_features] + self.proj = LoRACompatibleLinear(in_features, out_features) + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha) + self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha) + self.beta = nn.Parameter(torch.ones(self.in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + """ + x = self.proj(x) + if self.alpha_logscale: + alpha = torch.exp(self.alpha) + beta = torch.exp(self.beta) + else: + alpha = self.alpha + beta = self.beta + + x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2) + + return x + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + elif activation_fn == "snakebeta": + act_fn = SnakeBeta(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + # scale_qk=False, # uncomment this to not to use flash attention + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ): + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states diff --git a/egs/ljspeech/TTS/matcha_tts/models/matcha_tts.py b/egs/ljspeech/TTS/matcha_tts/models/matcha_tts.py new file mode 100644 index 0000000000..07f95ad2e3 --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/models/matcha_tts.py @@ -0,0 +1,244 @@ +import datetime as dt +import math +import random + +import torch + +import matcha.utils.monotonic_align as monotonic_align +from matcha import utils +from matcha.models.baselightningmodule import BaseLightningClass +from matcha.models.components.flow_matching import CFM +from matcha.models.components.text_encoder import TextEncoder +from matcha.utils.model import ( + denormalize, + duration_loss, + fix_len_compatibility, + generate_path, + sequence_mask, +) + +log = utils.get_pylogger(__name__) + + +class MatchaTTS(BaseLightningClass): # 🍵 + def __init__( + self, + n_vocab, + n_spks, + spk_emb_dim, + n_feats, + encoder, + decoder, + cfm, + data_statistics, + out_size, + optimizer=None, + scheduler=None, + prior_loss=True, + use_precomputed_durations=False, + ): + super().__init__() + + self.save_hyperparameters(logger=False) + + self.n_vocab = n_vocab + self.n_spks = n_spks + self.spk_emb_dim = spk_emb_dim + self.n_feats = n_feats + self.out_size = out_size + self.prior_loss = prior_loss + self.use_precomputed_durations = use_precomputed_durations + + if n_spks > 1: + self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) + + self.encoder = TextEncoder( + encoder.encoder_type, + encoder.encoder_params, + encoder.duration_predictor_params, + n_vocab, + n_spks, + spk_emb_dim, + ) + + self.decoder = CFM( + in_channels=2 * encoder.encoder_params.n_feats, + out_channel=encoder.encoder_params.n_feats, + cfm_params=cfm, + decoder_params=decoder, + n_spks=n_spks, + spk_emb_dim=spk_emb_dim, + ) + + self.update_data_statistics(data_statistics) + + @torch.inference_mode() + def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0): + """ + Generates mel-spectrogram from text. Returns: + 1. encoder outputs + 2. decoder outputs + 3. generated alignment + + Args: + x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. + shape: (batch_size, max_text_length) + x_lengths (torch.Tensor): lengths of texts in batch. + shape: (batch_size,) + n_timesteps (int): number of steps to use for reverse diffusion in decoder. + temperature (float, optional): controls variance of terminal distribution. + spks (bool, optional): speaker ids. + shape: (batch_size,) + length_scale (float, optional): controls speech pace. + Increase value to slow down generated speech and vice versa. + + Returns: + dict: { + "encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), + # Average mel spectrogram generated by the encoder + "decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), + # Refined mel spectrogram improved by the CFM + "attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length), + # Alignment map between text and mel spectrogram + "mel": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), + # Denormalized mel spectrogram + "mel_lengths": torch.Tensor, shape: (batch_size,), + # Lengths of mel spectrograms + "rtf": float, + # Real-time factor + """ + # For RTF computation + t = dt.datetime.now() + + if self.n_spks > 1: + # Get speaker embedding + spks = self.spk_emb(spks.long()) + + # Get encoder_outputs `mu_x` and log-scaled token durations `logw` + mu_x, logw, x_mask = self.encoder(x, x_lengths, spks) + + w = torch.exp(logw) * x_mask + w_ceil = torch.ceil(w) * length_scale + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_max_length = y_lengths.max() + y_max_length_ = fix_len_compatibility(y_max_length) + + # Using obtained durations `w` construct alignment map `attn` + y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) + attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) + attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) + + # Align encoded text and get mu_y + mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) + mu_y = mu_y.transpose(1, 2) + encoder_outputs = mu_y[:, :, :y_max_length] + + # Generate sample tracing the probability flow + decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, spks) + decoder_outputs = decoder_outputs[:, :, :y_max_length] + + t = (dt.datetime.now() - t).total_seconds() + rtf = t * 22050 / (decoder_outputs.shape[-1] * 256) + + return { + "encoder_outputs": encoder_outputs, + "decoder_outputs": decoder_outputs, + "attn": attn[:, :, :y_max_length], + "mel": denormalize(decoder_outputs, self.mel_mean, self.mel_std), + "mel_lengths": y_lengths, + "rtf": rtf, + } + + def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None, durations=None): + """ + Computes 3 losses: + 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS). + 2. prior loss: loss between mel-spectrogram and encoder outputs. + 3. flow matching loss: loss between mel-spectrogram and decoder outputs. + + Args: + x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. + shape: (batch_size, max_text_length) + x_lengths (torch.Tensor): lengths of texts in batch. + shape: (batch_size,) + y (torch.Tensor): batch of corresponding mel-spectrograms. + shape: (batch_size, n_feats, max_mel_length) + y_lengths (torch.Tensor): lengths of mel-spectrograms in batch. + shape: (batch_size,) + out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained. + Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size. + spks (torch.Tensor, optional): speaker ids. + shape: (batch_size,) + """ + if self.n_spks > 1: + # Get speaker embedding + spks = self.spk_emb(spks) + + # Get encoder_outputs `mu_x` and log-scaled token durations `logw` + mu_x, logw, x_mask = self.encoder(x, x_lengths, spks) + y_max_length = y.shape[-1] + + y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) + attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) + + if self.use_precomputed_durations: + attn = generate_path(durations.squeeze(1), attn_mask.squeeze(1)) + else: + # Use MAS to find most likely alignment `attn` between text and mel-spectrogram + with torch.no_grad(): + const = -0.5 * math.log(2 * math.pi) * self.n_feats + factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device) + y_square = torch.matmul(factor.transpose(1, 2), y**2) + y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y) + mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1) + log_prior = y_square - y_mu_double + mu_square + const + + attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1)) + attn = attn.detach() # b, t_text, T_mel + + # Compute loss between predicted log-scaled durations and those obtained from MAS + # refered to as prior loss in the paper + logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask + dur_loss = duration_loss(logw, logw_, x_lengths) + + # Cut a small segment of mel-spectrogram in order to increase batch size + # - "Hack" taken from Grad-TTS, in case of Grad-TTS, we cannot train batch size 32 on a 24GB GPU without it + # - Do not need this hack for Matcha-TTS, but it works with it as well + if not isinstance(out_size, type(None)): + max_offset = (y_lengths - out_size).clamp(0) + offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy())) + out_offset = torch.LongTensor( + [torch.tensor(random.choice(range(start, end)) if end > start else 0) for start, end in offset_ranges] + ).to(y_lengths) + attn_cut = torch.zeros(attn.shape[0], attn.shape[1], out_size, dtype=attn.dtype, device=attn.device) + y_cut = torch.zeros(y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device) + + y_cut_lengths = [] + for i, (y_, out_offset_) in enumerate(zip(y, out_offset)): + y_cut_length = out_size + (y_lengths[i] - out_size).clamp(None, 0) + y_cut_lengths.append(y_cut_length) + cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length + y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper] + attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper] + + y_cut_lengths = torch.LongTensor(y_cut_lengths) + y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask) + + attn = attn_cut + y = y_cut + y_mask = y_cut_mask + + # Align encoded text with mel-spectrogram and get mu_y segment + mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) + mu_y = mu_y.transpose(1, 2) + + # Compute loss of the decoder + diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond) + + if self.prior_loss: + prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) + prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats) + else: + prior_loss = 0 + + return dur_loss, prior_loss, diff_loss, attn diff --git a/egs/ljspeech/TTS/matcha_tts/text/__init__.py b/egs/ljspeech/TTS/matcha_tts/text/__init__.py new file mode 100644 index 0000000000..dc3427f0b4 --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/text/__init__.py @@ -0,0 +1,60 @@ +""" from https://github.com/keithito/tacotron """ +from matcha.text import cleaners +from matcha.text.symbols import symbols + +# Mappings from symbol to numeric ID and vice versa: +_symbol_to_id = {s: i for i, s in enumerate(symbols)} +_id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension + + +def text_to_sequence(text, cleaner_names): + """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: + text: string to convert to a sequence + cleaner_names: names of the cleaner functions to run the text through + Returns: + List of integers corresponding to the symbols in the text + """ + sequence = [] + + clean_text = _clean_text(text, cleaner_names) + for symbol in clean_text: + try: + if symbol in '_()[]# ̃': + continue + symbol_id = _symbol_to_id[symbol] + except Exception as ex: + print(text) + print(clean_text) + raise RuntimeError(f'text: {text}, clean_text: {clean_text}, ex: {ex}, symbol: {symbol}') + sequence += [symbol_id] + return sequence, clean_text + + +def cleaned_text_to_sequence(cleaned_text): + """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: + text: string to convert to a sequence + Returns: + List of integers corresponding to the symbols in the text + """ + sequence = [_symbol_to_id[symbol] for symbol in cleaned_text] + return sequence + + +def sequence_to_text(sequence): + """Converts a sequence of IDs back to a string""" + result = "" + for symbol_id in sequence: + s = _id_to_symbol[symbol_id] + result += s + return result + + +def _clean_text(text, cleaner_names): + for name in cleaner_names: + cleaner = getattr(cleaners, name) + if not cleaner: + raise Exception("Unknown cleaner: %s" % name) + text = cleaner(text) + return text diff --git a/egs/ljspeech/TTS/matcha_tts/text/cleaners.py b/egs/ljspeech/TTS/matcha_tts/text/cleaners.py new file mode 100644 index 0000000000..33cdc9fc61 --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/text/cleaners.py @@ -0,0 +1,129 @@ +""" from https://github.com/keithito/tacotron + +Cleaners are transformations that run over the input text at both training and eval time. + +Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" +hyperparameter. Some cleaners are English-specific. You'll typically want to use: + 1. "english_cleaners" for English text + 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using + the Unidecode library (https://pypi.python.org/pypi/Unidecode) + 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update + the symbols in symbols.py to match your data). +""" + +import logging +import re + +import phonemizer +from unidecode import unidecode + +# To avoid excessive logging we set the log level of the phonemizer package to Critical +critical_logger = logging.getLogger("phonemizer") +critical_logger.setLevel(logging.CRITICAL) + +# Intializing the phonemizer globally significantly reduces the speed +# now the phonemizer is not initialising at every call +# Might be less flexible, but it is much-much faster +global_phonemizer = phonemizer.backend.EspeakBackend( + language="en-us", + preserve_punctuation=True, + with_stress=True, + language_switch="remove-flags", + logger=critical_logger, +) + + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r"\s+") + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] +] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, " ", text) + +def remove_parentheses(text): + text = text.replace("(", "") + text = text.replace(")", "") + text = text.replace("[", "") + text = text.replace("]", "") + return text + + +def convert_to_ascii(text): + return unidecode(text) + + +def basic_cleaners(text): + """Basic pipeline that lowercases and collapses whitespace without transliteration.""" + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def transliteration_cleaners(text): + """Pipeline for non-English text that transliterates to ASCII.""" + text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def english_cleaners2(text): + """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_abbreviations(text) + text = remove_parentheses(text) + phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0] + phonemes = collapse_whitespace(phonemes) + return phonemes + + +# I am removing this due to incompatibility with several version of python +# However, if you want to use it, you can uncomment it +# and install piper-phonemize with the following command: +# pip install piper-phonemize + +# import piper_phonemize +# def english_cleaners_piper(text): +# """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" +# text = convert_to_ascii(text) +# text = lowercase(text) +# text = expand_abbreviations(text) +# phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0]) +# phonemes = collapse_whitespace(phonemes) +# return phonemes diff --git a/egs/ljspeech/TTS/matcha_tts/text/numbers.py b/egs/ljspeech/TTS/matcha_tts/text/numbers.py new file mode 100644 index 0000000000..f99a8686dc --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/text/numbers.py @@ -0,0 +1,71 @@ +""" from https://github.com/keithito/tacotron """ + +import re + +import inflect + +_inflect = inflect.engine() +_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") +_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") +_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") +_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") +_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") +_number_re = re.compile(r"[0-9]+") + + +def _remove_commas(m): + return m.group(1).replace(",", "") + + +def _expand_decimal_point(m): + return m.group(1).replace(".", " point ") + + +def _expand_dollars(m): + match = m.group(1) + parts = match.split(".") + if len(parts) > 2: + return match + " dollars" + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = "dollar" if dollars == 1 else "dollars" + cent_unit = "cent" if cents == 1 else "cents" + return f"{dollars} {dollar_unit}, {cents} {cent_unit}" + elif dollars: + dollar_unit = "dollar" if dollars == 1 else "dollars" + return f"{dollars} {dollar_unit}" + elif cents: + cent_unit = "cent" if cents == 1 else "cents" + return f"{cents} {cent_unit}" + else: + return "zero dollars" + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return "two thousand" + elif num > 2000 and num < 2010: + return "two thousand " + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + " hundred" + else: + return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") + else: + return _inflect.number_to_words(num, andword="") + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r"\1 pounds", text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text diff --git a/egs/ljspeech/TTS/matcha_tts/text/symbols.py b/egs/ljspeech/TTS/matcha_tts/text/symbols.py new file mode 100644 index 0000000000..7018df549a --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/text/symbols.py @@ -0,0 +1,17 @@ +""" from https://github.com/keithito/tacotron + +Defines the set of symbols used in text input to the model. +""" +_pad = "_" +_punctuation = ';:,.!?¡¿—…"«»“” ' +_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +_letters_ipa = ( + "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" +) + + +# Export all symbols: +symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + +# Special symbol ids +SPACE_ID = symbols.index(" ") diff --git a/egs/ljspeech/TTS/matcha_tts/train.py b/egs/ljspeech/TTS/matcha_tts/train.py new file mode 100644 index 0000000000..d1d64c6c44 --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/train.py @@ -0,0 +1,122 @@ +from typing import Any, Dict, List, Optional, Tuple + +import hydra +import lightning as L +import rootutils +from lightning import Callback, LightningDataModule, LightningModule, Trainer +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +from matcha import utils + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +# ------------------------------------------------------------------------------------ # +# the setup_root above is equivalent to: +# - adding project root dir to PYTHONPATH +# (so you don't need to force user to install project as a package) +# (necessary before importing any local modules e.g. `from src import utils`) +# - setting up PROJECT_ROOT environment variable +# (which is used as a base for paths in "configs/paths/default.yaml") +# (this way all filepaths are the same no matter where you run the code) +# - loading environment variables from ".env" in root dir +# +# you can remove it if you: +# 1. either install project as a package or move entry files to project root dir +# 2. set `root_dir` to "." in "configs/paths/default.yaml" +# +# more info: https://github.com/ashleve/rootutils +# ------------------------------------------------------------------------------------ # + + +log = utils.get_pylogger(__name__) + + +@utils.task_wrapper +def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Trains the model. Can additionally evaluate on a testset, using best weights obtained during + training. + + This method is wrapped in optional @task_wrapper decorator, that controls the behavior during + failure. Useful for multiruns, saving info about the crash, etc. + + :param cfg: A DictConfig configuration composed by Hydra. + :return: A tuple with metrics and dict with all instantiated objects. + """ + # set seed for random number generators in pytorch, numpy and python.random + if cfg.get("seed"): + L.seed_everything(cfg.seed, workers=True) + + log.info(f"Instantiating datamodule <{cfg.data._target_}>") # pylint: disable=protected-access + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) + + log.info(f"Instantiating model <{cfg.model._target_}>") # pylint: disable=protected-access + model: LightningModule = hydra.utils.instantiate(cfg.model) + + log.info("Instantiating callbacks...") + callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) + + log.info("Instantiating loggers...") + logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) + + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") # pylint: disable=protected-access + trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) + + object_dict = { + "cfg": cfg, + "datamodule": datamodule, + "model": model, + "callbacks": callbacks, + "logger": logger, + "trainer": trainer, + } + + if logger: + log.info("Logging hyperparameters!") + utils.log_hyperparameters(object_dict) + + if cfg.get("train"): + log.info("Starting training!") + trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) + + train_metrics = trainer.callback_metrics + + if cfg.get("test"): + log.info("Starting testing!") + ckpt_path = trainer.checkpoint_callback.best_model_path + if ckpt_path == "": + log.warning("Best ckpt not found! Using current weights for testing...") + ckpt_path = None + trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + log.info(f"Best ckpt path: {ckpt_path}") + + test_metrics = trainer.callback_metrics + + # merge train and test metrics + metric_dict = {**train_metrics, **test_metrics} + + return metric_dict, object_dict + + +@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") +def main(cfg: DictConfig) -> Optional[float]: + """Main entry point for training. + + :param cfg: DictConfig configuration composed by Hydra. + :return: Optional[float] with optimized metric value. + """ + # apply extra utilities + # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) + utils.extras(cfg) + + # train the model + metric_dict, _ = train(cfg) + + # safely retrieve metric value for hydra-based hyperparameter optimization + metric_value = utils.get_metric_value(metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")) + + # return optimized metric + return metric_value + + +if __name__ == "__main__": + main() # pylint: disable=no-value-for-parameter diff --git a/egs/ljspeech/TTS/matcha_tts/utils/__init__.py b/egs/ljspeech/TTS/matcha_tts/utils/__init__.py new file mode 100644 index 0000000000..074db64611 --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/utils/__init__.py @@ -0,0 +1,5 @@ +from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers +from matcha.utils.logging_utils import log_hyperparameters +from matcha.utils.pylogger import get_pylogger +from matcha.utils.rich_utils import enforce_tags, print_config_tree +from matcha.utils.utils import extras, get_metric_value, task_wrapper diff --git a/egs/ljspeech/TTS/matcha_tts/utils/audio.py b/egs/ljspeech/TTS/matcha_tts/utils/audio.py new file mode 100644 index 0000000000..0bcd74df47 --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/utils/audio.py @@ -0,0 +1,82 @@ +import numpy as np +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn +from scipy.io.wavfile import read + +MAX_WAV_VALUE = 32768.0 + + +def load_wav(full_path): + sampling_rate, data = read(full_path) + return data, sampling_rate + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window # pylint: disable=global-statement + if f"{str(fmax)}_{str(y.device)}" not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec diff --git a/egs/ljspeech/TTS/matcha_tts/utils/generate_data_statistics.py b/egs/ljspeech/TTS/matcha_tts/utils/generate_data_statistics.py new file mode 100644 index 0000000000..3b8cd67c91 --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/utils/generate_data_statistics.py @@ -0,0 +1,113 @@ +r""" +The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it +when needed. + +Parameters from hparam.py will be used +""" +import argparse +import json +import os +import sys +from pathlib import Path + +import rootutils +import torch +from hydra import compose, initialize +from omegaconf import open_dict +from tqdm.auto import tqdm + +from matcha.data.text_mel_datamodule import TextMelDataModule +from matcha.utils.logging_utils import pylogger + +log = pylogger.get_pylogger(__name__) + + +def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int): + """Generate data mean and standard deviation helpful in data normalisation + + Args: + data_loader (torch.utils.data.Dataloader): _description_ + out_channels (int): mel spectrogram channels + """ + total_mel_sum = 0 + total_mel_sq_sum = 0 + total_mel_len = 0 + + for batch in tqdm(data_loader, leave=False): + mels = batch["y"] + mel_lengths = batch["y_lengths"] + + total_mel_len += torch.sum(mel_lengths) + total_mel_sum += torch.sum(mels) + total_mel_sq_sum += torch.sum(torch.pow(mels, 2)) + + data_mean = total_mel_sum / (total_mel_len * out_channels) + data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)) + + return {"mel_mean": data_mean.item(), "mel_std": data_std.item()} + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-i", + "--input-config", + type=str, + default="vctk.yaml", + help="The name of the yaml config file under configs/data", + ) + + parser.add_argument( + "-b", + "--batch-size", + type=int, + default="256", + help="Can have increased batch size for faster computation", + ) + + parser.add_argument( + "-f", + "--force", + action="store_true", + default=False, + required=False, + help="force overwrite the file", + ) + args = parser.parse_args() + output_file = Path(args.input_config).with_suffix(".json") + + if os.path.exists(output_file) and not args.force: + print("File already exists. Use -f to force overwrite") + sys.exit(1) + + with initialize(version_base="1.3", config_path="../../configs/data"): + cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) + + root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") + + with open_dict(cfg): + print(cfg) + del cfg["hydra"] + del cfg["_target_"] + cfg["data_statistics"] = None + cfg["seed"] = 1234 + cfg["batch_size"] = args.batch_size + cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) + cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) + cfg["load_durations"] = False + + text_mel_datamodule = TextMelDataModule(**cfg) + text_mel_datamodule.setup() + data_loader = text_mel_datamodule.train_dataloader() + log.info("Dataloader loaded! Now computing stats...") + params = compute_data_statistics(data_loader, cfg["n_feats"]) + print(params) + json.dump( + params, + open(output_file, "w"), + ) + + +if __name__ == "__main__": + main() diff --git a/egs/ljspeech/TTS/matcha_tts/utils/get_durations_from_trained_model.py b/egs/ljspeech/TTS/matcha_tts/utils/get_durations_from_trained_model.py new file mode 100644 index 0000000000..0fe2f35c42 --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/utils/get_durations_from_trained_model.py @@ -0,0 +1,195 @@ +r""" +The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it +when needed. + +Parameters from hparam.py will be used +""" +import argparse +import json +import os +import sys +from pathlib import Path + +import lightning +import numpy as np +import rootutils +import torch +from hydra import compose, initialize +from omegaconf import open_dict +from torch import nn +from tqdm.auto import tqdm + +from matcha.cli import get_device +from matcha.data.text_mel_datamodule import TextMelDataModule +from matcha.models.matcha_tts import MatchaTTS +from matcha.utils.logging_utils import pylogger +from matcha.utils.utils import get_phoneme_durations + +log = pylogger.get_pylogger(__name__) + + +def save_durations_to_folder( + attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str +): + durations = attn.squeeze().sum(1)[:x_length].numpy() + durations_json = get_phoneme_durations(durations, text) + output = output_folder / Path(filepath).name.replace(".wav", ".npy") + with open(output.with_suffix(".json"), "w", encoding="utf-8") as f: + json.dump(durations_json, f, indent=4, ensure_ascii=False) + + np.save(output, durations) + + +@torch.inference_mode() +def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module, device: torch.device, output_folder): + """Generate durations from the model for each datapoint and save it in a folder + + Args: + data_loader (torch.utils.data.DataLoader): Dataloader + model (nn.Module): MatchaTTS model + device (torch.device): GPU or CPU + """ + + for batch in tqdm(data_loader, desc="🍵 Computing durations 🍵:"): + x, x_lengths = batch["x"], batch["x_lengths"] + y, y_lengths = batch["y"], batch["y_lengths"] + spks = batch["spks"] + x = x.to(device) + y = y.to(device) + x_lengths = x_lengths.to(device) + y_lengths = y_lengths.to(device) + spks = spks.to(device) if spks is not None else None + + _, _, _, attn = model( + x=x, + x_lengths=x_lengths, + y=y, + y_lengths=y_lengths, + spks=spks, + ) + attn = attn.cpu() + for i in range(attn.shape[0]): + save_durations_to_folder( + attn[i], + x_lengths[i].item(), + y_lengths[i].item(), + batch["filepaths"][i], + output_folder, + batch["x_texts"][i], + ) + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-i", + "--input-config", + type=str, + default="ljspeech.yaml", + help="The name of the yaml config file under configs/data", + ) + + parser.add_argument( + "-b", + "--batch-size", + type=int, + default="32", + help="Can have increased batch size for faster computation", + ) + + parser.add_argument( + "-f", + "--force", + action="store_true", + default=False, + required=False, + help="force overwrite the file", + ) + parser.add_argument( + "-c", + "--checkpoint_path", + type=str, + required=True, + help="Path to the checkpoint file to load the model from", + ) + + parser.add_argument( + "-o", + "--output-folder", + type=str, + default=None, + help="Output folder to save the data statistics", + ) + + parser.add_argument( + "--cpu", action="store_true", help="Use CPU for inference, not recommended (default: use GPU if available)" + ) + + args = parser.parse_args() + + with initialize(version_base="1.3", config_path="../../configs/data"): + cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) + + root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") + + with open_dict(cfg): + del cfg["hydra"] + del cfg["_target_"] + cfg["seed"] = 1234 + cfg["batch_size"] = args.batch_size + cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) + cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) + cfg["load_durations"] = False + + if args.output_folder is not None: + output_folder = Path(args.output_folder) + else: + output_folder = Path(cfg["train_filelist_path"]).parent / "durations" + + print(f"Output folder set to: {output_folder}") + + if os.path.exists(output_folder) and not args.force: + print("Folder already exists. Use -f to force overwrite") + sys.exit(1) + + output_folder.mkdir(parents=True, exist_ok=True) + + print(f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}") + print("Loading model...") + device = get_device(args) + model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device) + + text_mel_datamodule = TextMelDataModule(**cfg) + text_mel_datamodule.setup() + try: + print("Computing stats for training set if exists...") + train_dataloader = text_mel_datamodule.train_dataloader() + compute_durations(train_dataloader, model, device, output_folder) + except lightning.fabric.utilities.exceptions.MisconfigurationException: + print("No training set found") + + try: + print("Computing stats for validation set if exists...") + val_dataloader = text_mel_datamodule.val_dataloader() + compute_durations(val_dataloader, model, device, output_folder) + except lightning.fabric.utilities.exceptions.MisconfigurationException: + print("No validation set found") + + try: + print("Computing stats for test set if exists...") + test_dataloader = text_mel_datamodule.test_dataloader() + compute_durations(test_dataloader, model, device, output_folder) + except lightning.fabric.utilities.exceptions.MisconfigurationException: + print("No test set found") + + print(f"[+] Done! Data statistics saved to: {output_folder}") + + +if __name__ == "__main__": + # Helps with generating durations for the dataset to train other architectures + # that cannot learn to align due to limited size of dataset + # Example usage: + # python python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c pretrained_model + # This will create a folder in data/processed_data/durations/ljspeech with the durations + main() diff --git a/egs/ljspeech/TTS/matcha_tts/utils/instantiators.py b/egs/ljspeech/TTS/matcha_tts/utils/instantiators.py new file mode 100644 index 0000000000..5547b4ed61 --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/utils/instantiators.py @@ -0,0 +1,56 @@ +from typing import List + +import hydra +from lightning import Callback +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +from matcha.utils import pylogger + +log = pylogger.get_pylogger(__name__) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config. + + :param callbacks_cfg: A DictConfig object containing callback configurations. + :return: A list of instantiated callbacks. + """ + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("No callback configs found! Skipping..") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config. + + :param logger_cfg: A DictConfig object containing logger configurations. + :return: A list of instantiated loggers. + """ + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("No logger configs found! Skipping...") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger diff --git a/egs/ljspeech/TTS/matcha_tts/utils/logging_utils.py b/egs/ljspeech/TTS/matcha_tts/utils/logging_utils.py new file mode 100644 index 0000000000..1a12d1ddaf --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/utils/logging_utils.py @@ -0,0 +1,53 @@ +from typing import Any, Dict + +from lightning.pytorch.utilities import rank_zero_only +from omegaconf import OmegaConf + +from matcha.utils import pylogger + +log = pylogger.get_pylogger(__name__) + + +@rank_zero_only +def log_hyperparameters(object_dict: Dict[str, Any]) -> None: + """Controls which config parts are saved by Lightning loggers. + + Additionally saves: + - Number of model parameters + + :param object_dict: A dictionary containing the following objects: + - `"cfg"`: A DictConfig object containing the main config. + - `"model"`: The Lightning model. + - `"trainer"`: The Lightning trainer. + """ + hparams = {} + + cfg = OmegaConf.to_container(object_dict["cfg"]) + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + hparams["model"] = cfg["model"] + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad) + hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad) + + hparams["data"] = cfg["data"] + hparams["trainer"] = cfg["trainer"] + + hparams["callbacks"] = cfg.get("callbacks") + hparams["extras"] = cfg.get("extras") + + hparams["task_name"] = cfg.get("task_name") + hparams["tags"] = cfg.get("tags") + hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["seed"] = cfg.get("seed") + + # send hparams to all loggers + for logger in trainer.loggers: + logger.log_hyperparams(hparams) diff --git a/egs/ljspeech/TTS/matcha_tts/utils/model.py b/egs/ljspeech/TTS/matcha_tts/utils/model.py new file mode 100644 index 0000000000..869cc6092f --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/utils/model.py @@ -0,0 +1,90 @@ +""" from https://github.com/jaywalnut310/glow-tts """ + +import numpy as np +import torch + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def fix_len_compatibility(length, num_downsamplings_in_unet=2): + factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet) + length = (length / factor).ceil() * factor + if not torch.onnx.is_in_onnx_export(): + return length.int().item() + else: + return length + + +def convert_pad_shape(pad_shape): + inverted_shape = pad_shape[::-1] + pad_shape = [item for sublist in inverted_shape for item in sublist] + return pad_shape + + +def generate_path(duration, mask): + device = duration.device + + b, t_x, t_y = mask.shape + cum_duration = torch.cumsum(duration, 1) + path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path * mask + return path + + +def duration_loss(logw, logw_, lengths): + loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths) + return loss + + +def normalize(data, mu, std): + if not isinstance(mu, (float, int)): + if isinstance(mu, list): + mu = torch.tensor(mu, dtype=data.dtype, device=data.device) + elif isinstance(mu, torch.Tensor): + mu = mu.to(data.device) + elif isinstance(mu, np.ndarray): + mu = torch.from_numpy(mu).to(data.device) + mu = mu.unsqueeze(-1) + + if not isinstance(std, (float, int)): + if isinstance(std, list): + std = torch.tensor(std, dtype=data.dtype, device=data.device) + elif isinstance(std, torch.Tensor): + std = std.to(data.device) + elif isinstance(std, np.ndarray): + std = torch.from_numpy(std).to(data.device) + std = std.unsqueeze(-1) + + return (data - mu) / std + + +def denormalize(data, mu, std): + if not isinstance(mu, float): + if isinstance(mu, list): + mu = torch.tensor(mu, dtype=data.dtype, device=data.device) + elif isinstance(mu, torch.Tensor): + mu = mu.to(data.device) + elif isinstance(mu, np.ndarray): + mu = torch.from_numpy(mu).to(data.device) + mu = mu.unsqueeze(-1) + + if not isinstance(std, float): + if isinstance(std, list): + std = torch.tensor(std, dtype=data.dtype, device=data.device) + elif isinstance(std, torch.Tensor): + std = std.to(data.device) + elif isinstance(std, np.ndarray): + std = torch.from_numpy(std).to(data.device) + std = std.unsqueeze(-1) + + return data * std + mu diff --git a/egs/ljspeech/TTS/matcha_tts/utils/monotonic_align/__init__.py b/egs/ljspeech/TTS/matcha_tts/utils/monotonic_align/__init__.py new file mode 100644 index 0000000000..eee6e0d47c --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/utils/monotonic_align/__init__.py @@ -0,0 +1,22 @@ +import numpy as np +import torch + +from matcha.utils.monotonic_align.core import maximum_path_c + + +def maximum_path(value, mask): + """Cython optimised version. + value: [b, t_x, t_y] + mask: [b, t_x, t_y] + """ + value = value * mask + device = value.device + dtype = value.dtype + value = value.data.cpu().numpy().astype(np.float32) + path = np.zeros_like(value).astype(np.int32) + mask = mask.data.cpu().numpy() + + t_x_max = mask.sum(1)[:, 0].astype(np.int32) + t_y_max = mask.sum(2)[:, 0].astype(np.int32) + maximum_path_c(path, value, t_x_max, t_y_max) + return torch.from_numpy(path).to(device=device, dtype=dtype) diff --git a/egs/ljspeech/TTS/matcha_tts/utils/monotonic_align/core.pyx b/egs/ljspeech/TTS/matcha_tts/utils/monotonic_align/core.pyx new file mode 100644 index 0000000000..091fcc3a50 --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/utils/monotonic_align/core.pyx @@ -0,0 +1,47 @@ +import numpy as np + +cimport cython +cimport numpy as np + +from cython.parallel import prange + + +@cython.boundscheck(False) +@cython.wraparound(False) +cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: + cdef int x + cdef int y + cdef float v_prev + cdef float v_cur + cdef float tmp + cdef int index = t_x - 1 + + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[x, y-1] + if x == 0: + if y == 0: + v_prev = 0. + else: + v_prev = max_neg_val + else: + v_prev = value[x-1, y-1] + value[x, y] = max(v_cur, v_prev) + value[x, y] + + for y in range(t_y - 1, -1, -1): + path[index, y] = 1 + if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): + index = index - 1 + + +@cython.boundscheck(False) +@cython.wraparound(False) +cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: + cdef int b = values.shape[0] + + cdef int i + for i in prange(b, nogil=True): + maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) diff --git a/egs/ljspeech/TTS/matcha_tts/utils/monotonic_align/setup.py b/egs/ljspeech/TTS/matcha_tts/utils/monotonic_align/setup.py new file mode 100644 index 0000000000..6092e20d26 --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/utils/monotonic_align/setup.py @@ -0,0 +1,9 @@ +from distutils.core import setup +from Cython.Build import cythonize +import numpy + +setup( + name="monotonic_align", + ext_modules=cythonize("core.pyx"), + include_dirs=[numpy.get_include()], +) diff --git a/egs/ljspeech/TTS/matcha_tts/utils/pylogger.py b/egs/ljspeech/TTS/matcha_tts/utils/pylogger.py new file mode 100644 index 0000000000..6160067802 --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/utils/pylogger.py @@ -0,0 +1,21 @@ +import logging + +from lightning.pytorch.utilities import rank_zero_only + + +def get_pylogger(name: str = __name__) -> logging.Logger: + """Initializes a multi-GPU-friendly python command line logger. + + :param name: The name of the logger, defaults to ``__name__``. + + :return: A logger object. + """ + logger = logging.getLogger(name) + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") + for level in logging_levels: + setattr(logger, level, rank_zero_only(getattr(logger, level))) + + return logger diff --git a/egs/ljspeech/TTS/matcha_tts/utils/rich_utils.py b/egs/ljspeech/TTS/matcha_tts/utils/rich_utils.py new file mode 100644 index 0000000000..f602f6e935 --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/utils/rich_utils.py @@ -0,0 +1,101 @@ +from pathlib import Path +from typing import Sequence + +import rich +import rich.syntax +import rich.tree +from hydra.core.hydra_config import HydraConfig +from lightning.pytorch.utilities import rank_zero_only +from omegaconf import DictConfig, OmegaConf, open_dict +from rich.prompt import Prompt + +from matcha.utils import pylogger + +log = pylogger.get_pylogger(__name__) + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + "data", + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = False, + save_to_file: bool = False, +) -> None: + """Prints the contents of a DictConfig as a tree structure using the Rich library. + + :param cfg: A DictConfig composed by Hydra. + :param print_order: Determines in what order config components are printed. Default is ``("data", "model", + "callbacks", "logger", "trainer", "paths", "extras")``. + :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. + :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. + """ + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + _ = ( + queue.append(field) + if field in cfg + else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...") + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) + + +@rank_zero_only +def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: + """Prompts user to input tags from command line if no tags are provided in config. + + :param cfg: A DictConfig composed by Hydra. + :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. + """ + if not cfg.get("tags"): + if "id" in HydraConfig().cfg.hydra.job: + raise ValueError("Specify tags before launching a multirun!") + + log.warning("No tags provided in config. Prompting user to input tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + tags = [t.strip() for t in tags.split(",") if t != ""] + + with open_dict(cfg): + cfg.tags = tags + + log.info(f"Tags: {cfg.tags}") + + if save_to_file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + rich.print(cfg.tags, file=file) diff --git a/egs/ljspeech/TTS/matcha_tts/utils/utils.py b/egs/ljspeech/TTS/matcha_tts/utils/utils.py new file mode 100644 index 0000000000..fc3a48ec2b --- /dev/null +++ b/egs/ljspeech/TTS/matcha_tts/utils/utils.py @@ -0,0 +1,259 @@ +import os +import sys +import warnings +from importlib.util import find_spec +from math import ceil +from pathlib import Path +from typing import Any, Callable, Dict, Tuple + +import gdown +import matplotlib.pyplot as plt +import numpy as np +import torch +import wget +from omegaconf import DictConfig + +from matcha.utils import pylogger, rich_utils + +log = pylogger.get_pylogger(__name__) + + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before the task is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + + :param cfg: A DictConfig object containing the config tree. + """ + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + rich_utils.enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) + + +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that controls the failure behavior when executing the task function. + + This wrapper can be used to: + - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) + - save the exception to a `.log` file + - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) + - etc. (adjust depending on your needs) + + Example: + ``` + @utils.task_wrapper + def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + ... + return metric_dict, object_dict + ``` + + :param task_func: The task function to be wrapped. + + :return: The wrapped task function. + """ + + def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + # execute the task + try: + metric_dict, object_dict = task_func(cfg=cfg) + + # things to do if exception occurs + except Exception as ex: + # save exception to `.log` file + log.exception("") + + # some hyperparameter combinations might be invalid or cause out-of-memory errors + # so when using hparam search plugins like Optuna, you might want to disable + # raising the below exception to avoid multirun failure + raise ex + + # things to always do after either success or exception + finally: + # display output dir path in terminal + log.info(f"Output dir: {cfg.paths.output_dir}") + + # always close wandb run (even if exception occurs so multirun won't fail) + if find_spec("wandb"): # check if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() + + return metric_dict, object_dict + + return wrap + + +def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float: + """Safely retrieves value of the metric logged in LightningModule. + + :param metric_dict: A dict containing metric values. + :param metric_name: The name of the metric to retrieve. + :return: The value of the metric. + """ + if not metric_name: + log.info("Metric name is None! Skipping metric value retrieval...") + return None + + if metric_name not in metric_dict: + raise ValueError( + f"Metric value not found! \n" + "Make sure metric name logged in LightningModule is correct!\n" + "Make sure `optimized_metric` name in `hparams_search` config is correct!" + ) + + metric_value = metric_dict[metric_name].item() + log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") + + return metric_value + + +def intersperse(lst, item): + # Adds blank symbol + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def save_figure_to_numpy(fig): + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + return data + + +def plot_tensor(tensor): + plt.style.use("default") + fig, ax = plt.subplots(figsize=(12, 3)) + im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.tight_layout() + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close() + return data + + +def save_plot(tensor, savepath): + plt.style.use("default") + fig, ax = plt.subplots(figsize=(12, 3)) + im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.tight_layout() + fig.canvas.draw() + plt.savefig(savepath) + plt.close() + + +def to_numpy(tensor): + if isinstance(tensor, np.ndarray): + return tensor + elif isinstance(tensor, torch.Tensor): + return tensor.detach().cpu().numpy() + elif isinstance(tensor, list): + return np.array(tensor) + else: + raise TypeError("Unsupported type for conversion to numpy array") + + +def get_user_data_dir(appname="matcha_tts"): + """ + Args: + appname (str): Name of application + + Returns: + Path: path to user data directory + """ + + MATCHA_HOME = os.environ.get("MATCHA_HOME") + if MATCHA_HOME is not None: + ans = Path(MATCHA_HOME).expanduser().resolve(strict=False) + elif sys.platform == "win32": + import winreg # pylint: disable=import-outside-toplevel + + key = winreg.OpenKey( + winreg.HKEY_CURRENT_USER, + r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders", + ) + dir_, _ = winreg.QueryValueEx(key, "Local AppData") + ans = Path(dir_).resolve(strict=False) + elif sys.platform == "darwin": + ans = Path("~/Library/Application Support/").expanduser() + else: + ans = Path.home().joinpath(".local/share") + + final_path = ans.joinpath(appname) + final_path.mkdir(parents=True, exist_ok=True) + return final_path + + +def assert_model_downloaded(checkpoint_path, url, use_wget=True): + if Path(checkpoint_path).exists(): + log.debug(f"[+] Model already present at {checkpoint_path}!") + print(f"[+] Model already present at {checkpoint_path}!") + return + log.info(f"[-] Model not found at {checkpoint_path}! Will download it") + print(f"[-] Model not found at {checkpoint_path}! Will download it") + checkpoint_path = str(checkpoint_path) + if not use_wget: + gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True) + else: + wget.download(url=url, out=checkpoint_path) + + +def get_phoneme_durations(durations, phones): + prev = durations[0] + merged_durations = [] + # Convolve with stride 2 + for i in range(1, len(durations), 2): + if i == len(durations) - 2: + # if it is last take full value + next_half = durations[i + 1] + else: + next_half = ceil(durations[i + 1] / 2) + + curr = prev + durations[i] + next_half + prev = durations[i + 1] - next_half + merged_durations.append(curr) + + assert len(phones) == len(merged_durations) + assert len(merged_durations) == (len(durations) - 1) // 2 + + merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long) + start = torch.tensor(0) + duration_json = [] + for i, duration in enumerate(merged_durations): + duration_json.append( + { + phones[i]: { + "starttime": start.item(), + "endtime": duration.item(), + "duration": duration.item() - start.item(), + } + } + ) + start = duration + + assert list(duration_json[-1].values())[0]["endtime"] == sum( + durations + ), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}" + return duration_json From ac1125e1bb80306c16da53aec090844a4524e6d4 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 15 Oct 2024 15:50:06 +0800 Subject: [PATCH 02/27] rename --- egs/ljspeech/TTS/matcha_tts/{train.py => train-orig.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename egs/ljspeech/TTS/matcha_tts/{train.py => train-orig.py} (100%) diff --git a/egs/ljspeech/TTS/matcha_tts/train.py b/egs/ljspeech/TTS/matcha_tts/train-orig.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/train.py rename to egs/ljspeech/TTS/matcha_tts/train-orig.py From f95ac12d70737b9149f29fd391a491feb2f33281 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 15 Oct 2024 17:12:10 +0800 Subject: [PATCH 03/27] rename --- egs/ljspeech/TTS/{matcha_tts => matcha}/__init__.py | 0 egs/ljspeech/TTS/{matcha_tts => matcha}/hifigan/LICENSE | 0 egs/ljspeech/TTS/{matcha_tts => matcha}/hifigan/README.md | 0 egs/ljspeech/TTS/{matcha_tts => matcha}/hifigan/__init__.py | 0 egs/ljspeech/TTS/{matcha_tts => matcha}/hifigan/config.py | 0 egs/ljspeech/TTS/{matcha_tts => matcha}/hifigan/denoiser.py | 0 egs/ljspeech/TTS/{matcha_tts => matcha}/hifigan/env.py | 0 egs/ljspeech/TTS/{matcha_tts => matcha}/hifigan/meldataset.py | 0 egs/ljspeech/TTS/{matcha_tts => matcha}/hifigan/models.py | 0 egs/ljspeech/TTS/{matcha_tts => matcha}/hifigan/xutils.py | 0 egs/ljspeech/TTS/{matcha_tts => matcha}/models/__init__.py | 0 .../TTS/{matcha_tts => matcha}/models/baselightningmodule.py | 0 .../TTS/{matcha_tts => matcha}/models/components/__init__.py | 0 .../TTS/{matcha_tts => matcha}/models/components/decoder.py | 0 .../TTS/{matcha_tts => matcha}/models/components/flow_matching.py | 0 .../TTS/{matcha_tts => matcha}/models/components/text_encoder.py | 0 .../TTS/{matcha_tts => matcha}/models/components/transformer.py | 0 egs/ljspeech/TTS/{matcha_tts => matcha}/models/matcha_tts.py | 0 egs/ljspeech/TTS/{matcha_tts => matcha}/text/__init__.py | 0 egs/ljspeech/TTS/{matcha_tts => matcha}/text/cleaners.py | 0 egs/ljspeech/TTS/{matcha_tts => matcha}/text/numbers.py | 0 egs/ljspeech/TTS/{matcha_tts => matcha}/text/symbols.py | 0 egs/ljspeech/TTS/{matcha_tts => matcha}/train-orig.py | 0 egs/ljspeech/TTS/{matcha_tts => matcha}/utils/__init__.py | 0 egs/ljspeech/TTS/{matcha_tts => matcha}/utils/audio.py | 0 .../TTS/{matcha_tts => matcha}/utils/generate_data_statistics.py | 0 .../utils/get_durations_from_trained_model.py | 0 egs/ljspeech/TTS/{matcha_tts => matcha}/utils/instantiators.py | 0 egs/ljspeech/TTS/{matcha_tts => matcha}/utils/logging_utils.py | 0 egs/ljspeech/TTS/{matcha_tts => matcha}/utils/model.py | 0 .../TTS/{matcha_tts => matcha}/utils/monotonic_align/__init__.py | 0 .../TTS/{matcha_tts => matcha}/utils/monotonic_align/core.pyx | 0 .../TTS/{matcha_tts => matcha}/utils/monotonic_align/setup.py | 0 egs/ljspeech/TTS/{matcha_tts => matcha}/utils/pylogger.py | 0 egs/ljspeech/TTS/{matcha_tts => matcha}/utils/rich_utils.py | 0 egs/ljspeech/TTS/{matcha_tts => matcha}/utils/utils.py | 0 36 files changed, 0 insertions(+), 0 deletions(-) rename egs/ljspeech/TTS/{matcha_tts => matcha}/__init__.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/hifigan/LICENSE (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/hifigan/README.md (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/hifigan/__init__.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/hifigan/config.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/hifigan/denoiser.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/hifigan/env.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/hifigan/meldataset.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/hifigan/models.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/hifigan/xutils.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/models/__init__.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/models/baselightningmodule.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/models/components/__init__.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/models/components/decoder.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/models/components/flow_matching.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/models/components/text_encoder.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/models/components/transformer.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/models/matcha_tts.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/text/__init__.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/text/cleaners.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/text/numbers.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/text/symbols.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/train-orig.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/utils/__init__.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/utils/audio.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/utils/generate_data_statistics.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/utils/get_durations_from_trained_model.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/utils/instantiators.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/utils/logging_utils.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/utils/model.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/utils/monotonic_align/__init__.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/utils/monotonic_align/core.pyx (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/utils/monotonic_align/setup.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/utils/pylogger.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/utils/rich_utils.py (100%) rename egs/ljspeech/TTS/{matcha_tts => matcha}/utils/utils.py (100%) diff --git a/egs/ljspeech/TTS/matcha_tts/__init__.py b/egs/ljspeech/TTS/matcha/__init__.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/__init__.py rename to egs/ljspeech/TTS/matcha/__init__.py diff --git a/egs/ljspeech/TTS/matcha_tts/hifigan/LICENSE b/egs/ljspeech/TTS/matcha/hifigan/LICENSE similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/hifigan/LICENSE rename to egs/ljspeech/TTS/matcha/hifigan/LICENSE diff --git a/egs/ljspeech/TTS/matcha_tts/hifigan/README.md b/egs/ljspeech/TTS/matcha/hifigan/README.md similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/hifigan/README.md rename to egs/ljspeech/TTS/matcha/hifigan/README.md diff --git a/egs/ljspeech/TTS/matcha_tts/hifigan/__init__.py b/egs/ljspeech/TTS/matcha/hifigan/__init__.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/hifigan/__init__.py rename to egs/ljspeech/TTS/matcha/hifigan/__init__.py diff --git a/egs/ljspeech/TTS/matcha_tts/hifigan/config.py b/egs/ljspeech/TTS/matcha/hifigan/config.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/hifigan/config.py rename to egs/ljspeech/TTS/matcha/hifigan/config.py diff --git a/egs/ljspeech/TTS/matcha_tts/hifigan/denoiser.py b/egs/ljspeech/TTS/matcha/hifigan/denoiser.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/hifigan/denoiser.py rename to egs/ljspeech/TTS/matcha/hifigan/denoiser.py diff --git a/egs/ljspeech/TTS/matcha_tts/hifigan/env.py b/egs/ljspeech/TTS/matcha/hifigan/env.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/hifigan/env.py rename to egs/ljspeech/TTS/matcha/hifigan/env.py diff --git a/egs/ljspeech/TTS/matcha_tts/hifigan/meldataset.py b/egs/ljspeech/TTS/matcha/hifigan/meldataset.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/hifigan/meldataset.py rename to egs/ljspeech/TTS/matcha/hifigan/meldataset.py diff --git a/egs/ljspeech/TTS/matcha_tts/hifigan/models.py b/egs/ljspeech/TTS/matcha/hifigan/models.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/hifigan/models.py rename to egs/ljspeech/TTS/matcha/hifigan/models.py diff --git a/egs/ljspeech/TTS/matcha_tts/hifigan/xutils.py b/egs/ljspeech/TTS/matcha/hifigan/xutils.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/hifigan/xutils.py rename to egs/ljspeech/TTS/matcha/hifigan/xutils.py diff --git a/egs/ljspeech/TTS/matcha_tts/models/__init__.py b/egs/ljspeech/TTS/matcha/models/__init__.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/models/__init__.py rename to egs/ljspeech/TTS/matcha/models/__init__.py diff --git a/egs/ljspeech/TTS/matcha_tts/models/baselightningmodule.py b/egs/ljspeech/TTS/matcha/models/baselightningmodule.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/models/baselightningmodule.py rename to egs/ljspeech/TTS/matcha/models/baselightningmodule.py diff --git a/egs/ljspeech/TTS/matcha_tts/models/components/__init__.py b/egs/ljspeech/TTS/matcha/models/components/__init__.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/models/components/__init__.py rename to egs/ljspeech/TTS/matcha/models/components/__init__.py diff --git a/egs/ljspeech/TTS/matcha_tts/models/components/decoder.py b/egs/ljspeech/TTS/matcha/models/components/decoder.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/models/components/decoder.py rename to egs/ljspeech/TTS/matcha/models/components/decoder.py diff --git a/egs/ljspeech/TTS/matcha_tts/models/components/flow_matching.py b/egs/ljspeech/TTS/matcha/models/components/flow_matching.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/models/components/flow_matching.py rename to egs/ljspeech/TTS/matcha/models/components/flow_matching.py diff --git a/egs/ljspeech/TTS/matcha_tts/models/components/text_encoder.py b/egs/ljspeech/TTS/matcha/models/components/text_encoder.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/models/components/text_encoder.py rename to egs/ljspeech/TTS/matcha/models/components/text_encoder.py diff --git a/egs/ljspeech/TTS/matcha_tts/models/components/transformer.py b/egs/ljspeech/TTS/matcha/models/components/transformer.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/models/components/transformer.py rename to egs/ljspeech/TTS/matcha/models/components/transformer.py diff --git a/egs/ljspeech/TTS/matcha_tts/models/matcha_tts.py b/egs/ljspeech/TTS/matcha/models/matcha_tts.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/models/matcha_tts.py rename to egs/ljspeech/TTS/matcha/models/matcha_tts.py diff --git a/egs/ljspeech/TTS/matcha_tts/text/__init__.py b/egs/ljspeech/TTS/matcha/text/__init__.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/text/__init__.py rename to egs/ljspeech/TTS/matcha/text/__init__.py diff --git a/egs/ljspeech/TTS/matcha_tts/text/cleaners.py b/egs/ljspeech/TTS/matcha/text/cleaners.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/text/cleaners.py rename to egs/ljspeech/TTS/matcha/text/cleaners.py diff --git a/egs/ljspeech/TTS/matcha_tts/text/numbers.py b/egs/ljspeech/TTS/matcha/text/numbers.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/text/numbers.py rename to egs/ljspeech/TTS/matcha/text/numbers.py diff --git a/egs/ljspeech/TTS/matcha_tts/text/symbols.py b/egs/ljspeech/TTS/matcha/text/symbols.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/text/symbols.py rename to egs/ljspeech/TTS/matcha/text/symbols.py diff --git a/egs/ljspeech/TTS/matcha_tts/train-orig.py b/egs/ljspeech/TTS/matcha/train-orig.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/train-orig.py rename to egs/ljspeech/TTS/matcha/train-orig.py diff --git a/egs/ljspeech/TTS/matcha_tts/utils/__init__.py b/egs/ljspeech/TTS/matcha/utils/__init__.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/utils/__init__.py rename to egs/ljspeech/TTS/matcha/utils/__init__.py diff --git a/egs/ljspeech/TTS/matcha_tts/utils/audio.py b/egs/ljspeech/TTS/matcha/utils/audio.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/utils/audio.py rename to egs/ljspeech/TTS/matcha/utils/audio.py diff --git a/egs/ljspeech/TTS/matcha_tts/utils/generate_data_statistics.py b/egs/ljspeech/TTS/matcha/utils/generate_data_statistics.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/utils/generate_data_statistics.py rename to egs/ljspeech/TTS/matcha/utils/generate_data_statistics.py diff --git a/egs/ljspeech/TTS/matcha_tts/utils/get_durations_from_trained_model.py b/egs/ljspeech/TTS/matcha/utils/get_durations_from_trained_model.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/utils/get_durations_from_trained_model.py rename to egs/ljspeech/TTS/matcha/utils/get_durations_from_trained_model.py diff --git a/egs/ljspeech/TTS/matcha_tts/utils/instantiators.py b/egs/ljspeech/TTS/matcha/utils/instantiators.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/utils/instantiators.py rename to egs/ljspeech/TTS/matcha/utils/instantiators.py diff --git a/egs/ljspeech/TTS/matcha_tts/utils/logging_utils.py b/egs/ljspeech/TTS/matcha/utils/logging_utils.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/utils/logging_utils.py rename to egs/ljspeech/TTS/matcha/utils/logging_utils.py diff --git a/egs/ljspeech/TTS/matcha_tts/utils/model.py b/egs/ljspeech/TTS/matcha/utils/model.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/utils/model.py rename to egs/ljspeech/TTS/matcha/utils/model.py diff --git a/egs/ljspeech/TTS/matcha_tts/utils/monotonic_align/__init__.py b/egs/ljspeech/TTS/matcha/utils/monotonic_align/__init__.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/utils/monotonic_align/__init__.py rename to egs/ljspeech/TTS/matcha/utils/monotonic_align/__init__.py diff --git a/egs/ljspeech/TTS/matcha_tts/utils/monotonic_align/core.pyx b/egs/ljspeech/TTS/matcha/utils/monotonic_align/core.pyx similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/utils/monotonic_align/core.pyx rename to egs/ljspeech/TTS/matcha/utils/monotonic_align/core.pyx diff --git a/egs/ljspeech/TTS/matcha_tts/utils/monotonic_align/setup.py b/egs/ljspeech/TTS/matcha/utils/monotonic_align/setup.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/utils/monotonic_align/setup.py rename to egs/ljspeech/TTS/matcha/utils/monotonic_align/setup.py diff --git a/egs/ljspeech/TTS/matcha_tts/utils/pylogger.py b/egs/ljspeech/TTS/matcha/utils/pylogger.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/utils/pylogger.py rename to egs/ljspeech/TTS/matcha/utils/pylogger.py diff --git a/egs/ljspeech/TTS/matcha_tts/utils/rich_utils.py b/egs/ljspeech/TTS/matcha/utils/rich_utils.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/utils/rich_utils.py rename to egs/ljspeech/TTS/matcha/utils/rich_utils.py diff --git a/egs/ljspeech/TTS/matcha_tts/utils/utils.py b/egs/ljspeech/TTS/matcha/utils/utils.py similarity index 100% rename from egs/ljspeech/TTS/matcha_tts/utils/utils.py rename to egs/ljspeech/TTS/matcha/utils/utils.py From 6fac3a3143197c116b16d60d258725da469d3bfa Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 15 Oct 2024 17:57:10 +0800 Subject: [PATCH 04/27] create model from parameters --- .../matcha/models/components/flow_matching.py | 4 +- .../matcha/models/components/text_encoder.py | 4 +- egs/ljspeech/TTS/matcha/models/matcha_tts.py | 14 +-- egs/ljspeech/TTS/matcha/train.py | 98 +++++++++++++++++++ egs/ljspeech/TTS/matcha/utils/__init__.py | 10 +- .../matcha/utils/monotonic_align/.gitignore | 3 + 6 files changed, 117 insertions(+), 16 deletions(-) create mode 100755 egs/ljspeech/TTS/matcha/train.py create mode 100644 egs/ljspeech/TTS/matcha/utils/monotonic_align/.gitignore diff --git a/egs/ljspeech/TTS/matcha/models/components/flow_matching.py b/egs/ljspeech/TTS/matcha/models/components/flow_matching.py index 5cad7431ef..552c4b3834 100644 --- a/egs/ljspeech/TTS/matcha/models/components/flow_matching.py +++ b/egs/ljspeech/TTS/matcha/models/components/flow_matching.py @@ -4,9 +4,9 @@ import torch.nn.functional as F from matcha.models.components.decoder import Decoder -from matcha.utils.pylogger import get_pylogger +# from matcha.utils.pylogger import get_pylogger -log = get_pylogger(__name__) +# log = get_pylogger(__name__) class BASECFM(torch.nn.Module, ABC): diff --git a/egs/ljspeech/TTS/matcha/models/components/text_encoder.py b/egs/ljspeech/TTS/matcha/models/components/text_encoder.py index a388d05d63..efd2253562 100644 --- a/egs/ljspeech/TTS/matcha/models/components/text_encoder.py +++ b/egs/ljspeech/TTS/matcha/models/components/text_encoder.py @@ -6,10 +6,10 @@ import torch.nn as nn from einops import rearrange -import matcha.utils as utils +# import matcha.utils as utils from matcha.utils.model import sequence_mask -log = utils.get_pylogger(__name__) +# log = utils.get_pylogger(__name__) class LayerNorm(nn.Module): diff --git a/egs/ljspeech/TTS/matcha/models/matcha_tts.py b/egs/ljspeech/TTS/matcha/models/matcha_tts.py index 07f95ad2e3..d4b1c57ab6 100644 --- a/egs/ljspeech/TTS/matcha/models/matcha_tts.py +++ b/egs/ljspeech/TTS/matcha/models/matcha_tts.py @@ -5,8 +5,8 @@ import torch import matcha.utils.monotonic_align as monotonic_align -from matcha import utils -from matcha.models.baselightningmodule import BaseLightningClass +# from matcha import utils +# from matcha.models.baselightningmodule import BaseLightningClass from matcha.models.components.flow_matching import CFM from matcha.models.components.text_encoder import TextEncoder from matcha.utils.model import ( @@ -17,10 +17,10 @@ sequence_mask, ) -log = utils.get_pylogger(__name__) +# log = utils.get_pylogger(__name__) -class MatchaTTS(BaseLightningClass): # 🍵 +class MatchaTTS(torch.nn.Module): # 🍵 def __init__( self, n_vocab, @@ -30,7 +30,7 @@ def __init__( encoder, decoder, cfm, - data_statistics, + # data_statistics, out_size, optimizer=None, scheduler=None, @@ -39,7 +39,7 @@ def __init__( ): super().__init__() - self.save_hyperparameters(logger=False) + # self.save_hyperparameters(logger=False) self.n_vocab = n_vocab self.n_spks = n_spks @@ -70,7 +70,7 @@ def __init__( spk_emb_dim=spk_emb_dim, ) - self.update_data_statistics(data_statistics) + # self.update_data_statistics(data_statistics) @torch.inference_mode() def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0): diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py new file mode 100755 index 0000000000..1c5084204e --- /dev/null +++ b/egs/ljspeech/TTS/matcha/train.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + + +import torch + + +from icefall.utils import AttributeDict +from matcha.models.matcha_tts import MatchaTTS + + +def get_model(params): + m = MatchaTTS(**params.model) + return m + + +def main(): + n_feats = 80 + filter_channels_dp = 256 + encoder_params_p_dropout = 0.1 + params = AttributeDict( + { + "model": AttributeDict( + { + "n_vocab": 178, + "n_spks": 1, # for ljspeech. + "spk_emb_dim": 64, + "n_feats": n_feats, + "out_size": None, # or use 172 + "prior_loss": True, + "use_precomputed_durations": False, + "encoder": AttributeDict( + { + "encoder_type": "RoPE Encoder", # not used + "encoder_params": AttributeDict( + { + "n_feats": n_feats, + "n_channels": 192, + "filter_channels": 768, + "filter_channels_dp": filter_channels_dp, + "n_heads": 2, + "n_layers": 6, + "kernel_size": 3, + "p_dropout": encoder_params_p_dropout, + "spk_emb_dim": 64, + "n_spks": 1, + "prenet": True, + } + ), + "duration_predictor_params": AttributeDict( + { + "filter_channels_dp": filter_channels_dp, + "kernel_size": 3, + "p_dropout": encoder_params_p_dropout, + } + ), + } + ), + "decoder": AttributeDict( + { + "channels": [256, 256], + "dropout": 0.05, + "attention_head_dim": 64, + "n_blocks": 1, + "num_mid_blocks": 2, + "num_heads": 2, + "act_fn": "snakebeta", + } + ), + "cfm": AttributeDict( + { + "name": "CFM", + "solver": "euler", + "sigma_min": 1e-4, + } + ), + "optimizer": AttributeDict( + { + "lr": 1e-4, + "weight_decay": 0.0, + } + ), + } + ) + } + ) + m = get_model(params) + print(m) + + num_param = sum([p.numel() for p in m.parameters()]) + print(f"Number of parameters: {num_param}") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/ljspeech/TTS/matcha/utils/__init__.py b/egs/ljspeech/TTS/matcha/utils/__init__.py index 074db64611..2b74b40f50 100644 --- a/egs/ljspeech/TTS/matcha/utils/__init__.py +++ b/egs/ljspeech/TTS/matcha/utils/__init__.py @@ -1,5 +1,5 @@ -from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers -from matcha.utils.logging_utils import log_hyperparameters -from matcha.utils.pylogger import get_pylogger -from matcha.utils.rich_utils import enforce_tags, print_config_tree -from matcha.utils.utils import extras, get_metric_value, task_wrapper +# from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers +# from matcha.utils.logging_utils import log_hyperparameters +# from matcha.utils.pylogger import get_pylogger +# from matcha.utils.rich_utils import enforce_tags, print_config_tree +# from matcha.utils.utils import extras, get_metric_value, task_wrapper diff --git a/egs/ljspeech/TTS/matcha/utils/monotonic_align/.gitignore b/egs/ljspeech/TTS/matcha/utils/monotonic_align/.gitignore new file mode 100644 index 0000000000..28bdad6b84 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/utils/monotonic_align/.gitignore @@ -0,0 +1,3 @@ +build +core.c +*.so From ccd2dcc9f9919567839af714762ca5458815ba82 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 15 Oct 2024 22:48:35 +0800 Subject: [PATCH 05/27] add dataset --- egs/ljspeech/TTS/matcha/train.py | 177 +++++++++++++++++-------- egs/ljspeech/TTS/matcha/utils/utils.py | 12 +- 2 files changed, 125 insertions(+), 64 deletions(-) diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index 1c5084204e..f41ee4eae1 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -7,83 +7,144 @@ from icefall.utils import AttributeDict from matcha.models.matcha_tts import MatchaTTS +from matcha.data.text_mel_datamodule import TextMelDataModule -def get_model(params): - m = MatchaTTS(**params.model) - return m +def _get_data_params() -> AttributeDict: + params = AttributeDict( + { + "name": "ljspeech", + "train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt", + "valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt", + "batch_size": 32, + "num_workers": 3, + "pin_memory": False, + "cleaners": ["english_cleaners2"], + "add_blank": True, + "n_spks": 1, + "n_fft": 1024, + "n_feats": 80, + "sample_rate": 22050, + "hop_length": 256, + "win_length": 1024, + "f_min": 0, + "f_max": 8000, + "seed": 1234, + "load_durations": False, + "data_statistics": AttributeDict( + { + "mel_mean": -5.517028331756592, + "mel_std": 2.0643954277038574, + } + ), + } + ) + return params -def main(): +def _get_model_params() -> AttributeDict: n_feats = 80 filter_channels_dp = 256 encoder_params_p_dropout = 0.1 params = AttributeDict( { - "model": AttributeDict( + "n_vocab": 178, + "n_spks": 1, # for ljspeech. + "spk_emb_dim": 64, + "n_feats": n_feats, + "out_size": None, # or use 172 + "prior_loss": True, + "use_precomputed_durations": False, + "encoder": AttributeDict( { - "n_vocab": 178, - "n_spks": 1, # for ljspeech. - "spk_emb_dim": 64, - "n_feats": n_feats, - "out_size": None, # or use 172 - "prior_loss": True, - "use_precomputed_durations": False, - "encoder": AttributeDict( + "encoder_type": "RoPE Encoder", # not used + "encoder_params": AttributeDict( { - "encoder_type": "RoPE Encoder", # not used - "encoder_params": AttributeDict( - { - "n_feats": n_feats, - "n_channels": 192, - "filter_channels": 768, - "filter_channels_dp": filter_channels_dp, - "n_heads": 2, - "n_layers": 6, - "kernel_size": 3, - "p_dropout": encoder_params_p_dropout, - "spk_emb_dim": 64, - "n_spks": 1, - "prenet": True, - } - ), - "duration_predictor_params": AttributeDict( - { - "filter_channels_dp": filter_channels_dp, - "kernel_size": 3, - "p_dropout": encoder_params_p_dropout, - } - ), + "n_feats": n_feats, + "n_channels": 192, + "filter_channels": 768, + "filter_channels_dp": filter_channels_dp, + "n_heads": 2, + "n_layers": 6, + "kernel_size": 3, + "p_dropout": encoder_params_p_dropout, + "spk_emb_dim": 64, + "n_spks": 1, + "prenet": True, } ), - "decoder": AttributeDict( + "duration_predictor_params": AttributeDict( { - "channels": [256, 256], - "dropout": 0.05, - "attention_head_dim": 64, - "n_blocks": 1, - "num_mid_blocks": 2, - "num_heads": 2, - "act_fn": "snakebeta", - } - ), - "cfm": AttributeDict( - { - "name": "CFM", - "solver": "euler", - "sigma_min": 1e-4, - } - ), - "optimizer": AttributeDict( - { - "lr": 1e-4, - "weight_decay": 0.0, + "filter_channels_dp": filter_channels_dp, + "kernel_size": 3, + "p_dropout": encoder_params_p_dropout, } ), } - ) + ), + "decoder": AttributeDict( + { + "channels": [256, 256], + "dropout": 0.05, + "attention_head_dim": 64, + "n_blocks": 1, + "num_mid_blocks": 2, + "num_heads": 2, + "act_fn": "snakebeta", + } + ), + "cfm": AttributeDict( + { + "name": "CFM", + "solver": "euler", + "sigma_min": 1e-4, + } + ), + "optimizer": AttributeDict( + { + "lr": 1e-4, + "weight_decay": 0.0, + } + ), + } + ) + + return params + + +def get_params(): + params = AttributeDict( + { + "model": _get_model_params(), + "data": _get_data_params(), } ) + return params + + +def get_model(params): + m = MatchaTTS(**params.model) + return m + + +def main(): + params = get_params() + + data_module = TextMelDataModule(hparams=params.data) + if False: + for b in data_module.train_dataloader(): + assert isinstance(b, dict) + # b.keys() + # ['x', 'x_lengths', 'y', 'y_lengths', 'spks', 'filepaths', 'x_texts', 'durations'] + # x: [batch_size, 289], torch.int64 + # x_lengths: [batch_size], torch.int64 + # y: [batch_size, n_feats, num_frames], torch.float32 + # y_lengths: [batch_size], torch.int64 + # spks: None + # filepaths: list, (batch_size,) + # x_texts: list, (batch_size,) + # durations: None + m = get_model(params) print(m) diff --git a/egs/ljspeech/TTS/matcha/utils/utils.py b/egs/ljspeech/TTS/matcha/utils/utils.py index fc3a48ec2b..bc81c316ea 100644 --- a/egs/ljspeech/TTS/matcha/utils/utils.py +++ b/egs/ljspeech/TTS/matcha/utils/utils.py @@ -6,19 +6,17 @@ from pathlib import Path from typing import Any, Callable, Dict, Tuple -import gdown import matplotlib.pyplot as plt import numpy as np import torch -import wget -from omegaconf import DictConfig +# from omegaconf import DictConfig -from matcha.utils import pylogger, rich_utils +# from matcha.utils import pylogger, rich_utils -log = pylogger.get_pylogger(__name__) +# log = pylogger.get_pylogger(__name__) -def extras(cfg: DictConfig) -> None: +def extras(cfg: 'DictConfig') -> None: """Applies optional utilities before the task is started. Utilities: @@ -207,6 +205,8 @@ def get_user_data_dir(appname="matcha_tts"): def assert_model_downloaded(checkpoint_path, url, use_wget=True): + import gdown + import wget if Path(checkpoint_path).exists(): log.debug(f"[+] Model already present at {checkpoint_path}!") print(f"[+] Model already present at {checkpoint_path}!") From 56d3b92f3f4f8ada92f4f708ef990ff9b107a0bb Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 16 Oct 2024 19:35:35 +0800 Subject: [PATCH 06/27] First working version. --- egs/ljspeech/TTS/matcha/inference.py | 178 +++++++ egs/ljspeech/TTS/matcha/models/matcha_tts.py | 74 ++- egs/ljspeech/TTS/matcha/test-train.py | 159 ++++++ egs/ljspeech/TTS/matcha/train.py | 484 +++++++++++++++++-- egs/ljspeech/TTS/matcha/utils2.py | 1 + 5 files changed, 854 insertions(+), 42 deletions(-) create mode 100755 egs/ljspeech/TTS/matcha/inference.py create mode 100644 egs/ljspeech/TTS/matcha/test-train.py create mode 120000 egs/ljspeech/TTS/matcha/utils2.py diff --git a/egs/ljspeech/TTS/matcha/inference.py b/egs/ljspeech/TTS/matcha/inference.py new file mode 100755 index 0000000000..29a0f53a83 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/inference.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 + +import argparse +import datetime as dt +import logging +from pathlib import Path + +import numpy as np +import soundfile as sf +import torch +from matcha.hifigan.config import v1 +from matcha.hifigan.denoiser import Denoiser +from matcha.hifigan.models import Generator as HiFiGAN +from matcha.text import sequence_to_text, text_to_sequence +from matcha.utils.utils import intersperse +from tqdm.auto import tqdm +from train import get_model, get_params + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=140, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="matcha/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + return parser + + +def load_vocoder(checkpoint_path): + h = AttributeDict(v1) + hifigan = HiFiGAN(h).to("cpu") + hifigan.load_state_dict( + torch.load(checkpoint_path, map_location="cpu")["generator"] + ) + _ = hifigan.eval() + hifigan.remove_weight_norm() + return hifigan + + +def to_waveform(mel, vocoder, denoiser): + audio = vocoder(mel).clamp(-1, 1) + audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze() + return audio.cpu().squeeze() + + +def save_to_folder(filename: str, output: dict, folder: str): + folder = Path(folder) + folder.mkdir(exist_ok=True, parents=True) + np.save(folder / f"{filename}", output["mel"].cpu().numpy()) + sf.write(folder / f"{filename}.wav", output["waveform"], 22050, "PCM_24") + + +def process_text(text: str): + x = torch.tensor( + intersperse(text_to_sequence(text, ["english_cleaners2"])[0], 0), + dtype=torch.long, + device="cpu", + )[None] + x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu") + x_phones = sequence_to_text(x.squeeze(0).tolist()) + return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones} + + +def synthesise(model, n_timesteps, text, length_scale, temperature, spks=None): + text_processed = process_text(text) + start_t = dt.datetime.now() + output = model.synthesise( + text_processed["x"], + text_processed["x_lengths"], + n_timesteps=n_timesteps, + temperature=temperature, + spks=spks, + length_scale=length_scale, + ) + print("output.shape", list(output.keys()), output["mel"].shape) + # merge everything to one dict + output.update({"start_t": start_t, **text_processed}) + return output + + +@torch.inference_mode() +def main(): + parser = get_parser() + args = parser.parse_args() + params = get_params() + + params.update(vars(args)) + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + model.eval() + + vocoder = load_vocoder("/star-fj/fangjun/open-source/Matcha-TTS/generator_v1") + denoiser = Denoiser(vocoder, mode="zeros") + + texts = [ + "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", + "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.", + ] + + # Number of ODE Solver steps + n_timesteps = 2 + + # Changes to the speaking rate + length_scale = 1.0 + + # Sampling temperature + temperature = 0.667 + + outputs, rtfs = [], [] + rtfs_w = [] + for i, text in enumerate(tqdm(texts)): + output = synthesise( + model=model, + n_timesteps=n_timesteps, + text=text, + length_scale=length_scale, + temperature=temperature, + ) # , torch.tensor([15], device=device, dtype=torch.long).unsqueeze(0)) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + + # Compute Real Time Factor (RTF) with HiFi-GAN + t = (dt.datetime.now() - output["start_t"]).total_seconds() + rtf_w = t * 22050 / (output["waveform"].shape[-1]) + + # Pretty print + print(f"{'*' * 53}") + print(f"Input text - {i}") + print(f"{'-' * 53}") + print(output["x_orig"]) + print(f"{'*' * 53}") + print(f"Phonetised text - {i}") + print(f"{'-' * 53}") + print(output["x_phones"]) + print(f"{'*' * 53}") + print(f"RTF:\t\t{output['rtf']:.6f}") + print(f"RTF Waveform:\t{rtf_w:.6f}") + rtfs.append(output["rtf"]) + rtfs_w.append(rtf_w) + + # Save the generated waveform + save_to_folder(i, output, folder="./my-output") + + print(f"Number of ODE steps: {n_timesteps}") + print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}") + print( + f"Mean RTF Waveform (incl. vocoder):\t{np.mean(rtfs_w):.6f} ± {np.std(rtfs_w):.6f}" + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/ljspeech/TTS/matcha/models/matcha_tts.py b/egs/ljspeech/TTS/matcha/models/matcha_tts.py index d4b1c57ab6..d5d78c6196 100644 --- a/egs/ljspeech/TTS/matcha/models/matcha_tts.py +++ b/egs/ljspeech/TTS/matcha/models/matcha_tts.py @@ -5,6 +5,7 @@ import torch import matcha.utils.monotonic_align as monotonic_align + # from matcha import utils # from matcha.models.baselightningmodule import BaseLightningClass from matcha.models.components.flow_matching import CFM @@ -30,7 +31,7 @@ def __init__( encoder, decoder, cfm, - # data_statistics, + data_statistics, out_size, optimizer=None, scheduler=None, @@ -71,9 +72,13 @@ def __init__( ) # self.update_data_statistics(data_statistics) + self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"])) + self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"])) @torch.inference_mode() - def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0): + def synthesise( + self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0 + ): """ Generates mel-spectrogram from text. Returns: 1. encoder outputs @@ -149,7 +154,17 @@ def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, leng "rtf": rtf, } - def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None, durations=None): + def forward( + self, + x, + x_lengths, + y, + y_lengths, + spks=None, + out_size=None, + cond=None, + durations=None, + ): """ Computes 3 losses: 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS). @@ -187,7 +202,9 @@ def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=Non # Use MAS to find most likely alignment `attn` between text and mel-spectrogram with torch.no_grad(): const = -0.5 * math.log(2 * math.pi) * self.n_feats - factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device) + factor = -0.5 * torch.ones( + mu_x.shape, dtype=mu_x.dtype, device=mu_x.device + ) y_square = torch.matmul(factor.transpose(1, 2), y**2) y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y) mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1) @@ -206,12 +223,25 @@ def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=Non # - Do not need this hack for Matcha-TTS, but it works with it as well if not isinstance(out_size, type(None)): max_offset = (y_lengths - out_size).clamp(0) - offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy())) + offset_ranges = list( + zip([0] * max_offset.shape[0], max_offset.cpu().numpy()) + ) out_offset = torch.LongTensor( - [torch.tensor(random.choice(range(start, end)) if end > start else 0) for start, end in offset_ranges] + [ + torch.tensor(random.choice(range(start, end)) if end > start else 0) + for start, end in offset_ranges + ] ).to(y_lengths) - attn_cut = torch.zeros(attn.shape[0], attn.shape[1], out_size, dtype=attn.dtype, device=attn.device) - y_cut = torch.zeros(y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device) + attn_cut = torch.zeros( + attn.shape[0], + attn.shape[1], + out_size, + dtype=attn.dtype, + device=attn.device, + ) + y_cut = torch.zeros( + y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device + ) y_cut_lengths = [] for i, (y_, out_offset_) in enumerate(zip(y, out_offset)): @@ -233,12 +263,36 @@ def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=Non mu_y = mu_y.transpose(1, 2) # Compute loss of the decoder - diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond) + diff_loss, _ = self.decoder.compute_loss( + x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond + ) if self.prior_loss: - prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) + prior_loss = torch.sum( + 0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask + ) prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats) else: prior_loss = 0 return dur_loss, prior_loss, diff_loss, attn + + def get_losses(self, batch): + x, x_lengths = batch["x"], batch["x_lengths"] + y, y_lengths = batch["y"], batch["y_lengths"] + spks = batch["spks"] + + dur_loss, prior_loss, diff_loss, *_ = self( + x=x, + x_lengths=x_lengths, + y=y, + y_lengths=y_lengths, + spks=spks, + out_size=self.out_size, + durations=batch["durations"], + ) + return { + "dur_loss": dur_loss, + "prior_loss": prior_loss, + "diff_loss": diff_loss, + } diff --git a/egs/ljspeech/TTS/matcha/test-train.py b/egs/ljspeech/TTS/matcha/test-train.py new file mode 100644 index 0000000000..f41ee4eae1 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/test-train.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + + +import torch + + +from icefall.utils import AttributeDict +from matcha.models.matcha_tts import MatchaTTS +from matcha.data.text_mel_datamodule import TextMelDataModule + + +def _get_data_params() -> AttributeDict: + params = AttributeDict( + { + "name": "ljspeech", + "train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt", + "valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt", + "batch_size": 32, + "num_workers": 3, + "pin_memory": False, + "cleaners": ["english_cleaners2"], + "add_blank": True, + "n_spks": 1, + "n_fft": 1024, + "n_feats": 80, + "sample_rate": 22050, + "hop_length": 256, + "win_length": 1024, + "f_min": 0, + "f_max": 8000, + "seed": 1234, + "load_durations": False, + "data_statistics": AttributeDict( + { + "mel_mean": -5.517028331756592, + "mel_std": 2.0643954277038574, + } + ), + } + ) + return params + + +def _get_model_params() -> AttributeDict: + n_feats = 80 + filter_channels_dp = 256 + encoder_params_p_dropout = 0.1 + params = AttributeDict( + { + "n_vocab": 178, + "n_spks": 1, # for ljspeech. + "spk_emb_dim": 64, + "n_feats": n_feats, + "out_size": None, # or use 172 + "prior_loss": True, + "use_precomputed_durations": False, + "encoder": AttributeDict( + { + "encoder_type": "RoPE Encoder", # not used + "encoder_params": AttributeDict( + { + "n_feats": n_feats, + "n_channels": 192, + "filter_channels": 768, + "filter_channels_dp": filter_channels_dp, + "n_heads": 2, + "n_layers": 6, + "kernel_size": 3, + "p_dropout": encoder_params_p_dropout, + "spk_emb_dim": 64, + "n_spks": 1, + "prenet": True, + } + ), + "duration_predictor_params": AttributeDict( + { + "filter_channels_dp": filter_channels_dp, + "kernel_size": 3, + "p_dropout": encoder_params_p_dropout, + } + ), + } + ), + "decoder": AttributeDict( + { + "channels": [256, 256], + "dropout": 0.05, + "attention_head_dim": 64, + "n_blocks": 1, + "num_mid_blocks": 2, + "num_heads": 2, + "act_fn": "snakebeta", + } + ), + "cfm": AttributeDict( + { + "name": "CFM", + "solver": "euler", + "sigma_min": 1e-4, + } + ), + "optimizer": AttributeDict( + { + "lr": 1e-4, + "weight_decay": 0.0, + } + ), + } + ) + + return params + + +def get_params(): + params = AttributeDict( + { + "model": _get_model_params(), + "data": _get_data_params(), + } + ) + return params + + +def get_model(params): + m = MatchaTTS(**params.model) + return m + + +def main(): + params = get_params() + + data_module = TextMelDataModule(hparams=params.data) + if False: + for b in data_module.train_dataloader(): + assert isinstance(b, dict) + # b.keys() + # ['x', 'x_lengths', 'y', 'y_lengths', 'spks', 'filepaths', 'x_texts', 'durations'] + # x: [batch_size, 289], torch.int64 + # x_lengths: [batch_size], torch.int64 + # y: [batch_size, n_feats, num_frames], torch.float32 + # y_lengths: [batch_size], torch.int64 + # spks: None + # filepaths: list, (batch_size,) + # x_texts: list, (batch_size,) + # durations: None + + m = get_model(params) + print(m) + + num_param = sum([p.numel() for p in m.parameters()]) + print(f"Number of parameters: {num_param}") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index f41ee4eae1..385dcba23e 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -2,12 +2,111 @@ # Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Union + import torch +import torch.nn as nn +from lhotse.utils import fix_random_seed +from matcha.data.text_mel_datamodule import TextMelDataModule +from icefall.env import get_env_info +from matcha.models.matcha_tts import MatchaTTS +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter +from utils2 import MetricsTracker, plot_feature +from icefall.checkpoint import load_checkpoint, save_checkpoint +from icefall.dist import cleanup_dist, setup_dist +from icefall.utils import AttributeDict, setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=1000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="matcha/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=10, + help="""Save checkpoint after processing this number of epochs" + periodically. We save checkpoint to exp-dir/ whenever + params.cur_epoch % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. + Since it will take around 1000 epochs, we suggest using a large + save_every_n to save disk space. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--batch-size", + type=int, + default=32, + ) + + return parser -from icefall.utils import AttributeDict -from matcha.models.matcha_tts import MatchaTTS -from matcha.data.text_mel_datamodule import TextMelDataModule + +def get_data_statistics(): + return AttributeDict( + { + "mel_mean": -5.517028331756592, + "mel_std": 2.0643954277038574, + } + ) def _get_data_params() -> AttributeDict: @@ -16,7 +115,6 @@ def _get_data_params() -> AttributeDict: "name": "ljspeech", "train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt", "valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt", - "batch_size": 32, "num_workers": 3, "pin_memory": False, "cleaners": ["english_cleaners2"], @@ -31,12 +129,7 @@ def _get_data_params() -> AttributeDict: "f_max": 8000, "seed": 1234, "load_durations": False, - "data_statistics": AttributeDict( - { - "mel_mean": -5.517028331756592, - "mel_std": 2.0643954277038574, - } - ), + "data_statistics": get_data_statistics(), } ) return params @@ -55,6 +148,7 @@ def _get_model_params() -> AttributeDict: "out_size": None, # or use 172 "prior_loss": True, "use_precomputed_durations": False, + "data_statistics": get_data_statistics(), "encoder": AttributeDict( { "encoder_type": "RoPE Encoder", # not used @@ -115,42 +209,368 @@ def _get_model_params() -> AttributeDict: def get_params(): params = AttributeDict( { - "model": _get_model_params(), - "data": _get_data_params(), + "model_args": _get_model_params(), + "data_args": _get_data_params(), + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": -1, # 0 + "log_interval": 50, + "valid_interval": 2000, + "env_info": get_env_info(), } ) return params def get_model(params): - m = MatchaTTS(**params.model) + m = MatchaTTS(**params.model_args) return m +def load_checkpoint_if_available( + params: AttributeDict, model: nn.Module +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint(filename, model=model) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, + rank: int = 0, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.to(device) + losses = model.get_losses(batch) + loss = sum(losses.values()) + + batch_size = batch["x"].shape[0] + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + s = 0 + + for key, value in losses.items(): + v = value.detach().item() + loss_info[key] = v * batch_size + s += v * batch_size + + loss_info["tot_loss"] = s + + # summary stats + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["tot_loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: Optimizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to track the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + # used to track the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer=optimizer, + scaler=scaler, + rank=rank, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.to(device) + + batch_size = batch["x"].shape[0] + + try: + with autocast(enabled=params.use_fp16): + losses = model.get_losses(batch) + + loss = sum(losses.values()) + + optimizer.zero_grad() + scaler.scale(loss).backward() + scaler.step(optimizer) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + s = 0 + + for key, value in losses.items(): + v = value.detach().item() + loss_info[key] = v * batch_size + s += v * batch_size + + loss_info["tot_loss"] = s + + tot_loss = tot_loss + loss_info + except: # noqa + save_bad_model() + raise + + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if params.batch_idx_train % params.valid_interval == 1: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + rank=rank, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["tot_loss"] / tot_loss["samples"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + def main(): + parser = get_parser() + args = parser.parse_args() params = get_params() - data_module = TextMelDataModule(hparams=params.data) - if False: - for b in data_module.train_dataloader(): - assert isinstance(b, dict) - # b.keys() - # ['x', 'x_lengths', 'y', 'y_lengths', 'spks', 'filepaths', 'x_texts', 'durations'] - # x: [batch_size, 289], torch.int64 - # x_lengths: [batch_size], torch.int64 - # y: [batch_size, n_feats, num_frames], torch.float32 - # y_lengths: [batch_size], torch.int64 - # spks: None - # filepaths: list, (batch_size,) - # x_texts: list, (batch_size,) - # durations: None - - m = get_model(params) - print(m) - - num_param = sum([p.numel() for p in m.parameters()]) + params.update(vars(args)) + + params.data_args.batch_size = params.batch_size + del params.batch_size + + fix_random_seed(params.seed) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + logging.info(f"Device: {device}") + print(f"Device: {device}") + print(f"Device: {device}") + + logging.info(params) + print(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of parameters: {num_param}") print(f"Number of parameters: {num_param}") + logging.info("About to create datamodule") + data_module = TextMelDataModule(hparams=params.data_args) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + train_dl = data_module.train_dataloader() + valid_dl = data_module.val_dataloader() + + rank = 0 + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + fix_random_seed(params.seed + epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + ) + + if epoch % params.save_every_n == 0 or epoch == params.num_epochs: + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer=optimizer, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + logging.info("Done!") + torch.set_num_threads(1) torch.set_num_interop_threads(1) diff --git a/egs/ljspeech/TTS/matcha/utils2.py b/egs/ljspeech/TTS/matcha/utils2.py new file mode 120000 index 0000000000..c2144f8e07 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/utils2.py @@ -0,0 +1 @@ +../vits/utils.py \ No newline at end of file From 7077b4f99aef2bbdbf2dd873a837a05157c00f25 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 18 Oct 2024 22:14:14 +0800 Subject: [PATCH 07/27] switch to piper-phonemize --- egs/ljspeech/TTS/.gitignore | 4 + .../TTS/local/compute_fbank_ljspeech.py | 141 ++++++++ .../TTS/local/prepare_tokens_ljspeech.py | 28 +- egs/ljspeech/TTS/matcha/models/matcha_tts.py | 9 +- egs/ljspeech/TTS/matcha/tokenizer.py | 1 + egs/ljspeech/TTS/matcha/train.py | 202 ++++++++--- egs/ljspeech/TTS/matcha/tts_datamodule.py | 341 ++++++++++++++++++ egs/ljspeech/TTS/matcha/utils/__init__.py | 1 + egs/ljspeech/TTS/prepare.sh | 79 +++- 9 files changed, 746 insertions(+), 60 deletions(-) create mode 100644 egs/ljspeech/TTS/.gitignore create mode 100755 egs/ljspeech/TTS/local/compute_fbank_ljspeech.py create mode 120000 egs/ljspeech/TTS/matcha/tokenizer.py create mode 100644 egs/ljspeech/TTS/matcha/tts_datamodule.py diff --git a/egs/ljspeech/TTS/.gitignore b/egs/ljspeech/TTS/.gitignore new file mode 100644 index 0000000000..1eef06a289 --- /dev/null +++ b/egs/ljspeech/TTS/.gitignore @@ -0,0 +1,4 @@ +build +core.c +*.so +my-output* diff --git a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py new file mode 100755 index 0000000000..3aeb6add7d --- /dev/null +++ b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the LJSpeech dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path + +import torch +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + load_manifest, + load_manifest_lazy, +) +from lhotse.audio import RecordingSet +from lhotse.supervision import SupervisionSet + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--num-jobs", + type=int, + default=4, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + return parser + + +def compute_fbank_ljspeech(num_jobs: int): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + + if num_jobs < 1: + num_jobs = os.cpu_count() + + logging.info(f"num_jobs: {num_jobs}") + logging.info(f"src_dir: {src_dir}") + logging.info(f"output_dir: {output_dir}") + + sampling_rate = 22050 + frame_length = 1024 / sampling_rate # (in second) + frame_shift = 256 / sampling_rate # (in second) + + prefix = "ljspeech" + suffix = "jsonl.gz" + partition = "all" + + recordings = load_manifest( + src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet + ) + supervisions = load_manifest( + src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet + ) + + # Differences with matcha-tts + # 1. we use pre-emphasis + # 2. we remove dc offset + # 3. we use a different window + # 4. we use a different mel filter bank matrix + # 5. we don't normalize features + config = FbankConfig( + sampling_rate=sampling_rate, + frame_length=frame_length, + frame_shift=frame_shift, + use_fft_mag=True, + low_freq=0, + high_freq=8000, + # should be identical to n_feats in ../matcha/train.py + num_filters=80, + ) + extractor = Fbank(config) + + with get_executor() as ex: # Initialize the executor only once. + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{cuts_filename} already exists - skipping.") + return + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=recordings, supervisions=supervisions + ) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_parser().parse_args() + compute_fbank_ljspeech(args.num_jobs) diff --git a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py index 4ba88604ce..33a8ac2ab7 100755 --- a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py +++ b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py @@ -28,17 +28,33 @@ except ModuleNotFoundError as ex: raise RuntimeError(f"{ex}\nPlease run\n pip install espnet_tts_frontend\n") +import argparse + from lhotse import CutSet, load_manifest from piper_phonemize import phonemize_espeak -def prepare_tokens_ljspeech(): - output_dir = Path("data/spectrogram") +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--in-out-dir", + type=Path, + required=True, + help="Input and output directory", + ) + + return parser + + +def prepare_tokens_ljspeech(in_out_dir): prefix = "ljspeech" suffix = "jsonl.gz" partition = "all" - cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") + cut_set = load_manifest(in_out_dir / f"{prefix}_cuts_{partition}.{suffix}") new_cuts = [] for cut in cut_set: @@ -56,11 +72,13 @@ def prepare_tokens_ljspeech(): new_cuts.append(cut) new_cut_set = CutSet.from_cuts(new_cuts) - new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}") + new_cut_set.to_file(in_out_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}") if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) - prepare_tokens_ljspeech() + args = get_parser().parse_args() + + prepare_tokens_ljspeech(args.in_out_dir) diff --git a/egs/ljspeech/TTS/matcha/models/matcha_tts.py b/egs/ljspeech/TTS/matcha/models/matcha_tts.py index d5d78c6196..b1525695f2 100644 --- a/egs/ljspeech/TTS/matcha/models/matcha_tts.py +++ b/egs/ljspeech/TTS/matcha/models/matcha_tts.py @@ -71,9 +71,12 @@ def __init__( spk_emb_dim=spk_emb_dim, ) - # self.update_data_statistics(data_statistics) - self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"])) - self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"])) + if data_statistics is not None: + self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"])) + self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"])) + else: + self.register_buffer("mel_mean", torch.tensor(0.0)) + self.register_buffer("mel_std", torch.tensor(1.0)) @torch.inference_mode() def synthesise( diff --git a/egs/ljspeech/TTS/matcha/tokenizer.py b/egs/ljspeech/TTS/matcha/tokenizer.py new file mode 120000 index 0000000000..44a19b0f4a --- /dev/null +++ b/egs/ljspeech/TTS/matcha/tokenizer.py @@ -0,0 +1 @@ +../vits/tokenizer.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index 385dcba23e..94e089d7eb 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -8,20 +8,24 @@ from shutil import copyfile from typing import Any, Dict, Optional, Union +import k2 import torch +import torch.multiprocessing as mp import torch.nn as nn from lhotse.utils import fix_random_seed -from matcha.data.text_mel_datamodule import TextMelDataModule -from icefall.env import get_env_info from matcha.models.matcha_tts import MatchaTTS +from matcha.tokenizer import Tokenizer +from matcha.utils.model import fix_len_compatibility from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter -from utils2 import MetricsTracker, plot_feature +from tts_datamodule import LJSpeechTtsDataModule +from utils2 import MetricsTracker from icefall.checkpoint import load_checkpoint, save_checkpoint from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info from icefall.utils import AttributeDict, setup_logger, str2bool @@ -30,6 +34,20 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter ) + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12335, + help="Master port to use for DDP training.", + ) + parser.add_argument( "--tensorboard", type=str2bool, @@ -64,6 +82,13 @@ def get_parser(): """, ) + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + parser.add_argument( "--seed", type=int, @@ -91,20 +116,14 @@ def get_parser(): help="Whether to use half precision training.", ) - parser.add_argument( - "--batch-size", - type=int, - default=32, - ) - return parser def get_data_statistics(): return AttributeDict( { - "mel_mean": -5.517028331756592, - "mel_std": 2.0643954277038574, + "mel_mean": 0.0, + "mel_std": 1.0, } ) @@ -141,7 +160,6 @@ def _get_model_params() -> AttributeDict: encoder_params_p_dropout = 0.1 params = AttributeDict( { - "n_vocab": 178, "n_spks": 1, # for ljspeech. "spk_emb_dim": 64, "n_feats": n_feats, @@ -216,8 +234,8 @@ def get_params(): "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": -1, # 0 - "log_interval": 50, - "valid_interval": 2000, + "log_interval": 10, + "valid_interval": 1500, "env_info": get_env_info(), } ) @@ -271,9 +289,39 @@ def load_checkpoint_if_available( return saved_params +def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): + """Parse batch data""" + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + tokens = batch["tokens"] + + tokens = tokenizer.tokens_to_token_ids( + tokens, intersperse_blank=True, add_sos=True, add_eos=True + ) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # a tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) + + max_feature_length = fix_len_compatibility(features.shape[1]) + if max_feature_length > features.shape[1]: + pad = max_feature_length - features.shape[1] + features = torch.nn.functional.pad(features, (0, 0, 0, pad)) + + # features_lens[features_lens.argmax()] += pad + + return audio, audio_lens, features, features_lens, tokens, tokens_lens + + def compute_validation_loss( params: AttributeDict, model: Union[nn.Module, DDP], + tokenizer: Tokenizer, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, rank: int = 0, @@ -281,19 +329,35 @@ def compute_validation_loss( """Run the validation process.""" model.eval() device = model.device if isinstance(model, DDP) else next(model.parameters()).device + get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses # used to summary the stats over iterations tot_loss = MetricsTracker() with torch.no_grad(): for batch_idx, batch in enumerate(valid_dl): - for key, value in batch.items(): - if isinstance(value, torch.Tensor): - batch[key] = value.to(device) - losses = model.get_losses(batch) - loss = sum(losses.values()) - batch_size = batch["x"].shape[0] + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + ) = prepare_input(batch, tokenizer, device) + + losses = get_losses( + { + "x": tokens, + "x_lengths": tokens_lens, + "y": features.permute(0, 2, 1), + "y_lengths": features_lens, + "spks": None, # should change it for multi-speakers + "durations": None, + } + ) + + batch_size = len(batch["tokens"]) loss_info = MetricsTracker() loss_info["samples"] = batch_size @@ -324,6 +388,7 @@ def compute_validation_loss( def train_one_epoch( params: AttributeDict, model: Union[nn.Module, DDP], + tokenizer: Tokenizer, optimizer: Optimizer, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, @@ -356,6 +421,7 @@ def train_one_epoch( """ model.train() device = model.device if isinstance(model, DDP) else next(model.parameters()).device + get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses # used to track the stats over iterations in one epoch tot_loss = MetricsTracker() @@ -374,20 +440,35 @@ def save_bad_model(suffix: str = ""): params=params, optimizer=optimizer, scaler=scaler, - rank=rank, + rank=0, ) for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 - for key, value in batch.items(): - if isinstance(value, torch.Tensor): - batch[key] = value.to(device) + # audio: (N, T), float32 + # features: (N, T, C), float32 + # audio_lens, (N,), int32 + # features_lens, (N,), int32 + # tokens: List[List[str]], len(tokens) == N - batch_size = batch["x"].shape[0] + batch_size = len(batch["tokens"]) + + audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( + batch, tokenizer, device + ) try: with autocast(enabled=params.use_fp16): - losses = model.get_losses(batch) + losses = get_losses( + { + "x": tokens, + "x_lengths": tokens_lens, + "y": features.permute(0, 2, 1), + "y_lengths": features_lens, + "spks": None, # should change it for multi-speakers + "durations": None, + } + ) loss = sum(losses.values()) @@ -458,6 +539,7 @@ def save_bad_model(suffix: str = ""): valid_info = compute_validation_loss( params=params, model=model, + tokenizer=tokenizer, valid_dl=valid_dl, world_size=world_size, rank=rank, @@ -479,28 +561,31 @@ def save_bad_model(suffix: str = ""): params.best_train_loss = params.train_loss -def main(): - parser = get_parser() - args = parser.parse_args() +def run(rank, world_size, args): params = get_params() - params.update(vars(args)) - params.data_args.batch_size = params.batch_size - del params.batch_size - fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) setup_logger(f"{params.exp_dir}/log/log-train") logging.info("Training started") - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None device = torch.device("cpu") if torch.cuda.is_available(): - device = torch.device("cuda", 0) + device = torch.device("cuda", rank) logging.info(f"Device: {device}") - print(f"Device: {device}") - print(f"Device: {device}") + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + params.model_args.n_vocab = params.vocab_size logging.info(params) print(params) @@ -512,28 +597,35 @@ def main(): logging.info(f"Number of parameters: {num_param}") print(f"Number of parameters: {num_param}") - logging.info("About to create datamodule") - data_module = TextMelDataModule(hparams=params.data_args) - assert params.start_epoch > 0, params.start_epoch checkpoints = load_checkpoint_if_available(params=params, model=model) model.to(device) + + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer) + logging.info("About to create datamodule") + ljspeech = LJSpeechTtsDataModule(args) + + train_cuts = ljspeech.train_cuts() + train_dl = ljspeech.train_dataloaders(train_cuts) + + valid_cuts = ljspeech.valid_cuts() + valid_dl = ljspeech.valid_dataloaders(valid_cuts) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) - train_dl = data_module.train_dataloader() - valid_dl = data_module.val_dataloader() - - rank = 0 - for epoch in range(params.start_epoch, params.num_epochs + 1): logging.info(f"Start epoch {epoch}") fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) params.cur_epoch = epoch @@ -543,11 +635,14 @@ def main(): train_one_epoch( params=params, model=model, + tokenizer=tokenizer, optimizer=optimizer, train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, tb_writer=tb_writer, + world_size=world_size, + rank=rank, ) if epoch % params.save_every_n == 0 or epoch == params.num_epochs: @@ -571,6 +666,23 @@ def main(): logging.info("Done!") + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + LJSpeechTtsDataModule.add_arguments(parser) + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + torch.set_num_threads(1) torch.set_num_interop_threads(1) diff --git a/egs/ljspeech/TTS/matcha/tts_datamodule.py b/egs/ljspeech/TTS/matcha/tts_datamodule.py new file mode 100644 index 0000000000..c2be815d9e --- /dev/null +++ b/egs/ljspeech/TTS/matcha/tts_datamodule.py @@ -0,0 +1,341 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LJSpeechTtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = FbankConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + low_freq=0, + high_freq=8000, + # should be identical to n_feats in ./train.py + num_filters=80, + ) + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(Fbank(config)), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = FbankConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + low_freq=0, + high_freq=8000, + # should be identical to n_feats in ./train.py + num_filters=80, + ) + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(Fbank(config)), + return_cuts=self.args.return_cuts, + ) + else: + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = FbankConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + low_freq=0, + high_freq=8000, + # should be identical to n_feats in ./train.py + num_filters=80, + ) + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(Fbank(config)), + return_cuts=self.args.return_cuts, + ) + else: + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz" + ) + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get validation cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz" + ) diff --git a/egs/ljspeech/TTS/matcha/utils/__init__.py b/egs/ljspeech/TTS/matcha/utils/__init__.py index 2b74b40f50..311744a786 100644 --- a/egs/ljspeech/TTS/matcha/utils/__init__.py +++ b/egs/ljspeech/TTS/matcha/utils/__init__.py @@ -3,3 +3,4 @@ # from matcha.utils.pylogger import get_pylogger # from matcha.utils.rich_utils import enforce_tags, print_config_tree # from matcha.utils.utils import extras, get_metric_value, task_wrapper +from matcha.utils.utils import intersperse diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh index 9ed0f93fde..e1cd0897e9 100755 --- a/egs/ljspeech/TTS/prepare.sh +++ b/egs/ljspeech/TTS/prepare.sh @@ -5,7 +5,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python set -eou pipefail -stage=0 +stage=-1 stop_stage=100 dl_dir=$PWD/download @@ -31,7 +31,19 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then python3 setup.py build_ext --inplace cd ../../ else - log "monotonic_align lib already built" + log "monotonic_align lib for vits already built" + fi + + if [ ! -f ./matcha/utils/monotonic_align/core.cpython-38-x86_64-linux-gnu.so ]; then + pushd matcha/utils/monotonic_align + python3 setup.py build_ext --inplace + mv -v matcha/utils/monotonic_align/core.cpython-38-x86_64-linux-gnu.so ./ + rm -rf matcha + rm -rf build + rm core.c + popd + else + log "monotonic_align lib for matcha-tts already built" fi fi @@ -63,7 +75,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Compute spectrogram for LJSpeech" + log "Stage 2: Compute spectrogram for LJSpeech (used by ./vits)" mkdir -p data/spectrogram if [ ! -e data/spectrogram/.ljspeech.done ]; then ./local/compute_spectrogram_ljspeech.py @@ -71,7 +83,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi if [ ! -e data/spectrogram/.ljspeech-validated.done ]; then - log "Validating data/spectrogram for LJSpeech" + log "Validating data/spectrogram for LJSpeech (used by ./vits)" python3 ./local/validate_manifest.py \ data/spectrogram/ljspeech_cuts_all.jsonl.gz touch data/spectrogram/.ljspeech-validated.done @@ -79,13 +91,13 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Prepare phoneme tokens for LJSpeech" + log "Stage 3: Prepare phoneme tokens for LJSpeech (used by ./vits)" # We assume you have installed piper_phonemize and espnet_tts_frontend. # If not, please install them with: # - piper_phonemize: pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html, # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then - ./local/prepare_tokens_ljspeech.py + ./local/prepare_tokens_ljspeech.py --in-out-dir ./data/spectrogram mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \ data/spectrogram/ljspeech_cuts_all.jsonl.gz touch data/spectrogram/.ljspeech_with_token.done @@ -93,7 +105,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Split the LJSpeech cuts into train, valid and test sets" + log "Stage 4: Split the LJSpeech cuts into train, valid and test sets (used by vits)" if [ ! -e data/spectrogram/.ljspeech_split.done ]; then lhotse subset --last 600 \ data/spectrogram/ljspeech_cuts_all.jsonl.gz \ @@ -126,3 +138,56 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then ./local/prepare_token_file.py --tokens data/tokens.txt fi fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Generate fbank (used by ./matcha)" + mkdir -p data/fbank + if [ ! -e data/fbank/.ljspeech.done ]; then + ./local/compute_fbank_ljspeech.py + touch data/fbank/.ljspeech.done + fi + + if [ ! -e data/fbank/.ljspeech-validated.done ]; then + log "Validating data/fbank for LJSpeech (used by ./matcha)" + python3 ./local/validate_manifest.py \ + data/fbank/ljspeech_cuts_all.jsonl.gz + touch data/fbank/.ljspeech-validated.done + fi +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Prepare phoneme tokens for LJSpeech (used by ./matcha)" + # We assume you have installed piper_phonemize and espnet_tts_frontend. + # If not, please install them with: + # - piper_phonemize: pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html, + # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ + if [ ! -e data/fbank/.ljspeech_with_token.done ]; then + ./local/prepare_tokens_ljspeech.py --in-out-dir ./data/fbank + mv data/fbank/ljspeech_cuts_with_tokens_all.jsonl.gz \ + data/fbank/ljspeech_cuts_all.jsonl.gz + touch data/fbank/.ljspeech_with_token.done + fi +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Split the LJSpeech cuts into train, valid and test sets (used by ./matcha)" + if [ ! -e data/fbank/.ljspeech_split.done ]; then + lhotse subset --last 600 \ + data/fbank/ljspeech_cuts_all.jsonl.gz \ + data/fbank/ljspeech_cuts_validtest.jsonl.gz + lhotse subset --first 100 \ + data/fbank/ljspeech_cuts_validtest.jsonl.gz \ + data/fbank/ljspeech_cuts_valid.jsonl.gz + lhotse subset --last 500 \ + data/fbank/ljspeech_cuts_validtest.jsonl.gz \ + data/fbank/ljspeech_cuts_test.jsonl.gz + + rm data/fbank/ljspeech_cuts_validtest.jsonl.gz + + n=$(( $(gunzip -c data/fbank/ljspeech_cuts_all.jsonl.gz | wc -l) - 600 )) + lhotse subset --first $n \ + data/fbank/ljspeech_cuts_all.jsonl.gz \ + data/fbank/ljspeech_cuts_train.jsonl.gz + touch data/fbank/.ljspeech_split.done + fi +fi From 6a4cb112dd486c1140883c965d4119c6a685edcc Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 20 Oct 2024 10:14:10 +0800 Subject: [PATCH 08/27] use CMVN --- .../TTS/local/compute_fbank_ljspeech.py | 2 + .../TTS/local/compute_fbank_statistics.py | 84 +++++++++++ egs/ljspeech/TTS/matcha/inference.py | 46 +++--- egs/ljspeech/TTS/matcha/train.py | 139 ++++++++++++------ egs/ljspeech/TTS/matcha/tts_datamodule.py | 6 + egs/ljspeech/TTS/prepare.sh | 7 + 6 files changed, 220 insertions(+), 64 deletions(-) create mode 100755 egs/ljspeech/TTS/local/compute_fbank_statistics.py diff --git a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py index 3aeb6add7d..5c25c3cf4a 100755 --- a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py +++ b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py @@ -106,6 +106,8 @@ def compute_fbank_ljspeech(num_jobs: int): use_fft_mag=True, low_freq=0, high_freq=8000, + remove_dc_offset=False, + preemph_coeff=0, # should be identical to n_feats in ../matcha/train.py num_filters=80, ) diff --git a/egs/ljspeech/TTS/local/compute_fbank_statistics.py b/egs/ljspeech/TTS/local/compute_fbank_statistics.py new file mode 100755 index 0000000000..d0232c9832 --- /dev/null +++ b/egs/ljspeech/TTS/local/compute_fbank_statistics.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script compute the mean and std of the fbank features. +""" + +import argparse +import json +import logging +from pathlib import Path + +import torch +from lhotse import CutSet, load_manifest_lazy + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "manifest", + type=Path, + help="Path to the manifest file", + ) + + parser.add_argument( + "cmvn", + type=Path, + help="Path to the cmvn.json", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + manifest = args.manifest + logging.info( + f"Computing fbank mean and std for {manifest} and saving to {args.cmvn}" + ) + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest_lazy(manifest) + assert isinstance(cut_set, CutSet), type(cut_set) + + feat_dim = cut_set[0].features.num_features + num_frames = 0 + s = 0 + sq = 0 + for c in cut_set: + f = torch.from_numpy(c.load_features()) + num_frames += f.shape[0] + s += f.sum() + sq += f.square().sum() + + fbank_mean = s / (num_frames * feat_dim) + fbank_var = sq / (num_frames * feat_dim) - fbank_mean * fbank_mean + print("fbank var", fbank_var) + fbank_std = fbank_var.sqrt() + with open(args.cmvn, "w") as f: + json.dump({"fbank_mean": fbank_mean.item(), "fbank_std": fbank_std.item()}, f) + f.write("\n") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/ljspeech/TTS/matcha/inference.py b/egs/ljspeech/TTS/matcha/inference.py index 29a0f53a83..45d73bf4fb 100755 --- a/egs/ljspeech/TTS/matcha/inference.py +++ b/egs/ljspeech/TTS/matcha/inference.py @@ -10,6 +10,7 @@ import torch from matcha.hifigan.config import v1 from matcha.hifigan.denoiser import Denoiser +from tokenizer import Tokenizer from matcha.hifigan.models import Generator as HiFiGAN from matcha.text import sequence_to_text, text_to_sequence from matcha.utils.utils import intersperse @@ -28,7 +29,7 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=140, + default=1320, help="""It specifies the checkpoint to use for decoding. Note: Epoch counts from 1. """, @@ -37,13 +38,19 @@ def get_parser(): parser.add_argument( "--exp-dir", type=Path, - default="matcha/exp", + default="matcha/exp-fbank", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved """, ) + parser.add_argument( + "--tokens", + type=Path, + default="data/tokens.txt", + ) + return parser @@ -71,19 +78,17 @@ def save_to_folder(filename: str, output: dict, folder: str): sf.write(folder / f"{filename}.wav", output["waveform"], 22050, "PCM_24") -def process_text(text: str): - x = torch.tensor( - intersperse(text_to_sequence(text, ["english_cleaners2"])[0], 0), - dtype=torch.long, - device="cpu", - )[None] +def process_text(text: str, tokenizer): + x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) + x = torch.tensor(x, dtype=torch.long) x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu") - x_phones = sequence_to_text(x.squeeze(0).tolist()) - return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones} + return {"x_orig": text, "x": x, "x_lengths": x_lengths} -def synthesise(model, n_timesteps, text, length_scale, temperature, spks=None): - text_processed = process_text(text) +def synthesise( + model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None +): + text_processed = process_text(text, tokenizer) start_t = dt.datetime.now() output = model.synthesise( text_processed["x"], @@ -108,6 +113,11 @@ def main(): params.update(vars(args)) logging.info(params) + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + params.model_args.n_vocab = params.vocab_size + logging.info("About to create model") model = get_model(params) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) @@ -117,12 +127,13 @@ def main(): denoiser = Denoiser(vocoder, mode="zeros") texts = [ - "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", - "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.", + "How are you doing, my friend", + # "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", + # "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.", ] # Number of ODE Solver steps - n_timesteps = 2 + n_timesteps = 3 # Changes to the speaking rate length_scale = 1.0 @@ -135,6 +146,7 @@ def main(): for i, text in enumerate(tqdm(texts)): output = synthesise( model=model, + tokenizer=tokenizer, n_timesteps=n_timesteps, text=text, length_scale=length_scale, @@ -154,7 +166,7 @@ def main(): print(f"{'*' * 53}") print(f"Phonetised text - {i}") print(f"{'-' * 53}") - print(output["x_phones"]) + print(output["x"]) print(f"{'*' * 53}") print(f"RTF:\t\t{output['rtf']:.6f}") print(f"RTF Waveform:\t{rtf_w:.6f}") @@ -162,7 +174,7 @@ def main(): rtfs_w.append(rtf_w) # Save the generated waveform - save_to_folder(i, output, folder="./my-output") + save_to_folder(i, output, folder="./my-output-1320") print(f"Number of ODE steps: {n_timesteps}") print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}") diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index 94e089d7eb..edf7e1eef7 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -13,6 +13,7 @@ import torch.multiprocessing as mp import torch.nn as nn from lhotse.utils import fix_random_seed +from matcha.data.text_mel_datamodule import TextMelDataModule from matcha.models.matcha_tts import MatchaTTS from matcha.tokenizer import Tokenizer from matcha.utils.model import fix_len_compatibility @@ -122,8 +123,11 @@ def get_parser(): def get_data_statistics(): return AttributeDict( { - "mel_mean": 0.0, - "mel_std": 1.0, + # "mel_mean": -5.517028331756592, # matcha-tts + # "mel_std": 2.0643954277038574, + # ours + "mel_mean": -1.168782114982605, + "mel_std": 1.9283572435379028, } ) @@ -134,7 +138,8 @@ def _get_data_params() -> AttributeDict: "name": "ljspeech", "train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt", "valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt", - "num_workers": 3, + "batch_size": 64, + "num_workers": 1, "pin_memory": False, "cleaners": ["english_cleaners2"], "add_blank": True, @@ -289,8 +294,17 @@ def load_checkpoint_if_available( return saved_params -def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): +def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, params): """Parse batch data""" + mel_mean = params.data_args.data_statistics.mel_mean + mel_std_inv = 1 / params.data_args.data_statistics.mel_std + for i in range(batch["features"].shape[0]): + n = batch["features_lens"][i] + batch["features"][i : i + 1, :n, :] = ( + batch["features"][i : i + 1, :n, :] - mel_mean + ) * mel_std_inv + batch["features"][i : i + 1, n:, :] = 0 + audio = batch["audio"].to(device) features = batch["features"].to(device) audio_lens = batch["audio_lens"].to(device) @@ -298,7 +312,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): tokens = batch["tokens"] tokens = tokenizer.tokens_to_token_ids( - tokens, intersperse_blank=True, add_sos=True, add_eos=True + tokens, intersperse_blank=True, add_sos=False, add_eos=False ) tokens = k2.RaggedTensor(tokens) row_splits = tokens.shape.row_splits(1) @@ -315,7 +329,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): # features_lens[features_lens.argmax()] += pad - return audio, audio_lens, features, features_lens, tokens, tokens_lens + return audio, audio_lens, features, features_lens.long(), tokens, tokens_lens.long() def compute_validation_loss( @@ -336,28 +350,36 @@ def compute_validation_loss( with torch.no_grad(): for batch_idx, batch in enumerate(valid_dl): + if "tokens" in batch: - ( - audio, - audio_lens, - features, - features_lens, - tokens, - tokens_lens, - ) = prepare_input(batch, tokenizer, device) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + ) = prepare_input(batch, tokenizer, device, params) - losses = get_losses( - { - "x": tokens, - "x_lengths": tokens_lens, - "y": features.permute(0, 2, 1), - "y_lengths": features_lens, - "spks": None, # should change it for multi-speakers - "durations": None, - } - ) + losses = get_losses( + { + "x": tokens, + "x_lengths": tokens_lens, + "y": features.permute(0, 2, 1), + "y_lengths": features_lens, + "spks": None, # should change it for multi-speakers + "durations": None, + } + ) - batch_size = len(batch["tokens"]) + batch_size = len(batch["tokens"]) + else: + batch_size = batch["x"].shape[0] + batch["x"] = batch["x"].to(device) + batch["x_lengths"] = batch["x_lengths"].to(device) + batch["y"] = batch["y"].to(device) + batch["y_lengths"] = batch["y_lengths"].to(device) + losses = get_losses(batch) loss_info = MetricsTracker() loss_info["samples"] = batch_size @@ -451,24 +473,38 @@ def save_bad_model(suffix: str = ""): # features_lens, (N,), int32 # tokens: List[List[str]], len(tokens) == N - batch_size = len(batch["tokens"]) - - audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( - batch, tokenizer, device - ) + if "tokens" in batch: + batch_size = len(batch["tokens"]) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + ) = prepare_input(batch, tokenizer, device, params) + else: + batch_size = batch["x"].shape[0] try: with autocast(enabled=params.use_fp16): - losses = get_losses( - { - "x": tokens, - "x_lengths": tokens_lens, - "y": features.permute(0, 2, 1), - "y_lengths": features_lens, - "spks": None, # should change it for multi-speakers - "durations": None, - } - ) + if "tokens" in batch: + losses = get_losses( + { + "x": tokens, + "x_lengths": tokens_lens, + "y": features.permute(0, 2, 1), + "y_lengths": features_lens, + "spks": None, # should change it for multi-speakers + "durations": None, + } + ) + else: + batch["x"] = batch["x"].to(device) + batch["x_lengths"] = batch["x_lengths"].to(device) + batch["y"] = batch["y"].to(device) + batch["y_lengths"] = batch["y_lengths"].to(device) + losses = get_losses(batch) loss = sum(losses.values()) @@ -586,6 +622,7 @@ def run(rank, world_size, args): params.blank_id = tokenizer.pad_id params.vocab_size = tokenizer.vocab_size params.model_args.n_vocab = params.vocab_size + params.model_args.n_vocab = 178 logging.info(params) print(params) @@ -595,7 +632,6 @@ def run(rank, world_size, args): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of parameters: {num_param}") - print(f"Number of parameters: {num_param}") assert params.start_epoch > 0, params.start_epoch checkpoints = load_checkpoint_if_available(params=params, model=model) @@ -609,13 +645,21 @@ def run(rank, world_size, args): optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer) logging.info("About to create datamodule") - ljspeech = LJSpeechTtsDataModule(args) - train_cuts = ljspeech.train_cuts() - train_dl = ljspeech.train_dataloaders(train_cuts) + if False: + params.data_args.tokenizer = tokenizer + data_module = TextMelDataModule(hparams=params.data_args) + del params.data_args.tokenizer + train_dl = data_module.train_dataloader() + valid_dl = data_module.val_dataloader() + else: + ljspeech = LJSpeechTtsDataModule(args) + + train_cuts = ljspeech.train_cuts() + train_dl = ljspeech.train_dataloaders(train_cuts) - valid_cuts = ljspeech.valid_cuts() - valid_dl = ljspeech.valid_dataloaders(valid_cuts) + valid_cuts = ljspeech.valid_cuts() + valid_dl = ljspeech.valid_dataloaders(valid_cuts) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: @@ -625,7 +669,8 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs + 1): logging.info(f"Start epoch {epoch}") fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) + if "sampler" in train_dl: + train_dl.sampler.set_epoch(epoch - 1) params.cur_epoch = epoch diff --git a/egs/ljspeech/TTS/matcha/tts_datamodule.py b/egs/ljspeech/TTS/matcha/tts_datamodule.py index c2be815d9e..0fc16366e3 100644 --- a/egs/ljspeech/TTS/matcha/tts_datamodule.py +++ b/egs/ljspeech/TTS/matcha/tts_datamodule.py @@ -181,6 +181,8 @@ def train_dataloaders( frame_length=1024 / sampling_rate, # (in second), frame_shift=256 / sampling_rate, # (in second) use_fft_mag=True, + remove_dc_offset=False, + preemph_coeff=0, low_freq=0, high_freq=8000, # should be identical to n_feats in ./train.py @@ -242,6 +244,8 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: frame_length=1024 / sampling_rate, # (in second), frame_shift=256 / sampling_rate, # (in second) use_fft_mag=True, + remove_dc_offset=False, + preemph_coeff=0, low_freq=0, high_freq=8000, # should be identical to n_feats in ./train.py @@ -286,6 +290,8 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: frame_length=1024 / sampling_rate, # (in second), frame_shift=256 / sampling_rate, # (in second) use_fft_mag=True, + remove_dc_offset=False, + preemph_coeff=0, low_freq=0, high_freq=8000, # should be identical to n_feats in ./train.py diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh index e1cd0897e9..b140e6f010 100755 --- a/egs/ljspeech/TTS/prepare.sh +++ b/egs/ljspeech/TTS/prepare.sh @@ -191,3 +191,10 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then touch data/fbank/.ljspeech_split.done fi fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Compute fbank mean and std (used by ./matcha)" + if [ ! -f ./data/fbank/cmvn.json ]; then + ./local/compute_fbank_statistics.py ./data/fbank/ljspeech_cuts_train.jsonl.gz ./data/fbank/cmvn.json + fi +fi From 748557febab5ebac4efca5fee2c587a87f7584da Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 21 Oct 2024 21:24:29 +0800 Subject: [PATCH 09/27] add onnx export --- .../TTS/matcha/compute_fbank_ljspeech.py | 1 + egs/ljspeech/TTS/matcha/export_onnx.py | 119 ++++++++++++++++++ egs/ljspeech/TTS/matcha/inference.py | 30 +++-- egs/ljspeech/TTS/matcha/onnx_pretrained.py | 84 +++++++++++++ egs/ljspeech/TTS/matcha/train.py | 34 +++-- egs/ljspeech/TTS/matcha/tts_datamodule.py | 66 +++++----- 6 files changed, 280 insertions(+), 54 deletions(-) create mode 120000 egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py create mode 100755 egs/ljspeech/TTS/matcha/export_onnx.py create mode 100755 egs/ljspeech/TTS/matcha/onnx_pretrained.py diff --git a/egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py b/egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py new file mode 120000 index 0000000000..85255ba0c0 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py @@ -0,0 +1 @@ +../local/compute_fbank_ljspeech.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/matcha/export_onnx.py b/egs/ljspeech/TTS/matcha/export_onnx.py new file mode 100755 index 0000000000..c56e2da894 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/export_onnx.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 + +import json +import logging + +import torch +from inference import get_parser +from tokenizer import Tokenizer +from train import get_model, get_params +from icefall.checkpoint import load_checkpoint +from onnxruntime.quantization import QuantType, quantize_dynamic + + +class ModelWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + temperature: torch.Tensor, + length_scale: torch.Tensor, + ) -> torch.Tensor: + """ + Args: : + x: (batch_size, num_tokens), torch.int64 + x_lengths: (batch_size,), torch.int64 + temperature: (1,), torch.float32 + length_scale (1,), torch.float32 + Returns: + mel: (batch_size, feat_dim, num_frames) + + """ + mel = self.model.synthesise( + x=x, + x_lengths=x_lengths, + n_timesteps=3, + temperature=temperature, + length_scale=length_scale, + )["mel"] + + # mel: (batch_size, feat_dim, num_frames) + + return mel + + +@torch.inference_mode +def main(): + parser = get_parser() + args = parser.parse_args() + params = get_params() + + params.update(vars(args)) + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + params.model_args.n_vocab = params.vocab_size + + with open(params.cmvn) as f: + stats = json.load(f) + params.data_args.data_statistics.mel_mean = stats["fbank_mean"] + params.data_args.data_statistics.mel_std = stats["fbank_std"] + + params.model_args.data_statistics.mel_mean = stats["fbank_mean"] + params.model_args.data_statistics.mel_std = stats["fbank_std"] + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + wrapper = ModelWrapper(model) + wrapper.eval() + + # Use a large value so the the rotary position embedding in the text + # encoder has a large initial length + x = torch.ones(1, 2000, dtype=torch.int64) + x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) + temperature = torch.tensor([1.0]) + length_scale = torch.tensor([1.0]) + mel = wrapper(x, x_lengths, temperature, length_scale) + print("mel", mel.shape) + + opset_version = 14 + filename = "model.onnx" + torch.onnx.export( + wrapper, + (x, x_lengths, temperature, length_scale), + filename, + opset_version=opset_version, + input_names=["x", "x_length", "temperature", "length_scale"], + output_names=["mel"], + dynamic_axes={ + "x": {0: "N", 1: "L"}, + "x_length": {0: "N"}, + "mel": {0: "N", 2: "L"}, + }, + ) + + print("Generate int8 quantization models") + + filename_int8 = "model.int8.onnx" + quantize_dynamic( + model_input=filename, + model_output=filename_int8, + weight_type=QuantType.QInt8, + ) + + print(f"Saved to {filename} and {filename_int8}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/ljspeech/TTS/matcha/inference.py b/egs/ljspeech/TTS/matcha/inference.py index 45d73bf4fb..49c9c708aa 100755 --- a/egs/ljspeech/TTS/matcha/inference.py +++ b/egs/ljspeech/TTS/matcha/inference.py @@ -5,6 +5,7 @@ import logging from pathlib import Path +import json import numpy as np import soundfile as sf import torch @@ -29,7 +30,7 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=1320, + default=2810, help="""It specifies the checkpoint to use for decoding. Note: Epoch counts from 1. """, @@ -38,7 +39,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=Path, - default="matcha/exp-fbank", + default="matcha/exp-new-3", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -51,6 +52,13 @@ def get_parser(): default="data/tokens.txt", ) + parser.add_argument( + "--cmvn", + type=str, + default="data/fbank/cmvn.json", + help="""Path to vocabulary.""", + ) + return parser @@ -111,13 +119,21 @@ def main(): params = get_params() params.update(vars(args)) - logging.info(params) tokenizer = Tokenizer(params.tokens) params.blank_id = tokenizer.pad_id params.vocab_size = tokenizer.vocab_size params.model_args.n_vocab = params.vocab_size + with open(params.cmvn) as f: + stats = json.load(f) + params.data_args.data_statistics.mel_mean = stats["fbank_mean"] + params.data_args.data_statistics.mel_std = stats["fbank_std"] + + params.model_args.data_statistics.mel_mean = stats["fbank_mean"] + params.model_args.data_statistics.mel_std = stats["fbank_std"] + logging.info(params) + logging.info("About to create model") model = get_model(params) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) @@ -127,9 +143,9 @@ def main(): denoiser = Denoiser(vocoder, mode="zeros") texts = [ - "How are you doing, my friend", - # "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", - # "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.", + "How are you doing? my friend.", + "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", + "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.", ] # Number of ODE Solver steps @@ -174,7 +190,7 @@ def main(): rtfs_w.append(rtf_w) # Save the generated waveform - save_to_folder(i, output, folder="./my-output-1320") + save_to_folder(i, output, folder=f"./my-output-{params.epoch}") print(f"Number of ODE steps: {n_timesteps}") print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}") diff --git a/egs/ljspeech/TTS/matcha/onnx_pretrained.py b/egs/ljspeech/TTS/matcha/onnx_pretrained.py new file mode 100755 index 0000000000..1a973bcffc --- /dev/null +++ b/egs/ljspeech/TTS/matcha/onnx_pretrained.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +import logging + +import onnxruntime as ort +import torch +from tokenizer import Tokenizer + +from inference import load_vocoder +import soundfile as sf + + +class OnnxModel: + def __init__( + self, + filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + self.tokenizer = Tokenizer("./data/tokens.txt") + self.model = ort.InferenceSession( + filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + for i in self.model.get_inputs(): + print(i) + + print("-----") + + for i in self.model.get_outputs(): + print(i) + + def __call__(self, x: torch.tensor): + assert x.ndim == 2, x.shape + assert x.shape[0] == 1, x.shape + + x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) + print("x_lengths", x_lengths) + print("x", x.shape) + + temperature = torch.tensor([1.0], dtype=torch.float32) + length_scale = torch.tensor([1.0], dtype=torch.float32) + + mel = self.model.run( + [self.model.get_outputs()[0].name], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lengths.numpy(), + self.model.get_inputs()[2].name: temperature.numpy(), + self.model.get_inputs()[3].name: length_scale.numpy(), + }, + )[0] + + return torch.from_numpy(mel) + + +@torch.inference_mode() +def main(): + model = OnnxModel("./model.onnx") + text = "hello, how are you doing?" + text = "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar." + x = model.tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) + x = torch.tensor(x, dtype=torch.int64) + mel = model(x) + print("mel", mel.shape) # (1, 80, 170) + + vocoder = load_vocoder("/star-fj/fangjun/open-source/Matcha-TTS/generator_v1") + audio = vocoder(mel).clamp(-1, 1) + print("audio", audio.shape) # (1, 1, num_samples) + audio = audio.squeeze() + + # skip denoiser + sf.write("onnx.wav", audio, 22050, "PCM_16") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index edf7e1eef7..bb9307864d 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -7,6 +7,7 @@ from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Union +import json import k2 import torch @@ -90,6 +91,13 @@ def get_parser(): help="""Path to vocabulary.""", ) + parser.add_argument( + "--cmvn", + type=str, + default="data/fbank/cmvn.json", + help="""Path to vocabulary.""", + ) + parser.add_argument( "--seed", type=int, @@ -123,11 +131,8 @@ def get_parser(): def get_data_statistics(): return AttributeDict( { - # "mel_mean": -5.517028331756592, # matcha-tts - # "mel_std": 2.0643954277038574, - # ours - "mel_mean": -1.168782114982605, - "mel_std": 1.9283572435379028, + "mel_mean": 0, + "mel_std": 1, } ) @@ -138,9 +143,9 @@ def _get_data_params() -> AttributeDict: "name": "ljspeech", "train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt", "valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt", - "batch_size": 64, - "num_workers": 1, - "pin_memory": False, + # "batch_size": 64, + # "num_workers": 1, + # "pin_memory": False, "cleaners": ["english_cleaners2"], "add_blank": True, "n_spks": 1, @@ -312,7 +317,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, param tokens = batch["tokens"] tokens = tokenizer.tokens_to_token_ids( - tokens, intersperse_blank=True, add_sos=False, add_eos=False + tokens, intersperse_blank=True, add_sos=True, add_eos=True ) tokens = k2.RaggedTensor(tokens) row_splits = tokens.shape.row_splits(1) @@ -619,10 +624,17 @@ def run(rank, world_size, args): logging.info(f"Device: {device}") tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.pad_id + params.pad_id = tokenizer.pad_id params.vocab_size = tokenizer.vocab_size params.model_args.n_vocab = params.vocab_size - params.model_args.n_vocab = 178 + + with open(params.cmvn) as f: + stats = json.load(f) + params.data_args.data_statistics.mel_mean = stats["fbank_mean"] + params.data_args.data_statistics.mel_std = stats["fbank_std"] + + params.model_args.data_statistics.mel_mean = stats["fbank_mean"] + params.model_args.data_statistics.mel_std = stats["fbank_std"] logging.info(params) print(params) diff --git a/egs/ljspeech/TTS/matcha/tts_datamodule.py b/egs/ljspeech/TTS/matcha/tts_datamodule.py index 0fc16366e3..0227d9fdbe 100644 --- a/egs/ljspeech/TTS/matcha/tts_datamodule.py +++ b/egs/ljspeech/TTS/matcha/tts_datamodule.py @@ -24,7 +24,8 @@ from typing import Any, Dict, Optional import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy +from lhotse import CutSet, load_manifest_lazy +from compute_fbank_ljspeech import MyFbank, MyFbankConfig from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures CutConcatenate, CutMix, @@ -176,22 +177,19 @@ def train_dataloaders( if self.args.on_the_fly_feats: sampling_rate = 22050 - config = FbankConfig( + config = MyFbankConfig( + n_fft=1024, + n_mels=80, sampling_rate=sampling_rate, - frame_length=1024 / sampling_rate, # (in second), - frame_shift=256 / sampling_rate, # (in second) - use_fft_mag=True, - remove_dc_offset=False, - preemph_coeff=0, - low_freq=0, - high_freq=8000, - # should be identical to n_feats in ./train.py - num_filters=80, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, ) train = SpeechSynthesisDataset( return_text=False, return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(Fbank(config)), + feature_input_strategy=OnTheFlyFeatures(MyFbank(config)), return_cuts=self.args.return_cuts, ) @@ -229,7 +227,8 @@ def train_dataloaders( sampler=train_sampler, batch_size=None, num_workers=self.args.num_workers, - persistent_workers=False, + persistent_workers=True, + pin_memory=True, worker_init_fn=worker_init_fn, ) @@ -239,22 +238,19 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: logging.info("About to create dev dataset") if self.args.on_the_fly_feats: sampling_rate = 22050 - config = FbankConfig( + config = MyFbankConfig( + n_fft=1024, + n_mels=80, sampling_rate=sampling_rate, - frame_length=1024 / sampling_rate, # (in second), - frame_shift=256 / sampling_rate, # (in second) - use_fft_mag=True, - remove_dc_offset=False, - preemph_coeff=0, - low_freq=0, - high_freq=8000, - # should be identical to n_feats in ./train.py - num_filters=80, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, ) validate = SpeechSynthesisDataset( return_text=False, return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(Fbank(config)), + feature_input_strategy=OnTheFlyFeatures(MyFbank(config)), return_cuts=self.args.return_cuts, ) else: @@ -276,7 +272,8 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: sampler=valid_sampler, batch_size=None, num_workers=2, - persistent_workers=False, + persistent_workers=True, + pin_memory=True, ) return valid_dl @@ -285,22 +282,19 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.info("About to create test dataset") if self.args.on_the_fly_feats: sampling_rate = 22050 - config = FbankConfig( + config = MyFbankConfig( + n_fft=1024, + n_mels=80, sampling_rate=sampling_rate, - frame_length=1024 / sampling_rate, # (in second), - frame_shift=256 / sampling_rate, # (in second) - use_fft_mag=True, - remove_dc_offset=False, - preemph_coeff=0, - low_freq=0, - high_freq=8000, - # should be identical to n_feats in ./train.py - num_filters=80, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, ) test = SpeechSynthesisDataset( return_text=False, return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(Fbank(config)), + feature_input_strategy=OnTheFlyFeatures(MyFbank(config)), return_cuts=self.args.return_cuts, ) else: From a67d4b9a8084364013a8e048fe2d40a184cebab1 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 28 Oct 2024 17:51:45 +0800 Subject: [PATCH 10/27] support all hifigan versions --- egs/ljspeech/TTS/matcha/export_onnx.py | 125 ++++++++++++------ .../TTS/matcha/export_onnx_hifigan.py | 106 +++++++++++++++ egs/ljspeech/TTS/matcha/hifigan/config.py | 74 ++++++++++- egs/ljspeech/TTS/matcha/inference.py | 17 ++- egs/ljspeech/TTS/matcha/onnx_pretrained.py | 95 +++++++++++-- egs/ljspeech/TTS/matcha/train.py | 5 +- 6 files changed, 360 insertions(+), 62 deletions(-) create mode 100755 egs/ljspeech/TTS/matcha/export_onnx_hifigan.py diff --git a/egs/ljspeech/TTS/matcha/export_onnx.py b/egs/ljspeech/TTS/matcha/export_onnx.py index c56e2da894..cf5069b113 100755 --- a/egs/ljspeech/TTS/matcha/export_onnx.py +++ b/egs/ljspeech/TTS/matcha/export_onnx.py @@ -1,20 +1,51 @@ #!/usr/bin/env python3 +""" +This script exports a Matcha-TTS model to ONNX. +Note that the model outputs fbank. You need to use a vocoder to convert +it to audio. See also ./export_onnx_hifigan.py +""" + import json import logging +from typing import Any, Dict +import onnx import torch from inference import get_parser from tokenizer import Tokenizer from train import get_model, get_params + from icefall.checkpoint import load_checkpoint -from onnxruntime.quantization import QuantType, quantize_dynamic + + +def add_meta_data(filename: str, meta_data: Dict[str, Any]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + + while len(model.metadata_props): + model.metadata_props.pop() + + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + + onnx.save(model, filename) class ModelWrapper(torch.nn.Module): - def __init__(self, model): + def __init__(self, model, num_steps: int = 5): super().__init__() self.model = model + self.num_steps = num_steps def forward( self, @@ -30,23 +61,24 @@ def forward( temperature: (1,), torch.float32 length_scale (1,), torch.float32 Returns: - mel: (batch_size, feat_dim, num_frames) + audio: (batch_size, num_samples) """ mel = self.model.synthesise( x=x, x_lengths=x_lengths, - n_timesteps=3, + n_timesteps=self.num_steps, temperature=temperature, length_scale=length_scale, )["mel"] - # mel: (batch_size, feat_dim, num_frames) + # audio = self.vocoder(mel).clamp(-1, 1).squeeze(1) + return mel -@torch.inference_mode +@torch.inference_mode() def main(): parser = get_parser() args = parser.parse_args() @@ -72,44 +104,49 @@ def main(): model = get_model(params) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - wrapper = ModelWrapper(model) - wrapper.eval() - - # Use a large value so the the rotary position embedding in the text - # encoder has a large initial length - x = torch.ones(1, 2000, dtype=torch.int64) - x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) - temperature = torch.tensor([1.0]) - length_scale = torch.tensor([1.0]) - mel = wrapper(x, x_lengths, temperature, length_scale) - print("mel", mel.shape) - - opset_version = 14 - filename = "model.onnx" - torch.onnx.export( - wrapper, - (x, x_lengths, temperature, length_scale), - filename, - opset_version=opset_version, - input_names=["x", "x_length", "temperature", "length_scale"], - output_names=["mel"], - dynamic_axes={ - "x": {0: "N", 1: "L"}, - "x_length": {0: "N"}, - "mel": {0: "N", 2: "L"}, - }, - ) - - print("Generate int8 quantization models") - - filename_int8 = "model.int8.onnx" - quantize_dynamic( - model_input=filename, - model_output=filename_int8, - weight_type=QuantType.QInt8, - ) - - print(f"Saved to {filename} and {filename_int8}") + for num_steps in [2, 3, 4, 5, 6]: + logging.info(f"num_steps: {num_steps}") + wrapper = ModelWrapper(model, num_steps=num_steps) + wrapper.eval() + + # Use a large value so the rotary position embedding in the text + # encoder has a large initial length + x = torch.ones(1, 1000, dtype=torch.int64) + x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) + temperature = torch.tensor([1.0]) + length_scale = torch.tensor([1.0]) + + opset_version = 14 + filename = f"model-steps-{num_steps}.onnx" + torch.onnx.export( + wrapper, + (x, x_lengths, temperature, length_scale), + filename, + opset_version=opset_version, + input_names=["x", "x_length", "temperature", "length_scale"], + output_names=["mel"], + dynamic_axes={ + "x": {0: "N", 1: "L"}, + "x_length": {0: "N"}, + "mel": {0: "N", 2: "L"}, + }, + ) + + meta_data = { + "model_type": "matcha-tts", + "language": "English", + "voice": "en-us", + "has_espeak": 1, + "n_speakers": 1, + "sample_rate": 22050, + "version": 1, + "model_author": "icefall", + "maintainer": "k2-fsa", + "dataset": "LJ Speech", + "num_ode_steps": num_steps, + } + add_meta_data(filename=filename, meta_data=meta_data) + print(meta_data) if __name__ == "__main__": diff --git a/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py b/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py new file mode 100755 index 0000000000..3b2ebf5025 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 + +import logging +from typing import Any, Dict + +import onnx +import torch + +from inference import load_vocoder + + +def add_meta_data(filename: str, meta_data: Dict[str, Any]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + + while len(model.metadata_props): + model.metadata_props.pop() + + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + + onnx.save(model, filename) + + +class ModelWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward( + self, + mel: torch.Tensor, + ) -> torch.Tensor: + """ + Args: : + mel: (batch_size, feat_dim, num_frames), torch.float32 + Returns: + audio: (batch_size, num_samples), torch.float32 + """ + audio = self.model(mel).clamp(-1, 1).squeeze(1) + return audio + + +@torch.inference_mode() +def main(): + # Please go to + # https://github.com/csukuangfj/models/tree/master/hifigan + # to download the following files + model_filenames = ["./generator_v1", "./generator_v2", "./generator_v3"] + + for f in model_filenames: + logging.info(f) + model = load_vocoder(f) + wrapper = ModelWrapper(model) + wrapper.eval() + num_param = sum([p.numel() for p in wrapper.parameters()]) + logging.info(f"{f}: Number of parameters: {num_param}") + + # Use a large value so the rotary position embedding in the text + # encoder has a large initial length + x = torch.ones(1, 80, 100000, dtype=torch.float32) + opset_version = 14 + suffix = f.split("_")[-1] + filename = f"hifigan_{suffix}.onnx" + torch.onnx.export( + wrapper, + x, + filename, + opset_version=opset_version, + input_names=["mel"], + output_names=["audio"], + dynamic_axes={ + "mel": {0: "N", 2: "L"}, + "audio": {0: "N", 1: "L"}, + }, + ) + + meta_data = { + "model_type": "hifigan", + "model_filename": f.split("/")[-1], + "sample_rate": 22050, + "version": 1, + "model_author": "jik876", + "maintainer": "k2-fsa", + "dataset": "LJ Speech", + "url1": "https://github.com/jik876/hifi-gan", + "url2": "https://github.com/csukuangfj/models/tree/master/hifigan", + } + add_meta_data(filename=filename, meta_data=meta_data) + print(meta_data) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/ljspeech/TTS/matcha/hifigan/config.py b/egs/ljspeech/TTS/matcha/hifigan/config.py index b3abea9e15..ecba62fd42 100644 --- a/egs/ljspeech/TTS/matcha/hifigan/config.py +++ b/egs/ljspeech/TTS/matcha/hifigan/config.py @@ -24,5 +24,77 @@ "fmax": 8000, "fmax_loss": None, "num_workers": 4, - "dist_config": {"dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1}, + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1, + }, +} + +# See https://drive.google.com/drive/folders/1bB1tnGIxRN-edlf6k2Rmi1gNCK9Cpcvf +v2 = { + "resblock": "1", + "num_gpus": 0, + "batch_size": 16, + "learning_rate": 0.0002, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + "upsample_rates": [8, 8, 2, 2], + "upsample_kernel_sizes": [16, 16, 4, 4], + "upsample_initial_channel": 128, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "resblock_initial_channel": 64, + "segment_size": 8192, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + "sampling_rate": 22050, + "fmin": 0, + "fmax": 8000, + "fmax_loss": None, + "num_workers": 4, + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1, + }, +} + +# See https://drive.google.com/drive/folders/1KKvuJTLp_gZXC8lug7H_lSXct38_3kx1 +v3 = { + "resblock": "2", + "num_gpus": 0, + "batch_size": 16, + "learning_rate": 0.0002, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + "upsample_rates": [8, 8, 4], + "upsample_kernel_sizes": [16, 16, 8], + "upsample_initial_channel": 256, + "resblock_kernel_sizes": [3, 5, 7], + "resblock_dilation_sizes": [[1, 2], [2, 6], [3, 12]], + "resblock_initial_channel": 128, + "segment_size": 8192, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + "sampling_rate": 22050, + "fmin": 0, + "fmax": 8000, + "fmax_loss": None, + "num_workers": 4, + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1, + }, } diff --git a/egs/ljspeech/TTS/matcha/inference.py b/egs/ljspeech/TTS/matcha/inference.py index 49c9c708aa..250c38f200 100755 --- a/egs/ljspeech/TTS/matcha/inference.py +++ b/egs/ljspeech/TTS/matcha/inference.py @@ -9,7 +9,7 @@ import numpy as np import soundfile as sf import torch -from matcha.hifigan.config import v1 +from matcha.hifigan.config import v1, v2, v3 from matcha.hifigan.denoiser import Denoiser from tokenizer import Tokenizer from matcha.hifigan.models import Generator as HiFiGAN @@ -63,7 +63,15 @@ def get_parser(): def load_vocoder(checkpoint_path): - h = AttributeDict(v1) + if checkpoint_path.endswith("v1"): + h = AttributeDict(v1) + elif checkpoint_path.endswith("v2"): + h = AttributeDict(v2) + elif checkpoint_path.endswith("v3"): + h = AttributeDict(v3) + else: + raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}") + hifigan = HiFiGAN(h).to("cpu") hifigan.load_state_dict( torch.load(checkpoint_path, map_location="cpu")["generator"] @@ -143,13 +151,12 @@ def main(): denoiser = Denoiser(vocoder, mode="zeros") texts = [ - "How are you doing? my friend.", "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.", ] # Number of ODE Solver steps - n_timesteps = 3 + n_timesteps = 2 # Changes to the speaking rate length_scale = 1.0 @@ -203,4 +210,6 @@ def main(): formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) main() diff --git a/egs/ljspeech/TTS/matcha/onnx_pretrained.py b/egs/ljspeech/TTS/matcha/onnx_pretrained.py index 1a973bcffc..24955e881c 100755 --- a/egs/ljspeech/TTS/matcha/onnx_pretrained.py +++ b/egs/ljspeech/TTS/matcha/onnx_pretrained.py @@ -4,12 +4,13 @@ import onnxruntime as ort import torch from tokenizer import Tokenizer +import datetime as dt -from inference import load_vocoder import soundfile as sf +from inference import load_vocoder -class OnnxModel: +class OnnxHifiGANModel: def __init__( self, filename: str, @@ -18,6 +19,44 @@ def __init__( session_opts.inter_op_num_threads = 1 session_opts.intra_op_num_threads = 1 + self.session_opts = session_opts + self.model = ort.InferenceSession( + filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + for i in self.model.get_inputs(): + print(i) + + print("-----") + + for i in self.model.get_outputs(): + print(i) + + def __call__(self, x: torch.tensor): + assert x.ndim == 3, x.shape + assert x.shape[0] == 1, x.shape + + audio = self.model.run( + [self.model.get_outputs()[0].name], + { + self.model.get_inputs()[0].name: x.numpy(), + }, + )[0] + + return torch.from_numpy(audio) + + +class OnnxModel: + def __init__( + self, + filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 2 + self.session_opts = session_opts self.tokenizer = Tokenizer("./data/tokens.txt") self.model = ort.InferenceSession( @@ -58,27 +97,63 @@ def __call__(self, x: torch.tensor): return torch.from_numpy(mel) -@torch.inference_mode() +@torch.no_grad() def main(): - model = OnnxModel("./model.onnx") - text = "hello, how are you doing?" - text = "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar." + model = OnnxModel("./model-steps-6.onnx") + vocoder = OnnxHifiGANModel("./hifigan_v1.onnx") + text = "Today as always, men fall into two groups: slaves and free men." + text += "hello, how are you doing?" x = model.tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) x = torch.tensor(x, dtype=torch.int64) + + start_t = dt.datetime.now() mel = model(x) - print("mel", mel.shape) # (1, 80, 170) + end_t = dt.datetime.now() + + for i in range(3): + audio = vocoder(mel) + + start_t2 = dt.datetime.now() + audio = vocoder(mel) + end_t2 = dt.datetime.now() - vocoder = load_vocoder("/star-fj/fangjun/open-source/Matcha-TTS/generator_v1") - audio = vocoder(mel).clamp(-1, 1) print("audio", audio.shape) # (1, 1, num_samples) audio = audio.squeeze() + t = (end_t - start_t).total_seconds() + t2 = (end_t2 - start_t2).total_seconds() + rtf = t * 22050 / audio.shape[-1] + rtf2 = t2 * 22050 / audio.shape[-1] + print("RTF", rtf) + print("RTF", rtf2) + # skip denoiser - sf.write("onnx.wav", audio, 22050, "PCM_16") + sf.write("onnx2.wav", audio, 22050, "PCM_16") if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() + +""" + +|HifiGAN |RTF |#Parameters (M)| +|----------|-----|---------------| +|v1 |0.818| 13.926 | +|v2 |0.101| 0.925 | +|v3 |0.118| 1.462 | + +|Num steps|Acoustic Model RTF| +|---------|------------------| +| 2 | 0.039 | +| 3 | 0.047 | +| 4 | 0.071 | +| 5 | 0.076 | +| 6 | 0.103 | + +""" diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index bb9307864d..747292197a 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -741,8 +741,7 @@ def main(): run(rank=0, world_size=1, args=args) -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) main() From 7994684bf4d35054d00463403759ab5b08e3f624 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 28 Oct 2024 19:06:44 +0800 Subject: [PATCH 11/27] Reformat code --- .../TTS/matcha/models/baselightningmodule.py | 23 +++++-- .../TTS/matcha/models/components/decoder.py | 29 ++++++-- .../matcha/models/components/flow_matching.py | 25 +++++-- .../matcha/models/components/text_encoder.py | 69 +++++++++++++++---- .../matcha/models/components/transformer.py | 63 +++++++++++++---- egs/ljspeech/TTS/matcha/text/__init__.py | 10 ++- egs/ljspeech/TTS/matcha/text/cleaners.py | 9 +-- egs/ljspeech/TTS/matcha/text/numbers.py | 4 +- egs/ljspeech/TTS/matcha/text/symbols.py | 4 +- egs/ljspeech/TTS/matcha/utils/audio.py | 16 +++-- .../matcha/utils/generate_data_statistics.py | 20 ++++-- .../utils/get_durations_from_trained_model.py | 34 +++++++-- .../TTS/matcha/utils/instantiators.py | 8 ++- .../TTS/matcha/utils/logging_utils.py | 8 ++- egs/ljspeech/TTS/matcha/utils/model.py | 7 +- egs/ljspeech/TTS/matcha/utils/pylogger.py | 10 ++- egs/ljspeech/TTS/matcha/utils/rich_utils.py | 4 +- egs/ljspeech/TTS/matcha/utils/utils.py | 4 +- 18 files changed, 268 insertions(+), 79 deletions(-) diff --git a/egs/ljspeech/TTS/matcha/models/baselightningmodule.py b/egs/ljspeech/TTS/matcha/models/baselightningmodule.py index f8abe7b44f..e80d2a5c97 100644 --- a/egs/ljspeech/TTS/matcha/models/baselightningmodule.py +++ b/egs/ljspeech/TTS/matcha/models/baselightningmodule.py @@ -32,7 +32,10 @@ def configure_optimizers(self) -> Any: if self.hparams.scheduler not in (None, {}): scheduler_args = {} # Manage last epoch for exponential schedulers - if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters: + if ( + "last_epoch" + in inspect.signature(self.hparams.scheduler.scheduler).parameters + ): if hasattr(self, "ckpt_loaded_epoch"): current_epoch = self.ckpt_loaded_epoch - 1 else: @@ -74,7 +77,9 @@ def get_losses(self, batch): } def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init + self.ckpt_loaded_epoch = checkpoint[ + "epoch" + ] # pylint: disable=attribute-defined-outside-init def training_step(self, batch: Any, batch_idx: int): loss_dict = self.get_losses(batch) @@ -183,8 +188,14 @@ def on_validation_end(self) -> None: for i in range(2): x = one_batch["x"][i].unsqueeze(0).to(self.device) x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device) - spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None - output = self.synthesise(x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks) + spks = ( + one_batch["spks"][i].unsqueeze(0).to(self.device) + if one_batch["spks"] is not None + else None + ) + output = self.synthesise( + x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks + ) y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"] attn = output["attn"] self.logger.experiment.add_image( @@ -207,4 +218,6 @@ def on_validation_end(self) -> None: ) def on_before_optimizer_step(self, optimizer): - self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()}) + self.log_dict( + {f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()} + ) diff --git a/egs/ljspeech/TTS/matcha/models/components/decoder.py b/egs/ljspeech/TTS/matcha/models/components/decoder.py index 1137cd7008..5850f2639b 100644 --- a/egs/ljspeech/TTS/matcha/models/components/decoder.py +++ b/egs/ljspeech/TTS/matcha/models/components/decoder.py @@ -46,7 +46,9 @@ def forward(self, x, mask): class ResnetBlock1D(torch.nn.Module): def __init__(self, dim, dim_out, time_emb_dim, groups=8): super().__init__() - self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)) + self.mlp = torch.nn.Sequential( + nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out) + ) self.block1 = Block1D(dim, dim_out, groups=groups) self.block2 = Block1D(dim_out, dim_out, groups=groups) @@ -131,7 +133,14 @@ class Upsample1D(nn.Module): number of output channels. Defaults to `channels`. """ - def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"): + def __init__( + self, + channels, + use_conv=False, + use_conv_transpose=True, + out_channels=None, + name="conv", + ): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -235,7 +244,9 @@ def __init__( input_channel = output_channel output_channel = channels[i] is_last = i == len(channels) - 1 - resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + resnet = ResnetBlock1D( + dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim + ) transformer_blocks = nn.ModuleList( [ self.get_block( @@ -250,16 +261,22 @@ def __init__( ] ) downsample = ( - Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1) + Downsample1D(output_channel) + if not is_last + else nn.Conv1d(output_channel, output_channel, 3, padding=1) ) - self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) + self.down_blocks.append( + nn.ModuleList([resnet, transformer_blocks, downsample]) + ) for i in range(num_mid_blocks): input_channel = channels[-1] out_channels = channels[-1] - resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + resnet = ResnetBlock1D( + dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim + ) transformer_blocks = nn.ModuleList( [ diff --git a/egs/ljspeech/TTS/matcha/models/components/flow_matching.py b/egs/ljspeech/TTS/matcha/models/components/flow_matching.py index 552c4b3834..5a7226b4f7 100644 --- a/egs/ljspeech/TTS/matcha/models/components/flow_matching.py +++ b/egs/ljspeech/TTS/matcha/models/components/flow_matching.py @@ -4,6 +4,7 @@ import torch.nn.functional as F from matcha.models.components.decoder import Decoder + # from matcha.utils.pylogger import get_pylogger # log = get_pylogger(__name__) @@ -50,7 +51,9 @@ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): """ z = torch.randn_like(mu) * temperature t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) - return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) + return self.solve_euler( + z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond + ) def solve_euler(self, x, t_span, mu, mask, spks, cond): """ @@ -112,14 +115,22 @@ def compute_loss(self, x1, mask, mu, spks=None, cond=None): y = (1 - (1 - self.sigma_min) * t) * z + t * x1 u = x1 - (1 - self.sigma_min) * z - loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / ( - torch.sum(mask) * u.shape[1] - ) + loss = F.mse_loss( + self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum" + ) / (torch.sum(mask) * u.shape[1]) return loss, y class CFM(BASECFM): - def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64): + def __init__( + self, + in_channels, + out_channel, + cfm_params, + decoder_params, + n_spks=1, + spk_emb_dim=64, + ): super().__init__( n_feats=in_channels, cfm_params=cfm_params, @@ -129,4 +140,6 @@ def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks= in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0) # Just change the architecture of the estimator here - self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params) + self.estimator = Decoder( + in_channels=in_channels, out_channels=out_channel, **decoder_params + ) diff --git a/egs/ljspeech/TTS/matcha/models/components/text_encoder.py b/egs/ljspeech/TTS/matcha/models/components/text_encoder.py index efd2253562..68f8ad864e 100644 --- a/egs/ljspeech/TTS/matcha/models/components/text_encoder.py +++ b/egs/ljspeech/TTS/matcha/models/components/text_encoder.py @@ -34,7 +34,15 @@ def forward(self, x): class ConvReluNorm(nn.Module): - def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + def __init__( + self, + in_channels, + hidden_channels, + out_channels, + kernel_size, + n_layers, + p_dropout, + ): super().__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels @@ -45,12 +53,23 @@ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_la self.conv_layers = torch.nn.ModuleList() self.norm_layers = torch.nn.ModuleList() - self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.conv_layers.append( + torch.nn.Conv1d( + in_channels, hidden_channels, kernel_size, padding=kernel_size // 2 + ) + ) self.norm_layers.append(LayerNorm(hidden_channels)) - self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout)) + self.relu_drop = torch.nn.Sequential( + torch.nn.ReLU(), torch.nn.Dropout(p_dropout) + ) for _ in range(n_layers - 1): self.conv_layers.append( - torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2) + torch.nn.Conv1d( + hidden_channels, + hidden_channels, + kernel_size, + padding=kernel_size // 2, + ) ) self.norm_layers.append(LayerNorm(hidden_channels)) self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) @@ -75,9 +94,13 @@ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): self.p_dropout = p_dropout self.drop = torch.nn.Dropout(p_dropout) - self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.conv_1 = torch.nn.Conv1d( + in_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) self.norm_1 = LayerNorm(filter_channels) - self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.conv_2 = torch.nn.Conv1d( + filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) self.norm_2 = LayerNorm(filter_channels) self.proj = torch.nn.Conv1d(filter_channels, 1, 1) @@ -128,7 +151,9 @@ def _build_cache(self, x: torch.Tensor): seq_len = x.shape[0] # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ - theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) + theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to( + x.device + ) # Create position indexes `[0, 1, ..., seq_len - 1]` seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) @@ -167,7 +192,9 @@ def forward(self, x: torch.Tensor): # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ neg_half_x = self._neg_half(x_rope) - x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]]) + x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + ( + neg_half_x * self.sin_cached[: x.shape[0]] + ) return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d") @@ -236,7 +263,9 @@ def attention(self, query, key, value, mask=None): if self.proximal_bias: assert t_s == t_t, "Proximal bias is only available for self-attention." - scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) + scores = scores + self._attention_bias_proximal(t_s).to( + device=scores.device, dtype=scores.dtype + ) if mask is not None: scores = scores.masked_fill(mask == 0, -1e4) p_attn = torch.nn.functional.softmax(scores, dim=-1) @@ -253,7 +282,9 @@ def _attention_bias_proximal(length): class FFN(nn.Module): - def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0): + def __init__( + self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0 + ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -261,8 +292,12 @@ def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dr self.kernel_size = kernel_size self.p_dropout = p_dropout - self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) - self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2) + self.conv_1 = torch.nn.Conv1d( + in_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.conv_2 = torch.nn.Conv1d( + filter_channels, out_channels, kernel_size, padding=kernel_size // 2 + ) self.drop = torch.nn.Dropout(p_dropout) def forward(self, x, x_mask): @@ -298,7 +333,11 @@ def __init__( self.ffn_layers = torch.nn.ModuleList() self.norm_layers_2 = torch.nn.ModuleList() for _ in range(self.n_layers): - self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) + self.attn_layers.append( + MultiHeadAttention( + hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout + ) + ) self.norm_layers_1.append(LayerNorm(hidden_channels)) self.ffn_layers.append( FFN( @@ -367,7 +406,9 @@ def __init__( encoder_params.p_dropout, ) - self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1) + self.proj_m = torch.nn.Conv1d( + self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1 + ) self.proj_w = DurationPredictor( self.n_channels + (spk_emb_dim if n_spks > 1 else 0), duration_predictor_params.filter_channels_dp, diff --git a/egs/ljspeech/TTS/matcha/models/components/transformer.py b/egs/ljspeech/TTS/matcha/models/components/transformer.py index dd1afa3aff..a82e560bca 100644 --- a/egs/ljspeech/TTS/matcha/models/components/transformer.py +++ b/egs/ljspeech/TTS/matcha/models/components/transformer.py @@ -32,7 +32,14 @@ class SnakeBeta(nn.Module): >>> x = a1(x) """ - def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): + def __init__( + self, + in_features, + out_features, + alpha=1.0, + alpha_trainable=True, + alpha_logscale=True, + ): """ Initialization. INPUT: @@ -44,7 +51,9 @@ def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, a alpha will be trained along with the rest of your model. """ super().__init__() - self.in_features = out_features if isinstance(out_features, list) else [out_features] + self.in_features = ( + out_features if isinstance(out_features, list) else [out_features] + ) self.proj = LoRACompatibleLinear(in_features, out_features) # initialize alpha @@ -75,7 +84,9 @@ def forward(self, x): alpha = self.alpha beta = self.beta - x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2) + x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow( + torch.sin(x * alpha), 2 + ) return x @@ -176,8 +187,12 @@ def __init__( super().__init__() self.only_cross_attention = only_cross_attention - self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" - self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_zero = ( + num_embeds_ada_norm is not None + ) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = ( + num_embeds_ada_norm is not None + ) and norm_type == "ada_norm" if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: raise ValueError( @@ -215,7 +230,9 @@ def __init__( ) self.attn2 = Attention( query_dim=dim, - cross_attention_dim=cross_attention_dim if not double_self_attention else None, + cross_attention_dim=cross_attention_dim + if not double_self_attention + else None, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, @@ -229,7 +246,12 @@ def __init__( # 3. Feed-forward self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + ) # let chunk size default to None self._chunk_size = None @@ -261,12 +283,18 @@ def forward( else: norm_hidden_states = self.norm1(hidden_states) - cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + cross_attention_kwargs = ( + cross_attention_kwargs if cross_attention_kwargs is not None else {} + ) attn_output = self.attn1( norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask, + encoder_hidden_states=encoder_hidden_states + if self.only_cross_attention + else None, + attention_mask=encoder_attention_mask + if self.only_cross_attention + else attention_mask, **cross_attention_kwargs, ) if self.use_ada_layer_norm_zero: @@ -276,7 +304,9 @@ def forward( # 2. Cross-Attention if self.attn2 is not None: norm_hidden_states = ( - self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + self.norm2(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm2(hidden_states) ) attn_output = self.attn2( @@ -291,7 +321,9 @@ def forward( norm_hidden_states = self.norm3(hidden_states) if self.use_ada_layer_norm_zero: - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory @@ -302,7 +334,12 @@ def forward( num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size ff_output = torch.cat( - [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + [ + self.ff(hid_slice) + for hid_slice in norm_hidden_states.chunk( + num_chunks, dim=self._chunk_dim + ) + ], dim=self._chunk_dim, ) else: diff --git a/egs/ljspeech/TTS/matcha/text/__init__.py b/egs/ljspeech/TTS/matcha/text/__init__.py index dc3427f0b4..78c8b1f18e 100644 --- a/egs/ljspeech/TTS/matcha/text/__init__.py +++ b/egs/ljspeech/TTS/matcha/text/__init__.py @@ -4,7 +4,9 @@ # Mappings from symbol to numeric ID and vice versa: _symbol_to_id = {s: i for i, s in enumerate(symbols)} -_id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension +_id_to_symbol = { + i: s for i, s in enumerate(symbols) +} # pylint: disable=unnecessary-comprehension def text_to_sequence(text, cleaner_names): @@ -20,13 +22,15 @@ def text_to_sequence(text, cleaner_names): clean_text = _clean_text(text, cleaner_names) for symbol in clean_text: try: - if symbol in '_()[]# ̃': + if symbol in "_()[]# ̃": continue symbol_id = _symbol_to_id[symbol] except Exception as ex: print(text) print(clean_text) - raise RuntimeError(f'text: {text}, clean_text: {clean_text}, ex: {ex}, symbol: {symbol}') + raise RuntimeError( + f"text: {text}, clean_text: {clean_text}, ex: {ex}, symbol: {symbol}" + ) sequence += [symbol_id] return sequence, clean_text diff --git a/egs/ljspeech/TTS/matcha/text/cleaners.py b/egs/ljspeech/TTS/matcha/text/cleaners.py index 33cdc9fc61..0a1979afe1 100644 --- a/egs/ljspeech/TTS/matcha/text/cleaners.py +++ b/egs/ljspeech/TTS/matcha/text/cleaners.py @@ -75,11 +75,12 @@ def lowercase(text): def collapse_whitespace(text): return re.sub(_whitespace_re, " ", text) + def remove_parentheses(text): - text = text.replace("(", "") - text = text.replace(")", "") - text = text.replace("[", "") - text = text.replace("]", "") + text = text.replace("(", "") + text = text.replace(")", "") + text = text.replace("[", "") + text = text.replace("]", "") return text diff --git a/egs/ljspeech/TTS/matcha/text/numbers.py b/egs/ljspeech/TTS/matcha/text/numbers.py index f99a8686dc..49c21d4e99 100644 --- a/egs/ljspeech/TTS/matcha/text/numbers.py +++ b/egs/ljspeech/TTS/matcha/text/numbers.py @@ -56,7 +56,9 @@ def _expand_number(m): elif num % 100 == 0: return _inflect.number_to_words(num // 100) + " hundred" else: - return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") + return _inflect.number_to_words( + num, andword="", zero="oh", group=2 + ).replace(", ", " ") else: return _inflect.number_to_words(num, andword="") diff --git a/egs/ljspeech/TTS/matcha/text/symbols.py b/egs/ljspeech/TTS/matcha/text/symbols.py index 7018df549a..b32c124302 100644 --- a/egs/ljspeech/TTS/matcha/text/symbols.py +++ b/egs/ljspeech/TTS/matcha/text/symbols.py @@ -5,9 +5,7 @@ _pad = "_" _punctuation = ';:,.!?¡¿—…"«»“” ' _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" -_letters_ipa = ( - "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" -) +_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" # Export all symbols: diff --git a/egs/ljspeech/TTS/matcha/utils/audio.py b/egs/ljspeech/TTS/matcha/utils/audio.py index 0bcd74df47..0a9b8db2a9 100644 --- a/egs/ljspeech/TTS/matcha/utils/audio.py +++ b/egs/ljspeech/TTS/matcha/utils/audio.py @@ -42,7 +42,9 @@ def spectral_de_normalize_torch(magnitudes): hann_window = {} -def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): +def mel_spectrogram( + y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False +): if torch.min(y) < -1.0: print("min value is ", torch.min(y)) if torch.max(y) > 1.0: @@ -50,12 +52,18 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, global mel_basis, hann_window # pylint: disable=global-statement if f"{str(fmax)}_{str(y.device)}" not in mel_basis: - mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) - mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) + mel = librosa_mel_fn( + sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax + ) + mel_basis[str(fmax) + "_" + str(y.device)] = ( + torch.from_numpy(mel).float().to(y.device) + ) hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) y = torch.nn.functional.pad( - y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", ) y = y.squeeze(1) diff --git a/egs/ljspeech/TTS/matcha/utils/generate_data_statistics.py b/egs/ljspeech/TTS/matcha/utils/generate_data_statistics.py index 3b8cd67c91..3028e76959 100644 --- a/egs/ljspeech/TTS/matcha/utils/generate_data_statistics.py +++ b/egs/ljspeech/TTS/matcha/utils/generate_data_statistics.py @@ -22,7 +22,9 @@ log = pylogger.get_pylogger(__name__) -def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int): +def compute_data_statistics( + data_loader: torch.utils.data.DataLoader, out_channels: int +): """Generate data mean and standard deviation helpful in data normalisation Args: @@ -42,7 +44,9 @@ def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channe total_mel_sq_sum += torch.sum(torch.pow(mels, 2)) data_mean = total_mel_sum / (total_mel_len * out_channels) - data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)) + data_std = torch.sqrt( + (total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2) + ) return {"mel_mean": data_mean.item(), "mel_std": data_std.item()} @@ -82,7 +86,9 @@ def main(): sys.exit(1) with initialize(version_base="1.3", config_path="../../configs/data"): - cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) + cfg = compose( + config_name=args.input_config, return_hydra_config=True, overrides=[] + ) root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") @@ -93,8 +99,12 @@ def main(): cfg["data_statistics"] = None cfg["seed"] = 1234 cfg["batch_size"] = args.batch_size - cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) - cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) + cfg["train_filelist_path"] = str( + os.path.join(root_path, cfg["train_filelist_path"]) + ) + cfg["valid_filelist_path"] = str( + os.path.join(root_path, cfg["valid_filelist_path"]) + ) cfg["load_durations"] = False text_mel_datamodule = TextMelDataModule(**cfg) diff --git a/egs/ljspeech/TTS/matcha/utils/get_durations_from_trained_model.py b/egs/ljspeech/TTS/matcha/utils/get_durations_from_trained_model.py index 0fe2f35c42..acc7eabd9b 100644 --- a/egs/ljspeech/TTS/matcha/utils/get_durations_from_trained_model.py +++ b/egs/ljspeech/TTS/matcha/utils/get_durations_from_trained_model.py @@ -29,7 +29,12 @@ def save_durations_to_folder( - attn: torch.Tensor, x_length: int, y_length: int, filepath: str, output_folder: Path, text: str + attn: torch.Tensor, + x_length: int, + y_length: int, + filepath: str, + output_folder: Path, + text: str, ): durations = attn.squeeze().sum(1)[:x_length].numpy() durations_json = get_phoneme_durations(durations, text) @@ -41,7 +46,12 @@ def save_durations_to_folder( @torch.inference_mode() -def compute_durations(data_loader: torch.utils.data.DataLoader, model: nn.Module, device: torch.device, output_folder): +def compute_durations( + data_loader: torch.utils.data.DataLoader, + model: nn.Module, + device: torch.device, + output_folder, +): """Generate durations from the model for each datapoint and save it in a folder Args: @@ -123,13 +133,17 @@ def main(): ) parser.add_argument( - "--cpu", action="store_true", help="Use CPU for inference, not recommended (default: use GPU if available)" + "--cpu", + action="store_true", + help="Use CPU for inference, not recommended (default: use GPU if available)", ) args = parser.parse_args() with initialize(version_base="1.3", config_path="../../configs/data"): - cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) + cfg = compose( + config_name=args.input_config, return_hydra_config=True, overrides=[] + ) root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") @@ -138,8 +152,12 @@ def main(): del cfg["_target_"] cfg["seed"] = 1234 cfg["batch_size"] = args.batch_size - cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) - cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) + cfg["train_filelist_path"] = str( + os.path.join(root_path, cfg["train_filelist_path"]) + ) + cfg["valid_filelist_path"] = str( + os.path.join(root_path, cfg["valid_filelist_path"]) + ) cfg["load_durations"] = False if args.output_folder is not None: @@ -155,7 +173,9 @@ def main(): output_folder.mkdir(parents=True, exist_ok=True) - print(f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}") + print( + f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}" + ) print("Loading model...") device = get_device(args) model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device) diff --git a/egs/ljspeech/TTS/matcha/utils/instantiators.py b/egs/ljspeech/TTS/matcha/utils/instantiators.py index 5547b4ed61..bde0c0d757 100644 --- a/egs/ljspeech/TTS/matcha/utils/instantiators.py +++ b/egs/ljspeech/TTS/matcha/utils/instantiators.py @@ -27,7 +27,9 @@ def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: for _, cb_conf in callbacks_cfg.items(): if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: - log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access + log.info( + f"Instantiating callback <{cb_conf._target_}>" + ) # pylint: disable=protected-access callbacks.append(hydra.utils.instantiate(cb_conf)) return callbacks @@ -50,7 +52,9 @@ def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: for _, lg_conf in logger_cfg.items(): if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: - log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access + log.info( + f"Instantiating logger <{lg_conf._target_}>" + ) # pylint: disable=protected-access logger.append(hydra.utils.instantiate(lg_conf)) return logger diff --git a/egs/ljspeech/TTS/matcha/utils/logging_utils.py b/egs/ljspeech/TTS/matcha/utils/logging_utils.py index 1a12d1ddaf..2d2377eb2b 100644 --- a/egs/ljspeech/TTS/matcha/utils/logging_utils.py +++ b/egs/ljspeech/TTS/matcha/utils/logging_utils.py @@ -34,8 +34,12 @@ def log_hyperparameters(object_dict: Dict[str, Any]) -> None: # save number of model parameters hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) - hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad) - hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad) + hparams["model/params/trainable"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + hparams["model/params/non_trainable"] = sum( + p.numel() for p in model.parameters() if not p.requires_grad + ) hparams["data"] = cfg["data"] hparams["trainer"] = cfg["trainer"] diff --git a/egs/ljspeech/TTS/matcha/utils/model.py b/egs/ljspeech/TTS/matcha/utils/model.py index 869cc6092f..a488ab4e8b 100644 --- a/egs/ljspeech/TTS/matcha/utils/model.py +++ b/egs/ljspeech/TTS/matcha/utils/model.py @@ -36,7 +36,12 @@ def generate_path(duration, mask): cum_duration_flat = cum_duration.view(b * t_x) path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) path = path.view(b, t_x, t_y) - path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = ( + path + - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[ + :, :-1 + ] + ) path = path * mask return path diff --git a/egs/ljspeech/TTS/matcha/utils/pylogger.py b/egs/ljspeech/TTS/matcha/utils/pylogger.py index 6160067802..a7ed7a961e 100644 --- a/egs/ljspeech/TTS/matcha/utils/pylogger.py +++ b/egs/ljspeech/TTS/matcha/utils/pylogger.py @@ -14,7 +14,15 @@ def get_pylogger(name: str = __name__) -> logging.Logger: # this ensures all logging levels get marked with the rank zero decorator # otherwise logs would get multiplied for each GPU process in multi-GPU setup - logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") + logging_levels = ( + "debug", + "info", + "warning", + "error", + "exception", + "fatal", + "critical", + ) for level in logging_levels: setattr(logger, level, rank_zero_only(getattr(logger, level))) diff --git a/egs/ljspeech/TTS/matcha/utils/rich_utils.py b/egs/ljspeech/TTS/matcha/utils/rich_utils.py index f602f6e935..d7fcd1aae9 100644 --- a/egs/ljspeech/TTS/matcha/utils/rich_utils.py +++ b/egs/ljspeech/TTS/matcha/utils/rich_utils.py @@ -47,7 +47,9 @@ def print_config_tree( _ = ( queue.append(field) if field in cfg - else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...") + else log.warning( + f"Field '{field}' not found in config. Skipping '{field}' config printing..." + ) ) # add all the other fields to queue (not specified in `print_order`) diff --git a/egs/ljspeech/TTS/matcha/utils/utils.py b/egs/ljspeech/TTS/matcha/utils/utils.py index bc81c316ea..a545542632 100644 --- a/egs/ljspeech/TTS/matcha/utils/utils.py +++ b/egs/ljspeech/TTS/matcha/utils/utils.py @@ -9,6 +9,7 @@ import matplotlib.pyplot as plt import numpy as np import torch + # from omegaconf import DictConfig # from matcha.utils import pylogger, rich_utils @@ -16,7 +17,7 @@ # log = pylogger.get_pylogger(__name__) -def extras(cfg: 'DictConfig') -> None: +def extras(cfg: "DictConfig") -> None: """Applies optional utilities before the task is started. Utilities: @@ -207,6 +208,7 @@ def get_user_data_dir(appname="matcha_tts"): def assert_model_downloaded(checkpoint_path, url, use_wget=True): import gdown import wget + if Path(checkpoint_path).exists(): log.debug(f"[+] Model already present at {checkpoint_path}!") print(f"[+] Model already present at {checkpoint_path}!") From c558328dc571fd26cc09e6913f9d9aa40c56966d Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 28 Oct 2024 19:18:21 +0800 Subject: [PATCH 12/27] remove unused code --- .../TTS/local/compute_fbank_ljspeech.py | 139 +++++++++++++----- egs/ljspeech/TTS/local/validate_manifest.py | 1 + egs/ljspeech/TTS/matcha/train.py | 126 +++++++--------- 3 files changed, 154 insertions(+), 112 deletions(-) diff --git a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py index 5c25c3cf4a..fee66da480 100755 --- a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py +++ b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py @@ -27,28 +27,100 @@ import argparse import logging import os +from dataclasses import dataclass from pathlib import Path +from typing import Union +import numpy as np import torch -from lhotse import ( - CutSet, - Fbank, - FbankConfig, - LilcomChunkyWriter, - load_manifest, - load_manifest_lazy, -) +from lhotse import CutSet, LilcomChunkyWriter, load_manifest from lhotse.audio import RecordingSet +from lhotse.features.base import FeatureExtractor, register_extractor from lhotse.supervision import SupervisionSet +from lhotse.utils import Seconds, compute_num_frames +from matcha.utils.audio import mel_spectrogram from icefall.utils import get_executor -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) + +@dataclass +class MyFbankConfig: + n_fft: int + n_mels: int + sampling_rate: int + hop_length: int + win_length: int + f_min: float + f_max: float + + +@register_extractor +class MyFbank(FeatureExtractor): + + name = "MyFbank" + config_type = MyFbankConfig + + def __init__(self, config): + super().__init__(config=config) + + @property + def device(self) -> Union[str, torch.device]: + return self.config.device + + def feature_dim(self, sampling_rate: int) -> int: + return self.config.n_mels + + def extract( + self, + samples: np.ndarray, + sampling_rate: int, + ) -> torch.Tensor: + # Check for sampling rate compatibility. + expected_sr = self.config.sampling_rate + assert sampling_rate == expected_sr, ( + f"Mismatched sampling rate: extractor expects {expected_sr}, " + f"got {sampling_rate}" + ) + samples = torch.from_numpy(samples) + assert samples.ndim == 2, samples.shape + assert samples.shape[0] == 1, samples.shape + + mel = ( + mel_spectrogram( + samples, + self.config.n_fft, + self.config.n_mels, + self.config.sampling_rate, + self.config.hop_length, + self.config.win_length, + self.config.f_min, + self.config.f_max, + center=False, + ) + .squeeze() + .t() + ) + + assert mel.ndim == 2, mel.shape + assert mel.shape[1] == self.config.n_mels, mel.shape + + num_frames = compute_num_frames( + samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate + ) + + if mel.shape[0] > num_frames: + mel = mel[:num_frames] + elif mel.shape[0] < num_frames: + mel = mel.unsqueeze(0) + mel = torch.nn.functional.pad( + mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate" + ).squeeze(0) + + return mel.numpy() + + @property + def frame_shift(self) -> Seconds: + return self.config.hop_length / self.config.sampling_rate def get_parser(): @@ -77,10 +149,15 @@ def compute_fbank_ljspeech(num_jobs: int): logging.info(f"num_jobs: {num_jobs}") logging.info(f"src_dir: {src_dir}") logging.info(f"output_dir: {output_dir}") - - sampling_rate = 22050 - frame_length = 1024 / sampling_rate # (in second) - frame_shift = 256 / sampling_rate # (in second) + config = MyFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=22050, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) prefix = "ljspeech" suffix = "jsonl.gz" @@ -93,25 +170,7 @@ def compute_fbank_ljspeech(num_jobs: int): src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet ) - # Differences with matcha-tts - # 1. we use pre-emphasis - # 2. we remove dc offset - # 3. we use a different window - # 4. we use a different mel filter bank matrix - # 5. we don't normalize features - config = FbankConfig( - sampling_rate=sampling_rate, - frame_length=frame_length, - frame_shift=frame_shift, - use_fft_mag=True, - low_freq=0, - high_freq=8000, - remove_dc_offset=False, - preemph_coeff=0, - # should be identical to n_feats in ../matcha/train.py - num_filters=80, - ) - extractor = Fbank(config) + extractor = MyFbank(config) with get_executor() as ex: # Initialize the executor only once. cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" @@ -135,6 +194,12 @@ def compute_fbank_ljspeech(num_jobs: int): if __name__ == "__main__": + # Torch's multithreaded behavior needs to be disabled or + # it wastes a lot of CPU and slow things down. + # Do this outside of main() in case it needs to take effect + # even when we are not invoking the main (e.g. when spawning subprocesses). + torch.set_num_threads(1) + torch.set_num_interop_threads(1) formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/ljspeech/TTS/local/validate_manifest.py b/egs/ljspeech/TTS/local/validate_manifest.py index 68159ae036..bbd1bfe9d6 100755 --- a/egs/ljspeech/TTS/local/validate_manifest.py +++ b/egs/ljspeech/TTS/local/validate_manifest.py @@ -35,6 +35,7 @@ from lhotse import CutSet, load_manifest_lazy from lhotse.dataset.speech_synthesis import validate_for_tts +from compute_fbank_ljspeech import MyFbank def get_args(): diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index 747292197a..7f41ab1012 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -3,18 +3,17 @@ import argparse +import json import logging from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Union -import json import k2 import torch import torch.multiprocessing as mp import torch.nn as nn from lhotse.utils import fix_random_seed -from matcha.data.text_mel_datamodule import TextMelDataModule from matcha.models.matcha_tts import MatchaTTS from matcha.tokenizer import Tokenizer from matcha.utils.model import fix_len_compatibility @@ -355,36 +354,27 @@ def compute_validation_loss( with torch.no_grad(): for batch_idx, batch in enumerate(valid_dl): - if "tokens" in batch: - - ( - audio, - audio_lens, - features, - features_lens, - tokens, - tokens_lens, - ) = prepare_input(batch, tokenizer, device, params) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + ) = prepare_input(batch, tokenizer, device, params) - losses = get_losses( - { - "x": tokens, - "x_lengths": tokens_lens, - "y": features.permute(0, 2, 1), - "y_lengths": features_lens, - "spks": None, # should change it for multi-speakers - "durations": None, - } - ) + losses = get_losses( + { + "x": tokens, + "x_lengths": tokens_lens, + "y": features.permute(0, 2, 1), + "y_lengths": features_lens, + "spks": None, # should change it for multi-speakers + "durations": None, + } + ) - batch_size = len(batch["tokens"]) - else: - batch_size = batch["x"].shape[0] - batch["x"] = batch["x"].to(device) - batch["x_lengths"] = batch["x_lengths"].to(device) - batch["y"] = batch["y"].to(device) - batch["y_lengths"] = batch["y_lengths"].to(device) - losses = get_losses(batch) + batch_size = len(batch["tokens"]) loss_info = MetricsTracker() loss_info["samples"] = batch_size @@ -478,38 +468,28 @@ def save_bad_model(suffix: str = ""): # features_lens, (N,), int32 # tokens: List[List[str]], len(tokens) == N - if "tokens" in batch: - batch_size = len(batch["tokens"]) + batch_size = len(batch["tokens"]) - ( - audio, - audio_lens, - features, - features_lens, - tokens, - tokens_lens, - ) = prepare_input(batch, tokenizer, device, params) - else: - batch_size = batch["x"].shape[0] + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + ) = prepare_input(batch, tokenizer, device, params) try: with autocast(enabled=params.use_fp16): - if "tokens" in batch: - losses = get_losses( - { - "x": tokens, - "x_lengths": tokens_lens, - "y": features.permute(0, 2, 1), - "y_lengths": features_lens, - "spks": None, # should change it for multi-speakers - "durations": None, - } - ) - else: - batch["x"] = batch["x"].to(device) - batch["x_lengths"] = batch["x_lengths"].to(device) - batch["y"] = batch["y"].to(device) - batch["y_lengths"] = batch["y_lengths"].to(device) - losses = get_losses(batch) + losses = get_losses( + { + "x": tokens, + "x_lengths": tokens_lens, + "y": features.permute(0, 2, 1), + "y_lengths": features_lens, + "spks": None, # should change it for multi-speakers + "durations": None, + } + ) loss = sum(losses.values()) @@ -535,8 +515,9 @@ def save_bad_model(suffix: str = ""): raise if params.batch_idx_train % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different + # If the grad scale was less than 1, try increasing it. + # The _growth_interval of the grad scaler is configurable, + # but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() @@ -560,7 +541,8 @@ def save_bad_model(suffix: str = ""): logging.info( f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " + f"global_batch_idx: {params.batch_idx_train}, " + f"batch size: {batch_size}, " f"loss[{loss_info}], tot_loss[{tot_loss}], " + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) @@ -588,7 +570,8 @@ def save_bad_model(suffix: str = ""): model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + "Maximum memory allocated so far is " + f"{torch.cuda.max_memory_allocated()//1000000}MB" ) if tb_writer is not None: valid_info.write_summary( @@ -658,20 +641,13 @@ def run(rank, world_size, args): logging.info("About to create datamodule") - if False: - params.data_args.tokenizer = tokenizer - data_module = TextMelDataModule(hparams=params.data_args) - del params.data_args.tokenizer - train_dl = data_module.train_dataloader() - valid_dl = data_module.val_dataloader() - else: - ljspeech = LJSpeechTtsDataModule(args) + ljspeech = LJSpeechTtsDataModule(args) - train_cuts = ljspeech.train_cuts() - train_dl = ljspeech.train_dataloaders(train_cuts) + train_cuts = ljspeech.train_cuts() + train_dl = ljspeech.train_dataloaders(train_cuts) - valid_cuts = ljspeech.valid_cuts() - valid_dl = ljspeech.valid_dataloaders(valid_cuts) + valid_cuts = ljspeech.valid_cuts() + valid_dl = ljspeech.valid_dataloaders(valid_cuts) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: From ed569a938ac578e436f2b433e6f4dbde07fe91b9 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 28 Oct 2024 19:20:21 +0800 Subject: [PATCH 13/27] remove more unused code --- egs/ljspeech/TTS/matcha/test-train.py | 159 -------------------------- egs/ljspeech/TTS/matcha/train-orig.py | 122 -------------------- 2 files changed, 281 deletions(-) delete mode 100644 egs/ljspeech/TTS/matcha/test-train.py delete mode 100644 egs/ljspeech/TTS/matcha/train-orig.py diff --git a/egs/ljspeech/TTS/matcha/test-train.py b/egs/ljspeech/TTS/matcha/test-train.py deleted file mode 100644 index f41ee4eae1..0000000000 --- a/egs/ljspeech/TTS/matcha/test-train.py +++ /dev/null @@ -1,159 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) - - -import torch - - -from icefall.utils import AttributeDict -from matcha.models.matcha_tts import MatchaTTS -from matcha.data.text_mel_datamodule import TextMelDataModule - - -def _get_data_params() -> AttributeDict: - params = AttributeDict( - { - "name": "ljspeech", - "train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt", - "valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt", - "batch_size": 32, - "num_workers": 3, - "pin_memory": False, - "cleaners": ["english_cleaners2"], - "add_blank": True, - "n_spks": 1, - "n_fft": 1024, - "n_feats": 80, - "sample_rate": 22050, - "hop_length": 256, - "win_length": 1024, - "f_min": 0, - "f_max": 8000, - "seed": 1234, - "load_durations": False, - "data_statistics": AttributeDict( - { - "mel_mean": -5.517028331756592, - "mel_std": 2.0643954277038574, - } - ), - } - ) - return params - - -def _get_model_params() -> AttributeDict: - n_feats = 80 - filter_channels_dp = 256 - encoder_params_p_dropout = 0.1 - params = AttributeDict( - { - "n_vocab": 178, - "n_spks": 1, # for ljspeech. - "spk_emb_dim": 64, - "n_feats": n_feats, - "out_size": None, # or use 172 - "prior_loss": True, - "use_precomputed_durations": False, - "encoder": AttributeDict( - { - "encoder_type": "RoPE Encoder", # not used - "encoder_params": AttributeDict( - { - "n_feats": n_feats, - "n_channels": 192, - "filter_channels": 768, - "filter_channels_dp": filter_channels_dp, - "n_heads": 2, - "n_layers": 6, - "kernel_size": 3, - "p_dropout": encoder_params_p_dropout, - "spk_emb_dim": 64, - "n_spks": 1, - "prenet": True, - } - ), - "duration_predictor_params": AttributeDict( - { - "filter_channels_dp": filter_channels_dp, - "kernel_size": 3, - "p_dropout": encoder_params_p_dropout, - } - ), - } - ), - "decoder": AttributeDict( - { - "channels": [256, 256], - "dropout": 0.05, - "attention_head_dim": 64, - "n_blocks": 1, - "num_mid_blocks": 2, - "num_heads": 2, - "act_fn": "snakebeta", - } - ), - "cfm": AttributeDict( - { - "name": "CFM", - "solver": "euler", - "sigma_min": 1e-4, - } - ), - "optimizer": AttributeDict( - { - "lr": 1e-4, - "weight_decay": 0.0, - } - ), - } - ) - - return params - - -def get_params(): - params = AttributeDict( - { - "model": _get_model_params(), - "data": _get_data_params(), - } - ) - return params - - -def get_model(params): - m = MatchaTTS(**params.model) - return m - - -def main(): - params = get_params() - - data_module = TextMelDataModule(hparams=params.data) - if False: - for b in data_module.train_dataloader(): - assert isinstance(b, dict) - # b.keys() - # ['x', 'x_lengths', 'y', 'y_lengths', 'spks', 'filepaths', 'x_texts', 'durations'] - # x: [batch_size, 289], torch.int64 - # x_lengths: [batch_size], torch.int64 - # y: [batch_size, n_feats, num_frames], torch.float32 - # y_lengths: [batch_size], torch.int64 - # spks: None - # filepaths: list, (batch_size,) - # x_texts: list, (batch_size,) - # durations: None - - m = get_model(params) - print(m) - - num_param = sum([p.numel() for p in m.parameters()]) - print(f"Number of parameters: {num_param}") - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/ljspeech/TTS/matcha/train-orig.py b/egs/ljspeech/TTS/matcha/train-orig.py deleted file mode 100644 index d1d64c6c44..0000000000 --- a/egs/ljspeech/TTS/matcha/train-orig.py +++ /dev/null @@ -1,122 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple - -import hydra -import lightning as L -import rootutils -from lightning import Callback, LightningDataModule, LightningModule, Trainer -from lightning.pytorch.loggers import Logger -from omegaconf import DictConfig - -from matcha import utils - -rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) -# ------------------------------------------------------------------------------------ # -# the setup_root above is equivalent to: -# - adding project root dir to PYTHONPATH -# (so you don't need to force user to install project as a package) -# (necessary before importing any local modules e.g. `from src import utils`) -# - setting up PROJECT_ROOT environment variable -# (which is used as a base for paths in "configs/paths/default.yaml") -# (this way all filepaths are the same no matter where you run the code) -# - loading environment variables from ".env" in root dir -# -# you can remove it if you: -# 1. either install project as a package or move entry files to project root dir -# 2. set `root_dir` to "." in "configs/paths/default.yaml" -# -# more info: https://github.com/ashleve/rootutils -# ------------------------------------------------------------------------------------ # - - -log = utils.get_pylogger(__name__) - - -@utils.task_wrapper -def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: - """Trains the model. Can additionally evaluate on a testset, using best weights obtained during - training. - - This method is wrapped in optional @task_wrapper decorator, that controls the behavior during - failure. Useful for multiruns, saving info about the crash, etc. - - :param cfg: A DictConfig configuration composed by Hydra. - :return: A tuple with metrics and dict with all instantiated objects. - """ - # set seed for random number generators in pytorch, numpy and python.random - if cfg.get("seed"): - L.seed_everything(cfg.seed, workers=True) - - log.info(f"Instantiating datamodule <{cfg.data._target_}>") # pylint: disable=protected-access - datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) - - log.info(f"Instantiating model <{cfg.model._target_}>") # pylint: disable=protected-access - model: LightningModule = hydra.utils.instantiate(cfg.model) - - log.info("Instantiating callbacks...") - callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) - - log.info("Instantiating loggers...") - logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) - - log.info(f"Instantiating trainer <{cfg.trainer._target_}>") # pylint: disable=protected-access - trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) - - object_dict = { - "cfg": cfg, - "datamodule": datamodule, - "model": model, - "callbacks": callbacks, - "logger": logger, - "trainer": trainer, - } - - if logger: - log.info("Logging hyperparameters!") - utils.log_hyperparameters(object_dict) - - if cfg.get("train"): - log.info("Starting training!") - trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) - - train_metrics = trainer.callback_metrics - - if cfg.get("test"): - log.info("Starting testing!") - ckpt_path = trainer.checkpoint_callback.best_model_path - if ckpt_path == "": - log.warning("Best ckpt not found! Using current weights for testing...") - ckpt_path = None - trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) - log.info(f"Best ckpt path: {ckpt_path}") - - test_metrics = trainer.callback_metrics - - # merge train and test metrics - metric_dict = {**train_metrics, **test_metrics} - - return metric_dict, object_dict - - -@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") -def main(cfg: DictConfig) -> Optional[float]: - """Main entry point for training. - - :param cfg: DictConfig configuration composed by Hydra. - :return: Optional[float] with optimized metric value. - """ - # apply extra utilities - # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) - utils.extras(cfg) - - # train the model - metric_dict, _ = train(cfg) - - # safely retrieve metric value for hydra-based hyperparameter optimization - metric_value = utils.get_metric_value(metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")) - - # return optimized metric - return metric_value - - -if __name__ == "__main__": - main() # pylint: disable=no-value-for-parameter From ba4df1922404d9bf2e3a6419a06c81490cac6840 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 28 Oct 2024 19:24:09 +0800 Subject: [PATCH 14/27] fix inference --- egs/ljspeech/TTS/matcha/inference.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/egs/ljspeech/TTS/matcha/inference.py b/egs/ljspeech/TTS/matcha/inference.py index 250c38f200..209fb86b44 100755 --- a/egs/ljspeech/TTS/matcha/inference.py +++ b/egs/ljspeech/TTS/matcha/inference.py @@ -13,8 +13,6 @@ from matcha.hifigan.denoiser import Denoiser from tokenizer import Tokenizer from matcha.hifigan.models import Generator as HiFiGAN -from matcha.text import sequence_to_text, text_to_sequence -from matcha.utils.utils import intersperse from tqdm.auto import tqdm from train import get_model, get_params @@ -151,8 +149,13 @@ def main(): denoiser = Denoiser(vocoder, mode="zeros") texts = [ - "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", - "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.", + "The Secret Service believed that it was very doubtful that any " + "President would ride regularly in a vehicle with a fixed top, even " + "though transparent.", + "Today as always, men fall into two groups: slaves and free men. " + "Whoever does not have two-thirds of his day for himself, is a slave, " + "whatever he may be: a statesman, a businessman, an official, or a " + "scholar.", ] # Number of ODE Solver steps @@ -164,7 +167,7 @@ def main(): # Sampling temperature temperature = 0.667 - outputs, rtfs = [], [] + rtfs = [] rtfs_w = [] for i, text in enumerate(tqdm(texts)): output = synthesise( @@ -202,7 +205,8 @@ def main(): print(f"Number of ODE steps: {n_timesteps}") print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}") print( - f"Mean RTF Waveform (incl. vocoder):\t{np.mean(rtfs_w):.6f} ± {np.std(rtfs_w):.6f}" + "Mean RTF Waveform " + f"(incl. vocoder):\t{np.mean(rtfs_w):.6f} ± {np.std(rtfs_w):.6f}" ) From f6328edf5b86148cbf91dc61544c953b9b1ec32d Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 28 Oct 2024 19:26:25 +0800 Subject: [PATCH 15/27] remove the text folder --- egs/ljspeech/TTS/matcha/text/__init__.py | 64 ----------- egs/ljspeech/TTS/matcha/text/cleaners.py | 130 ----------------------- egs/ljspeech/TTS/matcha/text/numbers.py | 73 ------------- egs/ljspeech/TTS/matcha/text/symbols.py | 15 --- 4 files changed, 282 deletions(-) delete mode 100644 egs/ljspeech/TTS/matcha/text/__init__.py delete mode 100644 egs/ljspeech/TTS/matcha/text/cleaners.py delete mode 100644 egs/ljspeech/TTS/matcha/text/numbers.py delete mode 100644 egs/ljspeech/TTS/matcha/text/symbols.py diff --git a/egs/ljspeech/TTS/matcha/text/__init__.py b/egs/ljspeech/TTS/matcha/text/__init__.py deleted file mode 100644 index 78c8b1f18e..0000000000 --- a/egs/ljspeech/TTS/matcha/text/__init__.py +++ /dev/null @@ -1,64 +0,0 @@ -""" from https://github.com/keithito/tacotron """ -from matcha.text import cleaners -from matcha.text.symbols import symbols - -# Mappings from symbol to numeric ID and vice versa: -_symbol_to_id = {s: i for i, s in enumerate(symbols)} -_id_to_symbol = { - i: s for i, s in enumerate(symbols) -} # pylint: disable=unnecessary-comprehension - - -def text_to_sequence(text, cleaner_names): - """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. - Args: - text: string to convert to a sequence - cleaner_names: names of the cleaner functions to run the text through - Returns: - List of integers corresponding to the symbols in the text - """ - sequence = [] - - clean_text = _clean_text(text, cleaner_names) - for symbol in clean_text: - try: - if symbol in "_()[]# ̃": - continue - symbol_id = _symbol_to_id[symbol] - except Exception as ex: - print(text) - print(clean_text) - raise RuntimeError( - f"text: {text}, clean_text: {clean_text}, ex: {ex}, symbol: {symbol}" - ) - sequence += [symbol_id] - return sequence, clean_text - - -def cleaned_text_to_sequence(cleaned_text): - """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. - Args: - text: string to convert to a sequence - Returns: - List of integers corresponding to the symbols in the text - """ - sequence = [_symbol_to_id[symbol] for symbol in cleaned_text] - return sequence - - -def sequence_to_text(sequence): - """Converts a sequence of IDs back to a string""" - result = "" - for symbol_id in sequence: - s = _id_to_symbol[symbol_id] - result += s - return result - - -def _clean_text(text, cleaner_names): - for name in cleaner_names: - cleaner = getattr(cleaners, name) - if not cleaner: - raise Exception("Unknown cleaner: %s" % name) - text = cleaner(text) - return text diff --git a/egs/ljspeech/TTS/matcha/text/cleaners.py b/egs/ljspeech/TTS/matcha/text/cleaners.py deleted file mode 100644 index 0a1979afe1..0000000000 --- a/egs/ljspeech/TTS/matcha/text/cleaners.py +++ /dev/null @@ -1,130 +0,0 @@ -""" from https://github.com/keithito/tacotron - -Cleaners are transformations that run over the input text at both training and eval time. - -Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" -hyperparameter. Some cleaners are English-specific. You'll typically want to use: - 1. "english_cleaners" for English text - 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using - the Unidecode library (https://pypi.python.org/pypi/Unidecode) - 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update - the symbols in symbols.py to match your data). -""" - -import logging -import re - -import phonemizer -from unidecode import unidecode - -# To avoid excessive logging we set the log level of the phonemizer package to Critical -critical_logger = logging.getLogger("phonemizer") -critical_logger.setLevel(logging.CRITICAL) - -# Intializing the phonemizer globally significantly reduces the speed -# now the phonemizer is not initialising at every call -# Might be less flexible, but it is much-much faster -global_phonemizer = phonemizer.backend.EspeakBackend( - language="en-us", - preserve_punctuation=True, - with_stress=True, - language_switch="remove-flags", - logger=critical_logger, -) - - -# Regular expression matching whitespace: -_whitespace_re = re.compile(r"\s+") - -# List of (regular expression, replacement) pairs for abbreviations: -_abbreviations = [ - (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) - for x in [ - ("mrs", "misess"), - ("mr", "mister"), - ("dr", "doctor"), - ("st", "saint"), - ("co", "company"), - ("jr", "junior"), - ("maj", "major"), - ("gen", "general"), - ("drs", "doctors"), - ("rev", "reverend"), - ("lt", "lieutenant"), - ("hon", "honorable"), - ("sgt", "sergeant"), - ("capt", "captain"), - ("esq", "esquire"), - ("ltd", "limited"), - ("col", "colonel"), - ("ft", "fort"), - ] -] - - -def expand_abbreviations(text): - for regex, replacement in _abbreviations: - text = re.sub(regex, replacement, text) - return text - - -def lowercase(text): - return text.lower() - - -def collapse_whitespace(text): - return re.sub(_whitespace_re, " ", text) - - -def remove_parentheses(text): - text = text.replace("(", "") - text = text.replace(")", "") - text = text.replace("[", "") - text = text.replace("]", "") - return text - - -def convert_to_ascii(text): - return unidecode(text) - - -def basic_cleaners(text): - """Basic pipeline that lowercases and collapses whitespace without transliteration.""" - text = lowercase(text) - text = collapse_whitespace(text) - return text - - -def transliteration_cleaners(text): - """Pipeline for non-English text that transliterates to ASCII.""" - text = convert_to_ascii(text) - text = lowercase(text) - text = collapse_whitespace(text) - return text - - -def english_cleaners2(text): - """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" - text = convert_to_ascii(text) - text = lowercase(text) - text = expand_abbreviations(text) - text = remove_parentheses(text) - phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0] - phonemes = collapse_whitespace(phonemes) - return phonemes - - -# I am removing this due to incompatibility with several version of python -# However, if you want to use it, you can uncomment it -# and install piper-phonemize with the following command: -# pip install piper-phonemize - -# import piper_phonemize -# def english_cleaners_piper(text): -# """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" -# text = convert_to_ascii(text) -# text = lowercase(text) -# text = expand_abbreviations(text) -# phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0]) -# phonemes = collapse_whitespace(phonemes) -# return phonemes diff --git a/egs/ljspeech/TTS/matcha/text/numbers.py b/egs/ljspeech/TTS/matcha/text/numbers.py deleted file mode 100644 index 49c21d4e99..0000000000 --- a/egs/ljspeech/TTS/matcha/text/numbers.py +++ /dev/null @@ -1,73 +0,0 @@ -""" from https://github.com/keithito/tacotron """ - -import re - -import inflect - -_inflect = inflect.engine() -_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") -_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") -_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") -_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") -_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") -_number_re = re.compile(r"[0-9]+") - - -def _remove_commas(m): - return m.group(1).replace(",", "") - - -def _expand_decimal_point(m): - return m.group(1).replace(".", " point ") - - -def _expand_dollars(m): - match = m.group(1) - parts = match.split(".") - if len(parts) > 2: - return match + " dollars" - dollars = int(parts[0]) if parts[0] else 0 - cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 - if dollars and cents: - dollar_unit = "dollar" if dollars == 1 else "dollars" - cent_unit = "cent" if cents == 1 else "cents" - return f"{dollars} {dollar_unit}, {cents} {cent_unit}" - elif dollars: - dollar_unit = "dollar" if dollars == 1 else "dollars" - return f"{dollars} {dollar_unit}" - elif cents: - cent_unit = "cent" if cents == 1 else "cents" - return f"{cents} {cent_unit}" - else: - return "zero dollars" - - -def _expand_ordinal(m): - return _inflect.number_to_words(m.group(0)) - - -def _expand_number(m): - num = int(m.group(0)) - if num > 1000 and num < 3000: - if num == 2000: - return "two thousand" - elif num > 2000 and num < 2010: - return "two thousand " + _inflect.number_to_words(num % 100) - elif num % 100 == 0: - return _inflect.number_to_words(num // 100) + " hundred" - else: - return _inflect.number_to_words( - num, andword="", zero="oh", group=2 - ).replace(", ", " ") - else: - return _inflect.number_to_words(num, andword="") - - -def normalize_numbers(text): - text = re.sub(_comma_number_re, _remove_commas, text) - text = re.sub(_pounds_re, r"\1 pounds", text) - text = re.sub(_dollars_re, _expand_dollars, text) - text = re.sub(_decimal_number_re, _expand_decimal_point, text) - text = re.sub(_ordinal_re, _expand_ordinal, text) - text = re.sub(_number_re, _expand_number, text) - return text diff --git a/egs/ljspeech/TTS/matcha/text/symbols.py b/egs/ljspeech/TTS/matcha/text/symbols.py deleted file mode 100644 index b32c124302..0000000000 --- a/egs/ljspeech/TTS/matcha/text/symbols.py +++ /dev/null @@ -1,15 +0,0 @@ -""" from https://github.com/keithito/tacotron - -Defines the set of symbols used in text input to the model. -""" -_pad = "_" -_punctuation = ';:,.!?¡¿—…"«»“” ' -_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" -_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" - - -# Export all symbols: -symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) - -# Special symbol ids -SPACE_ID = symbols.index(" ") From 10c099ac909f6406d35dae5a9db5c709e81c3c91 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 28 Oct 2024 19:51:47 +0800 Subject: [PATCH 16/27] remove more unused code --- egs/ljspeech/TTS/matcha/{utils => }/audio.py | 0 .../TTS/matcha/export_onnx_hifigan.py | 1 - egs/ljspeech/TTS/matcha/inference.py | 4 +- egs/ljspeech/TTS/matcha/{utils => }/model.py | 0 .../TTS/matcha/models/baselightningmodule.py | 223 --------------- .../TTS/matcha/models/components/decoder.py | 1 - .../matcha/models/components/flow_matching.py | 5 - .../matcha/models/components/text_encoder.py | 6 +- egs/ljspeech/TTS/matcha/models/matcha_tts.py | 14 +- .../{utils => }/monotonic_align/.gitignore | 0 .../{utils => }/monotonic_align/__init__.py | 2 +- .../{utils => }/monotonic_align/core.pyx | 0 .../{utils => }/monotonic_align/setup.py | 0 egs/ljspeech/TTS/matcha/onnx_pretrained.py | 7 +- egs/ljspeech/TTS/matcha/train.py | 4 +- .../TTS/matcha/{utils2.py => utils.py} | 0 egs/ljspeech/TTS/matcha/utils/__init__.py | 6 - .../matcha/utils/generate_data_statistics.py | 123 --------- .../utils/get_durations_from_trained_model.py | 215 --------------- .../TTS/matcha/utils/instantiators.py | 60 ---- .../TTS/matcha/utils/logging_utils.py | 57 ---- egs/ljspeech/TTS/matcha/utils/pylogger.py | 29 -- egs/ljspeech/TTS/matcha/utils/rich_utils.py | 103 ------- egs/ljspeech/TTS/matcha/utils/utils.py | 261 ------------------ 24 files changed, 13 insertions(+), 1108 deletions(-) rename egs/ljspeech/TTS/matcha/{utils => }/audio.py (100%) rename egs/ljspeech/TTS/matcha/{utils => }/model.py (100%) delete mode 100644 egs/ljspeech/TTS/matcha/models/baselightningmodule.py rename egs/ljspeech/TTS/matcha/{utils => }/monotonic_align/.gitignore (100%) rename egs/ljspeech/TTS/matcha/{utils => }/monotonic_align/__init__.py (90%) rename egs/ljspeech/TTS/matcha/{utils => }/monotonic_align/core.pyx (100%) rename egs/ljspeech/TTS/matcha/{utils => }/monotonic_align/setup.py (100%) rename egs/ljspeech/TTS/matcha/{utils2.py => utils.py} (100%) delete mode 100644 egs/ljspeech/TTS/matcha/utils/__init__.py delete mode 100644 egs/ljspeech/TTS/matcha/utils/generate_data_statistics.py delete mode 100644 egs/ljspeech/TTS/matcha/utils/get_durations_from_trained_model.py delete mode 100644 egs/ljspeech/TTS/matcha/utils/instantiators.py delete mode 100644 egs/ljspeech/TTS/matcha/utils/logging_utils.py delete mode 100644 egs/ljspeech/TTS/matcha/utils/pylogger.py delete mode 100644 egs/ljspeech/TTS/matcha/utils/rich_utils.py delete mode 100644 egs/ljspeech/TTS/matcha/utils/utils.py diff --git a/egs/ljspeech/TTS/matcha/utils/audio.py b/egs/ljspeech/TTS/matcha/audio.py similarity index 100% rename from egs/ljspeech/TTS/matcha/utils/audio.py rename to egs/ljspeech/TTS/matcha/audio.py diff --git a/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py b/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py index 3b2ebf5025..af54f4e896 100755 --- a/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py +++ b/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py @@ -5,7 +5,6 @@ import onnx import torch - from inference import load_vocoder diff --git a/egs/ljspeech/TTS/matcha/inference.py b/egs/ljspeech/TTS/matcha/inference.py index 209fb86b44..89a6b33ae3 100755 --- a/egs/ljspeech/TTS/matcha/inference.py +++ b/egs/ljspeech/TTS/matcha/inference.py @@ -2,17 +2,17 @@ import argparse import datetime as dt +import json import logging from pathlib import Path -import json import numpy as np import soundfile as sf import torch from matcha.hifigan.config import v1, v2, v3 from matcha.hifigan.denoiser import Denoiser -from tokenizer import Tokenizer from matcha.hifigan.models import Generator as HiFiGAN +from tokenizer import Tokenizer from tqdm.auto import tqdm from train import get_model, get_params diff --git a/egs/ljspeech/TTS/matcha/utils/model.py b/egs/ljspeech/TTS/matcha/model.py similarity index 100% rename from egs/ljspeech/TTS/matcha/utils/model.py rename to egs/ljspeech/TTS/matcha/model.py diff --git a/egs/ljspeech/TTS/matcha/models/baselightningmodule.py b/egs/ljspeech/TTS/matcha/models/baselightningmodule.py deleted file mode 100644 index e80d2a5c97..0000000000 --- a/egs/ljspeech/TTS/matcha/models/baselightningmodule.py +++ /dev/null @@ -1,223 +0,0 @@ -""" -This is a base lightning module that can be used to train a model. -The benefit of this abstraction is that all the logic outside of model definition can be reused for different models. -""" -import inspect -from abc import ABC -from typing import Any, Dict - -import torch -from lightning import LightningModule -from lightning.pytorch.utilities import grad_norm - -from matcha import utils -from matcha.utils.utils import plot_tensor - -log = utils.get_pylogger(__name__) - - -class BaseLightningClass(LightningModule, ABC): - def update_data_statistics(self, data_statistics): - if data_statistics is None: - data_statistics = { - "mel_mean": 0.0, - "mel_std": 1.0, - } - - self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"])) - self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"])) - - def configure_optimizers(self) -> Any: - optimizer = self.hparams.optimizer(params=self.parameters()) - if self.hparams.scheduler not in (None, {}): - scheduler_args = {} - # Manage last epoch for exponential schedulers - if ( - "last_epoch" - in inspect.signature(self.hparams.scheduler.scheduler).parameters - ): - if hasattr(self, "ckpt_loaded_epoch"): - current_epoch = self.ckpt_loaded_epoch - 1 - else: - current_epoch = -1 - - scheduler_args.update({"optimizer": optimizer}) - scheduler = self.hparams.scheduler.scheduler(**scheduler_args) - scheduler.last_epoch = current_epoch - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": scheduler, - "interval": self.hparams.scheduler.lightning_args.interval, - "frequency": self.hparams.scheduler.lightning_args.frequency, - "name": "learning_rate", - }, - } - - return {"optimizer": optimizer} - - def get_losses(self, batch): - x, x_lengths = batch["x"], batch["x_lengths"] - y, y_lengths = batch["y"], batch["y_lengths"] - spks = batch["spks"] - - dur_loss, prior_loss, diff_loss, *_ = self( - x=x, - x_lengths=x_lengths, - y=y, - y_lengths=y_lengths, - spks=spks, - out_size=self.out_size, - durations=batch["durations"], - ) - return { - "dur_loss": dur_loss, - "prior_loss": prior_loss, - "diff_loss": diff_loss, - } - - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - self.ckpt_loaded_epoch = checkpoint[ - "epoch" - ] # pylint: disable=attribute-defined-outside-init - - def training_step(self, batch: Any, batch_idx: int): - loss_dict = self.get_losses(batch) - self.log( - "step", - float(self.global_step), - on_step=True, - prog_bar=True, - logger=True, - sync_dist=True, - ) - - self.log( - "sub_loss/train_dur_loss", - loss_dict["dur_loss"], - on_step=True, - on_epoch=True, - logger=True, - sync_dist=True, - ) - self.log( - "sub_loss/train_prior_loss", - loss_dict["prior_loss"], - on_step=True, - on_epoch=True, - logger=True, - sync_dist=True, - ) - self.log( - "sub_loss/train_diff_loss", - loss_dict["diff_loss"], - on_step=True, - on_epoch=True, - logger=True, - sync_dist=True, - ) - - total_loss = sum(loss_dict.values()) - self.log( - "loss/train", - total_loss, - on_step=True, - on_epoch=True, - logger=True, - prog_bar=True, - sync_dist=True, - ) - - return {"loss": total_loss, "log": loss_dict} - - def validation_step(self, batch: Any, batch_idx: int): - loss_dict = self.get_losses(batch) - self.log( - "sub_loss/val_dur_loss", - loss_dict["dur_loss"], - on_step=True, - on_epoch=True, - logger=True, - sync_dist=True, - ) - self.log( - "sub_loss/val_prior_loss", - loss_dict["prior_loss"], - on_step=True, - on_epoch=True, - logger=True, - sync_dist=True, - ) - self.log( - "sub_loss/val_diff_loss", - loss_dict["diff_loss"], - on_step=True, - on_epoch=True, - logger=True, - sync_dist=True, - ) - - total_loss = sum(loss_dict.values()) - self.log( - "loss/val", - total_loss, - on_step=True, - on_epoch=True, - logger=True, - prog_bar=True, - sync_dist=True, - ) - - return total_loss - - def on_validation_end(self) -> None: - if self.trainer.is_global_zero: - one_batch = next(iter(self.trainer.val_dataloaders)) - if self.current_epoch == 0: - log.debug("Plotting original samples") - for i in range(2): - y = one_batch["y"][i].unsqueeze(0).to(self.device) - self.logger.experiment.add_image( - f"original/{i}", - plot_tensor(y.squeeze().cpu()), - self.current_epoch, - dataformats="HWC", - ) - - log.debug("Synthesising...") - for i in range(2): - x = one_batch["x"][i].unsqueeze(0).to(self.device) - x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device) - spks = ( - one_batch["spks"][i].unsqueeze(0).to(self.device) - if one_batch["spks"] is not None - else None - ) - output = self.synthesise( - x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks - ) - y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"] - attn = output["attn"] - self.logger.experiment.add_image( - f"generated_enc/{i}", - plot_tensor(y_enc.squeeze().cpu()), - self.current_epoch, - dataformats="HWC", - ) - self.logger.experiment.add_image( - f"generated_dec/{i}", - plot_tensor(y_dec.squeeze().cpu()), - self.current_epoch, - dataformats="HWC", - ) - self.logger.experiment.add_image( - f"alignment/{i}", - plot_tensor(attn.squeeze().cpu()), - self.current_epoch, - dataformats="HWC", - ) - - def on_before_optimizer_step(self, optimizer): - self.log_dict( - {f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()} - ) diff --git a/egs/ljspeech/TTS/matcha/models/components/decoder.py b/egs/ljspeech/TTS/matcha/models/components/decoder.py index 5850f2639b..14d19f5d4e 100644 --- a/egs/ljspeech/TTS/matcha/models/components/decoder.py +++ b/egs/ljspeech/TTS/matcha/models/components/decoder.py @@ -7,7 +7,6 @@ from conformer import ConformerBlock from diffusers.models.activations import get_activation from einops import pack, rearrange, repeat - from matcha.models.components.transformer import BasicTransformerBlock diff --git a/egs/ljspeech/TTS/matcha/models/components/flow_matching.py b/egs/ljspeech/TTS/matcha/models/components/flow_matching.py index 5a7226b4f7..997689b1cb 100644 --- a/egs/ljspeech/TTS/matcha/models/components/flow_matching.py +++ b/egs/ljspeech/TTS/matcha/models/components/flow_matching.py @@ -2,13 +2,8 @@ import torch import torch.nn.functional as F - from matcha.models.components.decoder import Decoder -# from matcha.utils.pylogger import get_pylogger - -# log = get_pylogger(__name__) - class BASECFM(torch.nn.Module, ABC): def __init__( diff --git a/egs/ljspeech/TTS/matcha/models/components/text_encoder.py b/egs/ljspeech/TTS/matcha/models/components/text_encoder.py index 68f8ad864e..ca77cba51c 100644 --- a/egs/ljspeech/TTS/matcha/models/components/text_encoder.py +++ b/egs/ljspeech/TTS/matcha/models/components/text_encoder.py @@ -5,11 +5,7 @@ import torch import torch.nn as nn from einops import rearrange - -# import matcha.utils as utils -from matcha.utils.model import sequence_mask - -# log = utils.get_pylogger(__name__) +from matcha.model import sequence_mask class LayerNorm(nn.Module): diff --git a/egs/ljspeech/TTS/matcha/models/matcha_tts.py b/egs/ljspeech/TTS/matcha/models/matcha_tts.py index b1525695f2..330d1dc472 100644 --- a/egs/ljspeech/TTS/matcha/models/matcha_tts.py +++ b/egs/ljspeech/TTS/matcha/models/matcha_tts.py @@ -2,23 +2,17 @@ import math import random +import matcha.monotonic_align as monotonic_align import torch - -import matcha.utils.monotonic_align as monotonic_align - -# from matcha import utils -# from matcha.models.baselightningmodule import BaseLightningClass -from matcha.models.components.flow_matching import CFM -from matcha.models.components.text_encoder import TextEncoder -from matcha.utils.model import ( +from matcha.model import ( denormalize, duration_loss, fix_len_compatibility, generate_path, sequence_mask, ) - -# log = utils.get_pylogger(__name__) +from matcha.models.components.flow_matching import CFM +from matcha.models.components.text_encoder import TextEncoder class MatchaTTS(torch.nn.Module): # 🍵 diff --git a/egs/ljspeech/TTS/matcha/utils/monotonic_align/.gitignore b/egs/ljspeech/TTS/matcha/monotonic_align/.gitignore similarity index 100% rename from egs/ljspeech/TTS/matcha/utils/monotonic_align/.gitignore rename to egs/ljspeech/TTS/matcha/monotonic_align/.gitignore diff --git a/egs/ljspeech/TTS/matcha/utils/monotonic_align/__init__.py b/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py similarity index 90% rename from egs/ljspeech/TTS/matcha/utils/monotonic_align/__init__.py rename to egs/ljspeech/TTS/matcha/monotonic_align/__init__.py index eee6e0d47c..58286bdd42 100644 --- a/egs/ljspeech/TTS/matcha/utils/monotonic_align/__init__.py +++ b/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py @@ -1,7 +1,7 @@ import numpy as np import torch -from matcha.utils.monotonic_align.core import maximum_path_c +from matcha.monotonic_align.core import maximum_path_c def maximum_path(value, mask): diff --git a/egs/ljspeech/TTS/matcha/utils/monotonic_align/core.pyx b/egs/ljspeech/TTS/matcha/monotonic_align/core.pyx similarity index 100% rename from egs/ljspeech/TTS/matcha/utils/monotonic_align/core.pyx rename to egs/ljspeech/TTS/matcha/monotonic_align/core.pyx diff --git a/egs/ljspeech/TTS/matcha/utils/monotonic_align/setup.py b/egs/ljspeech/TTS/matcha/monotonic_align/setup.py similarity index 100% rename from egs/ljspeech/TTS/matcha/utils/monotonic_align/setup.py rename to egs/ljspeech/TTS/matcha/monotonic_align/setup.py diff --git a/egs/ljspeech/TTS/matcha/onnx_pretrained.py b/egs/ljspeech/TTS/matcha/onnx_pretrained.py index 24955e881c..3953d5d0ae 100755 --- a/egs/ljspeech/TTS/matcha/onnx_pretrained.py +++ b/egs/ljspeech/TTS/matcha/onnx_pretrained.py @@ -1,13 +1,12 @@ #!/usr/bin/env python3 +import datetime as dt import logging import onnxruntime as ort -import torch -from tokenizer import Tokenizer -import datetime as dt - import soundfile as sf +import torch from inference import load_vocoder +from tokenizer import Tokenizer class OnnxHifiGANModel: diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index 7f41ab1012..ce13e7e429 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -14,15 +14,15 @@ import torch.multiprocessing as mp import torch.nn as nn from lhotse.utils import fix_random_seed +from matcha.model import fix_len_compatibility from matcha.models.matcha_tts import MatchaTTS from matcha.tokenizer import Tokenizer -from matcha.utils.model import fix_len_compatibility from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter from tts_datamodule import LJSpeechTtsDataModule -from utils2 import MetricsTracker +from utils import MetricsTracker from icefall.checkpoint import load_checkpoint, save_checkpoint from icefall.dist import cleanup_dist, setup_dist diff --git a/egs/ljspeech/TTS/matcha/utils2.py b/egs/ljspeech/TTS/matcha/utils.py similarity index 100% rename from egs/ljspeech/TTS/matcha/utils2.py rename to egs/ljspeech/TTS/matcha/utils.py diff --git a/egs/ljspeech/TTS/matcha/utils/__init__.py b/egs/ljspeech/TTS/matcha/utils/__init__.py deleted file mode 100644 index 311744a786..0000000000 --- a/egs/ljspeech/TTS/matcha/utils/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers -# from matcha.utils.logging_utils import log_hyperparameters -# from matcha.utils.pylogger import get_pylogger -# from matcha.utils.rich_utils import enforce_tags, print_config_tree -# from matcha.utils.utils import extras, get_metric_value, task_wrapper -from matcha.utils.utils import intersperse diff --git a/egs/ljspeech/TTS/matcha/utils/generate_data_statistics.py b/egs/ljspeech/TTS/matcha/utils/generate_data_statistics.py deleted file mode 100644 index 3028e76959..0000000000 --- a/egs/ljspeech/TTS/matcha/utils/generate_data_statistics.py +++ /dev/null @@ -1,123 +0,0 @@ -r""" -The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it -when needed. - -Parameters from hparam.py will be used -""" -import argparse -import json -import os -import sys -from pathlib import Path - -import rootutils -import torch -from hydra import compose, initialize -from omegaconf import open_dict -from tqdm.auto import tqdm - -from matcha.data.text_mel_datamodule import TextMelDataModule -from matcha.utils.logging_utils import pylogger - -log = pylogger.get_pylogger(__name__) - - -def compute_data_statistics( - data_loader: torch.utils.data.DataLoader, out_channels: int -): - """Generate data mean and standard deviation helpful in data normalisation - - Args: - data_loader (torch.utils.data.Dataloader): _description_ - out_channels (int): mel spectrogram channels - """ - total_mel_sum = 0 - total_mel_sq_sum = 0 - total_mel_len = 0 - - for batch in tqdm(data_loader, leave=False): - mels = batch["y"] - mel_lengths = batch["y_lengths"] - - total_mel_len += torch.sum(mel_lengths) - total_mel_sum += torch.sum(mels) - total_mel_sq_sum += torch.sum(torch.pow(mels, 2)) - - data_mean = total_mel_sum / (total_mel_len * out_channels) - data_std = torch.sqrt( - (total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2) - ) - - return {"mel_mean": data_mean.item(), "mel_std": data_std.item()} - - -def main(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "-i", - "--input-config", - type=str, - default="vctk.yaml", - help="The name of the yaml config file under configs/data", - ) - - parser.add_argument( - "-b", - "--batch-size", - type=int, - default="256", - help="Can have increased batch size for faster computation", - ) - - parser.add_argument( - "-f", - "--force", - action="store_true", - default=False, - required=False, - help="force overwrite the file", - ) - args = parser.parse_args() - output_file = Path(args.input_config).with_suffix(".json") - - if os.path.exists(output_file) and not args.force: - print("File already exists. Use -f to force overwrite") - sys.exit(1) - - with initialize(version_base="1.3", config_path="../../configs/data"): - cfg = compose( - config_name=args.input_config, return_hydra_config=True, overrides=[] - ) - - root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") - - with open_dict(cfg): - print(cfg) - del cfg["hydra"] - del cfg["_target_"] - cfg["data_statistics"] = None - cfg["seed"] = 1234 - cfg["batch_size"] = args.batch_size - cfg["train_filelist_path"] = str( - os.path.join(root_path, cfg["train_filelist_path"]) - ) - cfg["valid_filelist_path"] = str( - os.path.join(root_path, cfg["valid_filelist_path"]) - ) - cfg["load_durations"] = False - - text_mel_datamodule = TextMelDataModule(**cfg) - text_mel_datamodule.setup() - data_loader = text_mel_datamodule.train_dataloader() - log.info("Dataloader loaded! Now computing stats...") - params = compute_data_statistics(data_loader, cfg["n_feats"]) - print(params) - json.dump( - params, - open(output_file, "w"), - ) - - -if __name__ == "__main__": - main() diff --git a/egs/ljspeech/TTS/matcha/utils/get_durations_from_trained_model.py b/egs/ljspeech/TTS/matcha/utils/get_durations_from_trained_model.py deleted file mode 100644 index acc7eabd9b..0000000000 --- a/egs/ljspeech/TTS/matcha/utils/get_durations_from_trained_model.py +++ /dev/null @@ -1,215 +0,0 @@ -r""" -The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it -when needed. - -Parameters from hparam.py will be used -""" -import argparse -import json -import os -import sys -from pathlib import Path - -import lightning -import numpy as np -import rootutils -import torch -from hydra import compose, initialize -from omegaconf import open_dict -from torch import nn -from tqdm.auto import tqdm - -from matcha.cli import get_device -from matcha.data.text_mel_datamodule import TextMelDataModule -from matcha.models.matcha_tts import MatchaTTS -from matcha.utils.logging_utils import pylogger -from matcha.utils.utils import get_phoneme_durations - -log = pylogger.get_pylogger(__name__) - - -def save_durations_to_folder( - attn: torch.Tensor, - x_length: int, - y_length: int, - filepath: str, - output_folder: Path, - text: str, -): - durations = attn.squeeze().sum(1)[:x_length].numpy() - durations_json = get_phoneme_durations(durations, text) - output = output_folder / Path(filepath).name.replace(".wav", ".npy") - with open(output.with_suffix(".json"), "w", encoding="utf-8") as f: - json.dump(durations_json, f, indent=4, ensure_ascii=False) - - np.save(output, durations) - - -@torch.inference_mode() -def compute_durations( - data_loader: torch.utils.data.DataLoader, - model: nn.Module, - device: torch.device, - output_folder, -): - """Generate durations from the model for each datapoint and save it in a folder - - Args: - data_loader (torch.utils.data.DataLoader): Dataloader - model (nn.Module): MatchaTTS model - device (torch.device): GPU or CPU - """ - - for batch in tqdm(data_loader, desc="🍵 Computing durations 🍵:"): - x, x_lengths = batch["x"], batch["x_lengths"] - y, y_lengths = batch["y"], batch["y_lengths"] - spks = batch["spks"] - x = x.to(device) - y = y.to(device) - x_lengths = x_lengths.to(device) - y_lengths = y_lengths.to(device) - spks = spks.to(device) if spks is not None else None - - _, _, _, attn = model( - x=x, - x_lengths=x_lengths, - y=y, - y_lengths=y_lengths, - spks=spks, - ) - attn = attn.cpu() - for i in range(attn.shape[0]): - save_durations_to_folder( - attn[i], - x_lengths[i].item(), - y_lengths[i].item(), - batch["filepaths"][i], - output_folder, - batch["x_texts"][i], - ) - - -def main(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "-i", - "--input-config", - type=str, - default="ljspeech.yaml", - help="The name of the yaml config file under configs/data", - ) - - parser.add_argument( - "-b", - "--batch-size", - type=int, - default="32", - help="Can have increased batch size for faster computation", - ) - - parser.add_argument( - "-f", - "--force", - action="store_true", - default=False, - required=False, - help="force overwrite the file", - ) - parser.add_argument( - "-c", - "--checkpoint_path", - type=str, - required=True, - help="Path to the checkpoint file to load the model from", - ) - - parser.add_argument( - "-o", - "--output-folder", - type=str, - default=None, - help="Output folder to save the data statistics", - ) - - parser.add_argument( - "--cpu", - action="store_true", - help="Use CPU for inference, not recommended (default: use GPU if available)", - ) - - args = parser.parse_args() - - with initialize(version_base="1.3", config_path="../../configs/data"): - cfg = compose( - config_name=args.input_config, return_hydra_config=True, overrides=[] - ) - - root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") - - with open_dict(cfg): - del cfg["hydra"] - del cfg["_target_"] - cfg["seed"] = 1234 - cfg["batch_size"] = args.batch_size - cfg["train_filelist_path"] = str( - os.path.join(root_path, cfg["train_filelist_path"]) - ) - cfg["valid_filelist_path"] = str( - os.path.join(root_path, cfg["valid_filelist_path"]) - ) - cfg["load_durations"] = False - - if args.output_folder is not None: - output_folder = Path(args.output_folder) - else: - output_folder = Path(cfg["train_filelist_path"]).parent / "durations" - - print(f"Output folder set to: {output_folder}") - - if os.path.exists(output_folder) and not args.force: - print("Folder already exists. Use -f to force overwrite") - sys.exit(1) - - output_folder.mkdir(parents=True, exist_ok=True) - - print( - f"Preprocessing: {cfg['name']} from training filelist: {cfg['train_filelist_path']}" - ) - print("Loading model...") - device = get_device(args) - model = MatchaTTS.load_from_checkpoint(args.checkpoint_path, map_location=device) - - text_mel_datamodule = TextMelDataModule(**cfg) - text_mel_datamodule.setup() - try: - print("Computing stats for training set if exists...") - train_dataloader = text_mel_datamodule.train_dataloader() - compute_durations(train_dataloader, model, device, output_folder) - except lightning.fabric.utilities.exceptions.MisconfigurationException: - print("No training set found") - - try: - print("Computing stats for validation set if exists...") - val_dataloader = text_mel_datamodule.val_dataloader() - compute_durations(val_dataloader, model, device, output_folder) - except lightning.fabric.utilities.exceptions.MisconfigurationException: - print("No validation set found") - - try: - print("Computing stats for test set if exists...") - test_dataloader = text_mel_datamodule.test_dataloader() - compute_durations(test_dataloader, model, device, output_folder) - except lightning.fabric.utilities.exceptions.MisconfigurationException: - print("No test set found") - - print(f"[+] Done! Data statistics saved to: {output_folder}") - - -if __name__ == "__main__": - # Helps with generating durations for the dataset to train other architectures - # that cannot learn to align due to limited size of dataset - # Example usage: - # python python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c pretrained_model - # This will create a folder in data/processed_data/durations/ljspeech with the durations - main() diff --git a/egs/ljspeech/TTS/matcha/utils/instantiators.py b/egs/ljspeech/TTS/matcha/utils/instantiators.py deleted file mode 100644 index bde0c0d757..0000000000 --- a/egs/ljspeech/TTS/matcha/utils/instantiators.py +++ /dev/null @@ -1,60 +0,0 @@ -from typing import List - -import hydra -from lightning import Callback -from lightning.pytorch.loggers import Logger -from omegaconf import DictConfig - -from matcha.utils import pylogger - -log = pylogger.get_pylogger(__name__) - - -def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: - """Instantiates callbacks from config. - - :param callbacks_cfg: A DictConfig object containing callback configurations. - :return: A list of instantiated callbacks. - """ - callbacks: List[Callback] = [] - - if not callbacks_cfg: - log.warning("No callback configs found! Skipping..") - return callbacks - - if not isinstance(callbacks_cfg, DictConfig): - raise TypeError("Callbacks config must be a DictConfig!") - - for _, cb_conf in callbacks_cfg.items(): - if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: - log.info( - f"Instantiating callback <{cb_conf._target_}>" - ) # pylint: disable=protected-access - callbacks.append(hydra.utils.instantiate(cb_conf)) - - return callbacks - - -def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: - """Instantiates loggers from config. - - :param logger_cfg: A DictConfig object containing logger configurations. - :return: A list of instantiated loggers. - """ - logger: List[Logger] = [] - - if not logger_cfg: - log.warning("No logger configs found! Skipping...") - return logger - - if not isinstance(logger_cfg, DictConfig): - raise TypeError("Logger config must be a DictConfig!") - - for _, lg_conf in logger_cfg.items(): - if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: - log.info( - f"Instantiating logger <{lg_conf._target_}>" - ) # pylint: disable=protected-access - logger.append(hydra.utils.instantiate(lg_conf)) - - return logger diff --git a/egs/ljspeech/TTS/matcha/utils/logging_utils.py b/egs/ljspeech/TTS/matcha/utils/logging_utils.py deleted file mode 100644 index 2d2377eb2b..0000000000 --- a/egs/ljspeech/TTS/matcha/utils/logging_utils.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import Any, Dict - -from lightning.pytorch.utilities import rank_zero_only -from omegaconf import OmegaConf - -from matcha.utils import pylogger - -log = pylogger.get_pylogger(__name__) - - -@rank_zero_only -def log_hyperparameters(object_dict: Dict[str, Any]) -> None: - """Controls which config parts are saved by Lightning loggers. - - Additionally saves: - - Number of model parameters - - :param object_dict: A dictionary containing the following objects: - - `"cfg"`: A DictConfig object containing the main config. - - `"model"`: The Lightning model. - - `"trainer"`: The Lightning trainer. - """ - hparams = {} - - cfg = OmegaConf.to_container(object_dict["cfg"]) - model = object_dict["model"] - trainer = object_dict["trainer"] - - if not trainer.logger: - log.warning("Logger not found! Skipping hyperparameter logging...") - return - - hparams["model"] = cfg["model"] - - # save number of model parameters - hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) - hparams["model/params/trainable"] = sum( - p.numel() for p in model.parameters() if p.requires_grad - ) - hparams["model/params/non_trainable"] = sum( - p.numel() for p in model.parameters() if not p.requires_grad - ) - - hparams["data"] = cfg["data"] - hparams["trainer"] = cfg["trainer"] - - hparams["callbacks"] = cfg.get("callbacks") - hparams["extras"] = cfg.get("extras") - - hparams["task_name"] = cfg.get("task_name") - hparams["tags"] = cfg.get("tags") - hparams["ckpt_path"] = cfg.get("ckpt_path") - hparams["seed"] = cfg.get("seed") - - # send hparams to all loggers - for logger in trainer.loggers: - logger.log_hyperparams(hparams) diff --git a/egs/ljspeech/TTS/matcha/utils/pylogger.py b/egs/ljspeech/TTS/matcha/utils/pylogger.py deleted file mode 100644 index a7ed7a961e..0000000000 --- a/egs/ljspeech/TTS/matcha/utils/pylogger.py +++ /dev/null @@ -1,29 +0,0 @@ -import logging - -from lightning.pytorch.utilities import rank_zero_only - - -def get_pylogger(name: str = __name__) -> logging.Logger: - """Initializes a multi-GPU-friendly python command line logger. - - :param name: The name of the logger, defaults to ``__name__``. - - :return: A logger object. - """ - logger = logging.getLogger(name) - - # this ensures all logging levels get marked with the rank zero decorator - # otherwise logs would get multiplied for each GPU process in multi-GPU setup - logging_levels = ( - "debug", - "info", - "warning", - "error", - "exception", - "fatal", - "critical", - ) - for level in logging_levels: - setattr(logger, level, rank_zero_only(getattr(logger, level))) - - return logger diff --git a/egs/ljspeech/TTS/matcha/utils/rich_utils.py b/egs/ljspeech/TTS/matcha/utils/rich_utils.py deleted file mode 100644 index d7fcd1aae9..0000000000 --- a/egs/ljspeech/TTS/matcha/utils/rich_utils.py +++ /dev/null @@ -1,103 +0,0 @@ -from pathlib import Path -from typing import Sequence - -import rich -import rich.syntax -import rich.tree -from hydra.core.hydra_config import HydraConfig -from lightning.pytorch.utilities import rank_zero_only -from omegaconf import DictConfig, OmegaConf, open_dict -from rich.prompt import Prompt - -from matcha.utils import pylogger - -log = pylogger.get_pylogger(__name__) - - -@rank_zero_only -def print_config_tree( - cfg: DictConfig, - print_order: Sequence[str] = ( - "data", - "model", - "callbacks", - "logger", - "trainer", - "paths", - "extras", - ), - resolve: bool = False, - save_to_file: bool = False, -) -> None: - """Prints the contents of a DictConfig as a tree structure using the Rich library. - - :param cfg: A DictConfig composed by Hydra. - :param print_order: Determines in what order config components are printed. Default is ``("data", "model", - "callbacks", "logger", "trainer", "paths", "extras")``. - :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. - :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. - """ - style = "dim" - tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) - - queue = [] - - # add fields from `print_order` to queue - for field in print_order: - _ = ( - queue.append(field) - if field in cfg - else log.warning( - f"Field '{field}' not found in config. Skipping '{field}' config printing..." - ) - ) - - # add all the other fields to queue (not specified in `print_order`) - for field in cfg: - if field not in queue: - queue.append(field) - - # generate config tree from queue - for field in queue: - branch = tree.add(field, style=style, guide_style=style) - - config_group = cfg[field] - if isinstance(config_group, DictConfig): - branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) - else: - branch_content = str(config_group) - - branch.add(rich.syntax.Syntax(branch_content, "yaml")) - - # print config tree - rich.print(tree) - - # save config tree to file - if save_to_file: - with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: - rich.print(tree, file=file) - - -@rank_zero_only -def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: - """Prompts user to input tags from command line if no tags are provided in config. - - :param cfg: A DictConfig composed by Hydra. - :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. - """ - if not cfg.get("tags"): - if "id" in HydraConfig().cfg.hydra.job: - raise ValueError("Specify tags before launching a multirun!") - - log.warning("No tags provided in config. Prompting user to input tags...") - tags = Prompt.ask("Enter a list of comma separated tags", default="dev") - tags = [t.strip() for t in tags.split(",") if t != ""] - - with open_dict(cfg): - cfg.tags = tags - - log.info(f"Tags: {cfg.tags}") - - if save_to_file: - with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: - rich.print(cfg.tags, file=file) diff --git a/egs/ljspeech/TTS/matcha/utils/utils.py b/egs/ljspeech/TTS/matcha/utils/utils.py deleted file mode 100644 index a545542632..0000000000 --- a/egs/ljspeech/TTS/matcha/utils/utils.py +++ /dev/null @@ -1,261 +0,0 @@ -import os -import sys -import warnings -from importlib.util import find_spec -from math import ceil -from pathlib import Path -from typing import Any, Callable, Dict, Tuple - -import matplotlib.pyplot as plt -import numpy as np -import torch - -# from omegaconf import DictConfig - -# from matcha.utils import pylogger, rich_utils - -# log = pylogger.get_pylogger(__name__) - - -def extras(cfg: "DictConfig") -> None: - """Applies optional utilities before the task is started. - - Utilities: - - Ignoring python warnings - - Setting tags from command line - - Rich config printing - - :param cfg: A DictConfig object containing the config tree. - """ - # return if no `extras` config - if not cfg.get("extras"): - log.warning("Extras config not found! ") - return - - # disable python warnings - if cfg.extras.get("ignore_warnings"): - log.info("Disabling python warnings! ") - warnings.filterwarnings("ignore") - - # prompt user to input tags from command line if none are provided in the config - if cfg.extras.get("enforce_tags"): - log.info("Enforcing tags! ") - rich_utils.enforce_tags(cfg, save_to_file=True) - - # pretty print config tree using Rich library - if cfg.extras.get("print_config"): - log.info("Printing config tree with Rich! ") - rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) - - -def task_wrapper(task_func: Callable) -> Callable: - """Optional decorator that controls the failure behavior when executing the task function. - - This wrapper can be used to: - - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) - - save the exception to a `.log` file - - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) - - etc. (adjust depending on your needs) - - Example: - ``` - @utils.task_wrapper - def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: - ... - return metric_dict, object_dict - ``` - - :param task_func: The task function to be wrapped. - - :return: The wrapped task function. - """ - - def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: - # execute the task - try: - metric_dict, object_dict = task_func(cfg=cfg) - - # things to do if exception occurs - except Exception as ex: - # save exception to `.log` file - log.exception("") - - # some hyperparameter combinations might be invalid or cause out-of-memory errors - # so when using hparam search plugins like Optuna, you might want to disable - # raising the below exception to avoid multirun failure - raise ex - - # things to always do after either success or exception - finally: - # display output dir path in terminal - log.info(f"Output dir: {cfg.paths.output_dir}") - - # always close wandb run (even if exception occurs so multirun won't fail) - if find_spec("wandb"): # check if wandb is installed - import wandb - - if wandb.run: - log.info("Closing wandb!") - wandb.finish() - - return metric_dict, object_dict - - return wrap - - -def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float: - """Safely retrieves value of the metric logged in LightningModule. - - :param metric_dict: A dict containing metric values. - :param metric_name: The name of the metric to retrieve. - :return: The value of the metric. - """ - if not metric_name: - log.info("Metric name is None! Skipping metric value retrieval...") - return None - - if metric_name not in metric_dict: - raise ValueError( - f"Metric value not found! \n" - "Make sure metric name logged in LightningModule is correct!\n" - "Make sure `optimized_metric` name in `hparams_search` config is correct!" - ) - - metric_value = metric_dict[metric_name].item() - log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") - - return metric_value - - -def intersperse(lst, item): - # Adds blank symbol - result = [item] * (len(lst) * 2 + 1) - result[1::2] = lst - return result - - -def save_figure_to_numpy(fig): - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") - data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) - return data - - -def plot_tensor(tensor): - plt.style.use("default") - fig, ax = plt.subplots(figsize=(12, 3)) - im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") - plt.colorbar(im, ax=ax) - plt.tight_layout() - fig.canvas.draw() - data = save_figure_to_numpy(fig) - plt.close() - return data - - -def save_plot(tensor, savepath): - plt.style.use("default") - fig, ax = plt.subplots(figsize=(12, 3)) - im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") - plt.colorbar(im, ax=ax) - plt.tight_layout() - fig.canvas.draw() - plt.savefig(savepath) - plt.close() - - -def to_numpy(tensor): - if isinstance(tensor, np.ndarray): - return tensor - elif isinstance(tensor, torch.Tensor): - return tensor.detach().cpu().numpy() - elif isinstance(tensor, list): - return np.array(tensor) - else: - raise TypeError("Unsupported type for conversion to numpy array") - - -def get_user_data_dir(appname="matcha_tts"): - """ - Args: - appname (str): Name of application - - Returns: - Path: path to user data directory - """ - - MATCHA_HOME = os.environ.get("MATCHA_HOME") - if MATCHA_HOME is not None: - ans = Path(MATCHA_HOME).expanduser().resolve(strict=False) - elif sys.platform == "win32": - import winreg # pylint: disable=import-outside-toplevel - - key = winreg.OpenKey( - winreg.HKEY_CURRENT_USER, - r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders", - ) - dir_, _ = winreg.QueryValueEx(key, "Local AppData") - ans = Path(dir_).resolve(strict=False) - elif sys.platform == "darwin": - ans = Path("~/Library/Application Support/").expanduser() - else: - ans = Path.home().joinpath(".local/share") - - final_path = ans.joinpath(appname) - final_path.mkdir(parents=True, exist_ok=True) - return final_path - - -def assert_model_downloaded(checkpoint_path, url, use_wget=True): - import gdown - import wget - - if Path(checkpoint_path).exists(): - log.debug(f"[+] Model already present at {checkpoint_path}!") - print(f"[+] Model already present at {checkpoint_path}!") - return - log.info(f"[-] Model not found at {checkpoint_path}! Will download it") - print(f"[-] Model not found at {checkpoint_path}! Will download it") - checkpoint_path = str(checkpoint_path) - if not use_wget: - gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True) - else: - wget.download(url=url, out=checkpoint_path) - - -def get_phoneme_durations(durations, phones): - prev = durations[0] - merged_durations = [] - # Convolve with stride 2 - for i in range(1, len(durations), 2): - if i == len(durations) - 2: - # if it is last take full value - next_half = durations[i + 1] - else: - next_half = ceil(durations[i + 1] / 2) - - curr = prev + durations[i] + next_half - prev = durations[i + 1] - next_half - merged_durations.append(curr) - - assert len(phones) == len(merged_durations) - assert len(merged_durations) == (len(durations) - 1) // 2 - - merged_durations = torch.cumsum(torch.tensor(merged_durations), 0, dtype=torch.long) - start = torch.tensor(0) - duration_json = [] - for i, duration in enumerate(merged_durations): - duration_json.append( - { - phones[i]: { - "starttime": start.item(), - "endtime": duration.item(), - "duration": duration.item() - start.item(), - } - } - ) - start = duration - - assert list(duration_json[-1].values())[0]["endtime"] == sum( - durations - ), f"{list(duration_json[-1].values())[0]['endtime'], sum(durations)}" - return duration_json From 8cb1cda040aeeb7caaebb830f6e966b756c3334a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 28 Oct 2024 19:59:38 +0800 Subject: [PATCH 17/27] refacotring --- .../TTS/local/compute_fbank_ljspeech.py | 2 +- egs/ljspeech/TTS/matcha/export_onnx.py | 2 -- egs/ljspeech/TTS/matcha/inference.py | 19 +++++++++++++++++-- icefall/checkpoint.py | 2 +- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py index fee66da480..5152ae675c 100755 --- a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py +++ b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py @@ -38,7 +38,7 @@ from lhotse.features.base import FeatureExtractor, register_extractor from lhotse.supervision import SupervisionSet from lhotse.utils import Seconds, compute_num_frames -from matcha.utils.audio import mel_spectrogram +from matcha.audio import mel_spectrogram from icefall.utils import get_executor diff --git a/egs/ljspeech/TTS/matcha/export_onnx.py b/egs/ljspeech/TTS/matcha/export_onnx.py index cf5069b113..c0eebcde09 100755 --- a/egs/ljspeech/TTS/matcha/export_onnx.py +++ b/egs/ljspeech/TTS/matcha/export_onnx.py @@ -73,8 +73,6 @@ def forward( )["mel"] # mel: (batch_size, feat_dim, num_frames) - # audio = self.vocoder(mel).clamp(-1, 1).squeeze(1) - return mel diff --git a/egs/ljspeech/TTS/matcha/inference.py b/egs/ljspeech/TTS/matcha/inference.py index 89a6b33ae3..8fc0ec3ace 100755 --- a/egs/ljspeech/TTS/matcha/inference.py +++ b/egs/ljspeech/TTS/matcha/inference.py @@ -28,7 +28,7 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=2810, + default=4000, help="""It specifies the checkpoint to use for decoding. Note: Epoch counts from 1. """, @@ -44,6 +44,13 @@ def get_parser(): """, ) + parser.add_argument( + "--vocoder", + type=Path, + default="./generator_v1", + help="Path to the vocoder", + ) + parser.add_argument( "--tokens", type=Path, @@ -61,6 +68,7 @@ def get_parser(): def load_vocoder(checkpoint_path): + checkpoint_path = str(checkpoint_path) if checkpoint_path.endswith("v1"): h = AttributeDict(v1) elif checkpoint_path.endswith("v2"): @@ -142,10 +150,17 @@ def main(): logging.info("About to create model") model = get_model(params) + + if not Path(f"{params.exp_dir}/epoch-{params.epoch}.pt").is_file(): + raise ValueError("{params.exp_dir}/epoch-{params.epoch}.pt does not exist") + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) model.eval() - vocoder = load_vocoder("/star-fj/fangjun/open-source/Matcha-TTS/generator_v1") + if not Path(params.vocoder).is_file(): + raise ValueError(f"{params.vocoder} does not exist") + + vocoder = load_vocoder(params.vocoder) denoiser = Denoiser(vocoder, mode="zeros") texts = [ diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 308a06b1f7..d31ce13019 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -90,7 +90,7 @@ def save_checkpoint( if params: for k, v in params.items(): - assert k not in checkpoint + assert k not in checkpoint, k checkpoint[k] = v torch.save(checkpoint, filename) From 14a28edab6e032e67caa1770da149aa0a72ef083 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 28 Oct 2024 22:49:14 +0800 Subject: [PATCH 18/27] Update README --- .github/scripts/ljspeech/TTS/run-matcha.sh | 0 egs/ljspeech/TTS/.gitignore | 3 + egs/ljspeech/TTS/README.md | 114 +++++++++++++++++++++ egs/ljspeech/TTS/matcha/inference.py | 84 +++++---------- egs/ljspeech/TTS/matcha/onnx_pretrained.py | 68 +++++++++--- egs/ljspeech/TTS/prepare.sh | 6 +- 6 files changed, 200 insertions(+), 75 deletions(-) create mode 100755 .github/scripts/ljspeech/TTS/run-matcha.sh diff --git a/.github/scripts/ljspeech/TTS/run-matcha.sh b/.github/scripts/ljspeech/TTS/run-matcha.sh new file mode 100755 index 0000000000..e69de29bb2 diff --git a/egs/ljspeech/TTS/.gitignore b/egs/ljspeech/TTS/.gitignore index 1eef06a289..d5c19797ab 100644 --- a/egs/ljspeech/TTS/.gitignore +++ b/egs/ljspeech/TTS/.gitignore @@ -2,3 +2,6 @@ build core.c *.so my-output* +*.wav +*.onnx +generator_v* diff --git a/egs/ljspeech/TTS/README.md b/egs/ljspeech/TTS/README.md index 7b112c12c8..fe613024ae 100644 --- a/egs/ljspeech/TTS/README.md +++ b/egs/ljspeech/TTS/README.md @@ -101,3 +101,117 @@ export CUDA_VISIBLE_DEVICES=4,5,6,7 # (Note it is killed after `epoch-820.pt`) ``` +# matcha + +[./matcha](./matcha) contains the code for training [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS) + +This recipe provides a Matcha-TTS model trained on the LJSpeech dataset. + +Pretrained model can be found [here](https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28). + +The training command is given below: +```bash +export CUDA_VISIBLE_DEVICES=0,1,2,3 + +python3 ./matcha/train.py \ + --exp-dir ./matcha/exp-new-3/ \ + --num-workers 4 \ + --world-size 4 \ + --num-epochs 4000 \ + --max-duration 1000 \ + --bucketing-sampler 1 \ + --start-epoch 1 +``` + +To inference, use: + +```bash +# Download Hifigan vocoder. We use Hifigan v1 below. You can select from v1, v2, or v3 + +wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 + +./matcha/inference \ + --exp-dir ./matcha/exp-new-3 \ + --epoch 4000 \ + --tokens ./data/tokens.txt \ + --vocoder ./generator_v1 \ + --input-text "how are you doing?" + --output-wav ./generated.wav +``` + +```bash +soxi ./generated.wav +``` +prints: +``` +Input File : './generated.wav' +Channels : 1 +Sample Rate : 22050 +Precision : 16-bit +Duration : 00:00:01.29 = 28416 samples ~ 96.6531 CDDA sectors +File Size : 56.9k +Bit Rate : 353k +Sample Encoding: 16-bit Signed Integer PCM +``` + +To export the checkpoint to onnx: + +```bash +# export the acoustic model to onnx + +./matcha/export_onnx.py \ + --exp-dir ./matcha/exp-new-3 \ + --epoch 4000 \ + --tokens ./data/tokens.txt +``` + +The above command generate the following files: + + - model-steps-2.onnx + - model-steps-3.onnx + - model-steps-4.onnx + - model-steps-5.onnx + - model-steps-6.onnx + +where the 2 in `model-steps-2.onnx` means it uses 2 steps for the ODE solver. + + +To export the Hifigan vocoder to onnx, please use: + +```bash +wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 +wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2 +wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v3 + +python3 ./matcha/export_onnx_hifigan.py +``` + +The above command generates 3 files: + + - hifigan_v1.onnx + - hifigan_v2.onnx + - hifigan_v3.onnx + +To use the generated onnx files to generate speech from text, please run: + +```bash +python3 ./matcha/onnx_pretrained.py \ + --acoustic-model ./model-steps-6.onnx \ + --vocoder ./hifigan_v2.onnx \ + --tokens ./data/tokens.txt \ + --input-text "how are you doing?" \ + --output-wav ./generated-2.wav +``` + +```bash +soxi ./generated-2.wav + +Input File : './generated-2.wav' +Channels : 1 +Sample Rate : 22050 +Precision : 16-bit +Duration : 00:00:01.25 = 27648 samples ~ 94.0408 CDDA sectors +File Size : 55.3k +Bit Rate : 353k +Sample Encoding: 16-bit Signed Integer PCM +``` diff --git a/egs/ljspeech/TTS/matcha/inference.py b/egs/ljspeech/TTS/matcha/inference.py index 8fc0ec3ace..1189160f64 100755 --- a/egs/ljspeech/TTS/matcha/inference.py +++ b/egs/ljspeech/TTS/matcha/inference.py @@ -6,14 +6,12 @@ import logging from pathlib import Path -import numpy as np import soundfile as sf import torch from matcha.hifigan.config import v1, v2, v3 from matcha.hifigan.denoiser import Denoiser from matcha.hifigan.models import Generator as HiFiGAN from tokenizer import Tokenizer -from tqdm.auto import tqdm from train import get_model, get_params from icefall.checkpoint import load_checkpoint @@ -64,6 +62,20 @@ def get_parser(): help="""Path to vocabulary.""", ) + parser.add_argument( + "--input-text", + type=str, + required=True, + help="The text to generate speech for", + ) + + parser.add_argument( + "--output-wav", + type=str, + required=True, + help="The filename of the wave to save the generated speech", + ) + return parser @@ -93,13 +105,6 @@ def to_waveform(mel, vocoder, denoiser): return audio.cpu().squeeze() -def save_to_folder(filename: str, output: dict, folder: str): - folder = Path(folder) - folder.mkdir(exist_ok=True, parents=True) - np.save(folder / f"{filename}", output["mel"].cpu().numpy()) - sf.write(folder / f"{filename}.wav", output["waveform"], 22050, "PCM_24") - - def process_text(text: str, tokenizer): x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) x = torch.tensor(x, dtype=torch.long) @@ -120,7 +125,6 @@ def synthesise( spks=spks, length_scale=length_scale, ) - print("output.shape", list(output.keys()), output["mel"].shape) # merge everything to one dict output.update({"start_t": start_t, **text_processed}) return output @@ -163,16 +167,6 @@ def main(): vocoder = load_vocoder(params.vocoder) denoiser = Denoiser(vocoder, mode="zeros") - texts = [ - "The Secret Service believed that it was very doubtful that any " - "President would ride regularly in a vehicle with a fixed top, even " - "though transparent.", - "Today as always, men fall into two groups: slaves and free men. " - "Whoever does not have two-thirds of his day for himself, is a slave, " - "whatever he may be: a statesman, a businessman, an official, or a " - "scholar.", - ] - # Number of ODE Solver steps n_timesteps = 2 @@ -182,47 +176,17 @@ def main(): # Sampling temperature temperature = 0.667 - rtfs = [] - rtfs_w = [] - for i, text in enumerate(tqdm(texts)): - output = synthesise( - model=model, - tokenizer=tokenizer, - n_timesteps=n_timesteps, - text=text, - length_scale=length_scale, - temperature=temperature, - ) # , torch.tensor([15], device=device, dtype=torch.long).unsqueeze(0)) - output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) - - # Compute Real Time Factor (RTF) with HiFi-GAN - t = (dt.datetime.now() - output["start_t"]).total_seconds() - rtf_w = t * 22050 / (output["waveform"].shape[-1]) - - # Pretty print - print(f"{'*' * 53}") - print(f"Input text - {i}") - print(f"{'-' * 53}") - print(output["x_orig"]) - print(f"{'*' * 53}") - print(f"Phonetised text - {i}") - print(f"{'-' * 53}") - print(output["x"]) - print(f"{'*' * 53}") - print(f"RTF:\t\t{output['rtf']:.6f}") - print(f"RTF Waveform:\t{rtf_w:.6f}") - rtfs.append(output["rtf"]) - rtfs_w.append(rtf_w) - - # Save the generated waveform - save_to_folder(i, output, folder=f"./my-output-{params.epoch}") - - print(f"Number of ODE steps: {n_timesteps}") - print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}") - print( - "Mean RTF Waveform " - f"(incl. vocoder):\t{np.mean(rtfs_w):.6f} ± {np.std(rtfs_w):.6f}" + output = synthesise( + model=model, + tokenizer=tokenizer, + n_timesteps=n_timesteps, + text=params.input_text, + length_scale=length_scale, + temperature=temperature, ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + + sf.write(params.output_wav, output["waveform"], 22050, "PCM_16") if __name__ == "__main__": diff --git a/egs/ljspeech/TTS/matcha/onnx_pretrained.py b/egs/ljspeech/TTS/matcha/onnx_pretrained.py index 3953d5d0ae..6a37f3c177 100755 --- a/egs/ljspeech/TTS/matcha/onnx_pretrained.py +++ b/egs/ljspeech/TTS/matcha/onnx_pretrained.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +import argparse import datetime as dt import logging @@ -9,6 +10,49 @@ from tokenizer import Tokenizer +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--acoustic-model", + type=str, + required=True, + help="Path to the acoustic model", + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--vocoder", + type=str, + required=True, + help="Path to the vocoder", + ) + + parser.add_argument( + "--input-text", + type=str, + required=True, + help="The text to generate speech for", + ) + + parser.add_argument( + "--output-wav", + type=str, + required=True, + help="The filename of the wave to save the generated speech", + ) + + return parser + + class OnnxHifiGANModel: def __init__( self, @@ -98,10 +142,12 @@ def __call__(self, x: torch.tensor): @torch.no_grad() def main(): - model = OnnxModel("./model-steps-6.onnx") - vocoder = OnnxHifiGANModel("./hifigan_v1.onnx") - text = "Today as always, men fall into two groups: slaves and free men." - text += "hello, how are you doing?" + params = get_parser().parse_args() + logging.info(vars(params)) + + model = OnnxModel(params.acoustic_model) + vocoder = OnnxHifiGANModel(params.vocoder) + text = params.input_text x = model.tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) x = torch.tensor(x, dtype=torch.int64) @@ -109,9 +155,6 @@ def main(): mel = model(x) end_t = dt.datetime.now() - for i in range(3): - audio = vocoder(mel) - start_t2 = dt.datetime.now() audio = vocoder(mel) end_t2 = dt.datetime.now() @@ -121,13 +164,14 @@ def main(): t = (end_t - start_t).total_seconds() t2 = (end_t2 - start_t2).total_seconds() - rtf = t * 22050 / audio.shape[-1] - rtf2 = t2 * 22050 / audio.shape[-1] - print("RTF", rtf) - print("RTF", rtf2) + rtf_am = t * 22050 / audio.shape[-1] + rtf_vocoder = t2 * 22050 / audio.shape[-1] + print("RTF for acoustic model ", rtf_am) + print("RTF for vocoder", rtf_vocoder) # skip denoiser - sf.write("onnx2.wav", audio, 22050, "PCM_16") + sf.write(params.output_wav, audio, 22050, "PCM_16") + logging.info(f"Saved to {params.output_wav}") if __name__ == "__main__": diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh index b140e6f010..dfc2b35405 100755 --- a/egs/ljspeech/TTS/prepare.sh +++ b/egs/ljspeech/TTS/prepare.sh @@ -34,10 +34,10 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then log "monotonic_align lib for vits already built" fi - if [ ! -f ./matcha/utils/monotonic_align/core.cpython-38-x86_64-linux-gnu.so ]; then - pushd matcha/utils/monotonic_align + if [ ! -f ./matcha/monotonic_align/core.cpython-38-x86_64-linux-gnu.so ]; then + pushd matcha/monotonic_align python3 setup.py build_ext --inplace - mv -v matcha/utils/monotonic_align/core.cpython-38-x86_64-linux-gnu.so ./ + mv -v matcha/monotonic_align/core.cpython-38-x86_64-linux-gnu.so ./ rm -rf matcha rm -rf build rm core.c From 761bc76c4521b3d83510529d64a0b9366873767f Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 28 Oct 2024 23:03:31 +0800 Subject: [PATCH 19/27] Add CI for matcha-tts --- .github/scripts/ljspeech/TTS/run-matcha.sh | 116 +++++++++++++++++++++ .github/workflows/ljspeech.yml | 1 + 2 files changed, 117 insertions(+) diff --git a/.github/scripts/ljspeech/TTS/run-matcha.sh b/.github/scripts/ljspeech/TTS/run-matcha.sh index e69de29bb2..b1da5ff137 100755 --- a/.github/scripts/ljspeech/TTS/run-matcha.sh +++ b/.github/scripts/ljspeech/TTS/run-matcha.sh @@ -0,0 +1,116 @@ +#!/usr/bin/env bash + +set -ex + +sudo apt-get install sox + +python3 -m pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html +python3 -m pip install espnet_tts_frontend +python3 -m pip install numba + +pytnon3 -m pip install conformer==0.3.2 diffusers + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/ljspeech/TTS + +sed -i.bak s/600/8/g ./prepare.sh +sed -i.bak s/"first 100"/"first 3"/g ./prepare.sh +sed -i.bak s/500/5/g ./prepare.sh +git diff + +function prepare_data() { + # We have created a subset of the data for testing + # + mkdir download + pushd download + wget -q https://huggingface.co/csukuangfj/ljspeech-subset-for-ci-test/resolve/main/LJSpeech-1.1.tar.bz2 + tar xvf LJSpeech-1.1.tar.bz2 + popd + + ./prepare.sh + tree . +} + +function train() { + pushd ./vits + sed -i.bak s/1500/3/g ./train.py + git diff . + popd + + ./vits/train.py \ + --exp-dir matcha/exp \ + --num-epochs 1 \ + --save-every-n 1 \ + --num-buckets 2 \ + --tokens data/tokens.txt \ + --max-duration 20 + + ls -lh match/exp + done +} + +function infer() { + + curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 + + ./matcha/inference.py \ + --epoch 1 \ + --exp-dir ./matcha/exp \ + --tokens data/tokens.txt \ + --vocoder ./generator_v1 \ + --input-text "how are you doing?" + --output-wav ./generated.wav + + ls -lh *.wav + soxi ./generated.wav + rm -v ./generated.wav + rm -v generator_v1 +} + +function export_onnx() { + pushd matcha/exp + + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/epoch-4000.pt + curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 + curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2 + curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v3 + + popd + + pushd data/fbank + rm -v *.json + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/data/cmvn.json + popd + + ./matcha/export_onnx.py \ + --exp-dir ./matcha/exp-new-3 \ + --epoch 4000 \ + --tokens ./data/tokens.txt \ + --cmvn ./data/fbank/cmvn.json + + ls -lh *.onnx + + python3 ./matcha/export_onnx_hifigan.py + + ls -lh *.onnx + + python3 ./matcha/onnx_pretrained.py \ + --acoustic-model ./model-steps-6.onnx \ + --vocoder ./hifigan_v1.onnx \ + --tokens ./data/tokens.txt \ + --input-text "how are you doing?" \ + --output-wav /icefall/generated-matcha-tts-6.wav + + ls -lh /icefall/*.wav + soxi /icefall/generated-matcha-tts-6.wav +} + +prepare_data +train +infer +export_onnx diff --git a/.github/workflows/ljspeech.yml b/.github/workflows/ljspeech.yml index e202d21b58..aaca730321 100644 --- a/.github/workflows/ljspeech.yml +++ b/.github/workflows/ljspeech.yml @@ -70,6 +70,7 @@ jobs: cd /icefall git config --global --add safe.directory /icefall + .github/scripts/ljspeech/TTS/run-matcha.sh .github/scripts/ljspeech/TTS/run.sh - name: display files From 908da44978d698ff7eea2a707d67ed69bff6f991 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 28 Oct 2024 23:08:54 +0800 Subject: [PATCH 20/27] fix building monotonic alignment --- egs/ljspeech/TTS/prepare.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh index dfc2b35405..6f16f8d473 100755 --- a/egs/ljspeech/TTS/prepare.sh +++ b/egs/ljspeech/TTS/prepare.sh @@ -36,11 +36,11 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then if [ ! -f ./matcha/monotonic_align/core.cpython-38-x86_64-linux-gnu.so ]; then pushd matcha/monotonic_align - python3 setup.py build_ext --inplace - mv -v matcha/monotonic_align/core.cpython-38-x86_64-linux-gnu.so ./ - rm -rf matcha + python3 setup.py build + mv -v build/lib.*/matcha/monotonic_align/core.*.so . rm -rf build rm core.c + ls -lh popd else log "monotonic_align lib for matcha-tts already built" From a6d018acecc9d3a800c9349e510cd4c6412c3800 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 28 Oct 2024 23:12:16 +0800 Subject: [PATCH 21/27] install missing deps --- .github/scripts/ljspeech/TTS/run-matcha.sh | 8 ++++---- egs/ljspeech/TTS/matcha/requirements.txt | 3 +++ 2 files changed, 7 insertions(+), 4 deletions(-) create mode 100644 egs/ljspeech/TTS/matcha/requirements.txt diff --git a/.github/scripts/ljspeech/TTS/run-matcha.sh b/.github/scripts/ljspeech/TTS/run-matcha.sh index b1da5ff137..26ce17b23d 100755 --- a/.github/scripts/ljspeech/TTS/run-matcha.sh +++ b/.github/scripts/ljspeech/TTS/run-matcha.sh @@ -2,13 +2,13 @@ set -ex -sudo apt-get install sox +apt-get install sox python3 -m pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html python3 -m pip install espnet_tts_frontend python3 -m pip install numba -pytnon3 -m pip install conformer==0.3.2 diffusers +python3 -m pip install conformer==0.3.2 diffusers librosa log() { # This function is from espnet @@ -37,12 +37,12 @@ function prepare_data() { } function train() { - pushd ./vits + pushd ./matcha sed -i.bak s/1500/3/g ./train.py git diff . popd - ./vits/train.py \ + ./matcha/train.py \ --exp-dir matcha/exp \ --num-epochs 1 \ --save-every-n 1 \ diff --git a/egs/ljspeech/TTS/matcha/requirements.txt b/egs/ljspeech/TTS/matcha/requirements.txt new file mode 100644 index 0000000000..5aadc89844 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/requirements.txt @@ -0,0 +1,3 @@ +conformer==0.3.2 +diffusers # developed using version ==0.25.0 +librosa From fa9f4d58fb70b82b9b6848fda3678616bb64da70 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 28 Oct 2024 23:25:11 +0800 Subject: [PATCH 22/27] fix typos --- .github/scripts/ljspeech/TTS/run-matcha.sh | 46 ++++++++++--------- .github/scripts/ljspeech/TTS/run.sh | 2 +- egs/ljspeech/TTS/matcha/export_onnx.py | 43 ++++++++++++++++- .../TTS/matcha/export_onnx_hifigan.py | 4 ++ 4 files changed, 72 insertions(+), 23 deletions(-) diff --git a/.github/scripts/ljspeech/TTS/run-matcha.sh b/.github/scripts/ljspeech/TTS/run-matcha.sh index 26ce17b23d..5da9fac577 100755 --- a/.github/scripts/ljspeech/TTS/run-matcha.sh +++ b/.github/scripts/ljspeech/TTS/run-matcha.sh @@ -2,13 +2,12 @@ set -ex -apt-get install sox +apt-get update +apt-get install -y sox python3 -m pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html python3 -m pip install espnet_tts_frontend -python3 -m pip install numba - -python3 -m pip install conformer==0.3.2 diffusers librosa +python3 -m pip install numba conformer==0.3.2 diffusers librosa log() { # This function is from espnet @@ -26,7 +25,7 @@ git diff function prepare_data() { # We have created a subset of the data for testing # - mkdir download + mkdir -p download pushd download wget -q https://huggingface.co/csukuangfj/ljspeech-subset-for-ci-test/resolve/main/LJSpeech-1.1.tar.bz2 tar xvf LJSpeech-1.1.tar.bz2 @@ -50,8 +49,7 @@ function train() { --tokens data/tokens.txt \ --max-duration 20 - ls -lh match/exp - done + ls -lh matcha/exp } function infer() { @@ -63,7 +61,7 @@ function infer() { --exp-dir ./matcha/exp \ --tokens data/tokens.txt \ --vocoder ./generator_v1 \ - --input-text "how are you doing?" + --input-text "how are you doing?" \ --output-wav ./generated.wav ls -lh *.wav @@ -74,12 +72,7 @@ function infer() { function export_onnx() { pushd matcha/exp - curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/epoch-4000.pt - curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 - curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2 - curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v3 - popd pushd data/fbank @@ -87,24 +80,33 @@ function export_onnx() { curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/data/cmvn.json popd + curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2 + curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v3 + ./matcha/export_onnx.py \ - --exp-dir ./matcha/exp-new-3 \ + --exp-dir ./matcha/exp \ --epoch 4000 \ --tokens ./data/tokens.txt \ --cmvn ./data/fbank/cmvn.json ls -lh *.onnx - python3 ./matcha/export_onnx_hifigan.py + if false; then + # THe CI machine does not have enough memory to run it + python3 ./matcha/export_onnx_hifigan.py + else + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v2.onnx + fi ls -lh *.onnx - python3 ./matcha/onnx_pretrained.py \ - --acoustic-model ./model-steps-6.onnx \ - --vocoder ./hifigan_v1.onnx \ - --tokens ./data/tokens.txt \ - --input-text "how are you doing?" \ - --output-wav /icefall/generated-matcha-tts-6.wav + + python3 ./matcha/onnx_pretrained.py \ + --acoustic-model ./model-steps-6.onnx \ + --vocoder ./hifigan_v2.onnx \ + --tokens ./data/tokens.txt \ + --input-text "how are you doing?" \ + --output-wav /icefall/generated-matcha-tts-6.wav ls -lh /icefall/*.wav soxi /icefall/generated-matcha-tts-6.wav @@ -114,3 +116,5 @@ prepare_data train infer export_onnx + +rm -rfv generator_v* matcha/exp diff --git a/.github/scripts/ljspeech/TTS/run.sh b/.github/scripts/ljspeech/TTS/run.sh index 707361782f..733a12c47b 100755 --- a/.github/scripts/ljspeech/TTS/run.sh +++ b/.github/scripts/ljspeech/TTS/run.sh @@ -22,7 +22,7 @@ git diff function prepare_data() { # We have created a subset of the data for testing # - mkdir download + mkdir -p download pushd download wget -q https://huggingface.co/csukuangfj/ljspeech-subset-for-ci-test/resolve/main/LJSpeech-1.1.tar.bz2 tar xvf LJSpeech-1.1.tar.bz2 diff --git a/egs/ljspeech/TTS/matcha/export_onnx.py b/egs/ljspeech/TTS/matcha/export_onnx.py index c0eebcde09..f7dc38c1bd 100755 --- a/egs/ljspeech/TTS/matcha/export_onnx.py +++ b/egs/ljspeech/TTS/matcha/export_onnx.py @@ -6,19 +6,60 @@ it to audio. See also ./export_onnx_hifigan.py """ +import argparse import json import logging +from pathlib import Path from typing import Any, Dict import onnx import torch -from inference import get_parser from tokenizer import Tokenizer from train import get_model, get_params from icefall.checkpoint import load_checkpoint +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=4000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="matcha/exp-new-3", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=Path, + default="data/tokens.txt", + ) + + parser.add_argument( + "--cmvn", + type=str, + default="data/fbank/cmvn.json", + help="""Path to vocabulary.""", + ) + + return parser + + def add_meta_data(filename: str, meta_data: Dict[str, Any]): """Add meta data to an ONNX model. It is changed in-place. diff --git a/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py b/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py index af54f4e896..ea4435479c 100755 --- a/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py +++ b/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import logging +from pathlib import Path from typing import Any, Dict import onnx @@ -58,6 +59,9 @@ def main(): for f in model_filenames: logging.info(f) + if not Path(f).is_file(): + logging.info(f"Skipping {f} since {f} does not exist") + continue model = load_vocoder(f) wrapper = ModelWrapper(model) wrapper.eval() From ab883a71ec20fc984a0cebd8a68fd93951129993 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 29 Oct 2024 07:24:36 +0800 Subject: [PATCH 23/27] Fix CI --- .github/workflows/audioset.yml | 6 +++--- .github/workflows/ljspeech.yml | 7 +++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/.github/workflows/audioset.yml b/.github/workflows/audioset.yml index 280ef8f8e4..9c9446239e 100644 --- a/.github/workflows/audioset.yml +++ b/.github/workflows/audioset.yml @@ -83,7 +83,7 @@ jobs: ls -lh ./model-onnx/* - name: Upload model to huggingface - if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push' + if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push' env: HF_TOKEN: ${{ secrets.HF_TOKEN }} uses: nick-fields/retry@v3 @@ -116,7 +116,7 @@ jobs: rm -rf huggingface - name: Prepare for release - if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push' + if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push' shell: bash run: | d=sherpa-onnx-zipformer-audio-tagging-2024-04-09 @@ -125,7 +125,7 @@ jobs: ls -lh - name: Release exported onnx models - if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push' + if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push' uses: svenstaro/upload-release-action@v2 with: file_glob: true diff --git a/.github/workflows/ljspeech.yml b/.github/workflows/ljspeech.yml index aaca730321..34a3797faa 100644 --- a/.github/workflows/ljspeech.yml +++ b/.github/workflows/ljspeech.yml @@ -79,19 +79,19 @@ jobs: ls -lh - uses: actions/upload-artifact@v4 - if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' + if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' with: name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }} path: ./*.wav - uses: actions/upload-artifact@v4 - if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' + if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' with: name: generated-models-py${{ matrix.python-version }}-torch${{ matrix.torch-version }} path: ./*.wav - name: Release exported onnx models - if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push' + if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push' uses: svenstaro/upload-release-action@v2 with: file_glob: true @@ -100,4 +100,3 @@ jobs: repo_name: k2-fsa/sherpa-onnx repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} tag: tts-models - From 3a986335d7a25780dab3408d3be33ddc1d71e55f Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 29 Oct 2024 10:04:47 +0800 Subject: [PATCH 24/27] Add copyright info --- .github/scripts/ljspeech/TTS/run-matcha.sh | 15 ++++++------- egs/ljspeech/TTS/matcha/LICENSE | 21 +++++++++++++++++++ egs/ljspeech/TTS/matcha/audio.py | 2 ++ egs/ljspeech/TTS/matcha/export_onnx.py | 1 + .../TTS/matcha/export_onnx_hifigan.py | 1 + egs/ljspeech/TTS/matcha/inference.py | 1 + egs/ljspeech/TTS/matcha/model.py | 2 ++ egs/ljspeech/TTS/matcha/models/README.md | 3 +++ .../TTS/matcha/monotonic_align/__init__.py | 2 ++ .../TTS/matcha/monotonic_align/core.pyx | 2 ++ .../TTS/matcha/monotonic_align/setup.py | 2 ++ egs/ljspeech/TTS/matcha/onnx_pretrained.py | 2 ++ egs/ljspeech/TTS/matcha/train.py | 2 +- 13 files changed, 48 insertions(+), 8 deletions(-) create mode 100644 egs/ljspeech/TTS/matcha/LICENSE create mode 100644 egs/ljspeech/TTS/matcha/models/README.md diff --git a/.github/scripts/ljspeech/TTS/run-matcha.sh b/.github/scripts/ljspeech/TTS/run-matcha.sh index 5da9fac577..b6eb81020b 100755 --- a/.github/scripts/ljspeech/TTS/run-matcha.sh +++ b/.github/scripts/ljspeech/TTS/run-matcha.sh @@ -80,9 +80,6 @@ function export_onnx() { curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/data/cmvn.json popd - curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2 - curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v3 - ./matcha/export_onnx.py \ --exp-dir ./matcha/exp \ --epoch 4000 \ @@ -93,9 +90,13 @@ function export_onnx() { if false; then # THe CI machine does not have enough memory to run it + # + curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 + curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2 + curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v3 python3 ./matcha/export_onnx_hifigan.py else - curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v2.onnx + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v1.onnx fi ls -lh *.onnx @@ -103,13 +104,13 @@ function export_onnx() { python3 ./matcha/onnx_pretrained.py \ --acoustic-model ./model-steps-6.onnx \ - --vocoder ./hifigan_v2.onnx \ + --vocoder ./hifigan_v1.onnx \ --tokens ./data/tokens.txt \ --input-text "how are you doing?" \ - --output-wav /icefall/generated-matcha-tts-6.wav + --output-wav /icefall/generated-matcha-tts-steps-6-v1.wav ls -lh /icefall/*.wav - soxi /icefall/generated-matcha-tts-6.wav + soxi /icefall/generated-matcha-tts-steps-6-v1.wav } prepare_data diff --git a/egs/ljspeech/TTS/matcha/LICENSE b/egs/ljspeech/TTS/matcha/LICENSE new file mode 100644 index 0000000000..858018e750 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Shivam Mehta + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/egs/ljspeech/TTS/matcha/audio.py b/egs/ljspeech/TTS/matcha/audio.py index 0a9b8db2a9..534331e596 100644 --- a/egs/ljspeech/TTS/matcha/audio.py +++ b/egs/ljspeech/TTS/matcha/audio.py @@ -1,3 +1,5 @@ +# This file is copied from +# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/audio.py import numpy as np import torch import torch.utils.data diff --git a/egs/ljspeech/TTS/matcha/export_onnx.py b/egs/ljspeech/TTS/matcha/export_onnx.py index f7dc38c1bd..487ea29952 100755 --- a/egs/ljspeech/TTS/matcha/export_onnx.py +++ b/egs/ljspeech/TTS/matcha/export_onnx.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) """ This script exports a Matcha-TTS model to ONNX. diff --git a/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py b/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py index ea4435479c..63d1fac205 100755 --- a/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py +++ b/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) import logging from pathlib import Path diff --git a/egs/ljspeech/TTS/matcha/inference.py b/egs/ljspeech/TTS/matcha/inference.py index 1189160f64..64abd8e50b 100755 --- a/egs/ljspeech/TTS/matcha/inference.py +++ b/egs/ljspeech/TTS/matcha/inference.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) import argparse import datetime as dt diff --git a/egs/ljspeech/TTS/matcha/model.py b/egs/ljspeech/TTS/matcha/model.py index a488ab4e8b..6539ffc24c 100644 --- a/egs/ljspeech/TTS/matcha/model.py +++ b/egs/ljspeech/TTS/matcha/model.py @@ -1,3 +1,5 @@ +# This file is copied from +# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/model.py """ from https://github.com/jaywalnut310/glow-tts """ import numpy as np diff --git a/egs/ljspeech/TTS/matcha/models/README.md b/egs/ljspeech/TTS/matcha/models/README.md new file mode 100644 index 0000000000..1099ef3c83 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/models/README.md @@ -0,0 +1,3 @@ +# Introduction +Files in this folder are copied from +https://github.com/shivammehta25/Matcha-TTS/tree/main/matcha/models diff --git a/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py b/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py index 58286bdd42..85e275fd05 100644 --- a/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py +++ b/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py @@ -1,3 +1,5 @@ +# Copied from +# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/__init__.py import numpy as np import torch diff --git a/egs/ljspeech/TTS/matcha/monotonic_align/core.pyx b/egs/ljspeech/TTS/matcha/monotonic_align/core.pyx index 091fcc3a50..eabc7f2736 100644 --- a/egs/ljspeech/TTS/matcha/monotonic_align/core.pyx +++ b/egs/ljspeech/TTS/matcha/monotonic_align/core.pyx @@ -1,3 +1,5 @@ +# Copied from +# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/core.pyx import numpy as np cimport cython diff --git a/egs/ljspeech/TTS/matcha/monotonic_align/setup.py b/egs/ljspeech/TTS/matcha/monotonic_align/setup.py index 6092e20d26..e406d67862 100644 --- a/egs/ljspeech/TTS/matcha/monotonic_align/setup.py +++ b/egs/ljspeech/TTS/matcha/monotonic_align/setup.py @@ -1,3 +1,5 @@ +# Copied from +# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/setup.py from distutils.core import setup from Cython.Build import cythonize import numpy diff --git a/egs/ljspeech/TTS/matcha/onnx_pretrained.py b/egs/ljspeech/TTS/matcha/onnx_pretrained.py index 6a37f3c177..be34343d3b 100755 --- a/egs/ljspeech/TTS/matcha/onnx_pretrained.py +++ b/egs/ljspeech/TTS/matcha/onnx_pretrained.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + import argparse import datetime as dt import logging diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index ce13e7e429..5e713fdfdb 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) import argparse From a531c92711cbbddc9beaed8f4c23f53d5dc2f3df Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 29 Oct 2024 10:14:09 +0800 Subject: [PATCH 25/27] fix style issues --- .github/scripts/ljspeech/TTS/run-matcha.sh | 13 +++-- .github/workflows/ljspeech.yml | 6 --- egs/ljspeech/TTS/matcha/hifigan/denoiser.py | 13 +++-- egs/ljspeech/TTS/matcha/hifigan/meldataset.py | 50 ++++++++++++++---- egs/ljspeech/TTS/matcha/hifigan/models.py | 52 ++++++++++++++++--- 5 files changed, 100 insertions(+), 34 deletions(-) diff --git a/.github/scripts/ljspeech/TTS/run-matcha.sh b/.github/scripts/ljspeech/TTS/run-matcha.sh index b6eb81020b..37e1bc3204 100755 --- a/.github/scripts/ljspeech/TTS/run-matcha.sh +++ b/.github/scripts/ljspeech/TTS/run-matcha.sh @@ -101,13 +101,12 @@ function export_onnx() { ls -lh *.onnx - - python3 ./matcha/onnx_pretrained.py \ - --acoustic-model ./model-steps-6.onnx \ - --vocoder ./hifigan_v1.onnx \ - --tokens ./data/tokens.txt \ - --input-text "how are you doing?" \ - --output-wav /icefall/generated-matcha-tts-steps-6-v1.wav + python3 ./matcha/onnx_pretrained.py \ + --acoustic-model ./model-steps-6.onnx \ + --vocoder ./hifigan_v1.onnx \ + --tokens ./data/tokens.txt \ + --input-text "how are you doing?" \ + --output-wav /icefall/generated-matcha-tts-steps-6-v1.wav ls -lh /icefall/*.wav soxi /icefall/generated-matcha-tts-steps-6-v1.wav diff --git a/.github/workflows/ljspeech.yml b/.github/workflows/ljspeech.yml index 34a3797faa..7dca96b37e 100644 --- a/.github/workflows/ljspeech.yml +++ b/.github/workflows/ljspeech.yml @@ -84,12 +84,6 @@ jobs: name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }} path: ./*.wav - - uses: actions/upload-artifact@v4 - if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' - with: - name: generated-models-py${{ matrix.python-version }}-torch${{ matrix.torch-version }} - path: ./*.wav - - name: Release exported onnx models if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push' uses: svenstaro/upload-release-action@v2 diff --git a/egs/ljspeech/TTS/matcha/hifigan/denoiser.py b/egs/ljspeech/TTS/matcha/hifigan/denoiser.py index 9fd33312a0..b9aea61b8e 100644 --- a/egs/ljspeech/TTS/matcha/hifigan/denoiser.py +++ b/egs/ljspeech/TTS/matcha/hifigan/denoiser.py @@ -7,13 +7,18 @@ class Denoiser(torch.nn.Module): """Removes model bias from audio produced with waveglow""" - def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"): + def __init__( + self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros" + ): super().__init__() self.filter_length = filter_length self.hop_length = int(filter_length / n_overlap) self.win_length = win_length - dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device + dtype, device = ( + next(vocoder.parameters()).dtype, + next(vocoder.parameters()).device, + ) self.device = device if mode == "zeros": mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device) @@ -32,7 +37,9 @@ def stft_fn(audio, n_fft, hop_length, win_length, window): return_complex=True, ) spec = torch.view_as_real(spec) - return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0]) + return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2( + spec[..., -1], spec[..., 0] + ) self.stft = lambda x: stft_fn( audio=x, diff --git a/egs/ljspeech/TTS/matcha/hifigan/meldataset.py b/egs/ljspeech/TTS/matcha/hifigan/meldataset.py index 8b43ea7965..6eb15a326c 100644 --- a/egs/ljspeech/TTS/matcha/hifigan/meldataset.py +++ b/egs/ljspeech/TTS/matcha/hifigan/meldataset.py @@ -49,7 +49,9 @@ def spectral_de_normalize_torch(magnitudes): hann_window = {} -def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): +def mel_spectrogram( + y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False +): if torch.min(y) < -1.0: print("min value is ", torch.min(y)) if torch.max(y) > 1.0: @@ -58,11 +60,15 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, global mel_basis, hann_window # pylint: disable=global-statement if fmax not in mel_basis: mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) - mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) + mel_basis[str(fmax) + "_" + str(y.device)] = ( + torch.from_numpy(mel).float().to(y.device) + ) hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) y = torch.nn.functional.pad( - y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", ) y = y.squeeze(1) @@ -92,12 +98,16 @@ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, def get_dataset_filelist(a): with open(a.input_training_file, encoding="utf-8") as fi: training_files = [ - os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 + os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") + for x in fi.read().split("\n") + if len(x) > 0 ] with open(a.input_validation_file, encoding="utf-8") as fi: validation_files = [ - os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 + os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") + for x in fi.read().split("\n") + if len(x) > 0 ] return training_files, validation_files @@ -152,7 +162,9 @@ def __getitem__(self, index): audio = normalize(audio) * 0.95 self.cached_wav = audio if sampling_rate != self.sampling_rate: - raise ValueError(f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR") + raise ValueError( + f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR" + ) self._cache_ref_count = self.n_cache_reuse else: audio = self.cached_wav @@ -168,7 +180,9 @@ def __getitem__(self, index): audio_start = random.randint(0, max_audio_start) audio = audio[:, audio_start : audio_start + self.segment_size] else: - audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") + audio = torch.nn.functional.pad( + audio, (0, self.segment_size - audio.size(1)), "constant" + ) mel = mel_spectrogram( audio, @@ -182,7 +196,12 @@ def __getitem__(self, index): center=False, ) else: - mel = np.load(os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + ".npy")) + mel = np.load( + os.path.join( + self.base_mels_path, + os.path.splitext(os.path.split(filename)[-1])[0] + ".npy", + ) + ) mel = torch.from_numpy(mel) if len(mel.shape) < 3: @@ -194,10 +213,19 @@ def __getitem__(self, index): if audio.size(1) >= self.segment_size: mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) mel = mel[:, :, mel_start : mel_start + frames_per_seg] - audio = audio[:, mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size] + audio = audio[ + :, + mel_start + * self.hop_size : (mel_start + frames_per_seg) + * self.hop_size, + ] else: - mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant") - audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") + mel = torch.nn.functional.pad( + mel, (0, frames_per_seg - mel.size(2)), "constant" + ) + audio = torch.nn.functional.pad( + audio, (0, self.segment_size - audio.size(1)), "constant" + ) mel_loss = mel_spectrogram( audio, diff --git a/egs/ljspeech/TTS/matcha/hifigan/models.py b/egs/ljspeech/TTS/matcha/hifigan/models.py index d209d9a4e9..e6da206108 100644 --- a/egs/ljspeech/TTS/matcha/hifigan/models.py +++ b/egs/ljspeech/TTS/matcha/hifigan/models.py @@ -151,7 +151,9 @@ def __init__(self, h): self.h = h self.num_kernels = len(h.resblock_kernel_sizes) self.num_upsamples = len(h.upsample_rates) - self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) + self.conv_pre = weight_norm( + Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3) + ) resblock = ResBlock1 if h.resblock == "1" else ResBlock2 self.ups = nn.ModuleList() @@ -171,7 +173,9 @@ def __init__(self, h): self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = h.upsample_initial_channel // (2 ** (i + 1)) - for _, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + for _, (k, d) in enumerate( + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) + ): self.resblocks.append(resblock(h, ch, k, d)) self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) @@ -213,10 +217,42 @@ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): norm_f = weight_norm if use_spectral_norm is False else spectral_norm self.convs = nn.ModuleList( [ - norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), - norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), - norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), - norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f( + Conv2d( + 1, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + 32, + 128, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + 128, + 512, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + 512, + 1024, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), ] ) @@ -313,7 +349,9 @@ def __init__(self): DiscriminatorS(), ] ) - self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]) + self.meanpools = nn.ModuleList( + [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)] + ) def forward(self, y, y_hat): y_d_rs = [] From 0db831910a17abebb5adccf41766c3187afd510e Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 29 Oct 2024 10:41:09 +0800 Subject: [PATCH 26/27] fix isort --- egs/ljspeech/TTS/local/validate_manifest.py | 2 +- egs/ljspeech/TTS/matcha/monotonic_align/__init__.py | 1 - egs/ljspeech/TTS/matcha/monotonic_align/setup.py | 3 ++- egs/ljspeech/TTS/matcha/tts_datamodule.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/ljspeech/TTS/local/validate_manifest.py b/egs/ljspeech/TTS/local/validate_manifest.py index bbd1bfe9d6..9535ba9f41 100755 --- a/egs/ljspeech/TTS/local/validate_manifest.py +++ b/egs/ljspeech/TTS/local/validate_manifest.py @@ -33,9 +33,9 @@ import logging from pathlib import Path +from compute_fbank_ljspeech import MyFbank from lhotse import CutSet, load_manifest_lazy from lhotse.dataset.speech_synthesis import validate_for_tts -from compute_fbank_ljspeech import MyFbank def get_args(): diff --git a/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py b/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py index 85e275fd05..5b26fe4743 100644 --- a/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py +++ b/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py @@ -2,7 +2,6 @@ # https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/__init__.py import numpy as np import torch - from matcha.monotonic_align.core import maximum_path_c diff --git a/egs/ljspeech/TTS/matcha/monotonic_align/setup.py b/egs/ljspeech/TTS/matcha/monotonic_align/setup.py index e406d67862..df26c633e0 100644 --- a/egs/ljspeech/TTS/matcha/monotonic_align/setup.py +++ b/egs/ljspeech/TTS/matcha/monotonic_align/setup.py @@ -1,8 +1,9 @@ # Copied from # https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/setup.py from distutils.core import setup -from Cython.Build import cythonize + import numpy +from Cython.Build import cythonize setup( name="monotonic_align", diff --git a/egs/ljspeech/TTS/matcha/tts_datamodule.py b/egs/ljspeech/TTS/matcha/tts_datamodule.py index 0227d9fdbe..8e37fc0308 100644 --- a/egs/ljspeech/TTS/matcha/tts_datamodule.py +++ b/egs/ljspeech/TTS/matcha/tts_datamodule.py @@ -24,8 +24,8 @@ from typing import Any, Dict, Optional import torch -from lhotse import CutSet, load_manifest_lazy from compute_fbank_ljspeech import MyFbank, MyFbankConfig +from lhotse import CutSet, load_manifest_lazy from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures CutConcatenate, CutMix, From 74925e65380c16e5e2d943e500692c6ddad16474 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 29 Oct 2024 12:36:03 +0800 Subject: [PATCH 27/27] Add generated wave --- egs/ljspeech/TTS/README.md | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/egs/ljspeech/TTS/README.md b/egs/ljspeech/TTS/README.md index fe613024ae..1cd6e8fd73 100644 --- a/egs/ljspeech/TTS/README.md +++ b/egs/ljspeech/TTS/README.md @@ -107,7 +107,8 @@ export CUDA_VISIBLE_DEVICES=4,5,6,7 This recipe provides a Matcha-TTS model trained on the LJSpeech dataset. -Pretrained model can be found [here](https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28). +Checkpoints and training logs can be found [here](https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28). +The pull-request for this recipe can be found at The training command is given below: ```bash @@ -197,21 +198,24 @@ To use the generated onnx files to generate speech from text, please run: ```bash python3 ./matcha/onnx_pretrained.py \ --acoustic-model ./model-steps-6.onnx \ - --vocoder ./hifigan_v2.onnx \ + --vocoder ./hifigan_v1.onnx \ --tokens ./data/tokens.txt \ - --input-text "how are you doing?" \ - --output-wav ./generated-2.wav + --input-text "Ask not what your country can do for you; ask what you can do for your country." \ + --output-wav ./matcha-epoch-4000-step6-hfigian-v1.wav ``` ```bash -soxi ./generated-2.wav +soxi ./matcha-epoch-4000-step6-hfigian-v1.wav -Input File : './generated-2.wav' +Input File : './matcha-epoch-4000-step6-hfigian-v1.wav' Channels : 1 Sample Rate : 22050 Precision : 16-bit -Duration : 00:00:01.25 = 27648 samples ~ 94.0408 CDDA sectors -File Size : 55.3k +Duration : 00:00:05.46 = 120320 samples ~ 409.252 CDDA sectors +File Size : 241k Bit Rate : 353k Sample Encoding: 16-bit Signed Integer PCM ``` + +https://github.com/user-attachments/assets/b7c197a6-3870-49c6-90ca-db4d3776869b +