|
| 1 | +#!/usr/bin/python3 |
| 2 | + |
| 3 | +## Copyright (C) 2016 D S Pavan Kumar |
| 4 | +## dspavankumar [at] gmail [dot] com |
| 5 | +## |
| 6 | +## This program is free software: you can redistribute it and/or modify |
| 7 | +## it under the terms of the GNU General Public License as published by |
| 8 | +## the Free Software Foundation, either version 3 of the License, or |
| 9 | +## (at your option) any later version. |
| 10 | +## |
| 11 | +## This program is distributed in the hope that it will be useful, |
| 12 | +## but WITHOUT ANY WARRANTY; without even the implied warranty of |
| 13 | +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| 14 | +## GNU General Public License for more details. |
| 15 | +## |
| 16 | +## You should have received a copy of the GNU General Public License |
| 17 | +## along with this program. If not, see <http://www.gnu.org/licenses/>. |
| 18 | + |
| 19 | + |
| 20 | +## NOTE: This script has limited functionality. It currently converts |
| 21 | +## feedforward networks with relu and softmax layers in HDF5 format |
| 22 | +## to the standard Kaldi's nnet3 "raw" format. Call this script from |
| 23 | +## steps_kt/saveModelNnet3.sh to get a complete model. |
| 24 | + |
| 25 | +import keras |
| 26 | +import numpy |
| 27 | +import sys |
| 28 | + |
| 29 | +def saveModel (model, fileName): |
| 30 | + with open (fileName, 'w') as f: |
| 31 | + f.write ('<Nnet3> \n') |
| 32 | + |
| 33 | + ## Write the component descriptions |
| 34 | + f.write ('input-node name=input dim=%d\n' % m.input_shape[-1]) |
| 35 | + prevLayerName = 'input' |
| 36 | + num_components = 0 |
| 37 | + for layer in model.layers: |
| 38 | + if layer.name.startswith ('dense'): |
| 39 | + f.write ('component-node name=%s.affine component=%s.affine input=%s\n' % (layer.name, layer.name, prevLayerName)) |
| 40 | + num_components += 1 |
| 41 | + activation_text = layer.get_config()['activation'] |
| 42 | + if activation_text != 'linear': |
| 43 | + f.write ('component-node name=%s.%s component=%s.%s input=%s.affine\n' % (layer.name, activation_text, layer.name, activation_text, layer.name)) |
| 44 | + num_components += 1 |
| 45 | + prevLayerName = layer.name + '.' + activation_text |
| 46 | + f.write('output-node name=output input=%s objective=linear\n' % prevLayerName) |
| 47 | + |
| 48 | + f.write('\n<NumComponents> %d\n' % num_components) |
| 49 | + |
| 50 | + ## Write the layer values |
| 51 | + for layer in model.layers: |
| 52 | + if not layer.name.startswith ('dense'): |
| 53 | + raise TypeError ('Unknown layer type: ' + layer.name) |
| 54 | + |
| 55 | + f.write ('<ComponentName> %s.affine <NaturalGradientAffineComponent> <MaxChange> 2.0 <LearningRate> 0.001 <LinearParams> [ \n ' % (layer.name)) |
| 56 | + for row in layer.get_weights()[0].T: |
| 57 | + row.tofile (f, format="%e", sep=' ') |
| 58 | + f.write (' \n ') |
| 59 | + f.write ('] \n <BiasParams> [ ') |
| 60 | + layer.get_weights()[1].tofile (f, format="%e", sep=' ') |
| 61 | + f.write (' ] \n') |
| 62 | + f.write ('<RankIn> 20 <RankOut> 80 <UpdatePeriod> 4 <NumSamplesHistory> 2000 <Alpha> 4 <IsGradient> F </NaturalGradientAffineComponent>\n') |
| 63 | + |
| 64 | + ## Deal with the activation |
| 65 | + activation_text = layer.get_config()['activation'] |
| 66 | + if activation_text == 'relu': |
| 67 | + f.write ('<ComponentName> %s.relu <RectifiedLinearComponent> <Dim> %d <ValueAvg> [ ] <DerivAvg> [ ] <Count> 0 <NumDimsSelfRepaired> 0 <NumDimsProcessed> 0 </RectifiedLinearComponent>\n' % (layer.name, layer.output_shape[-1])) |
| 68 | + elif activation_text == 'softmax': |
| 69 | + f.write ('<ComponentName> %s.softmax <LogSoftmaxComponent> <Dim> %d <ValueAvg> [ ] <DerivAvg> [ ] <Count> 0 <NumDimsSelfRepaired> 0 <NumDimsProcessed> 0 </LogSoftmaxComponent>\n' % (layer.name, layer.output_shape[-1])) |
| 70 | + else: |
| 71 | + raise TypeError ('Unknown/unhandled activation: ' + activation_text) |
| 72 | + f.write ('</Nnet3> \n') |
| 73 | + |
| 74 | +## Save h5 model in nnet3 format |
| 75 | +if __name__ == '__main__': |
| 76 | + h5model = sys.argv[1] |
| 77 | + nnet3 = sys.argv[2] |
| 78 | + m = keras.models.load_model (h5model) |
| 79 | + saveModel(m, nnet3) |
0 commit comments