From cd96f635c3b405b5ae6dcee5fb243e5fc5648a00 Mon Sep 17 00:00:00 2001 From: JinZr Date: Sat, 12 Oct 2024 16:08:07 +0800 Subject: [PATCH] added text norm for other decoding scripts --- egs/libritts/ASR/zipformer/ctc_decode.py | 12 ++++++------ egs/libritts/ASR/zipformer/onnx_decode.py | 11 ++++++----- egs/libritts/ASR/zipformer/streaming_decode.py | 6 +++--- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/egs/libritts/ASR/zipformer/ctc_decode.py b/egs/libritts/ASR/zipformer/ctc_decode.py index 177f2e392d..d77aa59626 100755 --- a/egs/libritts/ASR/zipformer/ctc_decode.py +++ b/egs/libritts/ASR/zipformer/ctc_decode.py @@ -122,7 +122,7 @@ import torch.nn as nn from asr_datamodule import LibriTTSAsrDataModule from lhotse import set_caching_enabled -from train import add_model_arguments, get_model, get_params +from train import add_model_arguments, get_model, get_params, normalize_text from icefall.checkpoint import ( average_checkpoints, @@ -949,13 +949,13 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - librispeech = LibriTTSAsrDataModule(args) + libritts = LibriTTSAsrDataModule(args) - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() + test_clean_cuts = libritts.test_clean_cuts().map(normalize_text) + test_other_cuts = libritts.test_other_cuts().map(normalize_text) - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) + test_clean_dl = libritts.test_dataloaders(test_clean_cuts) + test_other_dl = libritts.test_dataloaders(test_other_cuts) test_sets = ["test-clean", "test-other"] test_dl = [test_clean_dl, test_other_dl] diff --git a/egs/libritts/ASR/zipformer/onnx_decode.py b/egs/libritts/ASR/zipformer/onnx_decode.py index 99a02c5cf3..6f09cc8f7b 100755 --- a/egs/libritts/ASR/zipformer/onnx_decode.py +++ b/egs/libritts/ASR/zipformer/onnx_decode.py @@ -80,6 +80,7 @@ from asr_datamodule import LibriTTSAsrDataModule from k2 import SymbolTable from onnx_pretrained import OnnxModel, greedy_search +from train import normalize_text from icefall.utils import setup_logger, store_transcripts, write_error_stats @@ -290,13 +291,13 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - librispeech = LibriTTSAsrDataModule(args) + libritts = LibriTTSAsrDataModule(args) - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() + test_clean_cuts = libritts.test_clean_cuts().map(normalize_text) + test_other_cuts = libritts.test_other_cuts().map(normalize_text) - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) + test_clean_dl = libritts.test_dataloaders(test_clean_cuts) + test_other_dl = libritts.test_dataloaders(test_other_cuts) test_sets = ["test-clean", "test-other"] test_dl = [test_clean_dl, test_other_dl] diff --git a/egs/libritts/ASR/zipformer/streaming_decode.py b/egs/libritts/ASR/zipformer/streaming_decode.py index 3ecc5c94f1..b210187886 100755 --- a/egs/libritts/ASR/zipformer/streaming_decode.py +++ b/egs/libritts/ASR/zipformer/streaming_decode.py @@ -52,7 +52,7 @@ ) from torch import Tensor, nn from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_model, get_params +from train import add_model_arguments, get_model, get_params, normalize_text from icefall.checkpoint import ( average_checkpoints, @@ -866,8 +866,8 @@ def main(): libritts = LibriTTSAsrDataModule(args) - test_clean_cuts = libritts.test_clean_cuts() - test_other_cuts = libritts.test_other_cuts() + test_clean_cuts = libritts.test_clean_cuts().map(normalize_text) + test_other_cuts = libritts.test_other_cuts().map(normalize_text) test_sets = ["test-clean", "test-other"] test_cuts = [test_clean_cuts, test_other_cuts]