-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathinfer.py
executable file
·133 lines (107 loc) · 4.38 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from tqdm import tqdm
from scipy.io.wavfile import write
from torch.multiprocessing import set_start_method
from torch.utils.data import DataLoader
import argparse
import numpy as np
import os
import re
import torch
from msmctts.datasets import build_dataset
from msmctts.tasks import build_task
from msmctts.utils.plot import plot_matrix
from msmctts.utils.utils import to_model, feature_normalize
try:
set_start_method('spawn')
except RuntimeError:
pass
def get_output_base_path(checkpoint_path):
base_dir = os.path.dirname(checkpoint_path)
match = re.compile(r'.*_([0-9]+)').match(checkpoint_path)
name = 'eval-%d' % int(match.group(1)) if match else 'eval'
return os.path.join(base_dir, name)
def save_feature(path, feat, format, sample_rate):
if format == '.npy':
np.save(path, feat)
elif format == '.png':
plot_matrix(feat, path)
elif format == '.txt':
np.savetxt(path, feat, fmt="%.6f")
elif format == '.dat':
feat.astype(np.float32).tofile(path)
elif format == '.wav':
peak = max(abs(feat))
feat = feat / peak if peak > 1 else feat
write(path, sample_rate, (feat * 32767.0).astype(np.int16))
def test(task, testset, output_dir, n_jobs=1):
dataloader = DataLoader(testset, batch_size=n_jobs,
num_workers=0, shuffle=False, pin_memory=False, drop_last=False,
sampler=torch.utils.data.SequentialSampler(testset),
collate_fn=(testset.collate_fn
if hasattr(testset, 'collate_fn') else None)
)
# Startup task
if torch.cuda.is_available():
task = task.cuda()
task.eval()
# Build output directories
if not hasattr(task.config, 'save_features'):
raise ValueError("No saved features")
feat_dir = {}
for name, _, _ in task.config.save_features:
feat_dir[name] = os.path.join(output_dir, name)
os.makedirs(feat_dir[name], exist_ok=True)
# Multi-process or Single-process
for batch_i, features in tqdm(enumerate(dataloader)):
# Get IDs if the batch is sorted in collate_fn
test_ids = [testset.id_list[x] for x in features.pop('_id')]
# Model inference
features = to_model(features)
saved_features = task(features)
# Save output features
for i, test_id in enumerate(test_ids):
for name, fmt, sample_rate in task.config.save_features:
# Convert feature to numpy
feat = saved_features[name][i]
if isinstance(feat, torch.Tensor):
feat = feat.detach().cpu().numpy()
# Denormalize (optional)
if name in testset.feature_stat:
stat = testset.feature_stat[name]
feat = feature_normalize(feat, stat, True)
# Save feature
path = "{}/{}{}".format(feat_dir[name], test_id, fmt)
save_feature(path, feat, fmt, sample_rate=sample_rate)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-m', "--model", required=True)
parser.add_argument('-c', "--config", default=None)
parser.add_argument('-t', "--test_config", default=None)
parser.add_argument('-j', "--jobs", type=int, default=1)
parser.add_argument('-o', "--output_dir", default=None)
parser.add_argument("--debug", action='store_true')
args = parser.parse_args()
# Check arguments
if args.model is None and args.config is None:
parser.error('at least one argument shoule be given: -m/--model, -c/--config')
# Load task from checkpoint file
task = build_task(args.config,
mode='debug' if args.debug else 'infer',
checkpoint=args.model)
# Build Test Dataset
dataset_config = task.config.testset \
if hasattr(task.config, 'testset') else \
task.config.dataset
dataset_config['training'] = False
if args.test_config is not None:
dataset_config['id_list'] = args.test_config
dataset = build_dataset(dataset_config)
# Auto-generate output directory
if args.output_dir is None:
args.output_dir = get_output_base_path(args.model) \
if args.model is not None else os.path.dirname(args.config)
os.makedirs(args.output_dir, exist_ok=True)
# Inference
test(task, dataset, args.output_dir, n_jobs=args.jobs)
if __name__ == '__main__':
main()