diff --git a/egs/ljspeech/TTS/matcha/inference.py b/egs/ljspeech/TTS/matcha/inference.py new file mode 100755 index 0000000000..29a0f53a83 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/inference.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 + +import argparse +import datetime as dt +import logging +from pathlib import Path + +import numpy as np +import soundfile as sf +import torch +from matcha.hifigan.config import v1 +from matcha.hifigan.denoiser import Denoiser +from matcha.hifigan.models import Generator as HiFiGAN +from matcha.text import sequence_to_text, text_to_sequence +from matcha.utils.utils import intersperse +from tqdm.auto import tqdm +from train import get_model, get_params + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=140, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="matcha/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + return parser + + +def load_vocoder(checkpoint_path): + h = AttributeDict(v1) + hifigan = HiFiGAN(h).to("cpu") + hifigan.load_state_dict( + torch.load(checkpoint_path, map_location="cpu")["generator"] + ) + _ = hifigan.eval() + hifigan.remove_weight_norm() + return hifigan + + +def to_waveform(mel, vocoder, denoiser): + audio = vocoder(mel).clamp(-1, 1) + audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze() + return audio.cpu().squeeze() + + +def save_to_folder(filename: str, output: dict, folder: str): + folder = Path(folder) + folder.mkdir(exist_ok=True, parents=True) + np.save(folder / f"{filename}", output["mel"].cpu().numpy()) + sf.write(folder / f"{filename}.wav", output["waveform"], 22050, "PCM_24") + + +def process_text(text: str): + x = torch.tensor( + intersperse(text_to_sequence(text, ["english_cleaners2"])[0], 0), + dtype=torch.long, + device="cpu", + )[None] + x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu") + x_phones = sequence_to_text(x.squeeze(0).tolist()) + return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones} + + +def synthesise(model, n_timesteps, text, length_scale, temperature, spks=None): + text_processed = process_text(text) + start_t = dt.datetime.now() + output = model.synthesise( + text_processed["x"], + text_processed["x_lengths"], + n_timesteps=n_timesteps, + temperature=temperature, + spks=spks, + length_scale=length_scale, + ) + print("output.shape", list(output.keys()), output["mel"].shape) + # merge everything to one dict + output.update({"start_t": start_t, **text_processed}) + return output + + +@torch.inference_mode() +def main(): + parser = get_parser() + args = parser.parse_args() + params = get_params() + + params.update(vars(args)) + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + model.eval() + + vocoder = load_vocoder("/star-fj/fangjun/open-source/Matcha-TTS/generator_v1") + denoiser = Denoiser(vocoder, mode="zeros") + + texts = [ + "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", + "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.", + ] + + # Number of ODE Solver steps + n_timesteps = 2 + + # Changes to the speaking rate + length_scale = 1.0 + + # Sampling temperature + temperature = 0.667 + + outputs, rtfs = [], [] + rtfs_w = [] + for i, text in enumerate(tqdm(texts)): + output = synthesise( + model=model, + n_timesteps=n_timesteps, + text=text, + length_scale=length_scale, + temperature=temperature, + ) # , torch.tensor([15], device=device, dtype=torch.long).unsqueeze(0)) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + + # Compute Real Time Factor (RTF) with HiFi-GAN + t = (dt.datetime.now() - output["start_t"]).total_seconds() + rtf_w = t * 22050 / (output["waveform"].shape[-1]) + + # Pretty print + print(f"{'*' * 53}") + print(f"Input text - {i}") + print(f"{'-' * 53}") + print(output["x_orig"]) + print(f"{'*' * 53}") + print(f"Phonetised text - {i}") + print(f"{'-' * 53}") + print(output["x_phones"]) + print(f"{'*' * 53}") + print(f"RTF:\t\t{output['rtf']:.6f}") + print(f"RTF Waveform:\t{rtf_w:.6f}") + rtfs.append(output["rtf"]) + rtfs_w.append(rtf_w) + + # Save the generated waveform + save_to_folder(i, output, folder="./my-output") + + print(f"Number of ODE steps: {n_timesteps}") + print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}") + print( + f"Mean RTF Waveform (incl. vocoder):\t{np.mean(rtfs_w):.6f} ± {np.std(rtfs_w):.6f}" + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/ljspeech/TTS/matcha/models/matcha_tts.py b/egs/ljspeech/TTS/matcha/models/matcha_tts.py index d4b1c57ab6..d5d78c6196 100644 --- a/egs/ljspeech/TTS/matcha/models/matcha_tts.py +++ b/egs/ljspeech/TTS/matcha/models/matcha_tts.py @@ -5,6 +5,7 @@ import torch import matcha.utils.monotonic_align as monotonic_align + # from matcha import utils # from matcha.models.baselightningmodule import BaseLightningClass from matcha.models.components.flow_matching import CFM @@ -30,7 +31,7 @@ def __init__( encoder, decoder, cfm, - # data_statistics, + data_statistics, out_size, optimizer=None, scheduler=None, @@ -71,9 +72,13 @@ def __init__( ) # self.update_data_statistics(data_statistics) + self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"])) + self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"])) @torch.inference_mode() - def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0): + def synthesise( + self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0 + ): """ Generates mel-spectrogram from text. Returns: 1. encoder outputs @@ -149,7 +154,17 @@ def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, leng "rtf": rtf, } - def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None, durations=None): + def forward( + self, + x, + x_lengths, + y, + y_lengths, + spks=None, + out_size=None, + cond=None, + durations=None, + ): """ Computes 3 losses: 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS). @@ -187,7 +202,9 @@ def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=Non # Use MAS to find most likely alignment `attn` between text and mel-spectrogram with torch.no_grad(): const = -0.5 * math.log(2 * math.pi) * self.n_feats - factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device) + factor = -0.5 * torch.ones( + mu_x.shape, dtype=mu_x.dtype, device=mu_x.device + ) y_square = torch.matmul(factor.transpose(1, 2), y**2) y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y) mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1) @@ -206,12 +223,25 @@ def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=Non # - Do not need this hack for Matcha-TTS, but it works with it as well if not isinstance(out_size, type(None)): max_offset = (y_lengths - out_size).clamp(0) - offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy())) + offset_ranges = list( + zip([0] * max_offset.shape[0], max_offset.cpu().numpy()) + ) out_offset = torch.LongTensor( - [torch.tensor(random.choice(range(start, end)) if end > start else 0) for start, end in offset_ranges] + [ + torch.tensor(random.choice(range(start, end)) if end > start else 0) + for start, end in offset_ranges + ] ).to(y_lengths) - attn_cut = torch.zeros(attn.shape[0], attn.shape[1], out_size, dtype=attn.dtype, device=attn.device) - y_cut = torch.zeros(y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device) + attn_cut = torch.zeros( + attn.shape[0], + attn.shape[1], + out_size, + dtype=attn.dtype, + device=attn.device, + ) + y_cut = torch.zeros( + y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device + ) y_cut_lengths = [] for i, (y_, out_offset_) in enumerate(zip(y, out_offset)): @@ -233,12 +263,36 @@ def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=Non mu_y = mu_y.transpose(1, 2) # Compute loss of the decoder - diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond) + diff_loss, _ = self.decoder.compute_loss( + x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond + ) if self.prior_loss: - prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) + prior_loss = torch.sum( + 0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask + ) prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats) else: prior_loss = 0 return dur_loss, prior_loss, diff_loss, attn + + def get_losses(self, batch): + x, x_lengths = batch["x"], batch["x_lengths"] + y, y_lengths = batch["y"], batch["y_lengths"] + spks = batch["spks"] + + dur_loss, prior_loss, diff_loss, *_ = self( + x=x, + x_lengths=x_lengths, + y=y, + y_lengths=y_lengths, + spks=spks, + out_size=self.out_size, + durations=batch["durations"], + ) + return { + "dur_loss": dur_loss, + "prior_loss": prior_loss, + "diff_loss": diff_loss, + } diff --git a/egs/ljspeech/TTS/matcha/test-train.py b/egs/ljspeech/TTS/matcha/test-train.py new file mode 100644 index 0000000000..f41ee4eae1 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/test-train.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + + +import torch + + +from icefall.utils import AttributeDict +from matcha.models.matcha_tts import MatchaTTS +from matcha.data.text_mel_datamodule import TextMelDataModule + + +def _get_data_params() -> AttributeDict: + params = AttributeDict( + { + "name": "ljspeech", + "train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt", + "valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt", + "batch_size": 32, + "num_workers": 3, + "pin_memory": False, + "cleaners": ["english_cleaners2"], + "add_blank": True, + "n_spks": 1, + "n_fft": 1024, + "n_feats": 80, + "sample_rate": 22050, + "hop_length": 256, + "win_length": 1024, + "f_min": 0, + "f_max": 8000, + "seed": 1234, + "load_durations": False, + "data_statistics": AttributeDict( + { + "mel_mean": -5.517028331756592, + "mel_std": 2.0643954277038574, + } + ), + } + ) + return params + + +def _get_model_params() -> AttributeDict: + n_feats = 80 + filter_channels_dp = 256 + encoder_params_p_dropout = 0.1 + params = AttributeDict( + { + "n_vocab": 178, + "n_spks": 1, # for ljspeech. + "spk_emb_dim": 64, + "n_feats": n_feats, + "out_size": None, # or use 172 + "prior_loss": True, + "use_precomputed_durations": False, + "encoder": AttributeDict( + { + "encoder_type": "RoPE Encoder", # not used + "encoder_params": AttributeDict( + { + "n_feats": n_feats, + "n_channels": 192, + "filter_channels": 768, + "filter_channels_dp": filter_channels_dp, + "n_heads": 2, + "n_layers": 6, + "kernel_size": 3, + "p_dropout": encoder_params_p_dropout, + "spk_emb_dim": 64, + "n_spks": 1, + "prenet": True, + } + ), + "duration_predictor_params": AttributeDict( + { + "filter_channels_dp": filter_channels_dp, + "kernel_size": 3, + "p_dropout": encoder_params_p_dropout, + } + ), + } + ), + "decoder": AttributeDict( + { + "channels": [256, 256], + "dropout": 0.05, + "attention_head_dim": 64, + "n_blocks": 1, + "num_mid_blocks": 2, + "num_heads": 2, + "act_fn": "snakebeta", + } + ), + "cfm": AttributeDict( + { + "name": "CFM", + "solver": "euler", + "sigma_min": 1e-4, + } + ), + "optimizer": AttributeDict( + { + "lr": 1e-4, + "weight_decay": 0.0, + } + ), + } + ) + + return params + + +def get_params(): + params = AttributeDict( + { + "model": _get_model_params(), + "data": _get_data_params(), + } + ) + return params + + +def get_model(params): + m = MatchaTTS(**params.model) + return m + + +def main(): + params = get_params() + + data_module = TextMelDataModule(hparams=params.data) + if False: + for b in data_module.train_dataloader(): + assert isinstance(b, dict) + # b.keys() + # ['x', 'x_lengths', 'y', 'y_lengths', 'spks', 'filepaths', 'x_texts', 'durations'] + # x: [batch_size, 289], torch.int64 + # x_lengths: [batch_size], torch.int64 + # y: [batch_size, n_feats, num_frames], torch.float32 + # y_lengths: [batch_size], torch.int64 + # spks: None + # filepaths: list, (batch_size,) + # x_texts: list, (batch_size,) + # durations: None + + m = get_model(params) + print(m) + + num_param = sum([p.numel() for p in m.parameters()]) + print(f"Number of parameters: {num_param}") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index f41ee4eae1..385dcba23e 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -2,12 +2,111 @@ # Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Union + import torch +import torch.nn as nn +from lhotse.utils import fix_random_seed +from matcha.data.text_mel_datamodule import TextMelDataModule +from icefall.env import get_env_info +from matcha.models.matcha_tts import MatchaTTS +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter +from utils2 import MetricsTracker, plot_feature +from icefall.checkpoint import load_checkpoint, save_checkpoint +from icefall.dist import cleanup_dist, setup_dist +from icefall.utils import AttributeDict, setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=1000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="matcha/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=10, + help="""Save checkpoint after processing this number of epochs" + periodically. We save checkpoint to exp-dir/ whenever + params.cur_epoch % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. + Since it will take around 1000 epochs, we suggest using a large + save_every_n to save disk space. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--batch-size", + type=int, + default=32, + ) + + return parser -from icefall.utils import AttributeDict -from matcha.models.matcha_tts import MatchaTTS -from matcha.data.text_mel_datamodule import TextMelDataModule + +def get_data_statistics(): + return AttributeDict( + { + "mel_mean": -5.517028331756592, + "mel_std": 2.0643954277038574, + } + ) def _get_data_params() -> AttributeDict: @@ -16,7 +115,6 @@ def _get_data_params() -> AttributeDict: "name": "ljspeech", "train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt", "valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt", - "batch_size": 32, "num_workers": 3, "pin_memory": False, "cleaners": ["english_cleaners2"], @@ -31,12 +129,7 @@ def _get_data_params() -> AttributeDict: "f_max": 8000, "seed": 1234, "load_durations": False, - "data_statistics": AttributeDict( - { - "mel_mean": -5.517028331756592, - "mel_std": 2.0643954277038574, - } - ), + "data_statistics": get_data_statistics(), } ) return params @@ -55,6 +148,7 @@ def _get_model_params() -> AttributeDict: "out_size": None, # or use 172 "prior_loss": True, "use_precomputed_durations": False, + "data_statistics": get_data_statistics(), "encoder": AttributeDict( { "encoder_type": "RoPE Encoder", # not used @@ -115,42 +209,368 @@ def _get_model_params() -> AttributeDict: def get_params(): params = AttributeDict( { - "model": _get_model_params(), - "data": _get_data_params(), + "model_args": _get_model_params(), + "data_args": _get_data_params(), + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": -1, # 0 + "log_interval": 50, + "valid_interval": 2000, + "env_info": get_env_info(), } ) return params def get_model(params): - m = MatchaTTS(**params.model) + m = MatchaTTS(**params.model_args) return m +def load_checkpoint_if_available( + params: AttributeDict, model: nn.Module +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint(filename, model=model) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, + rank: int = 0, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.to(device) + losses = model.get_losses(batch) + loss = sum(losses.values()) + + batch_size = batch["x"].shape[0] + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + s = 0 + + for key, value in losses.items(): + v = value.detach().item() + loss_info[key] = v * batch_size + s += v * batch_size + + loss_info["tot_loss"] = s + + # summary stats + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["tot_loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: Optimizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to track the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + # used to track the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer=optimizer, + scaler=scaler, + rank=rank, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.to(device) + + batch_size = batch["x"].shape[0] + + try: + with autocast(enabled=params.use_fp16): + losses = model.get_losses(batch) + + loss = sum(losses.values()) + + optimizer.zero_grad() + scaler.scale(loss).backward() + scaler.step(optimizer) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + s = 0 + + for key, value in losses.items(): + v = value.detach().item() + loss_info[key] = v * batch_size + s += v * batch_size + + loss_info["tot_loss"] = s + + tot_loss = tot_loss + loss_info + except: # noqa + save_bad_model() + raise + + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if params.batch_idx_train % params.valid_interval == 1: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + rank=rank, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["tot_loss"] / tot_loss["samples"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + def main(): + parser = get_parser() + args = parser.parse_args() params = get_params() - data_module = TextMelDataModule(hparams=params.data) - if False: - for b in data_module.train_dataloader(): - assert isinstance(b, dict) - # b.keys() - # ['x', 'x_lengths', 'y', 'y_lengths', 'spks', 'filepaths', 'x_texts', 'durations'] - # x: [batch_size, 289], torch.int64 - # x_lengths: [batch_size], torch.int64 - # y: [batch_size, n_feats, num_frames], torch.float32 - # y_lengths: [batch_size], torch.int64 - # spks: None - # filepaths: list, (batch_size,) - # x_texts: list, (batch_size,) - # durations: None - - m = get_model(params) - print(m) - - num_param = sum([p.numel() for p in m.parameters()]) + params.update(vars(args)) + + params.data_args.batch_size = params.batch_size + del params.batch_size + + fix_random_seed(params.seed) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + logging.info(f"Device: {device}") + print(f"Device: {device}") + print(f"Device: {device}") + + logging.info(params) + print(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of parameters: {num_param}") print(f"Number of parameters: {num_param}") + logging.info("About to create datamodule") + data_module = TextMelDataModule(hparams=params.data_args) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + train_dl = data_module.train_dataloader() + valid_dl = data_module.val_dataloader() + + rank = 0 + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + fix_random_seed(params.seed + epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + ) + + if epoch % params.save_every_n == 0 or epoch == params.num_epochs: + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer=optimizer, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + logging.info("Done!") + torch.set_num_threads(1) torch.set_num_interop_threads(1) diff --git a/egs/ljspeech/TTS/matcha/utils2.py b/egs/ljspeech/TTS/matcha/utils2.py new file mode 120000 index 0000000000..c2144f8e07 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/utils2.py @@ -0,0 +1 @@ +../vits/utils.py \ No newline at end of file