Skip to content

Commit

Permalink
add smooth-regularized CTC
Browse files Browse the repository at this point in the history
  • Loading branch information
yaozengwei committed Oct 10, 2024
1 parent ae59e5d commit a85592d
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 68 deletions.
143 changes: 77 additions & 66 deletions egs/librispeech/ASR/zipformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple
from typing import List, Optional, Tuple

import k2
import torch
import torch.nn as nn
import torch.nn.functional as F
from encoder_interface import EncoderInterface
from scaling import ScaledLinear

Expand Down Expand Up @@ -111,9 +112,8 @@ def __init__(
if use_ctc:
# Modules for CTC head
self.ctc_output = nn.Sequential(
nn.Dropout(p=0.1),
nn.Dropout(p=0.1), # TODO: test removing this
nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
)

self.use_attention_decoder = use_attention_decoder
Expand Down Expand Up @@ -159,71 +159,82 @@ def forward_ctc(
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> torch.Tensor:
"""Compute CTC loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
targets:
Target Tensor of shape (sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
# Compute CTC log-prob
ctc_output = self.ctc_output(encoder_out) # (N, T, C)

ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
targets=targets.cpu(),
input_lengths=encoder_out_lens.cpu(),
target_lengths=target_lengths.cpu(),
reduction="sum",
)
return ctc_loss

def forward_cr_ctc(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
use_consistency_reg: bool = False,
use_smooth_reg: bool = False,
smooth_kernel: List[float] = [0.25, 0.5, 0.25],
eps: float = 1e-6,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute CTC loss with consistency regularization loss.
Args:
encoder_out:
Encoder output, of shape (2 * N, T, C).
Encoder output, of shape (N or 2 * N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (2 * N,).
Encoder output lengths, of shape (N or 2 * N,).
targets:
Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
use_consistency_reg:
Whether use consistency regularization.
use_smooth_reg:
Whether use smooth regularization.
"""
ctc_output = self.ctc_output(encoder_out) # (N or 2 * N, T, C)
length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1)

if not use_smooth_reg:
ctc_log_probs = F.log_softmax(ctc_output, dim=-1)
else:
ctc_probs = ctc_output.softmax(dim=-1) # Used in sr_loss
ctc_log_probs = (ctc_probs + eps).log()

# Compute CTC loss
ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C)
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C)
log_probs=ctc_log_probs.permute(1, 0, 2), # (T, N or 2 * N, C)
targets=targets.cpu(),
input_lengths=encoder_out_lens.cpu(),
target_lengths=target_lengths.cpu(),
reduction="sum",
)

# Compute consistency regularization loss
exchanged_targets = ctc_output.detach().chunk(2, dim=0)
exchanged_targets = torch.cat(
[exchanged_targets[1], exchanged_targets[0]], dim=0
) # exchange: [x1, x2] -> [x2, x1]
cr_loss = nn.functional.kl_div(
input=ctc_output,
target=exchanged_targets,
reduction="none",
log_target=True,
) # (2 * N, T, C)
length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1)
cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum()
if use_consistency_reg:
assert ctc_log_probs.shape[0] % 2 == 0
# Compute cr_loss
exchanged_targets = ctc_log_probs.detach().chunk(2, dim=0)
exchanged_targets = torch.cat(
[exchanged_targets[1], exchanged_targets[0]], dim=0
) # exchange: [x1, x2] -> [x2, x1]
cr_loss = nn.functional.kl_div(
input=ctc_log_probs,
target=exchanged_targets,
reduction="none",
log_target=True,
) # (2 * N, T, C)
cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum()
else:
cr_loss = torch.empty(0)

return ctc_loss, cr_loss
if use_smooth_reg:
# Hard code the kernel here, could try other values
assert len(smooth_kernel) == 3 and sum(smooth_kernel) == 1.0, smooth_kernel
smooth_kernel = torch.tensor(smooth_kernel, dtype=ctc_probs.dtype,
device=ctc_probs.device, requires_grad=False)
smooth_kernel = smooth_kernel.unsqueeze(0).unsqueeze(1).expand(ctc_probs.shape[-1], 1, 3)
# Now kernel: (C, 1, 3)
smoothed_ctc_probs = F.conv1d(
ctc_probs.detach().permute(0, 2, 1), # (N or 2 * N, C, T)
weight=smooth_kernel, stride=1, padding=0, groups=ctc_probs.shape[-1]
).permute(0, 2, 1) # (N or 2 * N, T - 2, C)
sr_loss = nn.functional.kl_div(
input=ctc_log_probs[:, 1:-1],
target=(smoothed_ctc_probs + eps).log(),
reduction="none",
log_target=True,
) # (N, T - 1 , C)
sr_loss = sr_loss.masked_fill(length_mask[:, 1:-1], 0.0).sum()
else:
sr_loss = torch.empty(0)

return ctc_loss, cr_loss, sr_loss

def forward_transducer(
self,
Expand Down Expand Up @@ -341,6 +352,7 @@ def forward(
am_scale: float = 0.0,
lm_scale: float = 0.0,
use_cr_ctc: bool = False,
use_sr_ctc: bool = False,
use_spec_aug: bool = False,
spec_augment: Optional[SpecAugment] = None,
supervision_segments: Optional[torch.Tensor] = None,
Expand All @@ -367,6 +379,8 @@ def forward(
part
use_cr_ctc:
Whether use consistency-regularized CTC.
use_sr_ctc:
Whether use smooth-regularized CTC.
use_spec_aug:
Whether apply spec-augment manually, used only if use_cr_ctc is True.
spec_augment:
Expand Down Expand Up @@ -445,26 +459,23 @@ def forward(
if self.use_ctc:
# Compute CTC loss
targets = y.values
if not use_cr_ctc:
ctc_loss = self.forward_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
)
cr_loss = torch.empty(0)
else:
ctc_loss, cr_loss = self.forward_cr_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
)
ctc_loss, cr_loss, sr_loss = self.forward_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
use_consistency_reg=use_cr_ctc,
use_smooth_reg=use_sr_ctc,
)
if use_cr_ctc:
# We duplicate the batch when use_cr_ctc is True
ctc_loss = ctc_loss * 0.5
cr_loss = cr_loss * 0.5
sr_loss = sr_loss * 0.5
else:
ctc_loss = torch.empty(0)
cr_loss = torch.empty(0)
sr_loss = torch.empty(0)

if self.use_attention_decoder:
attention_decoder_loss = self.attention_decoder.calc_att_loss(
Expand All @@ -478,4 +489,4 @@ def forward(
else:
attention_decoder_loss = torch.empty(0)

return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss
return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, sr_loss
24 changes: 22 additions & 2 deletions egs/librispeech/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="If True, use consistency-regularized CTC.",
)

parser.add_argument(
"--use-sr-ctc",
type=str2bool,
default=False,
help="If True, use smooth-regularized CTC.",
)


def get_parser():
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -464,6 +471,13 @@ def get_parser():
help="Scale for consistency-regularization loss.",
)

parser.add_argument(
"--sr-loss-scale",
type=float,
default=0.2,
help="Scale for smooth-regularization loss.",
)

parser.add_argument(
"--time-mask-ratio",
type=float,
Expand Down Expand Up @@ -916,6 +930,7 @@ def compute_loss(
y = k2.RaggedTensor(y)

use_cr_ctc = params.use_cr_ctc
use_sr_ctc = params.use_sr_ctc
use_spec_aug = use_cr_ctc and is_training
if use_spec_aug:
supervision_intervals = batch["supervisions"]
Expand All @@ -931,14 +946,15 @@ def compute_loss(
supervision_segments = None

with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = model(
simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, sr_loss = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
use_cr_ctc=use_cr_ctc,
use_sr_ctc=use_sr_ctc,
use_spec_aug=use_spec_aug,
spec_augment=spec_augment,
supervision_segments=supervision_segments,
Expand Down Expand Up @@ -967,6 +983,8 @@ def compute_loss(
loss += params.ctc_loss_scale * ctc_loss
if use_cr_ctc:
loss += params.cr_loss_scale * cr_loss
if use_sr_ctc:
loss += params.sr_loss_scale * sr_loss

if params.use_attention_decoder:
loss += params.attention_decoder_loss_scale * attention_decoder_loss
Expand All @@ -985,8 +1003,10 @@ def compute_loss(
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.use_ctc:
info["ctc_loss"] = ctc_loss.detach().cpu().item()
if params.use_cr_ctc:
if use_cr_ctc:
info["cr_loss"] = cr_loss.detach().cpu().item()
if use_sr_ctc:
info["sr_loss"] = sr_loss.detach().cpu().item()
if params.use_attention_decoder:
info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item()

Expand Down

0 comments on commit a85592d

Please sign in to comment.