Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper Fine-tuning Recipe on Aishell1 #1466

Merged
merged 34 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
f99f4d7
add decode seamlessm4t
yuekaizhang Sep 5, 2023
363c3f1
update finetuning codes
yuekaizhang Sep 7, 2023
3a7ad27
add requirements
yuekaizhang Sep 7, 2023
cbc3852
add fairseq2 require
yuekaizhang Sep 7, 2023
0d6d8f9
update fine-tuning lr
yuekaizhang Sep 7, 2023
e815457
update decoding from checkpoint
yuekaizhang Sep 7, 2023
5f399dc
load checkpoint to decode
yuekaizhang Sep 7, 2023
cc64324
add decoding with avg model
yuekaizhang Sep 7, 2023
72e9a43
fix typo
yuekaizhang Sep 8, 2023
7e387dd
change vocab table
yuekaizhang Sep 8, 2023
22ee287
add token files
yuekaizhang Sep 8, 2023
2a288fb
add custom tokenizer
yuekaizhang Sep 8, 2023
d926585
fix loading
yuekaizhang Sep 8, 2023
bb1c446
rename train, train2, add support to fine-tune embedding table
yuekaizhang Sep 12, 2023
6c2cd5b
support whisper ft
yuekaizhang Sep 26, 2023
5bf3a9c
using audio with any length
yuekaizhang Sep 26, 2023
8b832f1
update lhotse version
yuekaizhang Jan 9, 2024
07cefa8
change scaleadam to adamw
yuekaizhang Jan 11, 2024
98d11ab
remove padding to 30s, compute validation loss once
yuekaizhang Jan 11, 2024
92895f7
clean up codes
yuekaizhang Jan 11, 2024
b6418ac
support deepspeed to finetune large model
yuekaizhang Jan 12, 2024
fa7ad4d
update deepspeed model loading
yuekaizhang Jan 12, 2024
2ce0980
support large-v3
yuekaizhang Jan 14, 2024
ac53222
add model saving
yuekaizhang Jan 15, 2024
e883bb6
remove seamless for next PR
yuekaizhang Jan 15, 2024
eea4645
revert asr data module
yuekaizhang Jan 15, 2024
557b35c
clean codes
yuekaizhang Jan 15, 2024
84e4af9
add whisper fine-tuning results
yuekaizhang Jan 17, 2024
bda4829
using monkey patch to replace models
yuekaizhang Jan 22, 2024
b623c3b
fix requirements
yuekaizhang Jan 22, 2024
8d9ab30
fix lint
yuekaizhang Jan 22, 2024
ab08201
remove model file
yuekaizhang Jan 22, 2024
46605ea
fix wrong order of token slice
yuekaizhang Jan 22, 2024
fd4ebf3
add manifest dir option
yuekaizhang Jan 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions egs/aishell/ASR/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,10 @@ The following table lists the differences among them.
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
We place an additional Conv1d layer right after the input embedding layer.

# Whisper

Recipe to finetune large pretrained models
| | Encoder | Decoder | Comment |
|------------------------------------|-----------|--------------------|-----------------------------------------------------------------------------------|
| `whisper` | Transformer | Transformer | support fine-tuning using deepspeed
67 changes: 62 additions & 5 deletions egs/aishell/ASR/RESULTS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,63 @@
## Results

### Aishell training results (Fine-tuning Pretrained Models)
#### Whisper
[./whisper](./whisper)
##### fine-tuning results on Aishell test set on whisper medium, large-v2, large-v3

| | test (before fine-tuning) | test (after fine-tuning) | comment |
|------------------------|------|------|-----------------------------------------|
| medium | 7.23 | 3.27 | --epoch 10 --avg 4, ddp |
| large-v2 | 6.56 | 2.47 | --epoch 10 --avg 6, deepspeed zero stage1 |
| large-v3 | 6.06 | 2.84 | --epoch 5 --avg 3, deepspeed zero stage1 |

Command for training is:
```bash
pip install -r whisper/requirements.txt

./prepare.sh --stage 30 --stop_stage 30

#fine-tuning with deepspeed zero stage 1
torchrun --nproc-per-node 8 ./whisper/train.py \
--max-duration 200 \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--deepspeed \
--deepspeed_config ./whisper/ds_config_zero1.json

# fine-tuning with ddp
torchrun --nproc-per-node 8 ./whisper/train.py \
--max-duration 200 \
--exp-dir whisper/exp_medium \
--base-lr 1e-5 \
--model-name medium
```

Command for decoding using fine-tuned models:
```bash
git lfs install
git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper
ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt

python3 ./whisper/decode.py \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--epoch 999 --avg 1 \
--beam-size 10 --max-duration 50
```
Command for decoding using pretrained models (before fine-tuning):
```bash
python3 ./whisper/decode.py \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--epoch -1 --avg 1 \
--remove-whisper-encoder-input-length-restriction False \
--beam-size 10 --max-duration 50
```
Fine-tuned models, training logs, decoding logs, tensorboard and decoding results
are available at
<https://huggingface.co/yuekai/icefall_asr_aishell_whisper>

### Aishell training result (Stateless Transducer)

#### Zipformer (Non-streaming)
Expand All @@ -19,7 +77,7 @@ It's reworked Zipformer with Pruned RNNT loss.

Command for training is:
```bash
./prepare.sh
./prepare.sh

export CUDA_VISIBLE_DEVICES="0,1"

Expand Down Expand Up @@ -84,7 +142,7 @@ export CUDA_VISIBLE_DEVICES="0,1"
--feedforward-dim 512,768,768,768,768,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \
--max-duration 1200
--max-duration 1200
```

Command for decoding is:
Expand Down Expand Up @@ -134,7 +192,7 @@ export CUDA_VISIBLE_DEVICES="0,1"
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--max-duration 800
--max-duration 800
```

Command for decoding is:
Expand All @@ -150,7 +208,7 @@ for m in greedy_search modified_beam_search fast_beam_search ; do
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192
--encoder-unmasked-dim 192,192,256,320,256,192
done
```

Expand Down Expand Up @@ -703,7 +761,6 @@ python3 ./transducer_stateless/decode.py \
--max-sym-per-frame 3
```

### Aishell training results (Transducer-stateless)
#### 2022-02-18
(Pingfeng Luo) : The tensorboard log for training is available at <https://tensorboard.dev/experiment/k3QL6QMhRbCwCKYKM9po9w/>
And pretrained model is available at <https://huggingface.co/pfluo/icefall-aishell-transducer-stateless-char-2021-12-29>
Expand Down
45 changes: 38 additions & 7 deletions egs/aishell/ASR/local/compute_fbank_aishell.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@
from pathlib import Path

import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached

from icefall.utils import get_executor, str2bool
Expand All @@ -42,9 +49,14 @@
torch.set_num_interop_threads(1)


def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False):
def compute_fbank_aishell(
num_mel_bins: int = 80,
perturb_speed: bool = False,
whisper_fbank: bool = False,
output_dir: str = "data/fbank",
):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
output_dir = Path(output_dir)
num_jobs = min(15, os.cpu_count())

dataset_parts = (
Expand All @@ -68,8 +80,12 @@ def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False):
list(manifests.keys()),
dataset_parts,
)

extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
if whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))

with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
Expand All @@ -82,7 +98,7 @@ def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False):
supervisions=m["supervisions"],
)
if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb")
logging.info("Doing speed perturb")
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
Expand Down Expand Up @@ -111,6 +127,18 @@ def get_args():
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
parser.add_argument(
"--output-dir",
type=str,
default="data/fbank",
help="Output directory. Default: data/fbank.",
)
return parser.parse_args()


Expand All @@ -121,5 +149,8 @@ def get_args():

args = get_args()
compute_fbank_aishell(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
num_mel_bins=args.num_mel_bins,
perturb_speed=args.perturb_speed,
whisper_fbank=args.whisper_fbank,
output_dir=args.output_dir,
)
13 changes: 13 additions & 0 deletions egs/aishell/ASR/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,16 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
--vocab-size 4336 \
--master-port 12345
fi

# whisper large-v3 using 128 mel bins, others using 80 mel bins
whisper_mel_bins=80
output_dir=data/fbank_whisper
if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
log "Stage 30: Compute ${whisper_mel_bins} dim fbank for whisper model fine-tuning"
if [ ! -f $output_dir/.aishell.whisper.done ]; then
mkdir -p $output_dir
./local/compute_fbank_aishell.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true --output-dir $output_dir
./local/compute_fbank_musan.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true --output-dir $output_dir
touch $output_dir/.aishell.whisper.done
fi
fi
1 change: 1 addition & 0 deletions egs/aishell/ASR/whisper/asr_datamodule.py
Loading
Loading