Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 52 additions & 17 deletions models/nsf_HiFigan/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,21 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm

# from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm

LRELU_SLOPE = 0.1
_OLD_WEIGHT_NORM = False
try:
from torch.nn.utils.parametrizations import weight_norm
# from torch.nn.utils.parametrizations import spectral_norm
except ImportError:
from torch.nn.utils import weight_norm
from torch.nn.utils import remove_weight_norm


_OLD_WEIGHT_NORM = True
from torch.nn.utils import spectral_norm


class AttrDict(dict):
Expand All @@ -23,6 +35,7 @@ def init_weights(m, mean=0.0, std=0.01):
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)


class ResBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__()
Expand Down Expand Up @@ -57,10 +70,17 @@ def forward(self, x):
return x

def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
global _OLD_WEIGHT_NORM
if _OLD_WEIGHT_NORM:
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
else:
for l in self.convs1:
torch.nn.utils.parametrize.remove_parametrizations(l)
for l in self.convs2:
torch.nn.utils.parametrize.remove_parametrizations(l)


class ResBlock2(torch.nn.Module):
Expand All @@ -83,8 +103,15 @@ def forward(self, x):
return x

def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)

global _OLD_WEIGHT_NORM
if _OLD_WEIGHT_NORM:
for l in self.convs:
remove_weight_norm(l)

else:
for l in self.convs:
torch.nn.utils.parametrize.remove_parametrizations(l)


class SineGen(torch.nn.Module):
Expand Down Expand Up @@ -277,12 +304,23 @@ def forward(self, x, f0):
def remove_weight_norm(self):
# rank_zero_info('Removing weight norm...')
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
global _OLD_WEIGHT_NORM
if _OLD_WEIGHT_NORM:
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()

remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
else:
for l in self.ups:
torch.nn.utils.parametrize.remove_parametrizations(l)
for l in self.resblocks:
l.remove_weight_norm()

torch.nn.utils.parametrize.remove_parametrizations(self.conv_pre)
torch.nn.utils.parametrize.remove_parametrizations(self.conv_post)


class DiscriminatorP(torch.nn.Module):
Expand Down Expand Up @@ -372,14 +410,12 @@ def forward(self, y):

fmap_rs = []


for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)

y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)


return y_d_rs, fmap_rs,


Expand Down Expand Up @@ -444,7 +480,6 @@ def forward(self, y):
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)


return y_d_rs, fmap_rs,


Expand All @@ -464,7 +499,7 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs):

for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean((1 - dr) ** 2)
g_loss = torch.mean(dg**2)
g_loss = torch.mean(dg ** 2)
loss += r_loss + g_loss
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
Expand Down
64 changes: 64 additions & 0 deletions tes1t_vocoer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pathlib
import json
import click
import torch
import torchaudio
from tqdm import tqdm

from training.nsf_HiFigan_task import nsf_HiFigan, dynamic_range_compression_torch
from utils import get_latest_checkpoint_path
from utils.config_utils import read_full_config, print_config
from utils.wav2F0 import get_pitch
from utils.wav2mel import PitchAdjustableMelSpectrogram


@click.command(help='')
@click.option('--exp_name', required=False, metavar='EXP', help='Name of the experiment')
@click.option('--ckpt_path', required=False, metavar='FILE', help='Path to the checkpoint file')
@click.option('--save_path', required=True, metavar='FILE', help='Path to save the exported checkpoint')
@click.option('--work_dir', required=False, metavar='DIR', help='Working directory containing the experiments')
@click.option('--wav_path', required=True, metavar='DIR', help='Working directory containing the experiments')
@click.option('--key', required=False, metavar='DIR', help='Working directory containing the experiments',default=0)
def export(exp_name, ckpt_path, save_path, work_dir,wav_path,key):
# print_config(config)
if exp_name is None and ckpt_path is None:
raise RuntimeError('Either --exp_name or --ckpt_path should be specified.')
if ckpt_path is None:
if work_dir is None:
work_dir = pathlib.Path(__file__).parent / 'experiments'
else:
work_dir = pathlib.Path(work_dir)
work_dir = work_dir / exp_name
assert not work_dir.exists() or work_dir.is_dir(), f'Path \'{work_dir}\' is not a directory.'
ckpt_path = get_latest_checkpoint_path(work_dir)

config_file = pathlib.Path(ckpt_path).with_name('config.yaml')
config = read_full_config(config_file)
temp_dict = torch.load(ckpt_path)['state_dict']
model=nsf_HiFigan(config)
model.build_model()
model.load_state_dict(temp_dict)
mel_spec_transform = PitchAdjustableMelSpectrogram(sample_rate=config['audio_sample_rate'],
n_fft=config['fft_size'],
win_length=config['win_size'],
hop_length=config['hop_size'],
f_min=config['fmin'],
f_max=config['fmax'],
n_mels=config['audio_num_mel_bins'], )
audio,sr=torchaudio.load(wav_path)
if sr!=config['audio_sample_rate']:
audio=torchaudio.transforms.Resample(audio,sr,config['audio_sample_rate'])
mel = dynamic_range_compression_torch(mel_spec_transform(audio,key_shift=key))
f0, uv = get_pitch(audio[0].numpy(), hparams=config, speed=1, interp_uv=True, length=len(mel[0].T))
f0*=2 ** (key / 12)
f0=torch.from_numpy(f0).float()[None,:]
with torch.no_grad():
aout=model.Gforward(sample={'mel': mel, 'f0': f0, })['audio']
torchaudio.save(save_path,aout[0],sample_rate=config['audio_sample_rate'])





if __name__ == '__main__':
export()