-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtest_extract_deepspeech.py
105 lines (89 loc) · 4 KB
/
test_extract_deepspeech.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import warnings
import numpy as np
import resampy
import tensorflow as tf
from python_speech_features import mfcc
from scipy.io import wavfile
# DINet deepspeech
class DeepSpeech:
def __init__(self, model_path):
(
self.graph,
self.logits_ph,
self.input_node_ph,
self.input_lengths_ph,
) = self._prepare_deepspeech_net(model_path)
self.target_sample_rate = 16000
def _prepare_deepspeech_net(self, deepspeech_pb_path):
with tf.io.gfile.GFile(deepspeech_pb_path, "rb") as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.compat.v1.get_default_graph()
tf.import_graph_def(graph_def, name="deepspeech")
logits_ph = graph.get_tensor_by_name("logits:0")
input_node_ph = graph.get_tensor_by_name("input_node:0")
input_lengths_ph = graph.get_tensor_by_name("input_lengths:0")
return graph, logits_ph, input_node_ph, input_lengths_ph
def conv_audio_to_deepspeech_input_vector( # 输入之前的预处理
self, audio, sample_rate, num_cepstrum, num_context
):
# Get mfcc coefficients:
features = mfcc(signal=audio, samplerate=sample_rate, numcep=num_cepstrum) # 提取MFCC
# We only keep every second feature (BiRNN stride = 2):
features = features[::2]
# One stride per time step in the input:
num_strides = len(features)
# Add empty initial and final contexts: 保证了长度的一致
empty_context = np.zeros((num_context, num_cepstrum), dtype=features.dtype)
features = np.concatenate((empty_context, features, empty_context))
# Create a view into the array with overlapping strides of size
# numcontext (past) + 1 (present) + numcontext (future):
window_size = 2 * num_context + 1
# 创一个view
train_inputs = np.lib.stride_tricks.as_strided(
features,
shape=(num_strides, window_size, num_cepstrum),
strides=(features.strides[0], features.strides[0], features.strides[1]),
writeable=False,
)
# Flatten the second and third dimensions:
train_inputs = np.reshape(train_inputs, [num_strides, -1])
train_inputs = np.copy(train_inputs)
train_inputs = (train_inputs - np.mean(train_inputs)) / np.std(train_inputs)
return train_inputs
def compute_audio_feature(self, audio_path):
audio_sample_rate, audio = wavfile.read(audio_path)
if audio.ndim != 1:
warnings.warn("Audio has multiple channels, the first channel is used")
audio = audio[:, 0]
if audio_sample_rate != self.target_sample_rate:
# 不符合相关的采样率会对其进行采样
resampled_audio = resampy.resample(
x=audio.astype(np.float),
sr_orig=audio_sample_rate,
sr_new=self.target_sample_rate,
)
else:
resampled_audio = audio.astype(np.float)
with tf.compat.v1.Session(graph=self.graph) as sess:
input_vector = self.conv_audio_to_deepspeech_input_vector(
audio=resampled_audio.astype(np.int16),
sample_rate=self.target_sample_rate,
num_cepstrum=26,
num_context=9,
)
network_output = sess.run(
self.logits_ph,
feed_dict={
self.input_node_ph: input_vector[np.newaxis, ...],
self.input_lengths_ph: [input_vector.shape[0]],
},
)
ds_features = network_output[::2, 0, :]
return ds_features
if __name__ == "__main__":
audio_path = r"/mnt/sdb/cxh/liwen/EAT_code/audio_temp/tmp.wav"
model_path = r"/mnt/sdb/cxh/liwen/EAT_code/preprocess/DS_model/output_graph.pb"
DSModel = DeepSpeech(model_path)
ds_feature = DSModel.compute_audio_feature(audio_path)
print(ds_feature.shape)