Commit 4b33563 1 parent 2addc6c commit 4b33563 Copy full SHA for 4b33563
File tree 7 files changed +75
-887
lines changed
librispeech/ASR/pruned_transducer_stateless2
7 files changed +75
-887
lines changed Original file line number Diff line number Diff line change @@ -416,6 +416,17 @@ def get_parser():
416
416
help = "Accumulate stats on activations, print them and exit." ,
417
417
)
418
418
419
+ parser .add_argument (
420
+ "--scan-for-oom-batches" ,
421
+ type = str2bool ,
422
+ default = False ,
423
+ help = """
424
+ Whether to scan for oom batches before training, this is helpful for
425
+ finding the suitable max_duration, you only need to run it once.
426
+ Caution: a little time consuming.
427
+ """ ,
428
+ )
429
+
419
430
parser .add_argument (
420
431
"--inf-check" ,
421
432
type = str2bool ,
@@ -1197,14 +1208,14 @@ def remove_short_utt(c: Cut):
1197
1208
valid_cuts = valid_cuts .filter (remove_short_utt )
1198
1209
valid_dl = gigaspeech .valid_dataloaders (valid_cuts )
1199
1210
1200
- # if not params.print_diagnostics:
1201
- # scan_pessimistic_batches_for_oom(
1202
- # model=model,
1203
- # train_dl=train_dl,
1204
- # optimizer=optimizer,
1205
- # sp=sp,
1206
- # params=params,
1207
- # )
1211
+ if not params .print_diagnostics and params . scan_for_oom_batches :
1212
+ scan_pessimistic_batches_for_oom (
1213
+ model = model ,
1214
+ train_dl = train_dl ,
1215
+ optimizer = optimizer ,
1216
+ sp = sp ,
1217
+ params = params ,
1218
+ )
1208
1219
1209
1220
scaler = GradScaler (enabled = params .use_fp16 , init_scale = 1.0 )
1210
1221
if checkpoints and "grad_scaler" in checkpoints :
Original file line number Diff line number Diff line change 1
1
# Copyright 2021 Piotr Żelasko
2
- # Copyright 2023 Xiaomi Corporation (Author: Yifan Yang )
2
+ # Copyright 2024 Xiaomi Corporation (Author: Wei Kang )
3
3
#
4
4
# See ../../../../LICENSE for clarification regarding multiple authors
5
5
#
@@ -448,13 +448,6 @@ def test_cuts(self) -> CutSet:
448
448
self .args .manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz"
449
449
)
450
450
451
- @lru_cache ()
452
- def libri_100_cuts (self ) -> CutSet :
453
- logging .info ("About to get libri100 cuts" )
454
- return load_manifest_lazy (
455
- self .args .manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
456
- )
457
-
458
451
@lru_cache ()
459
452
def fsc_train_cuts (self ) -> CutSet :
460
453
logging .info ("About to get fluent speech commands train cuts" )
Original file line number Diff line number Diff line change @@ -274,7 +274,7 @@ def decode_one_batch(
274
274
model = model ,
275
275
encoder_out = encoder_out ,
276
276
encoder_out_lens = encoder_out_lens ,
277
- context_graph = kws_graph ,
277
+ keywords_graph = kws_graph ,
278
278
beam = params .beam ,
279
279
num_tailing_blanks = params .num_tailing_blanks ,
280
280
blank_penalty = params .blank_penalty ,
Original file line number Diff line number Diff line change 1
1
#!/usr/bin/env python3
2
2
#
3
- # Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
4
- # Zengwei Yao)
3
+ # Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang,
4
+ # Zengwei Yao,
5
+ # Wei Kang)
5
6
#
6
7
# See ../../../../LICENSE for clarification regarding multiple authors
7
8
#
You can’t perform that action at this time.
0 commit comments