Skip to content

Commit 277a201

Browse files
committed
inference.py: using amp for mixed precision inference
1 parent 091fbd1 commit 277a201

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

inference.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,8 @@ def main(mel_files, waveglow_path, sigma, output_dir, sampling_rate, is_fp16,
3838
waveglow = waveglow.remove_weightnorm(waveglow)
3939
waveglow.cuda().eval()
4040
if is_fp16:
41-
waveglow.half()
42-
for k in waveglow.convinv:
43-
k.float()
41+
from apex import amp
42+
waveglow, _ = amp.initialize(waveglow, [], opt_level="O3")
4443

4544
if denoiser_strength > 0:
4645
denoiser = Denoiser(waveglow).cuda()

0 commit comments

Comments
 (0)