-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtest.py
76 lines (67 loc) · 2.17 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# -*- coding: utf-8 -*-
"""
Author: Andrej Leban
Created on Sun May 29 13:05:27 2022
"""
import argparse
import os
import time
import sounddevice as sd
# import soundfile as sf
import torch
from infowavegan import WaveGANGenerator
from utils import get_continuation_fname
# cf: https://github.com/pytorch/pytorch/issues/16797
# class CPU_Unpickler(pk.Unpickler):
# def find_class(self, module, name):
# if module == 'torch.storage' and name == '_load_from_bytes':
# return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
# else:
# return super().find_class(module, name)
if __name__ == "__main__":
# generator = CPU_Unpickler(open("generator.pkl", 'rb')).load()
# discriminator = CPU_Unpickler(open("discriminator.pkl", 'rb')).load()
parser = argparse.ArgumentParser()
parser.add_argument(
'--dir',
type=str,
required=True,
help='Directory where checkpoints are saved'
)
parser.add_argument(
'--epoch',
type=int,
required=True,
help='Training Directory'
)
parser.add_argument(
'--sample_rate',
type=int,
default=16000,
help='Q-net categories'
)
parser.add_argument(
'--slice_len',
type=int,
default=16384,
)
args = parser.parse_args()
epoch = args.epoch
dir = args.dir
sample_rate = args.sample_rate
slice_len = args.slice_len
# Load generator from checkpoint
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fname, _ = get_continuation_fname(epoch, dir)
G = WaveGANGenerator(slice_len=slice_len)
G.load_state_dict(torch.load(os.path.join(dir, fname + "_G.pt"),
map_location = device))
G.to(device)
G.eval()
# Generate from random noise
for i in range(100):
z = torch.FloatTensor(1, 100).uniform_(-1, 1).to(device)
genData = G(z)[0, 0, :].detach().cpu().numpy()
# write(f'out.wav', sample_rate, (genData * 32767).astype(np.int16))
sd.play(genData, sample_rate)
time.sleep(1)