Skip to content

Commit

Permalink
refactor codes
Browse files Browse the repository at this point in the history
  • Loading branch information
yaozengwei committed Oct 7, 2024
1 parent a6eead6 commit ae59e5d
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 378 deletions.
56 changes: 3 additions & 53 deletions egs/librispeech/ASR/zipformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from encoder_interface import EncoderInterface
from scaling import ScaledLinear

from icefall.utils import add_sos, make_pad_mask
from spec_augment import SpecAugment, time_warp
from icefall.utils import add_sos, make_pad_mask, time_warp
from lhotse.dataset import SpecAugment


class AsrModel(nn.Module):
Expand Down Expand Up @@ -188,8 +188,6 @@ def forward_cr_ctc(
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
time_mask: Optional[torch.Tensor] = None,
cr_loss_masked_scale: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute CTC loss with consistency regularization loss.
Args:
Expand All @@ -200,10 +198,6 @@ def forward_cr_ctc(
targets:
Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
time_mask:
Downsampled time masks of shape (2 * N, T, 1).
cr_loss_masked_scale:
The loss scale used to scale up the cr_loss at masked positions.
"""
# Compute CTC loss
ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C)
Expand All @@ -226,14 +220,6 @@ def forward_cr_ctc(
reduction="none",
log_target=True,
) # (2 * N, T, C)
if time_mask is not None:
assert time_mask.shape[:-1] == ctc_output.shape[:-1], (
time_mask.shape, ctc_output.shape
)
masked_scale = time_mask * (cr_loss_masked_scale - 1) + 1
# e.g., if cr_loss_masked_scale = 3, scales at masked positions are 3,
# scales at unmasked positions are 1
cr_loss = cr_loss * masked_scale # scaling up masked positions
length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1)
cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum()

Expand Down Expand Up @@ -359,7 +345,6 @@ def forward(
spec_augment: Optional[SpecAugment] = None,
supervision_segments: Optional[torch.Tensor] = None,
time_warp_factor: Optional[int] = 80,
cr_loss_masked_scale: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
Expand Down Expand Up @@ -395,8 +380,6 @@ def forward(
Parameter for the time warping; larger values mean more warping.
Set to ``None``, or less than ``1``, to disable.
Used only if use_cr_ctc is True.
cr_loss_masked_scale:
The loss scale used to scale up the cr_loss at masked positions.
Returns:
Return the transducer losses, CTC loss, AED loss,
Expand Down Expand Up @@ -429,12 +412,9 @@ def forward(
supervision_segments=supervision_segments,
)
# Independently apply frequency masking and time masking to the two copies
x, time_mask = spec_augment(x.repeat(2, 1, 1))
# time_mask: 1 for masked, 0 for unmasked
time_mask = downsample_time_mask(time_mask, x.dtype)
x = spec_augment(x.repeat(2, 1, 1))
else:
x = x.repeat(2, 1, 1)
time_mask = None
x_lens = x_lens.repeat(2)
y = k2.ragged.cat([y, y], axis=0)

Expand Down Expand Up @@ -479,8 +459,6 @@ def forward(
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
time_mask=time_mask,
cr_loss_masked_scale=cr_loss_masked_scale,
)
ctc_loss = ctc_loss * 0.5
cr_loss = cr_loss * 0.5
Expand All @@ -501,31 +479,3 @@ def forward(
attention_decoder_loss = torch.empty(0)

return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss


def downsample_time_mask(time_mask: torch.Tensor, dtype: torch.dtype):
"""Downsample the time masks as in Zipformer.
Args:
time_mask: shape of (N, T)
Returns:
The downsampled time masks of shape (N, T', 1),
where T' = ((T - 7) // 2 + 1) // 2
"""
# Downsample the time masks as in Zipformer
time_mask = time_mask.to(dtype).unsqueeze(dim=1)
# as in conv-embed
time_mask = nn.functional.max_pool1d(
time_mask, kernel_size=3, stride=1, padding=0
) # T - 2
time_mask = nn.functional.max_pool1d(
time_mask, kernel_size=3, stride=2, padding=0
) # (T - 3) // 2
time_mask = nn.functional.max_pool1d(
time_mask, kernel_size=3, stride=1, padding=0
) # (T - 7) // 2
# as in output-downsampling
time_mask = nn.functional.max_pool1d(
time_mask, kernel_size=2, stride=2, padding=0, ceil_mode=True
)
time_mask = time_mask.transpose(1, 2) # (N * 2, T', 1)
return time_mask
Loading

0 comments on commit ae59e5d

Please sign in to comment.