Skip to content

Commit 2e8b321

Browse files
authored
Add fine-tuned whisper model on aishell (#565)
See also k2-fsa/icefall#1466
1 parent 0b18ccf commit 2e8b321

File tree

3 files changed

+36
-7
lines changed

3 files changed

+36
-7
lines changed

.github/workflows/export-whisper-to-onnx.yaml

+17-3
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ jobs:
1515
strategy:
1616
fail-fast: false
1717
matrix:
18-
os: [macos-latest]
18+
os: [ubuntu-latest]
1919
# model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2", "distil-large-v2"]
20-
model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium"]
20+
model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "medium-aishell"]
2121
python-version: ["3.8"]
2222

2323
steps:
@@ -49,16 +49,27 @@ jobs:
4949
elif [[ $model == distil-small.en ]]; then
5050
wget -q -O distil-small-en-original-model.bin https://huggingface.co/distil-whisper/distil-small.en/resolve/main/original-model.bin
5151
ls -lh
52+
elif [[ $model == medium-aishell ]]; then
53+
wget -q -O medium-aishell.pt https://huggingface.co/yuekai/icefall_asr_aishell_whisper/resolve/main/exp_medium/whisper-medium-aishell1-epoch-10-avg-4.pt
54+
ls -lh
5255
fi
5356
python3 ./export-onnx.py --model ${{ matrix.model }}
5457
# python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./
58+
#
59+
if [[ $model == medium-aishell ]]; then
60+
ls -lh *.onnx
61+
rm -fv medium-aishell-encoder.onnx
62+
rm -fv medium-aishell-decoder.onnx
63+
fi
64+
5565
5666
ls -lh
5767
5868
ls -lh ~/.cache/whisper || true
5969
ls -lh distil*original-model.bin || true
6070
rm -rf ~/.cache/whisper
6171
rm -f distil*original-model.bin
72+
rm -f medium-aishell.pt
6273
6374
src=sherpa-onnx-whisper-${{ matrix.model }}
6475
@@ -132,7 +143,10 @@ jobs:
132143
git config --global user.name "Fangjun Kuang"
133144
134145
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface
135-
rm -rf huggingface/*
146+
147+
if [[ $model != medium-aishell ]]; then
148+
rm -rf huggingface/*
149+
fi
136150
137151
if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then
138152
mv $src.tar* ./huggingface

scripts/whisper/export-onnx.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ def get_args():
4444
"tiny", "tiny.en", "base", "base.en",
4545
"small", "small.en", "medium", "medium.en",
4646
"large", "large-v1", "large-v2",
47-
"distil-medium.en", "distil-small.en", "distil-large-v2"
47+
"distil-medium.en", "distil-small.en", "distil-large-v2",
48+
# for fine-tuned models from icefall
49+
"medium-aishell",
4850
],
4951
# fmt: on
5052
)
@@ -340,6 +342,19 @@ def main():
340342
"""
341343
)
342344
model = whisper.load_model(filename)
345+
elif name == "medium-aishell":
346+
filename = "./medium-aishell.pt"
347+
if not Path(filename).is_file():
348+
raise ValueError(
349+
"""
350+
Please go to https://huggingface.co/yuekai/icefall_asr_aishell_whisper/tree/main/exp_medium
351+
to download whisper-medium-aishell1-epoch-10-avg-4.pt
352+
You can use the following command to do that:
353+
354+
wget -O medium-aishell.pt https://huggingface.co/yuekai/icefall_asr_aishell_whisper/resolve/main/exp_medium/whisper-medium-aishell1-epoch-10-avg-4.pt
355+
"""
356+
)
357+
model = whisper.load_model(filename)
343358
else:
344359
model = whisper.load_model(name)
345360
print(model.dims)

scripts/whisper/test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,9 @@ def compute_features(filename: str) -> torch.Tensor:
257257
mel = (log_spec + 4.0) / 4.0
258258
# mel (T, 80)
259259

260-
# We pad 50 frames at the end so that it is able to detect eot
261-
# You can use another value instead of 50.
262-
mel = torch.nn.functional.pad(mel, (0, 0, 0, 1000), "constant", 0)
260+
# We pad 1500 frames at the end so that it is able to detect eot
261+
# You can use another value instead of 1500.
262+
mel = torch.nn.functional.pad(mel, (0, 0, 0, 1500), "constant", 0)
263263
# Note that if it throws for a multilingual model,
264264
# please use a larger value, say 300
265265

0 commit comments

Comments
 (0)