Skip to content

Commit 8da4857

Browse files
authoredAug 18, 2020
Added TF-Lite-compatible feature extractor and model exporter for YAMNet (tensorflow#9098)
* Added TF-Lite-compatible feature extractor and model exporter for YAMNet. - Added a TF-Lite compatible feature extractor. With the latest TF-Lite, that involves a DFT-multiplication replacement for tf.abs(tf.signal.stft()) and not a lot else. Note that TF-Lite now allows variable-length inputs. - Added a YAMNet exporter that produces TF2 SavedModels, TF-Lite models, and TF-JS models. - Cleanups: switched hyperparameters to a dataclass, got rid of some lingering cruft in yamnet_test. * Responded to DAn's comments in tensorflow#9098 - Switched some hparams to float - Made class map asset available on the exported model, and tested that it can be loaded from the various exports.
1 parent ea5fc64 commit 8da4857

File tree

7 files changed

+400
-107
lines changed

7 files changed

+400
-107
lines changed
 

‎research/audioset/vggish/vggish_export_tfhub.py

+16-14
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
range) and returns a 2-d float32 batch of 128-d VGGish embeddings, one per
66
0.96s example generated from the waveform.
77
8+
Requires pip-installing tensorflow_hub.
9+
810
Usage:
9-
export_tfhub.py <path/to/VGGish/checkpoint> <path/to/tfhub/export>
11+
vggish_export_tfhub.py <path/to/VGGish/checkpoint> <path/to/tfhub/export>
1012
"""
1113

1214
import sys
@@ -41,19 +43,19 @@ def var_tracker(next_creator, **kwargs):
4143

4244
def waveform_to_features(waveform):
4345
"""Creates VGGish features using the YAMNet feature extractor."""
44-
yamnet_params.SAMPLE_RATE = vggish_params.SAMPLE_RATE
45-
yamnet_params.STFT_WINDOW_SECONDS = vggish_params.STFT_WINDOW_LENGTH_SECONDS
46-
yamnet_params.STFT_HOP_SECONDS = vggish_params.STFT_HOP_LENGTH_SECONDS
47-
yamnet_params.MEL_BANDS = vggish_params.NUM_MEL_BINS
48-
yamnet_params.MEL_MIN_HZ = vggish_params.MEL_MIN_HZ
49-
yamnet_params.MEL_MAX_HZ = vggish_params.MEL_MAX_HZ
50-
yamnet_params.LOG_OFFSET = vggish_params.LOG_OFFSET
51-
yamnet_params.PATCH_WINDOW_SECONDS = vggish_params.EXAMPLE_WINDOW_SECONDS
52-
yamnet_params.PATCH_HOP_SECONDS = vggish_params.EXAMPLE_HOP_SECONDS
53-
log_mel_spectrogram = yamnet_features.waveform_to_log_mel_spectrogram(
54-
waveform, yamnet_params)
55-
return yamnet_features.spectrogram_to_patches(
56-
log_mel_spectrogram, yamnet_params)
46+
params = yamnet_params.Params(
47+
sample_rate=vggish_params.SAMPLE_RATE,
48+
stft_window_seconds=vggish_params.STFT_WINDOW_LENGTH_SECONDS,
49+
stft_hop_seconds=vggish_params.STFT_HOP_LENGTH_SECONDS,
50+
mel_bands=vggish_params.NUM_MEL_BINS,
51+
mel_min_hz=vggish_params.MEL_MIN_HZ,
52+
mel_max_hz=vggish_params.MEL_MAX_HZ,
53+
log_offset=vggish_params.LOG_OFFSET,
54+
patch_window_seconds=vggish_params.EXAMPLE_WINDOW_SECONDS,
55+
patch_hop_seconds=vggish_params.EXAMPLE_HOP_SECONDS)
56+
log_mel_spectrogram, features = yamnet_features.waveform_to_log_mel_spectrogram_patches(
57+
waveform, params)
58+
return features
5759

5860
def define_vggish(waveform):
5961
with tf.variable_creator_scope(var_tracker):

‎research/audioset/yamnet/export.py

+213
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
"""Exports YAMNet as: TF2 SavedModel, TF-Lite model, TF-JS model.
2+
3+
The exported models all accept as input:
4+
- 1-d float32 Tensor of arbitrary shape containing an audio waveform
5+
(assumed to be mono 16 kHz samples in the [-1, +1] range)
6+
and return as output:
7+
- a 2-d float32 Tensor of shape [num_frames, num_classes] containing
8+
predicted class scores for each frame of audio extracted from the input.
9+
- a 2-d float32 Tensor of shape [num_frames, embedding_size] containing
10+
embeddings of each frame of audio.
11+
- a 2-d float32 Tensor of shape [num_spectrogram_frames, num_mel_bins]
12+
containing the log mel spectrogram of the entire waveform.
13+
The SavedModels will also contain (as an asset) a class map CSV file that maps
14+
class indices to AudioSet class names and Freebase MIDs. The path to the class
15+
map is available as the 'class_map_path()' method of the restored model.
16+
17+
Requires pip-installing tensorflow_hub and tensorflowjs.
18+
19+
Usage:
20+
export.py <path/to/YAMNet/weights-hdf-file> <path/to/output/directory>
21+
and the various exports will be created in subdirectories of the output directory.
22+
Assumes that it will be run in the yamnet source directory from where it loads
23+
the class map. Skips an export if the corresponding directory already exists.
24+
"""
25+
26+
import os
27+
import sys
28+
import tempfile
29+
import time
30+
31+
import numpy as np
32+
import tensorflow as tf
33+
assert tf.version.VERSION >= '2.0.0', (
34+
'Need at least TF 2.0, you have TF v{}'.format(tf.version.VERSION))
35+
import tensorflow_hub as tfhub
36+
from tensorflowjs.converters import tf_saved_model_conversion_v2 as tfjs_saved_model_converter
37+
38+
import params as yamnet_params
39+
import yamnet
40+
41+
42+
def log(msg):
43+
print('\n=====\n{} | {}\n=====\n'.format(time.asctime(), msg), flush=True)
44+
45+
46+
class YAMNet(tf.Module):
47+
"''A TF2 Module wrapper around YAMNet."""
48+
def __init__(self, weights_path, params):
49+
super().__init__()
50+
self._yamnet = yamnet.yamnet_frames_model(params)
51+
self._yamnet.load_weights(weights_path)
52+
self._class_map_asset = tf.saved_model.Asset('yamnet_class_map.csv')
53+
54+
@tf.function
55+
def class_map_path(self):
56+
return self._class_map_asset.asset_path
57+
58+
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.float32),))
59+
def __call__(self, waveform):
60+
return self._yamnet(waveform)
61+
62+
63+
def check_model(model_fn, class_map_path, params):
64+
yamnet_classes = yamnet.class_names(class_map_path)
65+
66+
"""Applies yamnet_test's sanity checks to an instance of YAMNet."""
67+
def clip_test(waveform, expected_class_name, top_n=10):
68+
predictions, embeddings, log_mel_spectrogram = model_fn(waveform)
69+
clip_predictions = np.mean(predictions, axis=0)
70+
top_n_indices = np.argsort(clip_predictions)[-top_n:]
71+
top_n_scores = clip_predictions[top_n_indices]
72+
top_n_class_names = yamnet_classes[top_n_indices]
73+
top_n_predictions = list(zip(top_n_class_names, top_n_scores))
74+
assert expected_class_name in top_n_class_names, (
75+
'Did not find expected class {} in top {} predictions: {}'.format(
76+
expected_class_name, top_n, top_n_predictions))
77+
78+
clip_test(
79+
waveform=np.zeros((int(3 * params.sample_rate),), dtype=np.float32),
80+
expected_class_name='Silence')
81+
82+
np.random.seed(51773) # Ensure repeatability.
83+
clip_test(
84+
waveform=np.random.uniform(-1.0, +1.0,
85+
(int(3 * params.sample_rate),)).astype(np.float32),
86+
expected_class_name='White noise')
87+
88+
clip_test(
89+
waveform=np.sin(2 * np.pi * 440 *
90+
np.arange(0, 3, 1 / params.sample_rate), dtype=np.float32),
91+
expected_class_name='Sine wave')
92+
93+
94+
def make_tf2_export(weights_path, export_dir):
95+
if os.path.exists(export_dir):
96+
log('TF2 export already exists in {}, skipping TF2 export'.format(
97+
export_dir))
98+
return
99+
100+
# Create a TF2 Module wrapper around YAMNet.
101+
log('Building and checking TF2 Module ...')
102+
params = yamnet_params.Params()
103+
yamnet = YAMNet(weights_path, params)
104+
check_model(yamnet, yamnet.class_map_path(), params)
105+
log('Done')
106+
107+
# Make TF2 SavedModel export.
108+
log('Making TF2 SavedModel export ...')
109+
tf.saved_model.save(yamnet, export_dir)
110+
log('Done')
111+
112+
# Check export with TF-Hub in TF2.
113+
log('Checking TF2 SavedModel export in TF2 ...')
114+
model = tfhub.load(export_dir)
115+
check_model(model, model.class_map_path(), params)
116+
log('Done')
117+
118+
# Check export with TF-Hub in TF1.
119+
log('Checking TF2 SavedModel export in TF1 ...')
120+
with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as sess:
121+
model = tfhub.load(export_dir)
122+
sess.run(tf.compat.v1.global_variables_initializer())
123+
def run_model(waveform):
124+
return sess.run(model(waveform))
125+
check_model(run_model, model.class_map_path().eval(), params)
126+
log('Done')
127+
128+
129+
def make_tflite_export(weights_path, export_dir):
130+
if os.path.exists(export_dir):
131+
log('TF-Lite export already exists in {}, skipping TF-Lite export'.format(
132+
export_dir))
133+
return
134+
135+
# Create a TF-Lite compatible Module wrapper around YAMNet.
136+
log('Building and checking TF-Lite Module ...')
137+
params = yamnet_params.Params(tflite_compatible=True)
138+
yamnet = YAMNet(weights_path, params)
139+
check_model(yamnet, yamnet.class_map_path(), params)
140+
log('Done')
141+
142+
# Make TF-Lite SavedModel export.
143+
log('Making TF-Lite SavedModel export ...')
144+
saved_model_dir = os.path.join(export_dir, 'saved_model')
145+
os.makedirs(saved_model_dir)
146+
tf.saved_model.save(yamnet, saved_model_dir)
147+
log('Done')
148+
149+
# Check that the export can be loaded and works.
150+
log('Checking TF-Lite SavedModel export in TF2 ...')
151+
model = tf.saved_model.load(saved_model_dir)
152+
check_model(model, model.class_map_path(), params)
153+
log('Done')
154+
155+
# Make a TF-Lite model from the SavedModel.
156+
log('Making TF-Lite model ...')
157+
tflite_converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
158+
tflite_model = tflite_converter.convert()
159+
tflite_model_path = os.path.join(export_dir, 'yamnet.tflite')
160+
with open(tflite_model_path, 'wb') as f:
161+
f.write(tflite_model)
162+
log('Done')
163+
164+
# Check the TF-Lite export.
165+
log('Checking TF-Lite model ...')
166+
interpreter = tf.lite.Interpreter(tflite_model_path)
167+
audio_input_index = interpreter.get_input_details()[0]['index']
168+
scores_output_index = interpreter.get_output_details()[0]['index']
169+
embeddings_output_index = interpreter.get_output_details()[1]['index']
170+
spectrogram_output_index = interpreter.get_output_details()[2]['index']
171+
def run_model(waveform):
172+
interpreter.resize_tensor_input(audio_input_index, [len(waveform)], strict=True)
173+
interpreter.allocate_tensors()
174+
interpreter.set_tensor(audio_input_index, waveform)
175+
interpreter.invoke()
176+
return (interpreter.get_tensor(scores_output_index),
177+
interpreter.get_tensor(embeddings_output_index),
178+
interpreter.get_tensor(spectrogram_output_index))
179+
check_model(run_model, 'yamnet_class_map.csv', params)
180+
log('Done')
181+
182+
return saved_model_dir
183+
184+
185+
def make_tfjs_export(tflite_saved_model_dir, export_dir):
186+
if os.path.exists(export_dir):
187+
log('TF-JS export already exists in {}, skipping TF-JS export'.format(
188+
export_dir))
189+
return
190+
191+
# Make a TF-JS model from the TF-Lite SavedModel export.
192+
log('Making TF-JS model ...')
193+
os.makedirs(export_dir)
194+
tfjs_saved_model_converter.convert_tf_saved_model(
195+
tflite_saved_model_dir, export_dir)
196+
log('Done')
197+
198+
199+
def main(args):
200+
weights_path = args[0]
201+
output_dir = args[1]
202+
203+
tf2_export_dir = os.path.join(output_dir, 'tf2')
204+
make_tf2_export(weights_path, tf2_export_dir)
205+
206+
tflite_export_dir = os.path.join(output_dir, 'tflite')
207+
tflite_saved_model_dir = make_tflite_export(weights_path, tflite_export_dir)
208+
209+
tfjs_export_dir = os.path.join(output_dir, 'tfjs')
210+
make_tfjs_export(tflite_saved_model_dir, tfjs_export_dir)
211+
212+
if __name__ == '__main__':
213+
main(sys.argv[1:])

‎research/audioset/yamnet/features.py

+93-28
Original file line numberDiff line numberDiff line change
@@ -27,47 +27,54 @@ def waveform_to_log_mel_spectrogram_patches(waveform, params):
2727
# Convert waveform into spectrogram using a Short-Time Fourier Transform.
2828
# Note that tf.signal.stft() uses a periodic Hann window by default.
2929
window_length_samples = int(
30-
round(params.SAMPLE_RATE * params.STFT_WINDOW_SECONDS))
30+
round(params.sample_rate * params.stft_window_seconds))
3131
hop_length_samples = int(
32-
round(params.SAMPLE_RATE * params.STFT_HOP_SECONDS))
32+
round(params.sample_rate * params.stft_hop_seconds))
3333
fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
3434
num_spectrogram_bins = fft_length // 2 + 1
35-
magnitude_spectrogram = tf.abs(tf.signal.stft(
36-
signals=waveform,
37-
frame_length=window_length_samples,
38-
frame_step=hop_length_samples,
39-
fft_length=fft_length))
35+
if params.tflite_compatible:
36+
magnitude_spectrogram = _tflite_stft_magnitude(
37+
signal=waveform,
38+
frame_length=window_length_samples,
39+
frame_step=hop_length_samples,
40+
fft_length=fft_length)
41+
else:
42+
magnitude_spectrogram = tf.abs(tf.signal.stft(
43+
signals=waveform,
44+
frame_length=window_length_samples,
45+
frame_step=hop_length_samples,
46+
fft_length=fft_length))
4047
# magnitude_spectrogram has shape [<# STFT frames>, num_spectrogram_bins]
4148

4249
# Convert spectrogram into log mel spectrogram.
4350
linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(
44-
num_mel_bins=params.MEL_BANDS,
51+
num_mel_bins=params.mel_bands,
4552
num_spectrogram_bins=num_spectrogram_bins,
46-
sample_rate=params.SAMPLE_RATE,
47-
lower_edge_hertz=params.MEL_MIN_HZ,
48-
upper_edge_hertz=params.MEL_MAX_HZ)
53+
sample_rate=params.sample_rate,
54+
lower_edge_hertz=params.mel_min_hz,
55+
upper_edge_hertz=params.mel_max_hz)
4956
mel_spectrogram = tf.matmul(
5057
magnitude_spectrogram, linear_to_mel_weight_matrix)
51-
log_mel_spectrogram = tf.math.log(mel_spectrogram + params.LOG_OFFSET)
52-
# log_mel_spectrogram has shape [<# STFT frames>, MEL_BANDS]
58+
log_mel_spectrogram = tf.math.log(mel_spectrogram + params.log_offset)
59+
# log_mel_spectrogram has shape [<# STFT frames>, params.mel_bands]
5360

54-
# Frame spectrogram (shape [<# STFT frames>, MEL_BANDS]) into patches (the
55-
# input examples). Only complete frames are emitted, so if there is less
56-
# than PATCH_WINDOW_SECONDS of waveform then nothing is emitted (to avoid
57-
# this, zero-pad before processing).
61+
# Frame spectrogram (shape [<# STFT frames>, params.mel_bands]) into patches
62+
# (the input examples). Only complete frames are emitted, so if there is
63+
# less than params.patch_window_seconds of waveform then nothing is emitted
64+
# (to avoid this, zero-pad before processing).
5865
spectrogram_hop_length_samples = int(
59-
round(params.SAMPLE_RATE * params.STFT_HOP_SECONDS))
60-
spectrogram_sample_rate = params.SAMPLE_RATE / spectrogram_hop_length_samples
66+
round(params.sample_rate * params.stft_hop_seconds))
67+
spectrogram_sample_rate = params.sample_rate / spectrogram_hop_length_samples
6168
patch_window_length_samples = int(
62-
round(spectrogram_sample_rate * params.PATCH_WINDOW_SECONDS))
69+
round(spectrogram_sample_rate * params.patch_window_seconds))
6370
patch_hop_length_samples = int(
64-
round(spectrogram_sample_rate * params.PATCH_HOP_SECONDS))
71+
round(spectrogram_sample_rate * params.patch_hop_seconds))
6572
features = tf.signal.frame(
6673
signal=log_mel_spectrogram,
6774
frame_length=patch_window_length_samples,
6875
frame_step=patch_hop_length_samples,
6976
axis=0)
70-
# features has shape [<# patches>, <# STFT frames in an patch>, MEL_BANDS]
77+
# features has shape [<# patches>, <# STFT frames in an patch>, params.mel_bands]
7178

7279
return log_mel_spectrogram, features
7380

@@ -78,23 +85,81 @@ def pad_waveform(waveform, params):
7885
# need at least one patch window length of waveform plus enough extra samples
7986
# to complete the final STFT analysis window.
8087
min_waveform_seconds = (
81-
params.PATCH_WINDOW_SECONDS +
82-
params.STFT_WINDOW_SECONDS - params.STFT_HOP_SECONDS)
83-
min_num_samples = tf.cast(min_waveform_seconds * params.SAMPLE_RATE, tf.int32)
84-
num_samples = tf.size(waveform)
88+
params.patch_window_seconds +
89+
params.stft_window_seconds - params.stft_hop_seconds)
90+
min_num_samples = tf.cast(min_waveform_seconds * params.sample_rate, tf.int32)
91+
num_samples = tf.shape(waveform)[0]
8592
num_padding_samples = tf.maximum(0, min_num_samples - num_samples)
8693

8794
# In addition, there might be enough waveform for one or more additional
8895
# patches formed by hopping forward. If there are more samples than one patch,
8996
# round up to an integral number of hops.
9097
num_samples = tf.maximum(num_samples, min_num_samples)
9198
num_samples_after_first_patch = num_samples - min_num_samples
92-
hop_samples = tf.cast(params.PATCH_HOP_SECONDS * params.SAMPLE_RATE, tf.int32)
99+
hop_samples = tf.cast(params.patch_hop_seconds * params.sample_rate, tf.int32)
93100
num_hops_after_first_patch = tf.cast(tf.math.ceil(
94-
tf.math.divide(num_samples_after_first_patch, hop_samples)), tf.int32)
101+
tf.cast(num_samples_after_first_patch, tf.float32) /
102+
tf.cast(hop_samples, tf.float32)), tf.int32)
95103
num_padding_samples += (
96104
hop_samples * num_hops_after_first_patch - num_samples_after_first_patch)
97105

98106
padded_waveform = tf.pad(waveform, [[0, num_padding_samples]],
99107
mode='CONSTANT', constant_values=0.0)
100108
return padded_waveform
109+
110+
111+
def _tflite_stft_magnitude(signal, frame_length, frame_step, fft_length):
112+
"""TF-Lite-compatible version of tf.abs(tf.signal.stft())."""
113+
def _hann_window():
114+
return tf.reshape(
115+
tf.constant(
116+
(0.5 - 0.5 * np.cos(2 * np.pi * np.arange(0, 1.0, 1.0 / frame_length))
117+
).astype(np.float32),
118+
name='hann_window'), [1, frame_length])
119+
120+
def _dft_matrix(dft_length):
121+
"""Calculate the full DFT matrix in NumPy."""
122+
# See https://en.wikipedia.org/wiki/DFT_matrix
123+
omega = (0 + 1j) * 2.0 * np.pi / float(dft_length)
124+
# Don't include 1/sqrt(N) scaling, tf.signal.rfft doesn't apply it.
125+
return np.exp(omega * np.outer(np.arange(dft_length), np.arange(dft_length)))
126+
127+
def _rdft(framed_signal, fft_length):
128+
"""Implement real-input Discrete Fourier Transform by matmul."""
129+
# We are right-multiplying by the DFT matrix, and we are keeping only the
130+
# first half ("positive frequencies"). So discard the second half of rows,
131+
# but transpose the array for right-multiplication. The DFT matrix is
132+
# symmetric, so we could have done it more directly, but this reflects our
133+
# intention better.
134+
complex_dft_matrix_kept_values = _dft_matrix(fft_length)[:(
135+
fft_length // 2 + 1), :].transpose()
136+
real_dft_matrix = tf.constant(
137+
np.real(complex_dft_matrix_kept_values).astype(np.float32),
138+
name='real_dft_matrix')
139+
imag_dft_matrix = tf.constant(
140+
np.imag(complex_dft_matrix_kept_values).astype(np.float32),
141+
name='imaginary_dft_matrix')
142+
signal_frame_length = tf.shape(framed_signal)[-1]
143+
half_pad = (fft_length - signal_frame_length) // 2
144+
padded_frames = tf.pad(
145+
framed_signal,
146+
[
147+
# Don't add any padding in the frame dimension.
148+
[0, 0],
149+
# Pad before and after the signal within each frame.
150+
[half_pad, fft_length - signal_frame_length - half_pad]
151+
],
152+
mode='CONSTANT',
153+
constant_values=0.0)
154+
real_stft = tf.matmul(padded_frames, real_dft_matrix)
155+
imag_stft = tf.matmul(padded_frames, imag_dft_matrix)
156+
return real_stft, imag_stft
157+
158+
def _complex_abs(real, imag):
159+
return tf.sqrt(tf.add(real * real, imag * imag))
160+
161+
framed_signal = tf.signal.frame(signal, frame_length, frame_step)
162+
windowed_signal = framed_signal * _hann_window()
163+
real_stft, imag_stft = _rdft(windowed_signal, fft_length)
164+
stft_magnitude = _complex_abs(real_stft, imag_stft)
165+
return stft_magnitude

‎research/audioset/yamnet/inference.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,14 @@
2323
import soundfile as sf
2424
import tensorflow as tf
2525

26-
import params
26+
import params as yamnet_params
2727
import yamnet as yamnet_model
2828

2929

3030
def main(argv):
3131
assert argv, 'Usage: inference.py <wav file> <wav file> ...'
3232

33+
params = yamnet_params.Params()
3334
yamnet = yamnet_model.yamnet_frames_model(params)
3435
yamnet.load_weights('yamnet.h5')
3536
yamnet_classes = yamnet_model.class_names('yamnet_class_map.csv')
@@ -44,8 +45,8 @@ def main(argv):
4445
# Convert to mono and the sample rate expected by YAMNet.
4546
if len(waveform.shape) > 1:
4647
waveform = np.mean(waveform, axis=1)
47-
if sr != params.SAMPLE_RATE:
48-
waveform = resampy.resample(waveform, sr, params.SAMPLE_RATE)
48+
if sr != params.sample_rate:
49+
waveform = resampy.resample(waveform, sr, params.sample_rate)
4950

5051
# Predict YAMNet classes.
5152
scores, embeddings, spectrogram = yamnet(waveform)

‎research/audioset/yamnet/params.py

+30-18
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,37 @@
1515

1616
"""Hyperparameters for YAMNet."""
1717

18-
# The following hyperparameters (except PATCH_HOP_SECONDS) were used to train YAMNet,
18+
from dataclasses import dataclass
19+
20+
# The following hyperparameters (except patch_hop_seconds) were used to train YAMNet,
1921
# so expect some variability in performance if you change these. The patch hop can
2022
# be changed arbitrarily: a smaller hop should give you more patches from the same
2123
# clip and possibly better performance at a larger computational cost.
22-
SAMPLE_RATE = 16000
23-
STFT_WINDOW_SECONDS = 0.025
24-
STFT_HOP_SECONDS = 0.010
25-
MEL_BANDS = 64
26-
MEL_MIN_HZ = 125
27-
MEL_MAX_HZ = 7500
28-
LOG_OFFSET = 0.001
29-
PATCH_WINDOW_SECONDS = 0.96
30-
PATCH_HOP_SECONDS = 0.48
24+
@dataclass(frozen=True) # Instances of this class are immutable.
25+
class Params:
26+
sample_rate: float = 16000.0
27+
stft_window_seconds: float = 0.025
28+
stft_hop_seconds: float = 0.010
29+
mel_bands: int = 64
30+
mel_min_hz: float = 125.0
31+
mel_max_hz: float = 7500.0
32+
log_offset: float = 0.001
33+
patch_window_seconds: float = 0.96
34+
patch_hop_seconds: float = 0.48
35+
36+
@property
37+
def patch_frames(self):
38+
return int(round(self.patch_window_seconds / self.stft_hop_seconds))
39+
40+
@property
41+
def patch_bands(self):
42+
return self.mel_bands
43+
44+
num_classes: int = 521
45+
conv_padding: str = 'same'
46+
batchnorm_center: bool = True
47+
batchnorm_scale: bool = False
48+
batchnorm_epsilon: float = 1e-4
49+
classifier_activation: str = 'sigmoid'
3150

32-
PATCH_FRAMES = int(round(PATCH_WINDOW_SECONDS / STFT_HOP_SECONDS))
33-
PATCH_BANDS = MEL_BANDS
34-
NUM_CLASSES = 521
35-
CONV_PADDING = 'same'
36-
BATCHNORM_CENTER = True
37-
BATCHNORM_SCALE = False
38-
BATCHNORM_EPSILON = 1e-4
39-
CLASSIFIER_ACTIVATION = 'sigmoid'
51+
tflite_compatible: bool = False

‎research/audioset/yamnet/yamnet.py

+25-25
Original file line numberDiff line numberDiff line change
@@ -22,53 +22,52 @@
2222
from tensorflow.keras import Model, layers
2323

2424
import features as features_lib
25-
import params
2625

2726

28-
def _batch_norm(name):
27+
def _batch_norm(name, params):
2928
def _bn_layer(layer_input):
3029
return layers.BatchNormalization(
3130
name=name,
32-
center=params.BATCHNORM_CENTER,
33-
scale=params.BATCHNORM_SCALE,
34-
epsilon=params.BATCHNORM_EPSILON)(layer_input)
31+
center=params.batchnorm_center,
32+
scale=params.batchnorm_scale,
33+
epsilon=params.batchnorm_epsilon)(layer_input)
3534
return _bn_layer
3635

3736

38-
def _conv(name, kernel, stride, filters):
37+
def _conv(name, kernel, stride, filters, params):
3938
def _conv_layer(layer_input):
4039
output = layers.Conv2D(name='{}/conv'.format(name),
4140
filters=filters,
4241
kernel_size=kernel,
4342
strides=stride,
44-
padding=params.CONV_PADDING,
43+
padding=params.conv_padding,
4544
use_bias=False,
4645
activation=None)(layer_input)
47-
output = _batch_norm(name='{}/conv/bn'.format(name))(output)
46+
output = _batch_norm('{}/conv/bn'.format(name), params)(output)
4847
output = layers.ReLU(name='{}/relu'.format(name))(output)
4948
return output
5049
return _conv_layer
5150

5251

53-
def _separable_conv(name, kernel, stride, filters):
52+
def _separable_conv(name, kernel, stride, filters, params):
5453
def _separable_conv_layer(layer_input):
5554
output = layers.DepthwiseConv2D(name='{}/depthwise_conv'.format(name),
5655
kernel_size=kernel,
5756
strides=stride,
5857
depth_multiplier=1,
59-
padding=params.CONV_PADDING,
58+
padding=params.conv_padding,
6059
use_bias=False,
6160
activation=None)(layer_input)
62-
output = _batch_norm(name='{}/depthwise_conv/bn'.format(name))(output)
61+
output = _batch_norm('{}/depthwise_conv/bn'.format(name), params)(output)
6362
output = layers.ReLU(name='{}/depthwise_conv/relu'.format(name))(output)
6463
output = layers.Conv2D(name='{}/pointwise_conv'.format(name),
6564
filters=filters,
6665
kernel_size=(1, 1),
6766
strides=1,
68-
padding=params.CONV_PADDING,
67+
padding=params.conv_padding,
6968
use_bias=False,
7069
activation=None)(output)
71-
output = _batch_norm(name='{}/pointwise_conv/bn'.format(name))(output)
70+
output = _batch_norm('{}/pointwise_conv/bn'.format(name), params)(output)
7271
output = layers.ReLU(name='{}/pointwise_conv/relu'.format(name))(output)
7372
return output
7473
return _separable_conv_layer
@@ -93,25 +92,24 @@ def _separable_conv_layer(layer_input):
9392
]
9493

9594

96-
def yamnet(features):
95+
def yamnet(features, params):
9796
"""Define the core YAMNet mode in Keras."""
9897
net = layers.Reshape(
99-
(params.PATCH_FRAMES, params.PATCH_BANDS, 1),
100-
input_shape=(params.PATCH_FRAMES, params.PATCH_BANDS))(features)
98+
(params.patch_frames, params.patch_bands, 1),
99+
input_shape=(params.patch_frames, params.patch_bands))(features)
101100
for (i, (layer_fun, kernel, stride, filters)) in enumerate(_YAMNET_LAYER_DEFS):
102-
net = layer_fun('layer{}'.format(i + 1), kernel, stride, filters)(net)
101+
net = layer_fun('layer{}'.format(i + 1), kernel, stride, filters, params)(net)
103102
embeddings = layers.GlobalAveragePooling2D()(net)
104-
logits = layers.Dense(units=params.NUM_CLASSES, use_bias=True)(embeddings)
105-
predictions = layers.Activation(activation=params.CLASSIFIER_ACTIVATION)(logits)
103+
logits = layers.Dense(units=params.num_classes, use_bias=True)(embeddings)
104+
predictions = layers.Activation(activation=params.classifier_activation)(logits)
106105
return predictions, embeddings
107106

108107

109-
def yamnet_frames_model(feature_params):
108+
def yamnet_frames_model(params):
110109
"""Defines the YAMNet waveform-to-class-scores model.
111110
112111
Args:
113-
feature_params: An object with parameter fields to control the feature
114-
calculation.
112+
params: An instance of Params containing hyperparameters.
115113
116114
Returns:
117115
A model accepting (num_samples,) waveform input and emitting:
@@ -120,10 +118,10 @@ def yamnet_frames_model(feature_params):
120118
- log_mel_spectrogram: (num_spectrogram_frames, num_mel_bins) spectrogram feature matrix
121119
"""
122120
waveform = layers.Input(batch_shape=(None,), dtype=tf.float32)
123-
waveform_padded = features_lib.pad_waveform(waveform, feature_params)
121+
waveform_padded = features_lib.pad_waveform(waveform, params)
124122
log_mel_spectrogram, features = features_lib.waveform_to_log_mel_spectrogram_patches(
125-
waveform_padded, feature_params)
126-
predictions, embeddings = yamnet(features)
123+
waveform_padded, params)
124+
predictions, embeddings = yamnet(features, params)
127125
frames_model = Model(
128126
name='yamnet_frames', inputs=waveform,
129127
outputs=[predictions, embeddings, log_mel_spectrogram])
@@ -132,6 +130,8 @@ def yamnet_frames_model(feature_params):
132130

133131
def class_names(class_map_csv):
134132
"""Read the class name definition file and return a list of strings."""
133+
if tf.is_tensor(class_map_csv):
134+
class_map_csv = class_map_csv.numpy()
135135
with open(class_map_csv) as csv_file:
136136
reader = csv.reader(csv_file)
137137
next(reader) # Skip header

‎research/audioset/yamnet/yamnet_test.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -23,46 +23,46 @@
2323

2424
class YAMNetTest(tf.test.TestCase):
2525

26-
_yamnet_graph = None
26+
_params = None
2727
_yamnet = None
2828
_yamnet_classes = None
2929

3030
@classmethod
3131
def setUpClass(cls):
32-
super(YAMNetTest, cls).setUpClass()
33-
cls._yamnet_graph = tf.Graph()
34-
with cls._yamnet_graph.as_default():
35-
cls._yamnet = yamnet.yamnet_frames_model(params)
36-
cls._yamnet.load_weights('yamnet.h5')
37-
cls._yamnet_classes = yamnet.class_names('yamnet_class_map.csv')
32+
super().setUpClass()
33+
cls._params = params.Params()
34+
cls._yamnet = yamnet.yamnet_frames_model(cls._params)
35+
cls._yamnet.load_weights('yamnet.h5')
36+
cls._yamnet_classes = yamnet.class_names('yamnet_class_map.csv')
3837

3938
def clip_test(self, waveform, expected_class_name, top_n=10):
4039
"""Run the model on the waveform, check that expected class is in top-n."""
41-
with YAMNetTest._yamnet_graph.as_default():
42-
prediction = np.mean(YAMNetTest._yamnet.predict(
43-
np.reshape(waveform, [1, -1]), steps=1)[0], axis=0)
44-
top_n_class_names = YAMNetTest._yamnet_classes[
45-
np.argsort(prediction)[-top_n:]]
46-
self.assertIn(expected_class_name, top_n_class_names)
40+
predictions, embeddings, log_mel_spectrogram = YAMNetTest._yamnet(waveform)
41+
clip_predictions = np.mean(predictions, axis=0)
42+
top_n_indices = np.argsort(clip_predictions)[-top_n:]
43+
top_n_scores = clip_predictions[top_n_indices]
44+
top_n_class_names = YAMNetTest._yamnet_classes[top_n_indices]
45+
top_n_predictions = list(zip(top_n_class_names, top_n_scores))
46+
self.assertIn(expected_class_name, top_n_class_names,
47+
'Did not find expected class {} in top {} predictions: {}'.format(
48+
expected_class_name, top_n, top_n_predictions))
4749

4850
def testZeros(self):
4951
self.clip_test(
50-
waveform=np.zeros((1, int(3 * params.SAMPLE_RATE))),
52+
waveform=np.zeros((int(3 * YAMNetTest._params.sample_rate),)),
5153
expected_class_name='Silence')
5254

5355
def testRandom(self):
5456
np.random.seed(51773) # Ensure repeatability.
5557
self.clip_test(
5658
waveform=np.random.uniform(-1.0, +1.0,
57-
(1, int(3 * params.SAMPLE_RATE))),
59+
(int(3 * YAMNetTest._params.sample_rate),)),
5860
expected_class_name='White noise')
5961

6062
def testSine(self):
6163
self.clip_test(
62-
waveform=np.reshape(
63-
np.sin(2 * np.pi * 440 * np.linspace(
64-
0, 3, int(3 *params.SAMPLE_RATE))),
65-
[1, -1]),
64+
waveform=np.sin(2 * np.pi * 440 *
65+
np.arange(0, 3, 1 / YAMNetTest._params.sample_rate)),
6666
expected_class_name='Sine wave')
6767

6868

0 commit comments

Comments
 (0)
Please sign in to comment.