|
| 1 | +import tensorflow as tf |
| 2 | +import numpy as np |
| 3 | +from layerNormedGRU import layerNormedGRU |
| 4 | + |
| 5 | +class model: |
| 6 | + |
| 7 | + def __init__(self, num_class, topk_paths = 10): |
| 8 | + self.xs = tf.placeholder(tf.float32, [None, 1000, 161]) |
| 9 | + self.ys = tf.sparse_placeholder(tf.int32) |
| 10 | + self.learning_rate = tf.placeholder(tf.float32) |
| 11 | + self.seq_len = tf.placeholder(tf.int32, [None]) |
| 12 | + self.isTrain = tf.placeholder(tf.bool, name='phase') |
| 13 | + |
| 14 | + xs_input = tf.expand_dims(self.xs, 3) |
| 15 | + |
| 16 | + conv1 = self._nn_conv_bn_layer(xs_input, 'conv_1', [11, 41, 1, 32], [3, 2]) |
| 17 | + conv2 = self._nn_conv_bn_layer(conv1, 'conv_2', [11, 21, 32, 64], [1, 2]) |
| 18 | + conv_out = tf.reshape(conv2, [-1, 334, 41*64]) |
| 19 | + biRNN1 = self._biRNN_bn_layer(conv_out, 'biRNN_1', 1024) |
| 20 | + biRNN2 = self._biRNN_bn_layer(biRNN1, 'biRNN_2', 1024) |
| 21 | + biRNN3 = self._biRNN_bn_layer(biRNN2, 'biRNN_3', 1024) |
| 22 | + biRNN4 = self._biRNN_bn_layer(biRNN3, 'biRNN_4', 1024) |
| 23 | + biRNN5 = self._biRNN_bn_layer(biRNN4, 'biRNN_5', 1024) |
| 24 | + |
| 25 | + self.phonemes = tf.layers.dense(biRNN5, num_class) |
| 26 | + |
| 27 | + # Notes: tf.nn.ctc_loss performs the softmax operation for you, so |
| 28 | + # inputs should be e.g. linear projections of outputs by an LSTM. |
| 29 | + self.loss = tf.reduce_mean(tf.nn.ctc_loss(labels=self.ys, inputs=self.phonemes, sequence_length=self.seq_len, |
| 30 | + ignore_longer_outputs_than_inputs=True, time_major=False)) |
| 31 | + |
| 32 | + optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.7, beta2=0.9) |
| 33 | + update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) |
| 34 | + with tf.control_dependencies(update_ops): |
| 35 | + gvs = optimizer.compute_gradients(self.loss) |
| 36 | + capped_gvs = [(tf.clip_by_value(grad, -400., 400.), var) for grad, var in gvs if grad is not None] |
| 37 | + self.train_op = optimizer.apply_gradients(capped_gvs) |
| 38 | + |
| 39 | + self.prediction, log_prob = tf.nn.ctc_beam_search_decoder(tf.transpose(self.phonemes,[1,0,2]), self.seq_len, top_paths=topk_paths, merge_repeated=False) |
| 40 | + |
| 41 | + self.loss_summary = tf.summary.scalar("loss", self.loss) |
| 42 | + self.merged = tf.summary.merge_all() |
| 43 | + |
| 44 | + def _nn_conv_bn_layer(self, inputs, scope, shape, strides): |
| 45 | + with tf.variable_scope(scope): |
| 46 | + W_conv = tf.get_variable("W", shape=shape, initializer=tf.contrib.layers.xavier_initializer()) |
| 47 | + h_conv = tf.nn.conv2d(inputs, W_conv, strides=[1, strides[0], strides[1], 1], padding='SAME', name="conv2d") |
| 48 | + b = tf.get_variable("bias" , shape=[shape[3]], initializer=tf.contrib.layers.xavier_initializer()) |
| 49 | + h_bn = tf.layers.batch_normalization(h_conv+b, training = self.isTrain) |
| 50 | + h_relu = tf.nn.relu6(h_bn, name="relu6") |
| 51 | + return h_relu |
| 52 | + |
| 53 | + def _biRNN_bn_layer(self, input, scope, hidden_units, cell = "LayerNormedGRU"): |
| 54 | + with tf.variable_scope(scope): |
| 55 | + if cell == 'GRU': |
| 56 | + fw_cell = tf.nn.rnn_cell.GRUCell(hidden_units, activation=tf.nn.relu, name = 'fw_cell') |
| 57 | + bw_cell = tf.nn.rnn_cell.GRUCell(hidden_units, activation=tf.nn.relu, name = 'bw_cell') |
| 58 | + elif cell == 'LSTM': |
| 59 | + fw_cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_units, activation=tf.nn.relu, name = 'fw_cell') |
| 60 | + bw_cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_units, activation=tf.nn.relu, name = 'bw_cell') |
| 61 | + elif cell == 'vanila': |
| 62 | + fw_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_units, activation=tf.nn.relu, name = 'fw_cell') |
| 63 | + bw_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_units, activation=tf.nn.relu, name = 'bw_cell') |
| 64 | + elif cell == 'LayerNormedGRU': |
| 65 | + with tf.variable_scope('fw_cell'): |
| 66 | + fw_cell = layerNormedGRU(hidden_units, activation=tf.nn.relu) |
| 67 | + with tf.variable_scope('bw_cell'): |
| 68 | + bw_cell = layerNormedGRU(hidden_units, activation=tf.nn.relu) |
| 69 | + else: |
| 70 | + raise ValueError("Invalid cell type: "+str(cell)) |
| 71 | + |
| 72 | + (output_fw, output_bw), _ = tf.nn.bidirectional_dynamic_rnn(fw_cell, bw_cell, input, dtype=tf.float32, scope="bi_dynamic_rnn") |
| 73 | + # output_fw_bn = tf.layers.batch_normalization(output_fw, training = self.isTrain, name = 'output_fw_bn') |
| 74 | + # output_bw_bn = tf.layers.batch_normalization(output_bw, training = self.isTrain, name = 'output_bw_bn') |
| 75 | + # bilstm_outputs_concat_1 = tf.concat([output_fw_bn, output_bw_bn], 2) |
| 76 | + bilstm_outputs_concat_1 = tf.concat([output_fw, output_bw], 2) |
| 77 | + return bilstm_outputs_concat_1 |
| 78 | + |
| 79 | + def train(self, sess, learning_rate, xs, ys): |
| 80 | + _, loss, summary = sess.run([self.train_op, self.loss, self.merged], feed_dict = {self.isTrain: True, self.learning_rate: learning_rate, self.seq_len: np.ones(xs.shape[0])*334, self.xs: xs, self.ys: ys}) |
| 81 | + return loss, summary |
| 82 | + |
| 83 | + def get_loss(self, sess, xs, ys): |
| 84 | + loss = sess.run(self.loss, feed_dict = {self.isTrain: False, self.seq_len: np.ones(xs.shape[0])*334, self.xs: xs, self.ys: ys}) |
| 85 | + return loss |
| 86 | + |
| 87 | + def predict(self, sess, xs): |
| 88 | + prediction = sess.run(self.prediction, feed_dict = {self.isTrain: False, self.seq_len: np.ones(xs.shape[0])*334, self.xs: xs}) |
| 89 | + return prediction |
0 commit comments