|
| 1 | +import sys |
| 2 | +sys.path.append('tacotron2') |
| 3 | +import torch |
| 4 | +from layers import STFT |
| 5 | + |
| 6 | + |
| 7 | +class Denoiser(torch.nn.Module): |
| 8 | + """ Removes model bias from audio produced with waveglow """ |
| 9 | + |
| 10 | + def __init__(self, waveglow, filter_length=1024, n_overlap=4, |
| 11 | + win_length=1024, mode='zeros'): |
| 12 | + super(Denoiser, self).__init__() |
| 13 | + self.stft = STFT(filter_length=filter_length, |
| 14 | + hop_length=int(filter_length/n_overlap), |
| 15 | + win_length=win_length).cuda() |
| 16 | + if mode == 'zeros': |
| 17 | + mel_input = torch.zeros( |
| 18 | + (1, 80, 88), |
| 19 | + dtype=waveglow.upsample.weight.dtype, |
| 20 | + device=waveglow.upsample.weight.device) |
| 21 | + elif mode == 'normal': |
| 22 | + mel_input = torch.randn( |
| 23 | + (1, 80, 88), |
| 24 | + dtype=waveglow.upsample.weight.dtype, |
| 25 | + device=waveglow.upsample.weight.device) |
| 26 | + else: |
| 27 | + raise Exception("Mode {} if not supported".format(mode)) |
| 28 | + |
| 29 | + with torch.no_grad(): |
| 30 | + bias_audio = waveglow.infer(mel_input, sigma=0.0).float() |
| 31 | + bias_spec, _ = self.stft.transform(bias_audio) |
| 32 | + |
| 33 | + self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None]) |
| 34 | + |
| 35 | + def forward(self, audio, strength=0.1): |
| 36 | + audio_spec, audio_angles = self.stft.transform(audio.cuda().float()) |
| 37 | + audio_spec_denoised = audio_spec - self.bias_spec * strength |
| 38 | + audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) |
| 39 | + audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles) |
| 40 | + return audio_denoised |
0 commit comments