From af7652d8dbd8e25a8343c1a3f332bdc219e524f9 Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 30 Oct 2024 10:29:57 +0800 Subject: [PATCH 1/2] Add hifigan --- egs/ljspeech/TTS/hifigan/models.py | 457 ++++++++++ egs/ljspeech/TTS/hifigan/train.py | 993 +++++++++++++++++++++ egs/ljspeech/TTS/hifigan/tts_datamodule.py | 372 ++++++++ egs/ljspeech/TTS/hifigan/utils.py | 145 +++ 4 files changed, 1967 insertions(+) create mode 100644 egs/ljspeech/TTS/hifigan/models.py create mode 100755 egs/ljspeech/TTS/hifigan/train.py create mode 100644 egs/ljspeech/TTS/hifigan/tts_datamodule.py create mode 100644 egs/ljspeech/TTS/hifigan/utils.py diff --git a/egs/ljspeech/TTS/hifigan/models.py b/egs/ljspeech/TTS/hifigan/models.py new file mode 100644 index 0000000000..bdb060cf0f --- /dev/null +++ b/egs/ljspeech/TTS/hifigan/models.py @@ -0,0 +1,457 @@ +import logging + +from typing import List +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import remove_weight_norm, spectral_norm +from torch.nn.utils.parametrizations import weight_norm + +LRELU_SLOPE = 0.1 + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Generator(torch.nn.Module): + def __init__( + self, + in_channels: int = 80, + upsample_initial_channel: int = 512, + upsample_rates: List[int] = [8, 8, 2, 2], + upsample_kernel_sizes: List[int] = [16, 16, 4, 4], + resblock_version: str = "1", + resblock_kernel_sizes: List[int] = [3, 7, 11], + resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + ): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = weight_norm( + Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3) + ) + resblock = ResBlock1 if resblock_version == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes) + ): + self.resblocks.append(resblock(channels=ch, kernel_size=k, dilation=d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + return x + + def remove_weight_norm(self): + logging.info("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f( + Conv2d( + 1, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + 32, + 128, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + 128, + 512, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + 512, + 1024, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ] + ) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiPeriodDiscriminator, self).__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorP(2), + DiscriminatorP(3), + DiscriminatorP(5), + DiscriminatorP(7), + DiscriminatorP(11), + ] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ] + ) + self.meanpools = nn.ModuleList( + [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + if rl.shape[2] < gl.shape[2]: + gl = gl[:, :, 0 : rl.shape[2], ...] + elif gl.shape[2] < rl.shape[2]: + rl = rl[:, :, 0 : gl.shape[2], ...] + loss += torch.mean(torch.abs(rl - gl)) + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg**2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + return loss, gen_losses + + +class HiFiGAN(torch.nn.Module): + def __init__( + self, + in_channels: int = 80, + upsample_initial_channel: int = 512, + upsample_rates: List[int] = [8, 8, 2, 2], + upsample_kernel_sizes: List[int] = [16, 16, 4, 4], + resblock_version: str = "1", + resblock_kernel_sizes: List[int] = [3, 7, 11], + resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + ): + super(HiFiGAN, self).__init__() + self.generator = Generator( + in_channels=in_channels, + upsample_initial_channel=upsample_initial_channel, + upsample_rates=upsample_rates, + upsample_kernel_sizes=upsample_kernel_sizes, + resblock_version=resblock_version, + resblock_kernel_sizes=resblock_kernel_sizes, + resblock_dilation_sizes=resblock_dilation_sizes, + ) + self.mpd = MultiPeriodDiscriminator() + self.msd = MultiScaleDiscriminator() + + def forward(self, x): + return self.generator(x) diff --git a/egs/ljspeech/TTS/hifigan/train.py b/egs/ljspeech/TTS/hifigan/train.py new file mode 100755 index 0000000000..79badef29b --- /dev/null +++ b/egs/ljspeech/TTS/hifigan/train.py @@ -0,0 +1,993 @@ +#!/usr/bin/env python3 +# Copyright 2023-2024 Xiaomi Corp. (authors: Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union +import itertools +import json +import copy +import math +import os +import random + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import LJSpeechTtsDataModule + +from torch.optim.lr_scheduler import ExponentialLR, LRScheduler +from torch.optim import Optimizer + +from utils import load_checkpoint, save_checkpoint, plot_spectrogram + +from icefall import diagnostics +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + get_parameter_groups_with_lrs, +) +from models import ( + HiFiGAN, + feature_loss, + generator_loss, + discriminator_loss, +) +from lhotse import Fbank, FbankConfig +from lhotse.utils import fix_random_seed + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="hifigan/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--learning-rate", type=float, default=0.0002, help="The learning rate." + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--hifigan-version", + type=str, + default="v1", + choices=["v1", "v2", "v3"], + help="Version of hifigan.", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 500, + "feature_dim": 80, + "segment_size": 8192, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "v1": { + "upsample_initial_channel": 512, + "resblock_version": "1", + "upsample_rates": [8, 8, 2, 2], + "upsample_kernel_sizes": [16, 16, 4, 4], + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + }, + "v2": { + "upsample_initial_channel": 128, + "resblock_version": "1", + "upsample_rates": [8, 8, 2, 2], + "upsample_kernel_sizes": [16, 16, 4, 4], + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + }, + "v3": { + "upsample_initial_channel": 256, + "resblock_version": "2", + "upsample_rates": [8, 8, 4], + "upsample_kernel_sizes": [16, 16, 8], + "resblock_kernel_sizes": [3, 5, 7], + "resblock_dilation_sizes": [[1, 2], [2, 6], [3, 12]], + }, + "env_info": get_env_info(), + } + ) + + return params + + +def fbank( + audio: torch.Tensor, + lengths: Optional[torch.Tensor] = None, + sampling_rate: int = 22050, + frame_length: int = 1024, + frame_shift: int = 256, + use_fft_mag: bool = True, +): + sampling_rate = sampling_rate + config = FbankConfig( + sampling_rate=sampling_rate, + frame_length=frame_length / sampling_rate, # (in second), + frame_shift=frame_shift / sampling_rate, # (in second) + use_fft_mag=use_fft_mag, + ) + fb = Fbank(config) + feat = fb.extract_batch(audio, sampling_rate=sampling_rate, lengths=lengths) + if feat.dim() == 2: + feat = feat.unsqueeze(0) + return feat + + +def get_model(params: AttributeDict) -> nn.Module: + device = params.device + model = HiFiGAN( + in_channels=params.feature_dim, + upsample_initial_channel=params[params.hifigan_version][ + "upsample_initial_channel" + ], + upsample_rates=params[params.hifigan_version]["upsample_rates"], + upsample_kernel_sizes=params[params.hifigan_version]["upsample_kernel_sizes"], + resblock_version=params[params.hifigan_version]["resblock_version"], + resblock_kernel_sizes=params[params.hifigan_version]["resblock_kernel_sizes"], + resblock_dilation_sizes=params[params.hifigan_version][ + "resblock_dilation_sizes" + ], + ).to(device) + num_param_g = sum([p.numel() for p in model.generator.parameters()]) + logging.info(f"Number of Generator parameters : {num_param_g}") + num_param_mpd = sum([p.numel() for p in model.mpd.parameters()]) + logging.info(f"Number of MultiPeriodDiscriminator parameters : {num_param_mpd}") + num_param_msd = sum([p.numel() for p in model.msd.parameters()]) + logging.info(f"Number of MultiScaleDiscriminator parameters : {num_param_msd}") + logging.info( + f"Number of model parameters : {num_param_g + num_param_mpd + num_param_msd}" + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer_g: Optional[Optimizer] = None, + optimizer_d: Optional[Optimizer] = None, + scheduler_g: Optional[LRScheduler] = None, + scheduler_d: Optional[LRScheduler] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def compute_generator_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + features: Tensor, + audios: Tensor, +) -> Tuple[Tensor, MetricsTracker]: + device = params.device + model = model.module if isinstance(model, DDP) else model + + audios = audios.unsqueeze(1) # (B, 1, T) + + gen_audios = model(features) # (B, 1, T) + + gen_features = fbank(gen_audios.squeeze(1)).permute(0, 2, 1).to(device) # (B, F, T) + + # L1 Mel-Spectrogram Loss + loss_mel = F.l1_loss(features, gen_features) * 45 + + y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = model.mpd(audios, gen_audios) + y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = model.msd(audios, gen_audios) + + loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) + loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) + + loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) + loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) + + loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel + + assert loss_gen_all.requires_grad == True + + info = MetricsTracker() + info["frames"] = 1 + info["loss_gen"] = loss_gen_all.detach().cpu().item() + info["loss_mel"] = loss_mel.detach().cpu().item() + info["loss_mel_error"] = loss_mel.detach().cpu().item() / 45 + info["loss_feature_msd"] = loss_fm_s.detach().cpu().item() + info["loss_feature_mpd"] = loss_fm_f.detach().cpu().item() + info["loss_gen_msd"] = loss_gen_s.detach().cpu().item() + info["loss_gen_mpd"] = loss_gen_f.detach().cpu().item() + + return loss_gen_all, info + + +def compute_discriminator_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + features: Tensor, + audios: Tensor, +) -> Tuple[Tensor, MetricsTracker]: + device = params.device + model = model.module if isinstance(model, DDP) else model + + audios = audios.unsqueeze(1) + + gen_audios = model(features) # (B, 1, T) + + # MPD + y_df_hat_r, y_df_hat_g, _, _ = model.mpd(audios, gen_audios.detach()) + loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss( + y_df_hat_r, y_df_hat_g + ) + + # MSD + y_ds_hat_r, y_ds_hat_g, _, _ = model.msd(audios, gen_audios.detach()) + loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss( + y_ds_hat_r, y_ds_hat_g + ) + + loss_disc_all = loss_disc_s + loss_disc_f + + info = MetricsTracker() + # MetricsTracker will norm the loss value with "frames", set it to 1 here to + # make tot_loss look normal. + info["frames"] = 1 + info["loss_disc"] = loss_disc_all.detach().cpu().item() + info["loss_disc_msd"] = loss_disc_s.detach().cpu().item() + info["loss_disc_mpd"] = loss_disc_f.detach().cpu().item() + + return loss_disc_all, info + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer_g: Optimizer, + optimizer_d: Optimizer, + scheduler_g: ExponentialLR, + scheduler_d: ExponentialLR, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer. + scheduler: + The learning rate scheduler, we call step() every epoch. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to track the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = batch["features_lens"].size(0) + + features = batch["features"].to(device) # (B, T, F) + features_lens = batch["features_lens"].to(device) + audios = batch["audio"].to(device) + + # 8192 samples is 29 frames + segment_frames = ( + params.segment_size - params.frame_length + ) // params.frame_shift + 1 + start_p = random.randint(0, features_lens.min() - (segment_frames + 1)) + + features = features[:, start_p : start_p + segment_frames, :].permute( + 0, 2, 1 + ) # (B, F, T) + + audios = audios[ + :, + start_p * params.frame_shift : start_p * params.frame_shift + + params.segment_size, + ] # (B, T) + + try: + + optimizer_d.zero_grad() + + loss_disc, loss_disc_info = compute_discriminator_loss( + params=params, + model=model, + features=features, + audios=audios, + ) + + loss_disc.backward() + optimizer_d.step() + + optimizer_g.zero_grad() + loss_gen, loss_gen_info = compute_generator_loss( + params=params, + model=model, + features=features, + audios=audios, + ) + + loss_gen.backward() + optimizer_g.step() + + loss_info = loss_gen_info + loss_disc_info + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_gen_info + + except Exception as e: + logging.info(f"Caught exception : {e}.") + save_bad_model() + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_lr_g = max(scheduler_g.get_last_lr()) + cur_lr_d = max(scheduler_d.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + f"cur_lr_g: {cur_lr_g:.2e}, " + f"cur_lr_d: {cur_lr_d:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate_gen", cur_lr_g, params.batch_idx_train + ) + tb_writer.add_scalar( + "train/learning_rate_disc", cur_lr_d, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if ( + params.batch_idx_train % params.valid_interval == 0 + and not params.print_diagnostics + ): + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + rank=rank, + tb_writer=tb_writer, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + scheduler_g.step() + scheduler_d.step() + loss_value = tot_loss["loss_gen"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, + rank: int = 0, + tb_writer: Optional[SummaryWriter] = None, +) -> MetricsTracker: + """Run the validation process.""" + + model.eval() + torch.cuda.empty_cache() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + features = batch["features"] # (B, T, F) + audios = batch["audio"] + + x = features.permute(0, 2, 1) # (B, F, T) + y = batch["audio"] # (B, T) + y_mel = x.clone().to(device) # (B, F, T) + + y_g_hat = model(x.to(device)) # (B, 1, T) + + y_g_hat_mel = ( + fbank(y_g_hat.squeeze(1)).permute(0, 2, 1).to(device) + ) # (B, F, T) + + loss_mel_error = F.l1_loss(y_mel, y_g_hat_mel) + + loss_info = MetricsTracker() + # MetricsTracker will norm the loss value with "frames", set it to 1 here to + # make tot_loss look normal. + loss_info["frames"] = 1 + loss_info["loss_mel_error"] = loss_mel_error.item() + + tot_loss = tot_loss + loss_info + + if batch_idx <= 5 and rank == 0 and tb_writer is not None: + if params.batch_idx_train == params.valid_interval: + tb_writer.add_audio( + "gt/y_{}".format(batch_idx), + y[0], + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_figure( + "gt/y_spec_{}".format(batch_idx), + plot_spectrogram(x[0].cpu().numpy()), + params.batch_idx_train, + ) + tb_writer.add_audio( + "generated/y_hat_{}".format(batch_idx), + y_g_hat[0], + params.batch_idx_train, + params.sampling_rate, + ) + + tb_writer.add_figure( + "generated/y_hat_spec_{}".format(batch_idx), + plot_spectrogram(y_g_hat_mel[0].detach().cpu().numpy()), + params.batch_idx_train, + ) + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["loss_mel_error"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + torch.autograd.set_detect_anomaly(True) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + params.device = device + logging.info(params) + logging.info("About to create model") + + model = get_model(params) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model = model.to(device) + generator = model.generator + msd = model.msd + mpd = model.mpd + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer_g = torch.optim.AdamW( + generator.parameters(), + params.learning_rate, + betas=[params.adam_b1, params.adam_b2], + ) + optimizer_d = torch.optim.AdamW( + itertools.chain(msd.parameters(), mpd.parameters()), + params.learning_rate, + betas=[params.adam_b1, params.adam_b2], + ) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR( + optimizer_g, gamma=params.lr_decay + ) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR( + optimizer_d, gamma=params.lr_decay + ) + + if checkpoints is not None: + # load state_dict for optimizers + if "optimizer_g" in checkpoints: + logging.info("Loading generator optimizer state dict") + optimizer_g.load_state_dict(checkpoints["optimizer_g"]) + if "optimizer_d" in checkpoints: + logging.info("Loading discriminator optimizer state dict") + optimizer_d.load_state_dict(checkpoints["optimizer_d"]) + + # load state_dict for schedulers + if "scheduler_g" in checkpoints: + logging.info("Loading generator scheduler state dict") + scheduler_g.load_state_dict(checkpoints["scheduler_g"]) + if "scheduler_d" in checkpoints: + logging.info("Loading discriminator scheduler state dict") + scheduler_d.load_state_dict(checkpoints["scheduler_d"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + ljspeech = LJSpeechTtsDataModule(args) + + train_cuts = ljspeech.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + return False + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_dl = ljspeech.train_dataloaders(train_cuts) + + valid_cuts = ljspeech.valid_cuts() + valid_dl = ljspeech.valid_dataloaders(valid_cuts) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + if params.batch_idx_train % params.save_every_n == 0: + filename = params.exp_dir / f"checkpoint-{params.batch_idx_train}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + LJSpeechTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + main() diff --git a/egs/ljspeech/TTS/hifigan/tts_datamodule.py b/egs/ljspeech/TTS/hifigan/tts_datamodule.py new file mode 100644 index 0000000000..44b8052096 --- /dev/null +++ b/egs/ljspeech/TTS/hifigan/tts_datamodule.py @@ -0,0 +1,372 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LJSpeechTtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--sampling-rate", + type=int, + default=22050, + help="The sampleing rate of ljspeech dataset", + ) + + group.add_argument( + "--frame-shift", + type=int, + default=256, + help="Frame shift.", + ) + + group.add_argument( + "--frame-length", + type=int, + default=1024, + help="Frame shift.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + group.add_argument( + "--use-fft-mag", + type=str2bool, + default=True, + help="Whether to use magnitude of fbank, false to use power energy.", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=True, + return_tokens=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + sampling_rate = self.args.sampling_rate + config = FbankConfig( + sampling_rate=sampling_rate, + frame_length=self.args.frame_length / sampling_rate, # (in second), + frame_shift=self.args.frame_shift / sampling_rate, # (in second) + use_fft_mag=self.args.use_fft_mag, + ) + train = SpeechSynthesisDataset( + return_text=True, + return_tokens=False, + feature_input_strategy=OnTheFlyFeatures(Fbank(config)), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = self.args.sampling_rate + config = FbankConfig( + sampling_rate=sampling_rate, + frame_length=self.args.frame_length / sampling_rate, # (in second), + frame_shift=self.args.frame_shift / sampling_rate, # (in second) + use_fft_mag=self.args.use_fft_mag, + ) + validate = SpeechSynthesisDataset( + return_text=True, + return_tokens=False, + feature_input_strategy=OnTheFlyFeatures(Fbank(config)), + return_cuts=self.args.return_cuts, + ) + else: + validate = SpeechSynthesisDataset( + return_text=True, + return_tokens=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = self.args.sampling_rate + config = FbankConfig( + sampling_rate=sampling_rate, + frame_length=self.args.frame_length / sampling_rate, # (in second), + frame_shift=self.args.frame_shift / sampling_rate, # (in second) + use_fft_mag=self.args.use_fft_mag, + ) + test = SpeechSynthesisDataset( + return_text=True, + return_tokens=False, + feature_input_strategy=OnTheFlyFeatures(Fbank(config)), + return_cuts=self.args.return_cuts, + ) + else: + test = SpeechSynthesisDataset( + return_text=True, + return_tokens=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz" + ) + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get validation cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz" + ) + + @lru_cache() + def train_cuts_finetune(self) -> CutSet: + logging.info("About to get train cuts finetune") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_train_finetune.jsonl.gz" + ) + + @lru_cache() + def valid_cuts_finetune(self) -> CutSet: + logging.info("About to get validation cuts finetune") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_valid_finetune.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz" + ) diff --git a/egs/ljspeech/TTS/hifigan/utils.py b/egs/ljspeech/TTS/hifigan/utils.py new file mode 100644 index 0000000000..90b4f1aa56 --- /dev/null +++ b/egs/ljspeech/TTS/hifigan/utils.py @@ -0,0 +1,145 @@ +import glob +import os +import logging +import matplotlib +import torch +import torch.nn as nn +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union +from torch.nn.utils import weight_norm +from torch.optim.lr_scheduler import LRScheduler +from torch.optim import Optimizer +from torch.cuda.amp import GradScaler +from lhotse.dataset.sampling.base import CutSampler +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP + + +matplotlib.use("Agg") +import matplotlib.pylab as plt + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def load_checkpoint( + filename: Path, + model: nn.Module, + model_avg: Optional[nn.Module] = None, + optimizer_g: Optional[Optimizer] = None, + optimizer_d: Optional[Optimizer] = None, + scheduler_g: Optional[LRScheduler] = None, + scheduler_d: Optional[LRScheduler] = None, + scaler: Optional[GradScaler] = None, + sampler: Optional[CutSampler] = None, + strict: bool = False, +) -> Dict[str, Any]: + logging.info(f"Loading checkpoint from {filename}") + checkpoint = torch.load(filename, map_location="cpu") + + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + + checkpoint.pop("model") + + if model_avg is not None and "model_avg" in checkpoint: + logging.info("Loading averaged model") + model_avg.load_state_dict(checkpoint["model_avg"], strict=strict) + checkpoint.pop("model_avg") + + def load(name, obj): + s = checkpoint.get(name, None) + if obj and s: + obj.load_state_dict(s) + checkpoint.pop(name) + + load("optimizer_g", optimizer_g) + load("optimizer_d", optimizer_d) + load("scheduler_g", scheduler_g) + load("scheduler_d", scheduler_d) + load("grad_scaler", scaler) + load("sampler", sampler) + + return checkpoint + + +def save_checkpoint( + filename: Path, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + params: Optional[Dict[str, Any]] = None, + optimizer_g: Optional[Optimizer] = None, + optimizer_d: Optional[Optimizer] = None, + scheduler_g: Optional[LRScheduler] = None, + scheduler_d: Optional[LRScheduler] = None, + scaler: Optional[GradScaler] = None, + sampler: Optional[CutSampler] = None, + rank: int = 0, +) -> None: + """Save training information to a file. + + Args: + filename: + The checkpoint filename. + model: + The model to be saved. We only save its `state_dict()`. + model_avg: + The stored model averaged from the start of training. + params: + User defined parameters, e.g., epoch, loss. + optimizer: + The optimizer to be saved. We only save its `state_dict()`. + scheduler: + The scheduler to be saved. We only save its `state_dict()`. + scalar: + The GradScaler to be saved. We only save its `state_dict()`. + rank: + Used in DDP. We save checkpoint only for the node whose rank is 0. + Returns: + Return None. + """ + if rank != 0: + return + + logging.info(f"Saving checkpoint to {filename}") + + if isinstance(model, DDP): + model = model.module + + checkpoint = { + "model": model.state_dict(), + "optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None, + "optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None, + "scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None, + "scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None, + "grad_scaler": scaler.state_dict() if scaler is not None else None, + "sampler": sampler.state_dict() if sampler is not None else None, + } + + if model_avg is not None: + checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict() + + if params: + for k, v in params.items(): + assert k not in checkpoint + checkpoint[k] = v + + torch.save(checkpoint, filename) From 7b548898a9e1d1c580d539c9646cd3d64d974c16 Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 1 Nov 2024 15:11:05 +0800 Subject: [PATCH 2/2] Add libriheavy recipe --- egs/libriheavy/TTS/hifigan/models.py | 1 + egs/libriheavy/TTS/hifigan/train.py | 993 ++++++++++++++++++ egs/libriheavy/TTS/hifigan/tts_datamodule.py | 378 +++++++ egs/libriheavy/TTS/hifigan/utils.py | 1 + .../TTS/local/compute_fbank_libriheavy.py | 286 +++++ 5 files changed, 1659 insertions(+) create mode 120000 egs/libriheavy/TTS/hifigan/models.py create mode 100755 egs/libriheavy/TTS/hifigan/train.py create mode 100644 egs/libriheavy/TTS/hifigan/tts_datamodule.py create mode 120000 egs/libriheavy/TTS/hifigan/utils.py create mode 100755 egs/libriheavy/TTS/local/compute_fbank_libriheavy.py diff --git a/egs/libriheavy/TTS/hifigan/models.py b/egs/libriheavy/TTS/hifigan/models.py new file mode 120000 index 0000000000..d7cd306dc8 --- /dev/null +++ b/egs/libriheavy/TTS/hifigan/models.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/hifigan/models.py \ No newline at end of file diff --git a/egs/libriheavy/TTS/hifigan/train.py b/egs/libriheavy/TTS/hifigan/train.py new file mode 100755 index 0000000000..d57c3567ae --- /dev/null +++ b/egs/libriheavy/TTS/hifigan/train.py @@ -0,0 +1,993 @@ +#!/usr/bin/env python3 +# Copyright 2023-2024 Xiaomi Corp. (authors: Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union +import itertools +import json +import copy +import math +import os +import random + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import LibriheavyTtsDataModule + +from torch.optim.lr_scheduler import ExponentialLR, LRScheduler +from torch.optim import Optimizer + +from utils import load_checkpoint, save_checkpoint, plot_spectrogram + +from icefall import diagnostics +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + get_parameter_groups_with_lrs, +) +from models import ( + HiFiGAN, + feature_loss, + generator_loss, + discriminator_loss, +) +from lhotse import Fbank, FbankConfig +from lhotse.utils import fix_random_seed + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="hifigan/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--learning-rate", type=float, default=0.0002, help="The learning rate." + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--hifigan-version", + type=str, + default="v1", + choices=["v1", "v2", "v3"], + help="Version of hifigan.", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 500, + "feature_dim": 80, + "segment_size": 8192, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "v1": { + "upsample_initial_channel": 512, + "resblock_version": "1", + "upsample_rates": [8, 8, 2, 2], + "upsample_kernel_sizes": [16, 16, 4, 4], + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + }, + "v2": { + "upsample_initial_channel": 128, + "resblock_version": "1", + "upsample_rates": [8, 8, 2, 2], + "upsample_kernel_sizes": [16, 16, 4, 4], + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + }, + "v3": { + "upsample_initial_channel": 256, + "resblock_version": "2", + "upsample_rates": [8, 8, 4], + "upsample_kernel_sizes": [16, 16, 8], + "resblock_kernel_sizes": [3, 5, 7], + "resblock_dilation_sizes": [[1, 2], [2, 6], [3, 12]], + }, + "env_info": get_env_info(), + } + ) + + return params + + +def fbank( + audio: torch.Tensor, + lengths: Optional[torch.Tensor] = None, + sampling_rate: int = 16000, + frame_length: int = 1024, + frame_shift: int = 256, + use_fft_mag: bool = True, +): + sampling_rate = sampling_rate + config = FbankConfig( + sampling_rate=sampling_rate, + frame_length=frame_length / sampling_rate, # (in second), + frame_shift=frame_shift / sampling_rate, # (in second) + use_fft_mag=use_fft_mag, + ) + fb = Fbank(config) + feat = fb.extract_batch(audio, sampling_rate=sampling_rate, lengths=lengths) + if feat.dim() == 2: + feat = feat.unsqueeze(0) + return feat + + +def get_model(params: AttributeDict) -> nn.Module: + device = params.device + model = HiFiGAN( + in_channels=params.feature_dim, + upsample_initial_channel=params[params.hifigan_version][ + "upsample_initial_channel" + ], + upsample_rates=params[params.hifigan_version]["upsample_rates"], + upsample_kernel_sizes=params[params.hifigan_version]["upsample_kernel_sizes"], + resblock_version=params[params.hifigan_version]["resblock_version"], + resblock_kernel_sizes=params[params.hifigan_version]["resblock_kernel_sizes"], + resblock_dilation_sizes=params[params.hifigan_version][ + "resblock_dilation_sizes" + ], + ).to(device) + num_param_g = sum([p.numel() for p in model.generator.parameters()]) + logging.info(f"Number of Generator parameters : {num_param_g}") + num_param_mpd = sum([p.numel() for p in model.mpd.parameters()]) + logging.info(f"Number of MultiPeriodDiscriminator parameters : {num_param_mpd}") + num_param_msd = sum([p.numel() for p in model.msd.parameters()]) + logging.info(f"Number of MultiScaleDiscriminator parameters : {num_param_msd}") + logging.info( + f"Number of model parameters : {num_param_g + num_param_mpd + num_param_msd}" + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer_g: Optional[Optimizer] = None, + optimizer_d: Optional[Optimizer] = None, + scheduler_g: Optional[LRScheduler] = None, + scheduler_d: Optional[LRScheduler] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def compute_generator_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + features: Tensor, + audios: Tensor, +) -> Tuple[Tensor, MetricsTracker]: + device = params.device + model = model.module if isinstance(model, DDP) else model + + audios = audios.unsqueeze(1) # (B, 1, T) + + gen_audios = model(features) # (B, 1, T) + + gen_features = fbank(gen_audios.squeeze(1)).permute(0, 2, 1).to(device) # (B, F, T) + + # L1 Mel-Spectrogram Loss + loss_mel = F.l1_loss(features, gen_features) * 45 + + y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = model.mpd(audios, gen_audios) + y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = model.msd(audios, gen_audios) + + loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) + loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) + + loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) + loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) + + loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel + + assert loss_gen_all.requires_grad == True + + info = MetricsTracker() + info["frames"] = 1 + info["loss_gen"] = loss_gen_all.detach().cpu().item() + info["loss_mel"] = loss_mel.detach().cpu().item() + info["loss_mel_error"] = loss_mel.detach().cpu().item() / 45 + info["loss_feature_msd"] = loss_fm_s.detach().cpu().item() + info["loss_feature_mpd"] = loss_fm_f.detach().cpu().item() + info["loss_gen_msd"] = loss_gen_s.detach().cpu().item() + info["loss_gen_mpd"] = loss_gen_f.detach().cpu().item() + + return loss_gen_all, info + + +def compute_discriminator_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + features: Tensor, + audios: Tensor, +) -> Tuple[Tensor, MetricsTracker]: + device = params.device + model = model.module if isinstance(model, DDP) else model + + audios = audios.unsqueeze(1) + + gen_audios = model(features) # (B, 1, T) + + # MPD + y_df_hat_r, y_df_hat_g, _, _ = model.mpd(audios, gen_audios.detach()) + loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss( + y_df_hat_r, y_df_hat_g + ) + + # MSD + y_ds_hat_r, y_ds_hat_g, _, _ = model.msd(audios, gen_audios.detach()) + loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss( + y_ds_hat_r, y_ds_hat_g + ) + + loss_disc_all = loss_disc_s + loss_disc_f + + info = MetricsTracker() + # MetricsTracker will norm the loss value with "frames", set it to 1 here to + # make tot_loss look normal. + info["frames"] = 1 + info["loss_disc"] = loss_disc_all.detach().cpu().item() + info["loss_disc_msd"] = loss_disc_s.detach().cpu().item() + info["loss_disc_mpd"] = loss_disc_f.detach().cpu().item() + + return loss_disc_all, info + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer_g: Optimizer, + optimizer_d: Optimizer, + scheduler_g: ExponentialLR, + scheduler_d: ExponentialLR, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer. + scheduler: + The learning rate scheduler, we call step() every epoch. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to track the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = batch["features_lens"].size(0) + + features = batch["features"].to(device) # (B, T, F) + features_lens = batch["features_lens"].to(device) + audios = batch["audio"].to(device) + + # 8192 samples is 29 frames + segment_frames = ( + params.segment_size - params.frame_length + ) // params.frame_shift + 1 + start_p = random.randint(0, features_lens.min() - (segment_frames + 1)) + + features = features[:, start_p : start_p + segment_frames, :].permute( + 0, 2, 1 + ) # (B, F, T) + + audios = audios[ + :, + start_p * params.frame_shift : start_p * params.frame_shift + + params.segment_size, + ] # (B, T) + + try: + + optimizer_d.zero_grad() + + loss_disc, loss_disc_info = compute_discriminator_loss( + params=params, + model=model, + features=features, + audios=audios, + ) + + loss_disc.backward() + optimizer_d.step() + + optimizer_g.zero_grad() + loss_gen, loss_gen_info = compute_generator_loss( + params=params, + model=model, + features=features, + audios=audios, + ) + + loss_gen.backward() + optimizer_g.step() + + loss_info = loss_gen_info + loss_disc_info + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_gen_info + + except Exception as e: + logging.info(f"Caught exception : {e}.") + save_bad_model() + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_lr_g = max(scheduler_g.get_last_lr()) + cur_lr_d = max(scheduler_d.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + f"cur_lr_g: {cur_lr_g:.2e}, " + f"cur_lr_d: {cur_lr_d:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate_gen", cur_lr_g, params.batch_idx_train + ) + tb_writer.add_scalar( + "train/learning_rate_disc", cur_lr_d, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if ( + params.batch_idx_train % params.valid_interval == 0 + and not params.print_diagnostics + ): + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + rank=rank, + tb_writer=tb_writer, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + scheduler_g.step() + scheduler_d.step() + loss_value = tot_loss["loss_gen"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, + rank: int = 0, + tb_writer: Optional[SummaryWriter] = None, +) -> MetricsTracker: + """Run the validation process.""" + + model.eval() + torch.cuda.empty_cache() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + features = batch["features"] # (B, T, F) + audios = batch["audio"] + + x = features.permute(0, 2, 1) # (B, F, T) + y = batch["audio"] # (B, T) + y_mel = x.clone().to(device) # (B, F, T) + + y_g_hat = model(x.to(device)) # (B, 1, T) + + y_g_hat_mel = ( + fbank(y_g_hat.squeeze(1)).permute(0, 2, 1).to(device) + ) # (B, F, T) + + loss_mel_error = F.l1_loss(y_mel, y_g_hat_mel) + + loss_info = MetricsTracker() + # MetricsTracker will norm the loss value with "frames", set it to 1 here to + # make tot_loss look normal. + loss_info["frames"] = 1 + loss_info["loss_mel_error"] = loss_mel_error.item() + + tot_loss = tot_loss + loss_info + + if batch_idx <= 5 and rank == 0 and tb_writer is not None: + if params.batch_idx_train == params.valid_interval: + tb_writer.add_audio( + "gt/y_{}".format(batch_idx), + y[0], + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_figure( + "gt/y_spec_{}".format(batch_idx), + plot_spectrogram(x[0].cpu().numpy()), + params.batch_idx_train, + ) + tb_writer.add_audio( + "generated/y_hat_{}".format(batch_idx), + y_g_hat[0], + params.batch_idx_train, + params.sampling_rate, + ) + + tb_writer.add_figure( + "generated/y_hat_spec_{}".format(batch_idx), + plot_spectrogram(y_g_hat_mel[0].detach().cpu().numpy()), + params.batch_idx_train, + ) + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["loss_mel_error"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + torch.autograd.set_detect_anomaly(True) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + params.device = device + logging.info(params) + logging.info("About to create model") + + model = get_model(params) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model = model.to(device) + generator = model.generator + msd = model.msd + mpd = model.mpd + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer_g = torch.optim.AdamW( + generator.parameters(), + params.learning_rate, + betas=[params.adam_b1, params.adam_b2], + ) + optimizer_d = torch.optim.AdamW( + itertools.chain(msd.parameters(), mpd.parameters()), + params.learning_rate, + betas=[params.adam_b1, params.adam_b2], + ) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR( + optimizer_g, gamma=params.lr_decay + ) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR( + optimizer_d, gamma=params.lr_decay + ) + + if checkpoints is not None: + # load state_dict for optimizers + if "optimizer_g" in checkpoints: + logging.info("Loading generator optimizer state dict") + optimizer_g.load_state_dict(checkpoints["optimizer_g"]) + if "optimizer_d" in checkpoints: + logging.info("Loading discriminator optimizer state dict") + optimizer_d.load_state_dict(checkpoints["optimizer_d"]) + + # load state_dict for schedulers + if "scheduler_g" in checkpoints: + logging.info("Loading generator scheduler state dict") + scheduler_g.load_state_dict(checkpoints["scheduler_g"]) + if "scheduler_d" in checkpoints: + logging.info("Loading discriminator scheduler state dict") + scheduler_d.load_state_dict(checkpoints["scheduler_d"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + libriheavy = LibriheavyTtsDataModule(args) + + train_cuts = libriheavy.train_cuts_small() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + return False + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_dl = libriheavy.train_dataloaders(train_cuts) + + valid_cuts = libriheavy.valid_cuts() + valid_dl = libriheavy.valid_dataloaders(valid_cuts) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + if params.batch_idx_train % params.save_every_n == 0: + filename = params.exp_dir / f"checkpoint-{params.batch_idx_train}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + LibriheavyTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + main() diff --git a/egs/libriheavy/TTS/hifigan/tts_datamodule.py b/egs/libriheavy/TTS/hifigan/tts_datamodule.py new file mode 100644 index 0000000000..f2a630665f --- /dev/null +++ b/egs/libriheavy/TTS/hifigan/tts_datamodule.py @@ -0,0 +1,378 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriheavyTtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--return-text", + type=str2bool, + default=True, + help="Whether to return the transcript of the audio.", + ) + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--sampling-rate", + type=int, + default=16000, + help="The sampleing rate of ljspeech dataset", + ) + + group.add_argument( + "--frame-shift", + type=int, + default=256, + help="Frame shift.", + ) + + group.add_argument( + "--frame-length", + type=int, + default=1024, + help="Frame shift.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + group.add_argument( + "--use-fft-mag", + type=str2bool, + default=True, + help="Whether to use magnitude of fbank, false to use power energy.", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=self.args.return_text, + return_tokens=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + sampling_rate = self.args.sampling_rate + config = FbankConfig( + sampling_rate=sampling_rate, + frame_length=self.args.frame_length / sampling_rate, # (in second), + frame_shift=self.args.frame_shift / sampling_rate, # (in second) + use_fft_mag=self.args.use_fft_mag, + ) + train = SpeechSynthesisDataset( + return_text=self.args.return_text, + return_tokens=False, + feature_input_strategy=OnTheFlyFeatures(Fbank(config)), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = self.args.sampling_rate + config = FbankConfig( + sampling_rate=sampling_rate, + frame_length=self.args.frame_length / sampling_rate, # (in second), + frame_shift=self.args.frame_shift / sampling_rate, # (in second) + use_fft_mag=self.args.use_fft_mag, + ) + validate = SpeechSynthesisDataset( + return_text=self.args.return_text, + return_tokens=False, + feature_input_strategy=OnTheFlyFeatures(Fbank(config)), + return_cuts=self.args.return_cuts, + ) + else: + validate = SpeechSynthesisDataset( + return_text=self.args.return_text, + return_tokens=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = self.args.sampling_rate + config = FbankConfig( + sampling_rate=sampling_rate, + frame_length=self.args.frame_length / sampling_rate, # (in second), + frame_shift=self.args.frame_shift / sampling_rate, # (in second) + use_fft_mag=self.args.use_fft_mag, + ) + test = SpeechSynthesisDataset( + return_text=self.args.return_text, + return_tokens=False, + feature_input_strategy=OnTheFlyFeatures(Fbank(config)), + return_cuts=self.args.return_cuts, + ) + else: + test = SpeechSynthesisDataset( + return_text=self.args.return_text, + return_tokens=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts_small(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_small.jsonl.gz" + ) + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get validation cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_dev.jsonl.gz" + ) + + @lru_cache() + def train_cuts_finetune(self) -> CutSet: + logging.info("About to get train cuts finetune") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_small_finetune.jsonl.gz" + ) + + @lru_cache() + def valid_cuts_finetune(self) -> CutSet: + logging.info("About to get validation cuts finetune") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_dev_finetune.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_test_clean.jsonl.gz" + ) diff --git a/egs/libriheavy/TTS/hifigan/utils.py b/egs/libriheavy/TTS/hifigan/utils.py new file mode 120000 index 0000000000..81ca818c1a --- /dev/null +++ b/egs/libriheavy/TTS/hifigan/utils.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/hifigan/utils.py \ No newline at end of file diff --git a/egs/libriheavy/TTS/local/compute_fbank_libriheavy.py b/egs/libriheavy/TTS/local/compute_fbank_libriheavy.py new file mode 100755 index 0000000000..16b3d668dd --- /dev/null +++ b/egs/libriheavy/TTS/local/compute_fbank_libriheavy.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the Libriheavy dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import torch +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + KaldifeatFbank, + KaldifeatFbankConfig, + LilcomChunkyWriter, +) + +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--manifest-dir", + type=str, + help="""The source directory that contains raw manifests. + """, + default="data/manifests", + ) + + parser.add_argument( + "--fbank-dir", + type=str, + help="""Fbank output dir + """, + default="data/fbank", + ) + + parser.add_argument( + "--sampling-rate", + type=int, + default=16000, + ) + + parser.add_argument( + "--num-mel-bins", + type=int, + default=80, + ) + + parser.add_argument( + "--frame-length", + type=int, + default=1024, + ) + + parser.add_argument( + "--frame-shift", + type=int, + default=256, + ) + + parser.add_argument( + "--use-fft-mag", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--subset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + parser.add_argument( + "--num-workers", + type=int, + default=20, + help="Number of dataloading workers used for reading the audio.", + ) + + parser.add_argument( + "--batch-duration", + type=float, + default=600.0, + help="The maximum number of audio seconds in a batch." + "Determines batch size dynamically.", + ) + + parser.add_argument( + "--use-splits", + type=str2bool, + default=False, + help="Whether to compute fbank on splits.", + ) + + parser.add_argument( + "--num-splits", + type=int, + help="""The number of splits of the medium and large subset. + Only needed when --use-splits is true.""", + ) + + parser.add_argument( + "--start", + type=int, + default=0, + help="""Process pieces starting from this number (inclusive). + Only needed when --use-splits is true.""", + ) + + parser.add_argument( + "--stop", + type=int, + default=-1, + help="""Stop processing pieces until this number (exclusive). + Only needed when --use-splits is true.""", + ) + + return parser.parse_args() + + +def compute_fbank_libriheavy(args): + src_dir = Path(args.manifest_dir) + output_dir = Path(args.fbank_dir) + num_jobs = min(15, os.cpu_count()) + num_mel_bins = args.num_mel_bins + subset = args.subset + + sampling_rate = args.sampling_rate + frame_length = args.frame_length / sampling_rate # (in second) + frame_shift = args.frame_shift / sampling_rate # (in second) + use_fft_mag = args.use_fft_mag + + config = FbankConfig( + sampling_rate=sampling_rate, + frame_length=frame_length, + frame_shift=frame_shift, + use_fft_mag=use_fft_mag, + num_mel_bins=num_mel_bins, + ) + extractor = Fbank(config) + + with get_executor() as ex: # Initialize the executor only once. + output_cuts_path = output_dir / f"libriheavy_cuts_{subset}.jsonl.gz" + if output_cuts_path.exists(): + logging.info(f"{output_cuts_path} exists - skipping") + return + + input_cuts_path = src_dir / f"libriheavy_cuts_{subset}.jsonl.gz" + assert input_cuts_path.exists(), f"{input_cuts_path} does not exist!" + logging.info(f"Loading {input_cuts_path}") + cut_set = CutSet.from_file(input_cuts_path) + + logging.info("Computing features") + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/libriheavy_feats_{subset}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + + logging.info(f"Saving to {output_cuts_path}") + cut_set.to_file(output_cuts_path) + + +def compute_fbank_libriheavy_splits(args): + num_splits = args.num_splits + subset = args.subset + src_dir = f"{args.manifest_dir}/libriheavy_{subset}_split" + src_dir = Path(src_dir) + output_dir = f"{args.fbank_dir}/libriheavy_{subset}_split" + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + start = args.start + stop = args.stop + if stop < start: + stop = num_splits + + stop = min(stop, num_splits) + + num_mel_bins = args.num_mel_bins + sampling_rate = args.sampling_rate + frame_length = args.frame_length / sampling_rate # (in second) + frame_shift = args.frame_shift / sampling_rate # (in second) + use_fft_mag = args.use_fft_mag + + config = FbankConfig( + sampling_rate=sampling_rate, + frame_length=frame_length, + frame_shift=frame_shift, + use_fft_mag=use_fft_mag, + num_mel_bins=num_mel_bins, + ) + extractor = Fbank(config) + + num_digits = 8 # num_digits is fixed by lhotse split-lazy + for i in range(start, stop): + idx = f"{i + 1}".zfill(num_digits) + logging.info(f"Processing {idx}/{num_splits}") + + cuts_path = output_dir / f"libriheavy_cuts_{subset}.{idx}.jsonl.gz" + if cuts_path.is_file(): + logging.info(f"{cuts_path} exists - skipping") + continue + + raw_cuts_path = src_dir / f"libriheavy_cuts_{subset}.{idx}.jsonl.gz" + if not raw_cuts_path.is_file(): + logging.info(f"{raw_cuts_path} does not exist - skipping it") + continue + + logging.info(f"Loading {raw_cuts_path}") + cut_set = CutSet.from_file(raw_cuts_path) + + logging.info("Computing features") + if (output_dir / f"libriheavy_feats_{subset}_{idx}.lca").exists(): + logging.info(f"Removing {output_dir}/libriheavy_feats_{subset}_{idx}.lca") + os.remove(output_dir / f"libriheavy_feats_{subset}_{idx}.lca") + + cut_set = cut_set.compute_and_store_features_batch( + extractor=extractor, + storage_path=f"{output_dir}/libriheavy_feats_{subset}_{idx}", + num_workers=args.num_workers, + batch_duration=args.batch_duration, + overwrite=True, + ) + + logging.info("About to split cuts into smaller chunks.") + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, min_duration=None + ) + + logging.info(f"Saving to {cuts_path}") + cut_set.to_file(cuts_path) + logging.info(f"Saved to {cuts_path}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + + if args.use_splits: + assert args.num_splits is not None, "Please provide num_splits" + compute_fbank_libriheavy_splits(args) + else: + compute_fbank_libriheavy(args)