Skip to content

Commit

Permalink
added text norm for other decoding scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr committed Oct 12, 2024
1 parent 5492a6a commit cd96f63
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 14 deletions.
12 changes: 6 additions & 6 deletions egs/libritts/ASR/zipformer/ctc_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@
import torch.nn as nn
from asr_datamodule import LibriTTSAsrDataModule
from lhotse import set_caching_enabled
from train import add_model_arguments, get_model, get_params
from train import add_model_arguments, get_model, get_params, normalize_text

from icefall.checkpoint import (
average_checkpoints,
Expand Down Expand Up @@ -949,13 +949,13 @@ def main():

# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriTTSAsrDataModule(args)
libritts = LibriTTSAsrDataModule(args)

test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_cuts = libritts.test_clean_cuts().map(normalize_text)
test_other_cuts = libritts.test_other_cuts().map(normalize_text)

test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_clean_dl = libritts.test_dataloaders(test_clean_cuts)
test_other_dl = libritts.test_dataloaders(test_other_cuts)

test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
Expand Down
11 changes: 6 additions & 5 deletions egs/libritts/ASR/zipformer/onnx_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
from asr_datamodule import LibriTTSAsrDataModule
from k2 import SymbolTable
from onnx_pretrained import OnnxModel, greedy_search
from train import normalize_text

from icefall.utils import setup_logger, store_transcripts, write_error_stats

Expand Down Expand Up @@ -290,13 +291,13 @@ def main():

# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriTTSAsrDataModule(args)
libritts = LibriTTSAsrDataModule(args)

test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_cuts = libritts.test_clean_cuts().map(normalize_text)
test_other_cuts = libritts.test_other_cuts().map(normalize_text)

test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_clean_dl = libritts.test_dataloaders(test_clean_cuts)
test_other_dl = libritts.test_dataloaders(test_other_cuts)

test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
Expand Down
6 changes: 3 additions & 3 deletions egs/libritts/ASR/zipformer/streaming_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
)
from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params
from train import add_model_arguments, get_model, get_params, normalize_text

from icefall.checkpoint import (
average_checkpoints,
Expand Down Expand Up @@ -866,8 +866,8 @@ def main():

libritts = LibriTTSAsrDataModule(args)

test_clean_cuts = libritts.test_clean_cuts()
test_other_cuts = libritts.test_other_cuts()
test_clean_cuts = libritts.test_clean_cuts().map(normalize_text)
test_other_cuts = libritts.test_other_cuts().map(normalize_text)

test_sets = ["test-clean", "test-other"]
test_cuts = [test_clean_cuts, test_other_cuts]
Expand Down

0 comments on commit cd96f63

Please sign in to comment.