@@ -27,47 +27,54 @@ def waveform_to_log_mel_spectrogram_patches(waveform, params):
27
27
# Convert waveform into spectrogram using a Short-Time Fourier Transform.
28
28
# Note that tf.signal.stft() uses a periodic Hann window by default.
29
29
window_length_samples = int (
30
- round (params .SAMPLE_RATE * params .STFT_WINDOW_SECONDS ))
30
+ round (params .sample_rate * params .stft_window_seconds ))
31
31
hop_length_samples = int (
32
- round (params .SAMPLE_RATE * params .STFT_HOP_SECONDS ))
32
+ round (params .sample_rate * params .stft_hop_seconds ))
33
33
fft_length = 2 ** int (np .ceil (np .log (window_length_samples ) / np .log (2.0 )))
34
34
num_spectrogram_bins = fft_length // 2 + 1
35
- magnitude_spectrogram = tf .abs (tf .signal .stft (
36
- signals = waveform ,
37
- frame_length = window_length_samples ,
38
- frame_step = hop_length_samples ,
39
- fft_length = fft_length ))
35
+ if params .tflite_compatible :
36
+ magnitude_spectrogram = _tflite_stft_magnitude (
37
+ signal = waveform ,
38
+ frame_length = window_length_samples ,
39
+ frame_step = hop_length_samples ,
40
+ fft_length = fft_length )
41
+ else :
42
+ magnitude_spectrogram = tf .abs (tf .signal .stft (
43
+ signals = waveform ,
44
+ frame_length = window_length_samples ,
45
+ frame_step = hop_length_samples ,
46
+ fft_length = fft_length ))
40
47
# magnitude_spectrogram has shape [<# STFT frames>, num_spectrogram_bins]
41
48
42
49
# Convert spectrogram into log mel spectrogram.
43
50
linear_to_mel_weight_matrix = tf .signal .linear_to_mel_weight_matrix (
44
- num_mel_bins = params .MEL_BANDS ,
51
+ num_mel_bins = params .mel_bands ,
45
52
num_spectrogram_bins = num_spectrogram_bins ,
46
- sample_rate = params .SAMPLE_RATE ,
47
- lower_edge_hertz = params .MEL_MIN_HZ ,
48
- upper_edge_hertz = params .MEL_MAX_HZ )
53
+ sample_rate = params .sample_rate ,
54
+ lower_edge_hertz = params .mel_min_hz ,
55
+ upper_edge_hertz = params .mel_max_hz )
49
56
mel_spectrogram = tf .matmul (
50
57
magnitude_spectrogram , linear_to_mel_weight_matrix )
51
- log_mel_spectrogram = tf .math .log (mel_spectrogram + params .LOG_OFFSET )
52
- # log_mel_spectrogram has shape [<# STFT frames>, MEL_BANDS ]
58
+ log_mel_spectrogram = tf .math .log (mel_spectrogram + params .log_offset )
59
+ # log_mel_spectrogram has shape [<# STFT frames>, params.mel_bands ]
53
60
54
- # Frame spectrogram (shape [<# STFT frames>, MEL_BANDS ]) into patches (the
55
- # input examples). Only complete frames are emitted, so if there is less
56
- # than PATCH_WINDOW_SECONDS of waveform then nothing is emitted (to avoid
57
- # this, zero-pad before processing).
61
+ # Frame spectrogram (shape [<# STFT frames>, params.mel_bands ]) into patches
62
+ # (the input examples). Only complete frames are emitted, so if there is
63
+ # less than params.patch_window_seconds of waveform then nothing is emitted
64
+ # (to avoid this, zero-pad before processing).
58
65
spectrogram_hop_length_samples = int (
59
- round (params .SAMPLE_RATE * params .STFT_HOP_SECONDS ))
60
- spectrogram_sample_rate = params .SAMPLE_RATE / spectrogram_hop_length_samples
66
+ round (params .sample_rate * params .stft_hop_seconds ))
67
+ spectrogram_sample_rate = params .sample_rate / spectrogram_hop_length_samples
61
68
patch_window_length_samples = int (
62
- round (spectrogram_sample_rate * params .PATCH_WINDOW_SECONDS ))
69
+ round (spectrogram_sample_rate * params .patch_window_seconds ))
63
70
patch_hop_length_samples = int (
64
- round (spectrogram_sample_rate * params .PATCH_HOP_SECONDS ))
71
+ round (spectrogram_sample_rate * params .patch_hop_seconds ))
65
72
features = tf .signal .frame (
66
73
signal = log_mel_spectrogram ,
67
74
frame_length = patch_window_length_samples ,
68
75
frame_step = patch_hop_length_samples ,
69
76
axis = 0 )
70
- # features has shape [<# patches>, <# STFT frames in an patch>, MEL_BANDS ]
77
+ # features has shape [<# patches>, <# STFT frames in an patch>, params.mel_bands ]
71
78
72
79
return log_mel_spectrogram , features
73
80
@@ -78,23 +85,81 @@ def pad_waveform(waveform, params):
78
85
# need at least one patch window length of waveform plus enough extra samples
79
86
# to complete the final STFT analysis window.
80
87
min_waveform_seconds = (
81
- params .PATCH_WINDOW_SECONDS +
82
- params .STFT_WINDOW_SECONDS - params .STFT_HOP_SECONDS )
83
- min_num_samples = tf .cast (min_waveform_seconds * params .SAMPLE_RATE , tf .int32 )
84
- num_samples = tf .size (waveform )
88
+ params .patch_window_seconds +
89
+ params .stft_window_seconds - params .stft_hop_seconds )
90
+ min_num_samples = tf .cast (min_waveform_seconds * params .sample_rate , tf .int32 )
91
+ num_samples = tf .shape (waveform )[ 0 ]
85
92
num_padding_samples = tf .maximum (0 , min_num_samples - num_samples )
86
93
87
94
# In addition, there might be enough waveform for one or more additional
88
95
# patches formed by hopping forward. If there are more samples than one patch,
89
96
# round up to an integral number of hops.
90
97
num_samples = tf .maximum (num_samples , min_num_samples )
91
98
num_samples_after_first_patch = num_samples - min_num_samples
92
- hop_samples = tf .cast (params .PATCH_HOP_SECONDS * params .SAMPLE_RATE , tf .int32 )
99
+ hop_samples = tf .cast (params .patch_hop_seconds * params .sample_rate , tf .int32 )
93
100
num_hops_after_first_patch = tf .cast (tf .math .ceil (
94
- tf .math .divide (num_samples_after_first_patch , hop_samples )), tf .int32 )
101
+ tf .cast (num_samples_after_first_patch , tf .float32 ) /
102
+ tf .cast (hop_samples , tf .float32 )), tf .int32 )
95
103
num_padding_samples += (
96
104
hop_samples * num_hops_after_first_patch - num_samples_after_first_patch )
97
105
98
106
padded_waveform = tf .pad (waveform , [[0 , num_padding_samples ]],
99
107
mode = 'CONSTANT' , constant_values = 0.0 )
100
108
return padded_waveform
109
+
110
+
111
+ def _tflite_stft_magnitude (signal , frame_length , frame_step , fft_length ):
112
+ """TF-Lite-compatible version of tf.abs(tf.signal.stft())."""
113
+ def _hann_window ():
114
+ return tf .reshape (
115
+ tf .constant (
116
+ (0.5 - 0.5 * np .cos (2 * np .pi * np .arange (0 , 1.0 , 1.0 / frame_length ))
117
+ ).astype (np .float32 ),
118
+ name = 'hann_window' ), [1 , frame_length ])
119
+
120
+ def _dft_matrix (dft_length ):
121
+ """Calculate the full DFT matrix in NumPy."""
122
+ # See https://en.wikipedia.org/wiki/DFT_matrix
123
+ omega = (0 + 1j ) * 2.0 * np .pi / float (dft_length )
124
+ # Don't include 1/sqrt(N) scaling, tf.signal.rfft doesn't apply it.
125
+ return np .exp (omega * np .outer (np .arange (dft_length ), np .arange (dft_length )))
126
+
127
+ def _rdft (framed_signal , fft_length ):
128
+ """Implement real-input Discrete Fourier Transform by matmul."""
129
+ # We are right-multiplying by the DFT matrix, and we are keeping only the
130
+ # first half ("positive frequencies"). So discard the second half of rows,
131
+ # but transpose the array for right-multiplication. The DFT matrix is
132
+ # symmetric, so we could have done it more directly, but this reflects our
133
+ # intention better.
134
+ complex_dft_matrix_kept_values = _dft_matrix (fft_length )[:(
135
+ fft_length // 2 + 1 ), :].transpose ()
136
+ real_dft_matrix = tf .constant (
137
+ np .real (complex_dft_matrix_kept_values ).astype (np .float32 ),
138
+ name = 'real_dft_matrix' )
139
+ imag_dft_matrix = tf .constant (
140
+ np .imag (complex_dft_matrix_kept_values ).astype (np .float32 ),
141
+ name = 'imaginary_dft_matrix' )
142
+ signal_frame_length = tf .shape (framed_signal )[- 1 ]
143
+ half_pad = (fft_length - signal_frame_length ) // 2
144
+ padded_frames = tf .pad (
145
+ framed_signal ,
146
+ [
147
+ # Don't add any padding in the frame dimension.
148
+ [0 , 0 ],
149
+ # Pad before and after the signal within each frame.
150
+ [half_pad , fft_length - signal_frame_length - half_pad ]
151
+ ],
152
+ mode = 'CONSTANT' ,
153
+ constant_values = 0.0 )
154
+ real_stft = tf .matmul (padded_frames , real_dft_matrix )
155
+ imag_stft = tf .matmul (padded_frames , imag_dft_matrix )
156
+ return real_stft , imag_stft
157
+
158
+ def _complex_abs (real , imag ):
159
+ return tf .sqrt (tf .add (real * real , imag * imag ))
160
+
161
+ framed_signal = tf .signal .frame (signal , frame_length , frame_step )
162
+ windowed_signal = framed_signal * _hann_window ()
163
+ real_stft , imag_stft = _rdft (windowed_signal , fft_length )
164
+ stft_magnitude = _complex_abs (real_stft , imag_stft )
165
+ return stft_magnitude
0 commit comments