diff --git a/models/nsf_HiFigan/models.py b/models/nsf_HiFigan/models.py index 1bbd8bf..fcc44ee 100644 --- a/models/nsf_HiFigan/models.py +++ b/models/nsf_HiFigan/models.py @@ -3,9 +3,21 @@ import torch.nn as nn import torch.nn.functional as F from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d -from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +# from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm LRELU_SLOPE = 0.1 +_OLD_WEIGHT_NORM = False +try: + from torch.nn.utils.parametrizations import weight_norm + # from torch.nn.utils.parametrizations import spectral_norm +except ImportError: + from torch.nn.utils import weight_norm + from torch.nn.utils import remove_weight_norm + + + _OLD_WEIGHT_NORM = True +from torch.nn.utils import spectral_norm class AttrDict(dict): @@ -23,6 +35,7 @@ def init_weights(m, mean=0.0, std=0.01): def get_padding(kernel_size, dilation=1): return int((kernel_size * dilation - dilation) / 2) + class ResBlock1(torch.nn.Module): def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): super(ResBlock1, self).__init__() @@ -57,10 +70,17 @@ def forward(self, x): return x def remove_weight_norm(self): - for l in self.convs1: - remove_weight_norm(l) - for l in self.convs2: - remove_weight_norm(l) + global _OLD_WEIGHT_NORM + if _OLD_WEIGHT_NORM: + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + else: + for l in self.convs1: + torch.nn.utils.parametrize.remove_parametrizations(l) + for l in self.convs2: + torch.nn.utils.parametrize.remove_parametrizations(l) class ResBlock2(torch.nn.Module): @@ -83,8 +103,15 @@ def forward(self, x): return x def remove_weight_norm(self): - for l in self.convs: - remove_weight_norm(l) + + global _OLD_WEIGHT_NORM + if _OLD_WEIGHT_NORM: + for l in self.convs: + remove_weight_norm(l) + + else: + for l in self.convs: + torch.nn.utils.parametrize.remove_parametrizations(l) class SineGen(torch.nn.Module): @@ -277,12 +304,23 @@ def forward(self, x, f0): def remove_weight_norm(self): # rank_zero_info('Removing weight norm...') print('Removing weight norm...') - for l in self.ups: - remove_weight_norm(l) - for l in self.resblocks: - l.remove_weight_norm() - remove_weight_norm(self.conv_pre) - remove_weight_norm(self.conv_post) + global _OLD_WEIGHT_NORM + if _OLD_WEIGHT_NORM: + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + else: + for l in self.ups: + torch.nn.utils.parametrize.remove_parametrizations(l) + for l in self.resblocks: + l.remove_weight_norm() + + torch.nn.utils.parametrize.remove_parametrizations(self.conv_pre) + torch.nn.utils.parametrize.remove_parametrizations(self.conv_post) class DiscriminatorP(torch.nn.Module): @@ -372,14 +410,12 @@ def forward(self, y): fmap_rs = [] - for i, d in enumerate(self.discriminators): y_d_r, fmap_r = d(y) y_d_rs.append(y_d_r) fmap_rs.append(fmap_r) - return y_d_rs, fmap_rs, @@ -444,7 +480,6 @@ def forward(self, y): y_d_rs.append(y_d_r) fmap_rs.append(fmap_r) - return y_d_rs, fmap_rs, @@ -464,7 +499,7 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs): for dr, dg in zip(disc_real_outputs, disc_generated_outputs): r_loss = torch.mean((1 - dr) ** 2) - g_loss = torch.mean(dg**2) + g_loss = torch.mean(dg ** 2) loss += r_loss + g_loss r_losses.append(r_loss.item()) g_losses.append(g_loss.item()) diff --git a/tes1t_vocoer.py b/tes1t_vocoer.py new file mode 100644 index 0000000..4193cf7 --- /dev/null +++ b/tes1t_vocoer.py @@ -0,0 +1,64 @@ +import pathlib +import json +import click +import torch +import torchaudio +from tqdm import tqdm + +from training.nsf_HiFigan_task import nsf_HiFigan, dynamic_range_compression_torch +from utils import get_latest_checkpoint_path +from utils.config_utils import read_full_config, print_config +from utils.wav2F0 import get_pitch +from utils.wav2mel import PitchAdjustableMelSpectrogram + + +@click.command(help='') +@click.option('--exp_name', required=False, metavar='EXP', help='Name of the experiment') +@click.option('--ckpt_path', required=False, metavar='FILE', help='Path to the checkpoint file') +@click.option('--save_path', required=True, metavar='FILE', help='Path to save the exported checkpoint') +@click.option('--work_dir', required=False, metavar='DIR', help='Working directory containing the experiments') +@click.option('--wav_path', required=True, metavar='DIR', help='Working directory containing the experiments') +@click.option('--key', required=False, metavar='DIR', help='Working directory containing the experiments',default=0) +def export(exp_name, ckpt_path, save_path, work_dir,wav_path,key): + # print_config(config) + if exp_name is None and ckpt_path is None: + raise RuntimeError('Either --exp_name or --ckpt_path should be specified.') + if ckpt_path is None: + if work_dir is None: + work_dir = pathlib.Path(__file__).parent / 'experiments' + else: + work_dir = pathlib.Path(work_dir) + work_dir = work_dir / exp_name + assert not work_dir.exists() or work_dir.is_dir(), f'Path \'{work_dir}\' is not a directory.' + ckpt_path = get_latest_checkpoint_path(work_dir) + + config_file = pathlib.Path(ckpt_path).with_name('config.yaml') + config = read_full_config(config_file) + temp_dict = torch.load(ckpt_path)['state_dict'] + model=nsf_HiFigan(config) + model.build_model() + model.load_state_dict(temp_dict) + mel_spec_transform = PitchAdjustableMelSpectrogram(sample_rate=config['audio_sample_rate'], + n_fft=config['fft_size'], + win_length=config['win_size'], + hop_length=config['hop_size'], + f_min=config['fmin'], + f_max=config['fmax'], + n_mels=config['audio_num_mel_bins'], ) + audio,sr=torchaudio.load(wav_path) + if sr!=config['audio_sample_rate']: + audio=torchaudio.transforms.Resample(audio,sr,config['audio_sample_rate']) + mel = dynamic_range_compression_torch(mel_spec_transform(audio,key_shift=key)) + f0, uv = get_pitch(audio[0].numpy(), hparams=config, speed=1, interp_uv=True, length=len(mel[0].T)) + f0*=2 ** (key / 12) + f0=torch.from_numpy(f0).float()[None,:] + with torch.no_grad(): + aout=model.Gforward(sample={'mel': mel, 'f0': f0, })['audio'] + torchaudio.save(save_path,aout[0],sample_rate=config['audio_sample_rate']) + + + + + +if __name__ == '__main__': + export()