Skip to content

Commit 9644c17

Browse files
committed
minor fixes
1 parent b9bbdfa commit 9644c17

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
2. Export the model to ONNX
2929
3030
./lstm_transducer_stateless2/export-onnx-zh.py \
31-
--lang-dir ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/data/lang_char \
31+
--tokens ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/data/lang_char/tokens.txt \
3232
--use-averaged-model 1 \
3333
--epoch 11 \
3434
--avg 1 \
@@ -55,6 +55,7 @@
5555
from pathlib import Path
5656
from typing import Dict, Optional, Tuple
5757

58+
import k2
5859
import onnx
5960
import torch
6061
import torch.nn as nn
@@ -70,8 +71,7 @@
7071
find_checkpoints,
7172
load_checkpoint,
7273
)
73-
from icefall.lexicon import Lexicon
74-
from icefall.utils import setup_logger, str2bool
74+
from icefall.utils import num_tokens, setup_logger, str2bool
7575

7676

7777
def get_parser():
@@ -128,10 +128,10 @@ def get_parser():
128128
)
129129

130130
parser.add_argument(
131-
"--lang-dir",
131+
"--tokens",
132132
type=str,
133-
default="data/lang_char",
134-
help="The lang dir",
133+
default="data/lang_char/tokens.txt",
134+
help="Path to the tokens.txt.",
135135
)
136136

137137
parser.add_argument(
@@ -441,9 +441,9 @@ def main():
441441

442442
logging.info(f"device: {device}")
443443

444-
lexicon = Lexicon(params.lang_dir)
445-
params.blank_id = 0
446-
params.vocab_size = max(lexicon.tokens) + 1
444+
token_table = k2.SymbolTable.from_file(params.tokens)
445+
params.blank_id = token_table["<blk>"]
446+
params.vocab_size = num_tokens(token_table) + 1
447447

448448
logging.info(params)
449449

egs/swbd/ASR/conformer_ctc/export.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def main():
118118
num_features=params.feature_dim,
119119
nhead=params.nhead,
120120
d_model=params.attention_dim,
121-
num_classes=num_classes,
121+
num_classes=params.vocab_size,
122122
subsampling_factor=params.subsampling_factor,
123123
num_decoder_layers=params.num_decoder_layers,
124124
vgg_frontend=False,

egs/tedlium3/ASR/conformer_ctc2/export.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def main():
182182

183183
model = Conformer(
184184
num_features=params.feature_dim,
185-
num_classes=num_classes,
185+
num_classes=params.vocab_size,
186186
subsampling_factor=params.subsampling_factor,
187187
d_model=params.dim_model,
188188
nhead=params.nhead,

0 commit comments

Comments
 (0)