-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathphoneme_stratify.py
76 lines (57 loc) · 2.62 KB
/
phoneme_stratify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import sox
import torch
import yaml
import shutil
import operator
import numpy as np
from util.preprocess_functions import preprocess_dataset,normalize,set_type
from util.timit_dataset import create_dataloader
from util.functions import test_file
from six.moves import cPickle
# if this is removed, pytorch will output harmless warning messages that may be irritating
import warnings
warnings.filterwarnings("ignore")
phn_occurrence = {}
# load phoneme information
with open('config/phn_occurrence.txt') as f:
for line in f.readlines():
phn_occurrence[line.split()[0]] = int(line.split()[1])
phonemes = ['ih', 'n', 'iy', 'l', 's', 'r', 'ah', 'aa', 'er', 'k', 'm', 't', 'eh', 'ae', 'z', 'd', 'q', 'w', 'dh', 'p',
'dx', 'f', 'b', 'sh', 'ay', 'ey', 'ow', 'g', 'uw', 'hh', 'v', 'y', 'ng', 'jh', 'th', 'oy', 'ch', 'uh', 'aw']
strat_phn_count = {}
def get_pred(path):
data_type = 'float32'
mean_val = np.loadtxt('config/mean_val.txt')
std_val = np.loadtxt('config/std_val.txt')
x, y = preprocess_dataset(path)
x = normalize(x, mean_val, std_val)
x = set_type(x, data_type)
test_set = create_dataloader(x, y, **conf['model_parameter'], **conf['training_parameter'], shuffle=False)
for batch_index,(batch_data,batch_label) in enumerate(test_set):
pred,true = test_file(batch_data, batch_label, listener, speller, optimizer, **conf['model_parameter'])
return pred
# load LAS model
config_path = 'config/las_example_config.yaml'
conf = yaml.load(open(config_path,'r'))
listener = torch.load(conf['training_parameter']['pretrained_listener_path'], map_location=lambda storage, loc: storage)
speller = torch.load(conf['training_parameter']['pretrained_speller_path'], map_location=lambda storage, loc: storage)
optimizer = torch.optim.Adam([{'params':listener.parameters()}, {'params':speller.parameters()}], lr=conf['training_parameter']['learning_rate'])
for phn in phonemes:
i = 0
for file in os.listdir(os.path.join('phoneme_set', phn)):
print('Testing {} {} out of {}'.format(phn, str(i), str(phn_occurrence[phn])), end='\r')
test = os.path.join('phoneme_set', phn, file)
pred = get_pred(test)
if phn in pred:
if phn not in strat_phn_count:
strat_phn_count[phn] = 1
else:
strat_phn_count[phn] += 1
os.makedirs(os.path.join('strat_phoneme_set', phn), exist_ok=True)
shutil.copy(test, os.path.join('strat_phoneme_set', phn, phn + str(strat_phn_count[phn]) + '.wav'))
i += 1
print()
sorted_phn = sorted(strat_phn_count.items(), key=operator.itemgetter(1), reverse=True)
with open('config/strat_phn_occurrence.txt', 'w+') as f:
[f.write(phn[0] + ' ' + str(phn[1]) + '\n') for phn in sorted_phn]