Skip to content

Commit

Permalink
Fix torch.jit.script() export for pruned_transducer_stateless2 (#1410)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Dec 10, 2023
1 parent df56aff commit b0f70c9
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 0 deletions.
2 changes: 2 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless2/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@

import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model

from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
Expand Down Expand Up @@ -198,6 +199,7 @@ def main():
model.eval()

if params.jit:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
Expand Down
1 change: 1 addition & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless2/lstmp.py

0 comments on commit b0f70c9

Please sign in to comment.