Skip to content

Commit 60674be

Browse files
committed
inference.py: adding bias removal from inference pipeline
1 parent a7168f3 commit 60674be

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

inference.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
# names of its contributors may be used to endorse or promote products
1313
# derived from this software without specific prior written permission.
1414
#
15-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16-
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17-
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18-
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18+
# ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
1919
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
2020
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
2121
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
@@ -28,9 +28,11 @@
2828
from scipy.io.wavfile import write
2929
import torch
3030
from mel2samp import files_to_list, MAX_WAV_VALUE
31+
from denoiser import Denoiser
3132

3233

33-
def main(mel_files, waveglow_path, sigma, output_dir, sampling_rate, is_fp16):
34+
def main(mel_files, waveglow_path, sigma, output_dir, sampling_rate, is_fp16,
35+
denoiser_strength):
3436
mel_files = files_to_list(mel_files)
3537
waveglow = torch.load(waveglow_path)['model']
3638
waveglow = waveglow.remove_weightnorm(waveglow)
@@ -40,21 +42,29 @@ def main(mel_files, waveglow_path, sigma, output_dir, sampling_rate, is_fp16):
4042
for k in waveglow.convinv:
4143
k.float()
4244

45+
if denoiser_strength > 0:
46+
denoiser = Denoiser(waveglow).cuda()
47+
4348
for i, file_path in enumerate(mel_files):
4449
file_name = os.path.splitext(os.path.basename(file_path))[0]
4550
mel = torch.load(file_path)
4651
mel = torch.autograd.Variable(mel.cuda())
4752
mel = torch.unsqueeze(mel, 0)
4853
mel = mel.half() if is_fp16 else mel
4954
with torch.no_grad():
50-
audio = MAX_WAV_VALUE*waveglow.infer(mel, sigma=sigma)[0]
55+
audio = waveglow.infer(mel, sigma=sigma)
56+
if denoiser_strength > 0:
57+
audio = denoiser(audio, denoiser_strength)
58+
audio = audio * MAX_WAV_VALUE
59+
audio = audio.squeeze()
5160
audio = audio.cpu().numpy()
5261
audio = audio.astype('int16')
5362
audio_path = os.path.join(
5463
output_dir, "{}_synthesis.wav".format(file_name))
5564
write(audio_path, sampling_rate, audio)
5665
print(audio_path)
5766

67+
5868
if __name__ == "__main__":
5969
import argparse
6070

@@ -66,8 +76,10 @@ def main(mel_files, waveglow_path, sigma, output_dir, sampling_rate, is_fp16):
6676
parser.add_argument("-s", "--sigma", default=1.0, type=float)
6777
parser.add_argument("--sampling_rate", default=22050, type=int)
6878
parser.add_argument("--is_fp16", action="store_true")
79+
parser.add_argument("-d", "--denoiser_strength", default=0.0, type=float,
80+
help='Removes model bias. Start with 0.1 and adjust')
6981

7082
args = parser.parse_args()
7183

72-
main(args.filelist_path, args.waveglow_path, args.sigma,
73-
args.output_dir, args.sampling_rate, args.is_fp16)
84+
main(args.filelist_path, args.waveglow_path, args.sigma, args.output_dir,
85+
args.sampling_rate, args.is_fp16, args.denoiser_strength)

0 commit comments

Comments
 (0)