diff --git a/icefall/decode.py b/icefall/decode.py index b17de0ba79..dd3af1e99b 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -1475,21 +1475,10 @@ def rescore_with_rnn_lm( return ans -def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: - # from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py - new_hyp: List[int] = [] - cur = 0 - while cur < len(hyp): - if hyp[cur] != 0: - new_hyp.append(hyp[cur]) - prev = cur - while cur < len(hyp) and hyp[cur] == hyp[prev]: - cur += 1 - return new_hyp - - def ctc_greedy_search( - ctc_output: torch.Tensor, encoder_out_lens: torch.Tensor + ctc_output: torch.Tensor, + encoder_out_lens: torch.Tensor, + blank_id: int = 0, ) -> List[List[int]]: """CTC greedy search. @@ -1501,6 +1490,10 @@ def ctc_greedy_search( """ batch = ctc_output.shape[0] index = ctc_output.argmax(dim=-1) # (batch, seq_len) - hyps = [index[i].tolist()[:encoder_out_lens[i]] for i in range(batch)] - hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps] + hyps = [ + torch.unique_consecutive(index[i, : encoder_out_lens[i]]) for i in range(batch) + ] + + hyps = [h[h != blank_id].tolist() for h in hyps] + return hyps diff --git a/test/test_ctc_greedy_search.py b/test/test_ctc_greedy_search.py new file mode 100755 index 0000000000..a82b2d8f13 --- /dev/null +++ b/test/test_ctc_greedy_search.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 + +import torch + +from icefall.decode import ctc_greedy_search + + +def test(): + log_probs = torch.tensor( + [ + [ + [10, 1, 2, 1, 1, 3, 2, 3], + [10, 3, 2, 2, 1, 3, 2, 3], + [1, 10, 2, 2, 1, 3, 2, 3], + [1, 10, 2, 2, 1, 3, 2, 3], + [1, 1, 10, 1, 1, 3, 2, 3], + [10, 1, 1, 1, 1, 3, 2, 3], + [1, 1, 1, 10, 1, 3, 2, 3], + ], + [ + [10, 1, 2, 1, 1, 3, 2, 3], + [10, 3, 2, 2, 1, 3, 2, 3], + [1, 10, 2, 2, 1, 3, 2, 3], + [1, 10, 2, 2, 1, 3, 2, 3], + [1, 1, 10, 1, 1, 3, 2, 3], + [10, 1, 1, 1, 1, 3, 2, 3], + [1, 1, 1, 10, 1, 3, 2, 3], + ], + ], + dtype=torch.float32, + ).log_softmax(dim=-1) + + log_probs_length = torch.tensor([7, 6]) + + hyps = ctc_greedy_search(log_probs, log_probs_length) + + assert hyps[0] == [1, 2, 3], hyps[0] + assert hyps[1] == [1, 2], hyps[1] + + +if __name__ == "__main__": + test()