12
12
# names of its contributors may be used to endorse or promote products
13
13
# derived from this software without specific prior written permission.
14
14
#
15
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16
- # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17
- # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18
- # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18
+ # ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19
19
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20
20
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21
21
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
28
28
from scipy .io .wavfile import write
29
29
import torch
30
30
from mel2samp import files_to_list , MAX_WAV_VALUE
31
+ from denoiser import Denoiser
31
32
32
33
33
- def main (mel_files , waveglow_path , sigma , output_dir , sampling_rate , is_fp16 ):
34
+ def main (mel_files , waveglow_path , sigma , output_dir , sampling_rate , is_fp16 ,
35
+ denoiser_strength ):
34
36
mel_files = files_to_list (mel_files )
35
37
waveglow = torch .load (waveglow_path )['model' ]
36
38
waveglow = waveglow .remove_weightnorm (waveglow )
@@ -40,21 +42,29 @@ def main(mel_files, waveglow_path, sigma, output_dir, sampling_rate, is_fp16):
40
42
for k in waveglow .convinv :
41
43
k .float ()
42
44
45
+ if denoiser_strength > 0 :
46
+ denoiser = Denoiser (waveglow ).cuda ()
47
+
43
48
for i , file_path in enumerate (mel_files ):
44
49
file_name = os .path .splitext (os .path .basename (file_path ))[0 ]
45
50
mel = torch .load (file_path )
46
51
mel = torch .autograd .Variable (mel .cuda ())
47
52
mel = torch .unsqueeze (mel , 0 )
48
53
mel = mel .half () if is_fp16 else mel
49
54
with torch .no_grad ():
50
- audio = MAX_WAV_VALUE * waveglow .infer (mel , sigma = sigma )[0 ]
55
+ audio = waveglow .infer (mel , sigma = sigma )
56
+ if denoiser_strength > 0 :
57
+ audio = denoiser (audio , denoiser_strength )
58
+ audio = audio * MAX_WAV_VALUE
59
+ audio = audio .squeeze ()
51
60
audio = audio .cpu ().numpy ()
52
61
audio = audio .astype ('int16' )
53
62
audio_path = os .path .join (
54
63
output_dir , "{}_synthesis.wav" .format (file_name ))
55
64
write (audio_path , sampling_rate , audio )
56
65
print (audio_path )
57
66
67
+
58
68
if __name__ == "__main__" :
59
69
import argparse
60
70
@@ -66,8 +76,10 @@ def main(mel_files, waveglow_path, sigma, output_dir, sampling_rate, is_fp16):
66
76
parser .add_argument ("-s" , "--sigma" , default = 1.0 , type = float )
67
77
parser .add_argument ("--sampling_rate" , default = 22050 , type = int )
68
78
parser .add_argument ("--is_fp16" , action = "store_true" )
79
+ parser .add_argument ("-d" , "--denoiser_strength" , default = 0.0 , type = float ,
80
+ help = 'Removes model bias. Start with 0.1 and adjust' )
69
81
70
82
args = parser .parse_args ()
71
83
72
- main (args .filelist_path , args .waveglow_path , args .sigma ,
73
- args .output_dir , args .sampling_rate , args .is_fp16 )
84
+ main (args .filelist_path , args .waveglow_path , args .sigma , args . output_dir ,
85
+ args .sampling_rate , args .is_fp16 , args .denoiser_strength )
0 commit comments