Skip to content

Commit 4b33563

Browse files
committed
More fixes to gigaspeech recipe
1 parent 2addc6c commit 4b33563

File tree

7 files changed

+75
-887
lines changed

7 files changed

+75
-887
lines changed

egs/gigaspeech/ASR/zipformer/train.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,17 @@ def get_parser():
416416
help="Accumulate stats on activations, print them and exit.",
417417
)
418418

419+
parser.add_argument(
420+
"--scan-for-oom-batches",
421+
type=str2bool,
422+
default=False,
423+
help="""
424+
Whether to scan for oom batches before training, this is helpful for
425+
finding the suitable max_duration, you only need to run it once.
426+
Caution: a little time consuming.
427+
""",
428+
)
429+
419430
parser.add_argument(
420431
"--inf-check",
421432
type=str2bool,
@@ -1197,14 +1208,14 @@ def remove_short_utt(c: Cut):
11971208
valid_cuts = valid_cuts.filter(remove_short_utt)
11981209
valid_dl = gigaspeech.valid_dataloaders(valid_cuts)
11991210

1200-
# if not params.print_diagnostics:
1201-
# scan_pessimistic_batches_for_oom(
1202-
# model=model,
1203-
# train_dl=train_dl,
1204-
# optimizer=optimizer,
1205-
# sp=sp,
1206-
# params=params,
1207-
# )
1211+
if not params.print_diagnostics and params.scan_for_oom_batches:
1212+
scan_pessimistic_batches_for_oom(
1213+
model=model,
1214+
train_dl=train_dl,
1215+
optimizer=optimizer,
1216+
sp=sp,
1217+
params=params,
1218+
)
12081219

12091220
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
12101221
if checkpoints and "grad_scaler" in checkpoints:

egs/gigaspeech/KWS/zipformer/asr_datamodule.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright 2021 Piotr Żelasko
2-
# Copyright 2023 Xiaomi Corporation (Author: Yifan Yang)
2+
# Copyright 2024 Xiaomi Corporation (Author: Wei Kang)
33
#
44
# See ../../../../LICENSE for clarification regarding multiple authors
55
#
@@ -448,13 +448,6 @@ def test_cuts(self) -> CutSet:
448448
self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz"
449449
)
450450

451-
@lru_cache()
452-
def libri_100_cuts(self) -> CutSet:
453-
logging.info("About to get libri100 cuts")
454-
return load_manifest_lazy(
455-
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
456-
)
457-
458451
@lru_cache()
459452
def fsc_train_cuts(self) -> CutSet:
460453
logging.info("About to get fluent speech commands train cuts")

egs/gigaspeech/KWS/zipformer/decode.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def decode_one_batch(
274274
model=model,
275275
encoder_out=encoder_out,
276276
encoder_out_lens=encoder_out_lens,
277-
context_graph=kws_graph,
277+
keywords_graph=kws_graph,
278278
beam=params.beam,
279279
num_tailing_blanks=params.num_tailing_blanks,
280280
blank_penalty=params.blank_penalty,

egs/gigaspeech/KWS/zipformer/decode-asr.py egs/gigaspeech/KWS/zipformer/decode_asr.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#!/usr/bin/env python3
22
#
3-
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
4-
# Zengwei Yao)
3+
# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang,
4+
# Zengwei Yao,
5+
# Wei Kang)
56
#
67
# See ../../../../LICENSE for clarification regarding multiple authors
78
#

0 commit comments

Comments
 (0)