From f4223ee1109db2dbdd944a15fe9046579b587249 Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 24 Aug 2021 21:09:27 +0800 Subject: [PATCH] Add TDNN-LSTM-CTC Results (#25) * Add tdnn-lstm pretrained model and results * Add docs for TDNN-LSTM-CTC * Minor fix * Fix typo * Fix style checking --- .flake8 | 1 + .../recipes/librispeech/tdnn_lstm_ctc.rst | 322 +++++++++++++++++- egs/librispeech/ASR/RESULTS.md | 25 +- egs/librispeech/ASR/conformer_ctc/README.md | 1 - egs/librispeech/ASR/conformer_ctc/decode.py | 2 +- .../ASR/tdnn_lstm_ctc/Pre-trained.md | 270 +++++++++++++++ egs/librispeech/ASR/tdnn_lstm_ctc/README.md | 4 +- egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 19 +- .../ASR/tdnn_lstm_ctc/pretrained.py | 277 +++++++++++++++ 9 files changed, 915 insertions(+), 6 deletions(-) create mode 100644 egs/librispeech/ASR/tdnn_lstm_ctc/Pre-trained.md create mode 100644 egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py diff --git a/.flake8 b/.flake8 index 3f1227b9b9..b2bbb0a44b 100644 --- a/.flake8 +++ b/.flake8 @@ -5,6 +5,7 @@ max-line-length = 80 per-file-ignores = # line too long egs/librispeech/ASR/conformer_ctc/conformer.py: E501, + egs/librispeech/ASR/conformer_ctc/decode.py: E501, exclude = .git, diff --git a/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst b/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst index 373bb59059..a59f34db76 100644 --- a/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst +++ b/docs/source/recipes/librispeech/tdnn_lstm_ctc.rst @@ -1,2 +1,322 @@ -TDNN LSTM CTC +TDNN-LSTM-CTC ============= + +This tutorial shows you how to run a TDNN-LSTM-CTC model with the `LibriSpeech `_ dataset. + + +.. HINT:: + + We assume you have read the page :ref:`install icefall` and have setup + the environment for ``icefall``. + + +Data preparation +---------------- + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh + +The script ``./prepare.sh`` handles the data preparation for you, **automagically**. +All you need to do is to run it. + +The data preparation contains several stages, you can use the following two +options: + + - ``--stage`` + - ``--stop-stage`` + +to control which stage(s) should be run. By default, all stages are executed. + + +For example, + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./prepare.sh --stage 0 --stop-stage 0 + +means to run only stage 0. + +To run stage 2 to stage 5, use: + +.. code-block:: bash + + $ ./prepare.sh --stage 2 --stop-stage 5 + + +Training +-------- + +Now describing the training of TDNN-LSTM-CTC model, contained in +the `tdnn_lstm_ctc `_ +folder. + +The command to run the training part is: + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ export CUDA_VISIBLE_DEVICES="0,1,2,3" + $ ./tdnn_lstm_ctc/train.py --world-size 4 + +By default, it will run ``20`` epochs. Training logs and checkpoints are saved +in ``tdnn_lstm_ctc/exp``. + +In ``tdnn_lstm_ctc/exp``, you will find the following files: + + - ``epoch-0.pt``, ``epoch-1.pt``, ..., ``epoch-19.pt`` + + These are checkpoint files, containing model ``state_dict`` and optimizer ``state_dict``. + To resume training from some checkpoint, say ``epoch-10.pt``, you can use: + + .. code-block:: bash + + $ ./tdnn_lstm_ctc/train.py --start-epoch 11 + + - ``tensorboard/`` + + This folder contains TensorBoard logs. Training loss, validation loss, learning + rate, etc, are recorded in these logs. You can visualize them by: + + .. code-block:: bash + + $ cd tdnn_lstm_ctc/exp/tensorboard + $ tensorboard dev upload --logdir . --description "TDNN LSTM training for librispeech with icefall" + + - ``log/log-train-xxxx`` + + It is the detailed training log in text format, same as the one + you saw printed to the console during training. + + +To see available training options, you can use: + +.. code-block:: bash + + $ ./tdnn_lstm_ctc/train.py --help + +Other training options, e.g., learning rate, results dir, etc., are +pre-configured in the function ``get_params()`` +in `tdnn_lstm_ctc/train.py `_. +Normally, you don't need to change them. You can change them by modifying the code, if +you want. + +Decoding +-------- + +The decoding part uses checkpoints saved by the training part, so you have +to run the training part first. + +The command for decoding is: + +.. code-block:: bash + + $ export CUDA_VISIBLE_DEVICES="0" + $ ./tdnn_lstm_ctc/decode.py + +You will see the WER in the output log. + +Decoded results are saved in ``tdnn_lstm_ctc/exp``. + +.. code-block:: bash + + $ ./tdnn_lstm_ctc/decode.py --help + +shows you the available decoding options. + +Some commonly used options are: + + - ``--epoch`` + + You can select which checkpoint to be used for decoding. + For instance, ``./tdnn_lstm_ctc/decode.py --epoch 10`` means to use + ``./tdnn_lstm_ctc/exp/epoch-10.pt`` for decoding. + + - ``--avg`` + + It's related to model averaging. It specifies number of checkpoints + to be averaged. The averaged model is used for decoding. + For example, the following command: + + .. code-block:: bash + + $ ./tdnn_lstm_ctc/decode.py --epoch 10 --avg 3 + + uses the average of ``epoch-8.pt``, ``epoch-9.pt`` and ``epoch-10.pt`` + for decoding. + + - ``--export`` + + If it is ``True``, i.e., ``./tdnn_lstm_ctc/decode.py --export 1``, the code + will save the averaged model to ``tdnn_lstm_ctc/exp/pretrained.pt``. + See :ref:`tdnn_lstm_ctc use a pre-trained model` for how to use it. + +.. HINT:: + + There are several decoding methods provided in `tdnn_lstm_ctc/decode.py `_, you can change the decoding method by modifying ``method`` parameter in function ``get_params()``. + + +.. _tdnn_lstm_ctc use a pre-trained model: + +Pre-trained Model +----------------- + +We have uploaded the pre-trained model to +``_. + +The following shows you how to use the pre-trained model. + +Download the pre-trained model +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ mkdir tmp + $ cd tmp + $ git lfs install + $ git clone https://huggingface.co/pkufool/icefall_asr_librispeech_tdnn-lstm_ctc + +.. CAUTION:: + + You have to use ``git lfs`` to download the pre-trained model. + +After downloading, you will have the following files: + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ tree tmp + +.. code-block:: bash + + tmp/ + `-- icefall_asr_librispeech_tdnn-lstm_ctc + |-- README.md + |-- data + | |-- lang_phone + | | |-- HLG.pt + | | |-- tokens.txt + | | `-- words.txt + | `-- lm + | `-- G_4_gram.pt + |-- exp + | `-- pretrained.pt + `-- test_wavs + |-- 1089-134686-0001.flac + |-- 1221-135766-0001.flac + |-- 1221-135766-0002.flac + `-- trans.txt + + 6 directories, 10 files + + +Download kaldifeat +~~~~~~~~~~~~~~~~~~ + +`kaldifeat `_ is used for extracting +features from a single or multiple sound files. Please refer to +``_ to install ``kaldifeat`` first. + +Inference with a pre-trained model +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./tdnn_lstm_ctc/pretrained.py --help + +shows the usage information of ``./tdnn_lstm_ctc/pretrained.py``. + +To decode with ``1best`` method, we can use: + +.. code-block:: bash + + ./tdnn_lstm_ctc/pretrained.py \ + --checkpoint ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/exp/pretraind.pt \ + --words-file ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/words.txt \ + --HLG ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt \ + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac \ + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac \ + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac + +The output is: + +.. code-block:: + + 2021-08-24 16:57:13,315 INFO [pretrained.py:168] device: cuda:0 + 2021-08-24 16:57:13,315 INFO [pretrained.py:170] Creating model + 2021-08-24 16:57:18,331 INFO [pretrained.py:182] Loading HLG from ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt + 2021-08-24 16:57:27,581 INFO [pretrained.py:199] Constructing Fbank computer + 2021-08-24 16:57:27,584 INFO [pretrained.py:209] Reading sound files: ['./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac'] + 2021-08-24 16:57:27,599 INFO [pretrained.py:215] Decoding started + 2021-08-24 16:57:27,791 INFO [pretrained.py:245] Use HLG decoding + 2021-08-24 16:57:28,098 INFO [pretrained.py:266] + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac: + AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS + + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac: + GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN + + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac: + YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION + + + 2021-08-24 16:57:28,099 INFO [pretrained.py:268] Decoding Done + + +To decode with ``whole-lattice-rescoring`` methond, you can use + +.. code-block:: bash + + ./tdnn_lstm_ctc/pretrained.py \ + --checkpoint ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/exp/pretraind.pt \ + --words-file ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/words.txt \ + --HLG ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt \ + --method whole-lattice-rescoring \ + --G ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lm/G_4_gram.pt \ + --ngram-lm-scale 0.8 \ + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac \ + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac \ + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac + +The decoding output is: + +.. code-block:: + + 2021-08-24 16:39:24,725 INFO [pretrained.py:168] device: cuda:0 + 2021-08-24 16:39:24,725 INFO [pretrained.py:170] Creating model + 2021-08-24 16:39:29,403 INFO [pretrained.py:182] Loading HLG from ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt + 2021-08-24 16:39:40,631 INFO [pretrained.py:190] Loading G from ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lm/G_4_gram.pt + 2021-08-24 16:39:53,098 INFO [pretrained.py:199] Constructing Fbank computer + 2021-08-24 16:39:53,107 INFO [pretrained.py:209] Reading sound files: ['./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac'] + 2021-08-24 16:39:53,121 INFO [pretrained.py:215] Decoding started + 2021-08-24 16:39:53,443 INFO [pretrained.py:250] Use HLG decoding + LM rescoring + 2021-08-24 16:39:54,010 INFO [pretrained.py:266] + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac: + AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS + + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac: + GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN + + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac: + YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION + + + 2021-08-24 16:39:54,010 INFO [pretrained.py:268] Decoding Done + + +Colab notebook +-------------- + +We provide a colab notebook for decoding with pre-trained model. + +|librispeech tdnn_lstm_ctc colab notebook| + +.. |librispeech tdnn_lstm_ctc colab notebook| image:: https://colab.research.google.com/assets/colab-badge.svg + :target: https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd + + +**Congratulations!** You have finished the TDNN-LSTM-CTC recipe on librispeech in ``icefall``. diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 159147a3e5..dfc412672b 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -6,7 +6,7 @@ TensorBoard log is available at https://tensorboard.dev/experiment/GnRzq8WWQW62dK4bklXBTg/#scalars -Pretrained model is available at https://huggingface.co/pkufool/conformer_ctc +Pretrained model is available at https://huggingface.co/pkufool/icefall_asr_librispeech_conformer_ctc The best decoding results (WER) are listed below, we got this results by averaging models from epoch 15 to 34, and using `attention-decoder` decoder with num_paths equals to 100. @@ -21,3 +21,26 @@ To get more unique paths, we scaled the lattice.scores with 0.5 (see https://git |test-clean|1.3|1.2| |test-other|1.2|1.1| + +### LibriSpeech training results (Tdnn-Lstm) +#### 2021-08-24 + +(Wei Kang): Result of phone based Tdnn-Lstm model. + +Icefall version: https://github.com/k2-fsa/icefall/commit/caa0b9e9425af27e0c6211048acb55a76ed5d315 + +Pretrained model is available at https://huggingface.co/pkufool/icefall_asr_librispeech_tdnn-lstm_ctc + +The best decoding results (WER) are listed below, we got this results by averaging models from epoch 19 to 14, and using `whole-lattice-rescoring` decoding method. + +||test-clean|test-other| +|--|--|--| +|WER| 6.59% | 17.69% | + +We searched the lm_score_scale for best results, the scales that produced the WER above are also listed below. + +||lm_scale| +|--|--| +|test-clean|0.8| +|test-other|0.9| + diff --git a/egs/librispeech/ASR/conformer_ctc/README.md b/egs/librispeech/ASR/conformer_ctc/README.md index 0092fd14e9..23b51167b6 100644 --- a/egs/librispeech/ASR/conformer_ctc/README.md +++ b/egs/librispeech/ASR/conformer_ctc/README.md @@ -1,4 +1,3 @@ - Please visit for how to run this recipe. diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 321ce970e4..fc1d2bf2fb 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -83,7 +83,7 @@ def get_parser(): - (3) nbest-rescoring. Extract n paths from the decoding lattice, rescore them with an n-gram LM (e.g., a 4-gram LM), the path with the highest score is the decoding result. - - (4) whole-lattice. Rescore the decoding lattice with an n-gram LM + - (4) whole-lattice-rescoring. Rescore the decoding lattice with an n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice is the decoding result. - (5) attention-decoder. Extract n paths from the LM rescored lattice, diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/Pre-trained.md b/egs/librispeech/ASR/tdnn_lstm_ctc/Pre-trained.md new file mode 100644 index 0000000000..83e98b37c0 --- /dev/null +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/Pre-trained.md @@ -0,0 +1,270 @@ + +# How to use a pre-trained model to transcribe a sound file or multiple sound files + +(See the bottom of this document for the link to a colab notebook.) + +You need to prepare 4 files: + + - a model checkpoint file, e.g., epoch-20.pt + - HLG.pt, the decoding graph + - words.txt, the word symbol table + - a sound file, whose sampling rate has to be 16 kHz. + Supported formats are those supported by `torchaudio.load()`, + e.g., wav and flac. + +Also, you need to install `kaldifeat`. Please refer to + for installation. + +```bash +./tdnn_lstm_ctc/pretrained.py --help +``` + +displays the help information. + +## HLG decoding + +Once you have the above files ready and have `kaldifeat` installed, +you can run: + +```bash +./tdnn_lstm_ctc/pretrained.py \ + --checkpoint /path/to/your/checkpoint.pt \ + --words-file /path/to/words.txt \ + --HLG /path/to/HLG.pt \ + /path/to/your/sound.wav +``` + +and you will see the transcribed result. + +If you want to transcribe multiple files at the same time, you can use: + +```bash +./tdnn_lstm_ctc/pretrained.py \ + --checkpoint /path/to/your/checkpoint.pt \ + --words-file /path/to/words.txt \ + --HLG /path/to/HLG.pt \ + /path/to/your/sound1.wav \ + /path/to/your/sound2.wav \ + /path/to/your/sound3.wav +``` + +**Note**: This is the fastest decoding method. + +## HLG decoding + LM rescoring + +`./tdnn_lstm_ctc/pretrained.py` also supports `whole lattice LM rescoring`. + +To use whole lattice LM rescoring, you also need the following files: + + - G.pt, e.g., `data/lm/G_4_gram.pt` if you have run `./prepare.sh` + +The command to run decoding with LM rescoring is: + +```bash +./tdnn_lstm_ctc/pretrained.py \ + --checkpoint /path/to/your/checkpoint.pt \ + --words-file /path/to/words.txt \ + --HLG /path/to/HLG.pt \ + --method whole-lattice-rescoring \ + --G data/lm/G_4_gram.pt \ + --ngram-lm-scale 0.8 \ + /path/to/your/sound1.wav \ + /path/to/your/sound2.wav \ + /path/to/your/sound3.wav +``` + +# Decoding with a pre-trained model in action + +We have uploaded a pre-trained model to + +The following shows the steps about the usage of the provided pre-trained model. + +### (1) Download the pre-trained model + +```bash +sudo apt-get install git-lfs +cd /path/to/icefall/egs/librispeech/ASR +git lfs install +mkdir tmp +cd tmp +git clone https://huggingface.co/pkufool/icefall_asr_librispeech_tdnn-lstm_ctc +``` + +**CAUTION**: You have to install `git-lfs` to download the pre-trained model. + +You will find the following files: + +``` +tmp/ +`-- icefall_asr_librispeech_tdnn-lstm_ctc + |-- README.md + |-- data + | |-- lang_phone + | | |-- HLG.pt + | | |-- tokens.txt + | | `-- words.txt + | `-- lm + | `-- G_4_gram.pt + |-- exp + | `-- pretrained.pt + `-- test_wavs + |-- 1089-134686-0001.flac + |-- 1221-135766-0001.flac + |-- 1221-135766-0002.flac + `-- trans.txt + +6 directories, 10 files +``` + +**File descriptions**: + + - `data/lang_phone/HLG.pt` + + It is the decoding graph. + + - `data/lang_phone/tokens.txt` + + It contains tokens and their IDs. + + - `data/lang_phone/words.txt` + + It contains words and their IDs. + + - `data/lm/G_4_gram.pt` + + It is a 4-gram LM, useful for LM rescoring. + + - `exp/pretrained.pt` + + It contains pre-trained model parameters, obtained by averaging + checkpoints from `epoch-14.pt` to `epoch-19.pt`. + Note: We have removed optimizer `state_dict` to reduce file size. + + - `test_waves/*.flac` + + It contains some test sound files from LibriSpeech `test-clean` dataset. + + - `test_waves/trans.txt` + + It contains the reference transcripts for the sound files in `test_waves/`. + +The information of the test sound files is listed below: + +``` +$ soxi tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/*.flac + +Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac' +Channels : 1 +Sample Rate : 16000 +Precision : 16-bit +Duration : 00:00:06.62 = 106000 samples ~ 496.875 CDDA sectors +File Size : 116k +Bit Rate : 140k +Sample Encoding: 16-bit FLAC + + +Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac' +Channels : 1 +Sample Rate : 16000 +Precision : 16-bit +Duration : 00:00:16.71 = 267440 samples ~ 1253.62 CDDA sectors +File Size : 343k +Bit Rate : 164k +Sample Encoding: 16-bit FLAC + + +Input File : 'tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac' +Channels : 1 +Sample Rate : 16000 +Precision : 16-bit +Duration : 00:00:04.83 = 77200 samples ~ 361.875 CDDA sectors +File Size : 105k +Bit Rate : 174k +Sample Encoding: 16-bit FLAC + +Total Duration of 3 files: 00:00:28.16 +``` + +### (2) Use HLG decoding + +```bash +cd /path/to/icefall/egs/librispeech/ASR + +./tdnn_lstm_ctc/pretrained.py \ + --checkpoint ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/exp/pretraind.pt \ + --words-file ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/words.txt \ + --HLG ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt \ + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac \ + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac \ + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac +``` + +The output is given below: + +``` +2021-08-24 16:57:13,315 INFO [pretrained.py:168] device: cuda:0 +2021-08-24 16:57:13,315 INFO [pretrained.py:170] Creating model +2021-08-24 16:57:18,331 INFO [pretrained.py:182] Loading HLG from ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt +2021-08-24 16:57:27,581 INFO [pretrained.py:199] Constructing Fbank computer +2021-08-24 16:57:27,584 INFO [pretrained.py:209] Reading sound files: ['./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac'] +2021-08-24 16:57:27,599 INFO [pretrained.py:215] Decoding started +2021-08-24 16:57:27,791 INFO [pretrained.py:245] Use HLG decoding +2021-08-24 16:57:28,098 INFO [pretrained.py:266] +./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac: +AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS + +./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac: +GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN + +./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac: +YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION + + +2021-08-24 16:57:28,099 INFO [pretrained.py:268] Decoding Done +``` + +### (3) Use HLG decoding + LM rescoring + +```bash +./tdnn_lstm_ctc/pretrained.py \ + --checkpoint ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/exp/pretraind.pt \ + --words-file ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/words.txt \ + --HLG ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt \ + --method whole-lattice-rescoring \ + --G ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lm/G_4_gram.pt \ + --ngram-lm-scale 0.8 \ + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac \ + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac \ + ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac +``` + +The output is: + +``` +2021-08-24 16:39:24,725 INFO [pretrained.py:168] device: cuda:0 +2021-08-24 16:39:24,725 INFO [pretrained.py:170] Creating model +2021-08-24 16:39:29,403 INFO [pretrained.py:182] Loading HLG from ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lang_phone/HLG.pt +2021-08-24 16:39:40,631 INFO [pretrained.py:190] Loading G from ./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/data/lm/G_4_gram.pt +2021-08-24 16:39:53,098 INFO [pretrained.py:199] Constructing Fbank computer +2021-08-24 16:39:53,107 INFO [pretrained.py:209] Reading sound files: ['./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac'] +2021-08-24 16:39:53,121 INFO [pretrained.py:215] Decoding started +2021-08-24 16:39:53,443 INFO [pretrained.py:250] Use HLG decoding + LM rescoring +2021-08-24 16:39:54,010 INFO [pretrained.py:266] +./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac: +AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS + +./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac: +GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONORED BOSOM TO CONNECT HER PARENT FOREVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN + +./tmp/icefall_asr_librispeech_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac: +YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION + + +2021-08-24 16:39:54,010 INFO [pretrained.py:268] Decoding Done +``` + +**NOTE**: We provide a colab notebook for demonstration. +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing) + +Due to limited memory provided by Colab, you have to upgrade to Colab Pro to run `HLG decoding + LM rescoring`. +Otherwise, you can only run `HLG decoding` with Colab. diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/README.md b/egs/librispeech/ASR/tdnn_lstm_ctc/README.md index df98a0e118..94d4ed6a30 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/README.md +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/README.md @@ -1,2 +1,4 @@ -Will add results later. +Please visit + +for how to run this recipe. diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 27e0b96433..26946e0d02 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -43,6 +43,7 @@ setup_logger, store_transcripts, write_error_stats, + str2bool, ) @@ -54,7 +55,7 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=9, + default=19, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) @@ -66,6 +67,16 @@ def get_parser(): "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", ) + parser.add_argument( + "--export", + type=str2bool, + default=False, + help="""When enabled, the averaged model is saved to + tdnn/exp/pretrained.pt. Note: only model.state_dict() is saved. + pretrained.pt contains a dict {"model": model.state_dict()}, + which can be loaded by `icefall.checkpoint.load_checkpoint()`. + """, + ) return parser @@ -408,6 +419,12 @@ def main(): logging.info(f"averaging {filenames}") model.load_state_dict(average_checkpoints(filenames)) + if params.export: + logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) + model.to(device) model.eval() diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py new file mode 100644 index 0000000000..4f82a989c7 --- /dev/null +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from model import TdnnLstm +from torch.nn.utils.rnn import pad_sequence + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_whole_lattice, +) +from icefall.utils import AttributeDict, get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--words-file", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument( + "--HLG", type=str, required=True, help="Path to HLG.pt." + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.8, + help=""" + Used only when method is whole-lattice-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 80, + "subsampling_factor": 3, + "num_classes": 72, + "sample_rate": 16000, + "search_beam": 20, + "output_beam": 5, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = TdnnLstm( + num_features=params.feature_dim, + num_classes=params.num_classes, + subsampling_factor=params.subsampling_factor, + ) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"]) + model.to(device) + model.eval() + + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method == "whole-lattice-rescoring": + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = G.to(device) + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G.lm_scores = G.scores.clone() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + features = features.permute(0, 2, 1) # now features is [N, C, T] + + with torch.no_grad(): + nnet_output = model(features) + # nnet_output is [N, T, C] + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + HLG=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main()