From 6072407669893663a9a9e7ae01a645a070185af3 Mon Sep 17 00:00:00 2001 From: Daniil Date: Thu, 17 Oct 2024 00:09:54 +0000 Subject: [PATCH] modified zipformer --- egs/librispeech/ASR/zipformer/decoder.py | 225 +- egs/librispeech/ASR/zipformer/joiner.py | 73 +- egs/librispeech/ASR/zipformer/scaling.py | 1740 +-------- egs/librispeech/ASR/zipformer/subsampling.py | 512 +-- egs/librispeech/ASR/zipformer/zipformer.py | 3676 ++++++++---------- 5 files changed, 2138 insertions(+), 4088 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/decoder.py b/egs/librispeech/ASR/zipformer/decoder.py index 7ce44495bf..6f0754c103 100644 --- a/egs/librispeech/ASR/zipformer/decoder.py +++ b/egs/librispeech/ASR/zipformer/decoder.py @@ -20,115 +20,162 @@ from scaling import Balancer -class Decoder(nn.Module): - """This class modifies the stateless decoder from the following paper: - - RNN-transducer with stateless prediction network - https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 - - It removes the recurrent connection from the decoder, i.e., the prediction - network. Different from the above paper, it adds an extra Conv1d - right after the embedding layer. +class Decoder(torch.nn.Module): + """ + This class modifies the stateless decoder from the following paper: + RNN-transducer with stateless prediction network + https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 - TODO: Implement https://arxiv.org/pdf/2109.07513.pdf + It removes the recurrent connection from the decoder, i.e., the prediction network. + Different from the above paper, it adds an extra Conv1d right after the embedding layer. """ def __init__( - self, - vocab_size: int, - decoder_dim: int, - blank_id: int, - context_size: int, - ): + self, vocab_size: int, decoder_dim: int, context_size: int, device: torch.device, + ) -> None: """ - Args: - vocab_size: - Number of tokens of the modeling unit including blank. - decoder_dim: - Dimension of the input embedding, and of the decoder output. - blank_id: - The ID of the blank symbol. - context_size: - Number of previous words to use to predict the next word. + Decoder initialization. + + Parameters + ---------- + vocab_size : int + A number of tokens or modeling units, includes blank. + decoder_dim : int + A dimension of the decoder embeddings, and the decoder output. + context_size : int + A number of previous words to use to predict the next word. 1 means bigram; 2 means trigram. n means (n+1)-gram. + device : torch.device + The device used to store the layer weights. Should be + either torch.device("cpu") or torch.device("cuda"). """ - super().__init__() - self.embedding = nn.Embedding( - num_embeddings=vocab_size, - embedding_dim=decoder_dim, - ) - # the balancers are to avoid any drift in the magnitude of the - # embeddings, which would interact badly with parameter averaging. - self.balancer = Balancer( - decoder_dim, - channel_dim=-1, - min_positive=0.0, - max_positive=1.0, - min_abs=0.5, - max_abs=1.0, - prob=0.05, - ) + super().__init__() - self.blank_id = blank_id + self.embedding = torch.nn.Embedding(vocab_size, decoder_dim) - assert context_size >= 1, context_size + if context_size < 1: + raise ValueError( + 'RNN-T decoder context size should be an integer greater ' + f'or equal than 1, but got {context_size}.', + ) self.context_size = context_size - self.vocab_size = vocab_size - if context_size > 1: - self.conv = nn.Conv1d( - in_channels=decoder_dim, - out_channels=decoder_dim, - kernel_size=context_size, - padding=0, - groups=decoder_dim // 4, # group size == 4 - bias=False, - ) - self.balancer2 = Balancer( - decoder_dim, - channel_dim=-1, - min_positive=0.0, - max_positive=1.0, - min_abs=0.5, - max_abs=1.0, - prob=0.05, - ) - else: - # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'` - # when inference with torch.jit.script and context_size == 1 - self.conv = nn.Identity() - self.balancer2 = nn.Identity() + self.conv = torch.nn.Conv1d( + decoder_dim, + decoder_dim, + context_size, + groups=decoder_dim // 4, + bias=False, + device=device, + ) - def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: + def forward(self, y: torch.Tensor) -> torch.Tensor: """ - Args: - y: - A 2-D tensor of shape (N, U). - need_pad: - True to left pad the input. Should be True during training. - False to not pad the input. Should be False during inference. - Returns: - Return a tensor of shape (N, U, decoder_dim). + Does a forward pass of the stateless Decoder module. Returns an output decoder tensor. + + Parameters + ---------- + y : torch.Tensor[torch.int32] + The input integer tensor of shape (N, context_size). + The module input that corresponds to the last context_size decoded token indexes. + + Returns + ------- + torch.Tensor[torch.float32] + An output float tensor of shape (N, 1, decoder_dim). """ - y = y.to(torch.int64) - # this stuff about clamp() is a temporary fix for a mismatch - # at utterance start, we use negative ids in beam_search.py - embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) - embedding_out = self.balancer(embedding_out) + # this stuff about clamp() is a fix for a mismatch at utterance start, + # we use negative ids in RNN-T decoding. + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(2) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) - if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) - else: - # During inference time, there is no need to do extra padding - # as we only need one output - assert embedding_out.size(-1) == self.context_size embedding_out = self.conv(embedding_out) embedding_out = embedding_out.permute(0, 2, 1) - embedding_out = F.relu(embedding_out) - embedding_out = self.balancer2(embedding_out) + embedding_out = torch.nn.functional.relu(embedding_out) return embedding_out + + +class DecoderModule(torch.nn.Module): + """ + A helper module to combine decoder, decoder projection, and joiner inference together. + """ + + def __init__( + self, + vocab_size: int, + decoder_dim: int, + joiner_dim: int, + context_size: int, + beam: int, + device: torch.device, + ) -> None: + """ + DecoderModule initialization. + + Parameters + ---------- + vocab_size: + A number of tokens or modeling units, includes blank. + decoder_dim : int + A dimension of the decoder embeddings, and the decoder output. + joiner_dim : int + Input joiner dimension. + context_size : int + A number of previous words to use to predict the next word. + 1 means bigram; 2 means trigram. n means (n+1)-gram. + beam : int + A decoder beam. + device : torch.device + The device used to store the layer weights. Should be + either torch.device("cpu") or torch.device("cuda"). + """ + + super().__init__() + + self.decoder = Decoder(vocab_size, decoder_dim, context_size, device) + self.decoder_proj = torch.nn.Linear(decoder_dim, joiner_dim, device=device) + self.joiner = Joiner(joiner_dim, vocab_size, device) + + self.vocab_size = vocab_size + self.beam = beam + + def forward( + self, decoder_input: torch.Tensor, encoder_out: torch.Tensor, hyps_log_prob: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Does a forward pass of the stateless Decoder module. Returns an output decoder tensor. + + Parameters + ---------- + decoder_input : torch.Tensor[torch.int32] + The input integer tensor of shape (num_hyps, context_size). + The module input that corresponds to the last context_size decoded token indexes. + encoder_out : torch.Tensor[torch.float32] + An output tensor from the encoder after projection of shape (num_hyps, joiner_dim). + hyps_log_prob : torch.Tensor[torch.float32] + Hypothesis probabilities in a logarithmic scale of shape (num_hyps, 1). + + Returns + ------- + torch.Tensor[torch.float32] + A float output tensor of logit token probabilities of shape (num_hyps, vocab_size). + """ + + decoder_out = self.decoder(decoder_input) + decoder_out = self.decoder_proj(decoder_out) + + logits = self.joiner(encoder_out, decoder_out[:, 0, :]) + + tokens_log_prob = torch.log_softmax(logits, dim=1) + log_probs = (tokens_log_prob + hyps_log_prob).reshape(-1) + + hyps_topk_log_prob, topk_indexes = log_probs.topk(self.beam) + topk_hyp_indexes = torch.floor_divide(topk_indexes, self.vocab_size).to(torch.int32) + topk_token_indexes = torch.remainder(topk_indexes, self.vocab_size).to(torch.int32) + tokens_topk_prob = torch.exp(tokens_log_prob.reshape(-1)[topk_indexes]) + + return hyps_topk_log_prob, tokens_topk_prob, topk_hyp_indexes, topk_token_indexes \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer/joiner.py b/egs/librispeech/ASR/zipformer/joiner.py index 0406efe834..23aa1ae531 100644 --- a/egs/librispeech/ASR/zipformer/joiner.py +++ b/egs/librispeech/ASR/zipformer/joiner.py @@ -19,49 +19,42 @@ from scaling import ScaledLinear -class Joiner(nn.Module): - def __init__( - self, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - ): - super().__init__() - - self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25) - self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25) - self.output_linear = nn.Linear(joiner_dim, vocab_size) +class Joiner(torch.nn.Module): - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - project_input: bool = True, - ) -> torch.Tensor: + def __init__(self, joiner_dim: int, vocab_size: int, device: torch.device) -> None: """ - Args: - encoder_out: - Output from the encoder. Its shape is (N, T, s_range, C). - decoder_out: - Output from the decoder. Its shape is (N, T, s_range, C). - project_input: - If true, apply input projections encoder_proj and decoder_proj. - If this is false, it is the user's responsibility to do this - manually. - Returns: - Return a tensor of shape (N, T, s_range, C). + Joiner initialization. + + Parameters + ---------- + joiner_dim : int + Input joiner dimension. + vocab_size : int + Output joiner dimension, the vocabulary size, the number of BPEs of the model. + device : torch.device + The device used to store the layer weights. Should be + either torch.device("cpu") or torch.device("cuda"). """ - assert encoder_out.ndim == decoder_out.ndim, ( - encoder_out.shape, - decoder_out.shape, - ) - if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) - else: - logit = encoder_out + decoder_out + super().__init__() + + self.output_linear = torch.nn.Linear(joiner_dim, vocab_size, device=device) - logit = self.output_linear(torch.tanh(logit)) + def forward(self, encoder_out: torch.Tensor, decoder_out: torch.Tensor) -> torch.Tensor: + """ + Does a forward pass of the Joiner module. Returns an output tensor after a simple joining. + + Parameters + ---------- + encoder_out : torch.Tensor[torch.float32] + An output tensor from the encoder after projection of shape (N, joiner_dim). + decoder_out : torch.Tensor[torch.float32] + An output tensor from the decoder after projection of shape (N, joiner_dim). + + Returns + ------- + torch.Tensor[torch.float32] + A float output tensor of log token probabilities of shape (N, vocab_size). + """ - return logit + return self.output_linear(torch.tanh(encoder_out + decoder_out)) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index d345c29316..c1ef1e3e60 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -27,1641 +27,219 @@ from torch.cuda.amp import custom_bwd, custom_fwd -def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: - max_value = torch.max(x, y) - diff = torch.abs(x - y) - return max_value + torch.log1p(torch.exp(-diff)) - - -# RuntimeError: Exporting the operator logaddexp to ONNX opset version -# 14 is not supported. Please feel free to request support or submit -# a pull request on PyTorch GitHub. -# -# The following function is to solve the above error when exporting -# models to ONNX via torch.jit.trace() -def logaddexp(x: Tensor, y: Tensor) -> Tensor: - # Caution(fangjun): Put torch.jit.is_scripting() before - # torch.onnx.is_in_onnx_export(); - # otherwise, it will cause errors for torch.jit.script(). - # - # torch.logaddexp() works for both torch.jit.script() and - # torch.jit.trace() but it causes errors for ONNX export. - # - if torch.jit.is_scripting(): - # Note: We cannot use torch.jit.is_tracing() here as it also - # matches torch.onnx.export(). - return torch.logaddexp(x, y) - elif torch.onnx.is_in_onnx_export(): - return logaddexp_onnx(x, y) - else: - # for torch.jit.trace() - return torch.logaddexp(x, y) - - -class PiecewiseLinear(object): +class BiasNorm(torch.nn.Module): """ - Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with - the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y] - respectively. + This is intended to be a simpler replacement for LayerNorm. The observation this is based on, + is that Transformer-type networks, especially with pre-normalization, sometimes seem to set one + of the channel dimensions to a large constant value (e.g. 50), which "defeats" the LayerNorm + because the output magnitude is then not strongly dependent on the other (useful) channels. + Presumably the weight and bias of the LayerNorm are required to allow it to do this. Instead, + we give the BiasNorm a trainable bias that it can use when computing the scale for + normalization. We also give it a scalar trainable scale on the output. """ - def __init__(self, *args): - assert len(args) >= 1, len(args) - if len(args) == 1 and isinstance(args[0], PiecewiseLinear): - self.pairs = list(args[0].pairs) - else: - self.pairs = [(float(x), float(y)) for x, y in args] - for x, y in self.pairs: - assert isinstance(x, (float, int)), type(x) - assert isinstance(y, (float, int)), type(y) - - for i in range(len(self.pairs) - 1): - assert self.pairs[i + 1][0] > self.pairs[i][0], ( - i, - self.pairs[i], - self.pairs[i + 1], - ) - - def __str__(self): - # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' - return f"PiecewiseLinear({str(self.pairs)[1:-1]})" - - def __call__(self, x): - if x <= self.pairs[0][0]: - return self.pairs[0][1] - elif x >= self.pairs[-1][0]: - return self.pairs[-1][1] - else: - cur_x, cur_y = self.pairs[0] - for i in range(1, len(self.pairs)): - next_x, next_y = self.pairs[i] - if x >= cur_x and x <= next_x: - return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x) - cur_x, cur_y = next_x, next_y - assert False - - def __mul__(self, alpha): - return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs]) - - def __add__(self, x): - if isinstance(x, (float, int)): - return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs]) - s, x = self.get_common_basis(x) - return PiecewiseLinear( - *[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)] - ) - - def max(self, x): - if isinstance(x, (float, int)): - x = PiecewiseLinear((0, x)) - s, x = self.get_common_basis(x, include_crossings=True) - return PiecewiseLinear( - *[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] - ) - - def min(self, x): - if isinstance(x, float) or isinstance(x, int): - x = PiecewiseLinear((0, x)) - s, x = self.get_common_basis(x, include_crossings=True) - return PiecewiseLinear( - *[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] - ) - - def __eq__(self, other): - return self.pairs == other.pairs - - def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False): + def __init__(self, num_channels: int, device: torch.device) -> None: """ - Returns (self_mod, p_mod) which are equivalent piecewise linear - functions to self and p, but with the same x values. - - p: the other piecewise linear function - include_crossings: if true, include in the x values positions - where the functions indicate by this and p cross. + BiasNorm initialization. + + Parameters + ---------- + num_channels : int + The number of input channels. + device : torch.device + The device used to store the layer weights. + Either torch.device("cpu") or torch.device("cuda"). """ - assert isinstance(p, PiecewiseLinear), type(p) - - # get sorted x-values without repetition. - x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs])) - y_vals1 = [self(x) for x in x_vals] - y_vals2 = [p(x) for x in x_vals] - - if include_crossings: - extra_x_vals = [] - for i in range(len(x_vals) - 1): - if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]): - # if the two lines in this subsegment potentially cross each other.. - diff_cur = abs(y_vals1[i] - y_vals2[i]) - diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1]) - # `pos`, between 0 and 1, gives the relative x position, - # with 0 being x_vals[i] and 1 being x_vals[i+1]. - pos = diff_cur / (diff_cur + diff_next) - extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i]) - extra_x_vals.append(extra_x_val) - if len(extra_x_vals) > 0: - x_vals = sorted(set(x_vals + extra_x_vals)) - y_vals1 = [self(x) for x in x_vals] - y_vals2 = [p(x) for x in x_vals] - return ( - PiecewiseLinear(*zip(x_vals, y_vals1)), - PiecewiseLinear(*zip(x_vals, y_vals2)), - ) - - -class ScheduledFloat(torch.nn.Module): - """ - This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); - it does not have a working forward() function. You are supposed to cast it to float, as - in, float(parent_module.whatever), and use it as something like a dropout prob. - It is a floating point value whose value changes depending on the batch count of the - training loop. It is a piecewise linear function where you specify the (x,y) pairs - in sorted order on x; x corresponds to the batch index. For batch-index values before the - first x or after the last x, we just use the first or last y value. - - Example: - self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) - - `default` is used when self.batch_count is not set or not in training mode or in - torch.jit scripting mode. - """ - - def __init__(self, *args, default: float = 0.0): super().__init__() - # self.batch_count and self.name will be written to in the training loop. - self.batch_count = None - self.name = None - self.default = default - self.schedule = PiecewiseLinear(*args) - - def extra_repr(self) -> str: - return ( - f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}" - ) - def __float__(self): - batch_count = self.batch_count - if ( - batch_count is None - or not self.training - or torch.jit.is_scripting() - or torch.jit.is_tracing() - ): - return float(self.default) - else: - ans = self.schedule(self.batch_count) - if random.random() < 0.0002: - logging.info( - f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}" - ) - return ans - - def __add__(self, x): - if isinstance(x, float) or isinstance(x, int): - return ScheduledFloat(self.schedule + x, default=self.default) - else: - return ScheduledFloat( - self.schedule + x.schedule, default=self.default + x.default - ) - - def max(self, x): - if isinstance(x, float) or isinstance(x, int): - return ScheduledFloat(self.schedule.max(x), default=self.default) - else: - return ScheduledFloat( - self.schedule.max(x.schedule), default=max(self.default, x.default) - ) - - -FloatLike = Union[float, ScheduledFloat] - - -def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: - """ - A randomized way of casting a floating point value to half precision. - """ - if x.dtype == torch.float16: - return x - x_abs = x.abs() - is_too_small = x_abs < min_abs - # for elements where is_too_small is true, random_val will contain +-min_abs with - # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, - # for those elements]. - random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) - return torch.where(is_too_small, random_val, x).to(torch.float16) - - -class CutoffEstimator: - """ - Estimates cutoffs of an arbitrary numerical quantity such that a specified - proportion of items will be above the cutoff on average. - - p is the proportion of items that should be above the cutoff. - """ - - def __init__(self, p: float): - self.p = p - # total count of items - self.count = 0 - # total count of items that were above the cutoff - self.count_above = 0 - # initial cutoff value - self.cutoff = 0 + self.scale = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float32, device=device)) + self.bias = torch.nn.Parameter( + torch.zeros(num_channels, dtype=torch.float32, device=device), + ) - def __call__(self, x: float) -> bool: + def forward(self, x: torch.Tensor) -> torch.Tensor: """ - Returns true if x is above the cutoff. + Does a forward pass of the BiasNorm module. + + Parameters + ---------- + x : torch.Tensor[torch.float32] + A float tensor of shape (1, seq_len, num_channels). The module input. + + Returns + ------- + torch.Tensor[torch.float32] + A float tensor of shape (1, seq_len, num_channels). + A normalized output tensor of the same shape as input. """ - ans = x > self.cutoff - self.count += 1 - if ans: - self.count_above += 1 - cur_p = self.count_above / self.count - delta_p = cur_p - self.p - if (delta_p > 0) == ans: - q = abs(delta_p) - self.cutoff = x * q + self.cutoff * (1 - q) - return ans - - -class SoftmaxFunction(torch.autograd.Function): - """ - Tries to handle half-precision derivatives in a randomized way that should - be more accurate for training than the default behavior. - """ - - @staticmethod - def forward(ctx, x: Tensor, dim: int): - ans = x.softmax(dim=dim) - # if x dtype is float16, x.softmax() returns a float32 because - # (presumably) that op does not support float16, and autocast - # is enabled. - if torch.is_autocast_enabled(): - ans = ans.to(torch.get_autocast_gpu_dtype()) - ctx.save_for_backward(ans) - ctx.x_dtype = x.dtype - ctx.dim = dim - return ans - - @staticmethod - def backward(ctx, ans_grad: Tensor): - (ans,) = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): - ans_grad = ans_grad.to(torch.float32) - ans = ans.to(torch.float32) - x_grad = ans_grad * ans - x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) - return x_grad, None - - -def softmax(x: Tensor, dim: int): - if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing(): - return x.softmax(dim=dim) - - return SoftmaxFunction.apply(x, dim) - - -class MaxEigLimiterFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float, - ) -> Tensor: - ctx.channel_dim = channel_dim - ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) - return x - - @staticmethod - def backward(ctx, x_grad, *args): - with torch.enable_grad(): - (x_orig, coeffs, new_direction) = ctx.saved_tensors - x_orig.requires_grad = True - num_channels = x_orig.shape[ctx.channel_dim] - x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) - new_direction.requires_grad = False - x = x - x.mean(dim=0) - x_var = (x**2).mean() - x_residual = x - coeffs * new_direction - x_residual_var = (x_residual**2).mean() - # `variance_proportion` is the proportion of the variance accounted for - # by the top eigen-direction. This is to be minimized. - variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) - variance_proportion.backward() - x_orig_grad = x_orig.grad - x_extra_grad = ( - x_orig.grad - * ctx.grad_scale - * x_grad.norm() - / (x_orig_grad.norm() + 1.0e-20) - ) - return x_grad + x_extra_grad.detach(), None, None, None, None - - -class BiasNormFunction(torch.autograd.Function): - # This computes: - # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp() - # return x * scales - # (after unsqueezing the bias), but it does it in a memory-efficient way so that - # it can just store the returned value (chances are, this will also be needed for - # some other reason, related to the next operation, so we can save memory). - @staticmethod - def forward( - ctx, - x: Tensor, - bias: Tensor, - log_scale: Tensor, - channel_dim: int, - store_output_for_backprop: bool, - ) -> Tensor: - assert bias.ndim == 1 - if channel_dim < 0: - channel_dim = channel_dim + x.ndim - ctx.store_output_for_backprop = store_output_for_backprop - ctx.channel_dim = channel_dim - for _ in range(channel_dim + 1, x.ndim): - bias = bias.unsqueeze(-1) - scales = ( - torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 - ) * log_scale.exp() - ans = x * scales - ctx.save_for_backward( - ans.detach() if store_output_for_backprop else x, - scales.detach(), - bias.detach(), - log_scale.detach(), - ) - return ans - @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tensor: - ans_or_x, scales, bias, log_scale = ctx.saved_tensors - if ctx.store_output_for_backprop: - x = ans_or_x / scales - else: - x = ans_or_x - x = x.detach() - x.requires_grad = True - bias.requires_grad = True - log_scale.requires_grad = True - with torch.enable_grad(): - # recompute scales from x, bias and log_scale. - scales = ( - torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5 - ) * log_scale.exp() - ans = x * scales - ans.backward(gradient=ans_grad) - return x.grad, bias.grad.flatten(), log_scale.grad, None, None + return x * self.scale / torch.mean((x - self.bias)**2, dim=2, keepdim=True)**0.5 -class BiasNorm(torch.nn.Module): +class ChunkCausalDepthwiseConv1d(torch.nn.Module): """ - This is intended to be a simpler, and hopefully cheaper, replacement for - LayerNorm. The observation this is based on, is that Transformer-type - networks, especially with pre-norm, sometimes seem to set one of the - feature dimensions to a large constant value (e.g. 50), which "defeats" - the LayerNorm because the output magnitude is then not strongly dependent - on the other (useful) features. Presumably the weight and bias of the - LayerNorm are required to allow it to do this. - - Instead, we give the BiasNorm a trainable bias that it can use when - computing the scale for normalization. We also give it a (scalar) - trainable scale on the output. - - - Args: - num_channels: the number of channels, e.g. 512. - channel_dim: the axis/dimension corresponding to the channel, - interpreted as an offset from the input's ndim if negative. - This is NOT the num_channels; it should typically be one of - {-2, -1, 0, 1, 2, 3}. - log_scale: the initial log-scale that we multiply the output by; this - is learnable. - log_scale_min: FloatLike, minimum allowed value of log_scale - log_scale_max: FloatLike, maximum allowed value of log_scale - store_output_for_backprop: only possibly affects memory use; recommend - to set to True if you think the output of this module is more likely - than the input of this module to be required to be stored for the - backprop. + Behaves like a depthwise 1D convolution, except that it is causal in a chunkwise way, as if we + had a block-triangular attention mask.The chunk size is provided at test time, it should be kept + in sync with the attention mask. + + This has a little more than twice the parameters of a conventional depthwise conv1d module: + we implement it by having one depthwise convolution, of half the width, that is causal (via + right padding), and one depthwise convolution that is applied only within chunks, that we + multiply by a scaling factor which depends on the position within the chunk. """ def __init__( - self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - log_scale: float = 1.0, - log_scale_min: float = -1.5, - log_scale_max: float = 1.5, - store_output_for_backprop: bool = False, + self, num_channels: int, kernel_size: int, right_context: int, device: torch.device, ) -> None: - super(BiasNorm, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - self.log_scale = nn.Parameter(torch.tensor(log_scale)) - self.bias = nn.Parameter(torch.empty(num_channels).normal_(mean=0, std=1e-4)) - - self.log_scale_min = log_scale_min - self.log_scale_max = log_scale_max - - self.store_output_for_backprop = store_output_for_backprop - - def forward(self, x: Tensor) -> Tensor: - assert x.shape[self.channel_dim] == self.num_channels - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - channel_dim = self.channel_dim - if channel_dim < 0: - channel_dim += x.ndim - bias = self.bias - for _ in range(channel_dim + 1, x.ndim): - bias = bias.unsqueeze(-1) - scales = ( - torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 - ) * self.log_scale.exp() - return x * scales - - log_scale = limit_param_value( - self.log_scale, - min=float(self.log_scale_min), - max=float(self.log_scale_max), - training=self.training, - ) - - return BiasNormFunction.apply( - x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop - ) - - -def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: - """ - Behaves like a constructor of a modified version of nn.Linear - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Linear(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) - return ans - - -def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: - """ - Behaves like a constructor of a modified version of nn.Conv1d - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Conv1d(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) - return ans - + """ + ChunkCausalDepthwiseConv1d initialization. + + Parameters + ---------- + num_channels : int + The number of input channels. + kernel_size : int + The kernel size for chunkwise convolution. The causal convolution kernel size is + the half of this original value. Should be an odd number. + right_context : int + The module look ahead future context, used to update module left cache correctly. + device : torch.device + The device used to store the layer weights. + Either torch.device("cpu") or torch.device("cuda"). + """ -def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: - """ - Behaves like a constructor of a modified version of nn.Conv2d - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False, but: - NO PADDING-RELATED ARGS. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Conv2d(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) - return ans + super().__init__() + if kernel_size % 2 == 0: + raise ValueError( + 'Kernel size for ChunkCausalDepthwiseConv1d convolution ' + f'module should be an odd number, but got {kernel_size}.', + ) -class ChunkCausalDepthwiseConv1d(torch.nn.Module): - """ - Behaves like a depthwise 1d convolution, except that it is causal in - a chunkwise way, as if we had a block-triangular attention mask. - The chunk size is provided at test time (it should probably be - kept in sync with the attention mask). - - This has a little more than twice the parameters of a conventional - depthwise conv1d module: we implement it by having one - depthwise convolution, of half the width, that is causal (via - right-padding); and one depthwise convolution that is applied only - within chunks, that we multiply by a scaling factor which depends - on the position within the chunk. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ + self.kernel_size = kernel_size + self.right_context = right_context - def __init__( - self, - channels: int, - kernel_size: int, - initial_scale: float = 1.0, - bias: bool = True, - ): - super().__init__() - assert kernel_size % 2 == 1 - - half_kernel_size = (kernel_size + 1) // 2 - # will pad manually, on one side. - self.causal_conv = nn.Conv1d( - in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=half_kernel_size, - padding=0, - bias=True, + self.causal_conv = torch.nn.Conv1d( + num_channels, num_channels, (kernel_size + 1) // 2, groups=num_channels, device=device, ) - self.chunkwise_conv = nn.Conv1d( - in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=kernel_size, + self.chunkwise_conv = torch.nn.Conv1d( + num_channels, + num_channels, + kernel_size, padding=kernel_size // 2, - bias=bias, + groups=num_channels, + device=device, ) - # first row is correction factors added to the scale near the left edge of the chunk, + # First row is correction factors added to the scale near the left edge of the chunk, # second row is correction factors added to the scale near the right edge of the chunk, # both of these are added to a default scale of 1.0. - self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size)) - self.kernel_size = kernel_size - - with torch.no_grad(): - self.causal_conv.weight[:] *= initial_scale - self.chunkwise_conv.weight[:] *= initial_scale - if bias: - torch.nn.init.uniform_( - self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale - ) - - def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor: - """Forward function. - - Args: - x: a Tensor of shape (batch_size, channels, seq_len) - chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. - """ - (batch_size, num_channels, seq_len) = x.shape - - # half_kernel_size = self.kernel_size + 1 // 2 - # left_pad is half_kernel_size - 1 where half_kernel_size is the size used - # in the causal conv. It's the amount by which we must pad on the left, - # to make the convolution causal. - left_pad = self.kernel_size // 2 - - if chunk_size < 0 or chunk_size > seq_len: - chunk_size = seq_len - right_pad = -seq_len % chunk_size - - x = torch.nn.functional.pad(x, (left_pad, right_pad)) - - x_causal = self.causal_conv(x[..., : left_pad + seq_len]) - assert x_causal.shape == (batch_size, num_channels, seq_len) - - x_chunk = x[..., left_pad:] - num_chunks = x_chunk.shape[2] // chunk_size - x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size) - x_chunk = x_chunk.permute(0, 2, 1, 3).reshape( - batch_size * num_chunks, num_channels, chunk_size + self.chunkwise_conv_scale = torch.nn.Parameter( + torch.zeros(2, num_channels, kernel_size, dtype=torch.float32, device=device), ) - x_chunk = self.chunkwise_conv(x_chunk) # does not change shape - - chunk_scale = self._get_chunk_scale(chunk_size) - - x_chunk = x_chunk * chunk_scale - x_chunk = x_chunk.reshape( - batch_size, num_chunks, num_channels, chunk_size - ).permute(0, 2, 1, 3) - x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[ - ..., :seq_len - ] - - return x_chunk + x_causal - - def _get_chunk_scale(self, chunk_size: int): - """Returns tensor of shape (num_channels, chunk_size) that will be used to - scale the output of self.chunkwise_conv.""" - left_edge = self.chunkwise_conv_scale[0] - right_edge = self.chunkwise_conv_scale[1] - if chunk_size < self.kernel_size: - left_edge = left_edge[:, :chunk_size] - right_edge = right_edge[:, -chunk_size:] - else: - t = chunk_size - self.kernel_size - channels = left_edge.shape[0] - pad = torch.zeros( - channels, t, device=left_edge.device, dtype=left_edge.dtype - ) - left_edge = torch.cat((left_edge, pad), dim=-1) - right_edge = torch.cat((pad, right_edge), dim=-1) - return 1.0 + (left_edge + right_edge) - - def streaming_forward( - self, - x: Tensor, - cache: Tensor, - ) -> Tuple[Tensor, Tensor]: - """Streaming Forward function. - - Args: - x: a Tensor of shape (batch_size, channels, seq_len) - cache: cached left context of shape (batch_size, channels, left_pad) - """ - (batch_size, num_channels, seq_len) = x.shape - - # left_pad is half_kernel_size - 1 where half_kernel_size is the size used - # in the causal conv. It's the amount by which we must pad on the left, - # to make the convolution causal. - left_pad = self.kernel_size // 2 - - # Pad cache - assert cache.shape[-1] == left_pad, (cache.shape[-1], left_pad) - x = torch.cat([cache, x], dim=2) - # Update cache - cache = x[..., -left_pad:] - - x_causal = self.causal_conv(x) - assert x_causal.shape == (batch_size, num_channels, seq_len) - - x_chunk = x[..., left_pad:] - x_chunk = self.chunkwise_conv(x_chunk) # does not change shape - - chunk_scale = self._get_chunk_scale(chunk_size=seq_len) - x_chunk = x_chunk * chunk_scale - - return x_chunk + x_causal, cache - -class BalancerFunction(torch.autograd.Function): - @staticmethod def forward( - ctx, - x: Tensor, - min_mean: float, - max_mean: float, - min_rms: float, - max_rms: float, - grad_scale: float, - channel_dim: int, - ) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - ctx.channel_dim = channel_dim - ctx.save_for_backward(x) - ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) - return x - - @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: - (x,) = ctx.saved_tensors - (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config - - try: - with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): - x = x.to(torch.float32) - x = x.detach() - x.requires_grad = True - mean_dims = [i for i in range(x.ndim) if i != channel_dim] - uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True) - mean = x.mean(dim=mean_dims, keepdim=True) - stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() - rms = uncentered_var.clamp(min=1.0e-20).sqrt() - - m = mean / stddev - # part of loss that relates to mean / stddev - m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() - - # put a much larger scale on the RMS-max-limit loss, so that if both it and the - # m_loss are violated we fix the RMS loss first. - rms_clamped = rms.clamp(min=min_rms, max=max_rms) - r_loss = (rms_clamped / rms).log().abs() - - loss = m_loss + r_loss - - loss.backward(gradient=torch.ones_like(loss)) - loss_grad = x.grad - loss_grad_rms = ( - (loss_grad**2) - .mean(dim=mean_dims, keepdim=True) - .sqrt() - .clamp(min=1.0e-20) - ) - - loss_grad = loss_grad * (grad_scale / loss_grad_rms) - - x_grad_float = x_grad.to(torch.float32) - # scale each element of loss_grad by the absolute value of the corresponding - # element of x_grad, which we view as a noisy estimate of its magnitude for that - # (frame and dimension). later we can consider factored versions. - x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) - x_grad = x_grad_mod.to(x_grad.dtype) - except Exception as e: - logging.info( - f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue." - ) - - return x_grad, None, None, None, None, None, None - - -class Balancer(torch.nn.Module): - """ - Modifies the backpropped derivatives of a function to try to encourage, for - each channel, that it is positive at least a proportion `threshold` of the - time. It does this by multiplying negative derivative values by up to - (1+max_factor), and positive derivative values by up to (1-max_factor), - interpolated from 1 at the threshold to those extremal values when none - of the inputs are positive. - - Args: - num_channels: the number of channels - channel_dim: the dimension/axis corresponding to the channel, e.g. - -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - min_positive: the minimum, per channel, of the proportion of the time - that (x > 0), below which we start to modify the derivatives. - max_positive: the maximum, per channel, of the proportion of the time - that (x > 0), above which we start to modify the derivatives. - scale_gain_factor: determines the 'gain' with which we increase the - change in gradient once the constraints on min_abs and max_abs - are violated. - min_abs: the minimum average-absolute-value difference from the mean - value per channel, which we allow, before we start to modify - the derivatives to prevent this. - max_abs: the maximum average-absolute-value difference from the mean - value per channel, which we allow, before we start to modify - the derivatives to prevent this. - prob: determines the minimum probability with which we modify the - gradients for the {min,max}_positive and {min,max}_abs constraints, - on each forward(). This is done randomly to prevent all layers - from doing it at the same time. - """ - - def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: FloatLike = 0.05, - max_positive: FloatLike = 0.95, - min_abs: FloatLike = 0.2, - max_abs: FloatLike = 100.0, - grad_scale: FloatLike = 0.04, - prob: Optional[FloatLike] = None, - ): - super().__init__() - - if prob is None: - prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4) - self.prob = prob - # 5% of the time we will return and do nothing because memory usage is - # too high. - self.mem_cutoff = CutoffEstimator(0.05) - - # actually self.num_channels is no longer needed except for an assertion. - self.num_channels = num_channels - self.channel_dim = channel_dim - self.min_positive = min_positive - self.max_positive = max_positive - self.min_abs = min_abs - self.max_abs = max_abs - self.grad_scale = grad_scale - - def forward(self, x: Tensor) -> Tensor: - if ( - torch.jit.is_scripting() - or not x.requires_grad - or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())) - ): - return _no_op(x) - - prob = float(self.prob) - if random.random() < prob: - # The following inner-functions convert from the way we historically specified - # these limitations, as limits on the absolute value and the proportion of positive - # values, to limits on the RMS value and the (mean / stddev). - def _abs_to_rms(x): - # for normally distributed data, if the expected absolute value is x, the - # expected rms value will be sqrt(pi/2) * x. - return 1.25331413732 * x - - def _proportion_positive_to_mean(x): - def _atanh(x): - eps = 1.0e-10 - # eps is to prevent crashes if x is exactly 0 or 1. - # we'll just end up returning a fairly large value. - return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0 - - def _approx_inverse_erf(x): - # 1 / (sqrt(pi) * ln(2)), - # see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions - # this approximation is extremely crude and gets progressively worse for - # x very close to -1 or +1, but we mostly care about the "middle" region - # e.g. _approx_inverse_erf(0.05) = 0.0407316414078772, - # and math.erf(0.0407316414078772) = 0.045935330944660666, - # which is pretty close to 0.05. - return 0.8139535143 * _atanh(x) - - # first convert x from the range 0..1 to the range -1..1 which the error - # function returns - x = -1 + (2 * x) - return _approx_inverse_erf(x) - - min_mean = _proportion_positive_to_mean(float(self.min_positive)) - max_mean = _proportion_positive_to_mean(float(self.max_positive)) - min_rms = _abs_to_rms(float(self.min_abs)) - max_rms = _abs_to_rms(float(self.max_abs)) - grad_scale = float(self.grad_scale) - - assert x.shape[self.channel_dim] == self.num_channels - - return BalancerFunction.apply( - x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim - ) - else: - return _no_op(x) - - -def penalize_abs_values_gt( - x: Tensor, limit: float, penalty: float, name: str = None -) -> Tensor: - """ - Returns x unmodified, but in backprop will put a penalty for the excess of - the absolute values of elements of x over the limit "limit". E.g. if - limit == 10.0, then if x has any values over 10 it will get a penalty. - - Caution: the value of this penalty will be affected by grad scaling used - in automatic mixed precision training. For this reasons we use this, - it shouldn't really matter, or may even be helpful; we just use this - to disallow really implausible values of scores to be given to softmax. - - The name is for randomly printed debug info. - """ - x_sign = x.sign() - over_limit = (x.abs() - limit) > 0 - # The following is a memory efficient way to penalize the absolute values of - # x that's over the limit. (The memory efficiency comes when you think - # about which items torch needs to cache for the autograd, and which ones it - # can throw away). The numerical value of aux_loss as computed here will - # actually be larger than it should be, by limit * over_limit.sum(), but it - # has the same derivative as the real aux_loss which is penalty * (x.abs() - - # limit).relu(). - aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) - # note: we don't do sum() here on aux)_loss, but it's as if we had done - # sum() due to how with_loss() works. - x = with_loss(x, aux_loss, name) - # you must use x for something, or this will be ineffective. - return x - - -def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. - if x.ndim == 2: - return x.diag() - else: - (batch, dim, dim) = x.shape - x = x.reshape(batch, dim * dim) - x = x[:, :: dim + 1] - assert x.shape == (batch, dim) - return x - - -def _whitening_metric(x: Tensor, num_groups: int): - """ - Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of - of the centered feature covariance are the same within each group's covariance matrix - and also between groups. - Args: - x: a Tensor of shape (*, num_channels) - num_groups: the number of groups of channels, a number >=1 that divides num_channels - Returns: - Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and - greater than 1.0 otherwise. - """ - assert x.dtype != torch.float16 - x = x.reshape(-1, x.shape[-1]) - (num_frames, num_channels) = x.shape - assert num_channels % num_groups == 0 - channels_per_group = num_channels // num_groups - x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) - # x now has shape (num_groups, num_frames, channels_per_group) - # subtract the mean so we use the centered, not uncentered, covariance. - # My experience has been that when we "mess with the gradients" like this, - # it's better not do anything that tries to move the mean around, because - # that can easily cause instability. - x = x - x.mean(dim=1, keepdim=True) - # x_covar: (num_groups, channels_per_group, channels_per_group) - x_covar = torch.matmul(x.transpose(1, 2), x) - x_covar_mean_diag = _diag(x_covar).mean() - # the following expression is what we'd get if we took the matrix product - # of each covariance and measured the mean of its trace, i.e. - # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) - # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) - return metric - - -class WhiteningPenaltyFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, module: nn.Module) -> Tensor: - ctx.save_for_backward(x) - ctx.module = module - return x - - @staticmethod - def backward(ctx, x_grad: Tensor): - (x_orig,) = ctx.saved_tensors - w = ctx.module - - try: - with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): - x_detached = x_orig.to(torch.float32).detach() - x_detached.requires_grad = True - - metric = _whitening_metric(x_detached, w.num_groups) - - if random.random() < 0.005 or __name__ == "__main__": - logging.info( - f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " - f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}" - ) - - if metric < float(w.whitening_limit): - w.prob = w.min_prob - return x_grad, None - else: - w.prob = w.max_prob - metric.backward() - penalty_grad = x_detached.grad - scale = float(w.grad_scale) * ( - x_grad.to(torch.float32).norm() - / (penalty_grad.norm() + 1.0e-20) - ) - penalty_grad = penalty_grad * scale - return x_grad + penalty_grad.to(x_grad.dtype), None - except Exception as e: - logging.info( - f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue." - ) - return x_grad, None - - -class Whiten(nn.Module): - def __init__( - self, - num_groups: int, - whitening_limit: FloatLike, - prob: Union[float, Tuple[float, float]], - grad_scale: FloatLike, - ): + self, x: torch.Tensor, left_cache: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: """ - Args: - num_groups: the number of groups to divide the channel dim into before - whitening. We will attempt to make the feature covariance - within each group, after mean subtraction, as "white" as possible, - while having the same trace across all groups. - whitening_limit: a value greater than 1.0, that dictates how much - freedom we have to violate the constraints. 1.0 would mean perfectly - white, with exactly the same trace across groups; larger values - give more freedom. E.g. 2.0. - prob: the probability with which we apply the gradient modification - (also affects the grad scale). May be supplied as a float, - or as a pair (min_prob, max_prob) - - grad_scale: determines the scale on the gradient term from this object, - relative to the rest of the gradient on the attention weights. - E.g. 0.02 (you may want to use smaller values than this if prob is large) + Does a forward pass of the ChunkCausalDepthwiseConv1d module. Returns processed tensor + of the same shape as input and updated cached convolution tensor of the left context. + + Parameters + ---------- + x : torch.Tensor[torch.float32] + The input float tensor of shape (1, num_channels, seq_len). The module input. + left_cache : torch.Tensor[torch.float32] + A cached convolution tensor of the left context + of shape (1, num_channels, left_cache_len). + + Returns + ------- + tuple[torch.Tensor[torch.float32], torch.Tensor[torch.float32]] + A tuple of two float tensors: + - module output of shape (1, num_channels, seq_len). + A tensor with the same shape as input x. + - updated cached convolution tensor of the left context + of shape (1, num_channels, left_cache_len). """ - super(Whiten, self).__init__() - assert num_groups >= 1 - assert float(whitening_limit) >= 1 - assert float(grad_scale) >= 0 - self.num_groups = num_groups - self.whitening_limit = whitening_limit - self.grad_scale = grad_scale - - if isinstance(prob, float): - prob = (prob, prob) - (self.min_prob, self.max_prob) = prob - assert 0 < self.min_prob <= self.max_prob <= 1 - self.prob = self.max_prob - self.name = None # will be set in training loop - - def forward(self, x: Tensor) -> Tensor: - """ - In the forward pass, this function just returns the input unmodified. - In the backward pass, it will modify the gradients to ensure that the - distribution in each group has close to (lambda times I) as the covariance - after mean subtraction, with the same lambda across groups. - For whitening_limit > 1, there will be more freedom to violate this - constraint. - - Args: - x: the input of shape (*, num_channels) - - Returns: - x, unmodified. You should make sure - you use the returned value, or the graph will be freed - and nothing will happen in backprop. - """ - grad_scale = float(self.grad_scale) - if not x.requires_grad or random.random() > self.prob or grad_scale == 0: - return _no_op(x) - else: - return WhiteningPenaltyFunction.apply(x, self) - - -class WithLoss(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, y: Tensor, name: str): - ctx.y_shape = y.shape - if random.random() < 0.002 and name is not None: - loss_sum = y.sum().item() - logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}") - return x - - @staticmethod - def backward(ctx, ans_grad: Tensor): - return ( - ans_grad, - torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device), - None, - ) - - -def with_loss(x, y, name): - # returns x but adds y.sum() to the loss function. - return WithLoss.apply(x, y, name) - - -class ScaleGradFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, alpha: float) -> Tensor: - ctx.alpha = alpha - return x - - @staticmethod - def backward(ctx, grad: Tensor): - return grad * ctx.alpha, None - - -def scale_grad(x: Tensor, alpha: float): - return ScaleGradFunction.apply(x, alpha) - - -class ScaleGrad(nn.Module): - def __init__(self, alpha: float): - super().__init__() - self.alpha = alpha - - def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: - return x - return scale_grad(x, self.alpha) - - -class LimitParamValue(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, min: float, max: float): - ctx.save_for_backward(x) - assert max >= min - ctx.min = min - ctx.max = max - return x - - @staticmethod - def backward(ctx, x_grad: Tensor): - (x,) = ctx.saved_tensors - # where x < ctx.min, ensure all grads are negative (this will tend to make - # x more positive). - x_grad = x_grad * torch.where( - torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0 - ) - # where x > ctx.max, ensure all grads are positive (this will tend to make - # x more negative). - x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) - return x_grad, None, None - -def limit_param_value( - x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True -): - # You apply this to (typically) an nn.Parameter during training to ensure that its - # (elements mostly) stays within a supplied range. This is done by modifying the - # gradients in backprop. - # It's not necessary to do this on every batch: do it only some of the time, - # to save a little time. - if training and random.random() < prob: - return LimitParamValue.apply(x, min, max) - else: - return x + seq_len = x.size(2) + x_chunk = self.chunkwise_conv(x) # does not change shape -def _no_op(x: Tensor) -> Tensor: - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return x - else: - # a no-op function that will have a node in the autograd graph, - # to avoid certain bugs relating to backward hooks - return x.chunk(1, dim=-1)[0] + x = torch.cat((left_cache, x), dim=2) # Pad with left cache + left_cache = x[ + :, + :, + x.size(2) - self.right_context - left_cache.size(2): + x.size(2) - self.right_context, + ] # Update cache + x_causal = self.causal_conv(x) -class Identity(torch.nn.Module): - def __init__(self): - super(Identity, self).__init__() - - def forward(self, x): - return _no_op(x) - - -class DoubleSwishFunction(torch.autograd.Function): - """ - double_swish(x) = x * torch.sigmoid(x-1) - - This is a definition, originally motivated by its close numerical - similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). - - Memory-efficient derivative computation: - double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) - double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). - Now, s'(x) = s(x) * (1-s(x)). - double_swish'(x) = x * s'(x) + s(x). - = x * s(x) * (1-s(x)) + s(x). - = double_swish(x) * (1-s(x)) + s(x) - ... so we just need to remember s(x) but not x itself. - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - if x.dtype == torch.float16 or x.dtype == torch.bfloat16: - x = x.to(torch.float32) - - s = torch.sigmoid(x - 1.0) - y = x * s - - if requires_grad: - deriv = y * (1 - s) + s - - # notes on derivative of x * sigmoid(x - 1): - # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 - # min \simeq -0.043638. Take floor as -0.044 so it's a lower bund - # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. - # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which - # floors), should be expectation-preserving. - floor = -0.044 - ceil = 1.2 - d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( - deriv + if seq_len < self.kernel_size: + left_edge = self.chunkwise_conv_scale[:1, :, :seq_len] + right_edge = self.chunkwise_conv_scale[1:, :, self.kernel_size - seq_len:] + else: + pad = torch.zeros( + 1, self.chunkwise_conv_scale.size(1), seq_len - self.kernel_size, + dtype=torch.float32, + device=self.chunkwise_conv_scale.device, ) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - (d,) = ctx.saved_tensors - # the same constants as used in forward pass. - floor = -0.043637 - ceil = 1.2 - - d = d * ((ceil - floor) / 255.0) + floor - return y_grad * d - - -class DoubleSwish(torch.nn.Module): - def __init__(self): - super().__init__() + left_edge = torch.cat((self.chunkwise_conv_scale[:1], pad), dim=2) + right_edge = torch.cat((pad, self.chunkwise_conv_scale[1:]), dim=2) - def forward(self, x: Tensor) -> Tensor: - """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). - """ - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return x * torch.sigmoid(x - 1.0) - return DoubleSwishFunction.apply(x) + chunk_scale = 1.0 + left_edge + right_edge + x = x_chunk * chunk_scale + x_causal -# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates. -class Dropout2(nn.Module): - def __init__(self, p: FloatLike): - super().__init__() - self.p = p - - def forward(self, x: Tensor) -> Tensor: - return torch.nn.functional.dropout(x, p=float(self.p), training=self.training) - - -class MulForDropout3(torch.autograd.Function): - # returns (x * y * alpha) where alpha is a float and y doesn't require - # grad and is zero-or-one. - @staticmethod - @custom_fwd - def forward(ctx, x, y, alpha): - assert not y.requires_grad - ans = x * y * alpha - ctx.save_for_backward(ans) - ctx.alpha = alpha - return ans - - @staticmethod - @custom_bwd - def backward(ctx, ans_grad): - (ans,) = ctx.saved_tensors - x_grad = ctx.alpha * ans_grad * (ans != 0) - return x_grad, None, None - - -# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates, -# and it lets you choose one dimension to share the dropout mask over -class Dropout3(nn.Module): - def __init__(self, p: FloatLike, shared_dim: int): - super().__init__() - self.p = p - self.shared_dim = shared_dim - - def forward(self, x: Tensor) -> Tensor: - p = float(self.p) - if not self.training or p == 0: - return _no_op(x) - scale = 1.0 / (1 - p) - rand_shape = list(x.shape) - rand_shape[self.shared_dim] = 1 - mask = torch.rand(*rand_shape, device=x.device) > p - ans = MulForDropout3.apply(x, mask, scale) - return ans - - -class SwooshLFunction(torch.autograd.Function): - """ - swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - if x.dtype == torch.float16 or x.dtype == torch.bfloat16: - x = x.to(torch.float32) - - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - - coeff = -0.08 - - with torch.cuda.amp.autocast(enabled=False): - with torch.enable_grad(): - x = x.detach() - x.requires_grad = True - y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035 - - if not requires_grad: - return y - - y.backward(gradient=torch.ones_like(y)) - - grad = x.grad - floor = coeff - ceil = 1.0 + coeff + 0.005 - - d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( - grad - ) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.get_autocast_gpu_dtype()) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - (d,) = ctx.saved_tensors - # the same constants as used in forward pass. - - coeff = -0.08 - floor = coeff - ceil = 1.0 + coeff + 0.005 - d = d * ((ceil - floor) / 255.0) + floor - return y_grad * d + return x, left_cache class SwooshL(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-L activation.""" - if torch.jit.is_scripting() or torch.jit.is_tracing(): - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 - if not x.requires_grad: - return k2.swoosh_l_forward(x) - else: - return k2.swoosh_l(x) - # return SwooshLFunction.apply(x) - - -class SwooshLOnnx(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-L activation.""" - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035 - - -class SwooshRFunction(torch.autograd.Function): - """ - swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 - - derivatives are between -0.08 and 0.92. - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - - if x.dtype == torch.float16 or x.dtype == torch.bfloat16: - x = x.to(torch.float32) - - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - - with torch.cuda.amp.autocast(enabled=False): - with torch.enable_grad(): - x = x.detach() - x.requires_grad = True - y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 - - if not requires_grad: - return y - y.backward(gradient=torch.ones_like(y)) + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Does a forward pass and returns Swoosh-L activation. + + Parameters + ---------- + x : torch.Tensor[torch.float32] + A float tensor of an arbitrary shape (*). The module input. + + Returns + ------- + torch.Tensor[torch.float32] + A float tensor of an arbitrary shape (*). A Swoosh-L activation output tensor + of the same shape as input x. + """ - grad = x.grad - floor = -0.08 - ceil = 0.925 + logaddexp = torch.clamp(x - 4.0, min=0.0) + torch.log1p(torch.exp(-torch.abs(x - 4.0))) - d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( - grad - ) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.get_autocast_gpu_dtype()) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - (d,) = ctx.saved_tensors - # the same constants as used in forward pass. - floor = -0.08 - ceil = 0.925 - d = d * ((ceil - floor) / 255.0) + floor - return y_grad * d + return logaddexp - 0.08 * x - 0.035 class SwooshR(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-R activation.""" - if torch.jit.is_scripting() or torch.jit.is_tracing(): - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 - if not x.requires_grad: - return k2.swoosh_r_forward(x) - else: - return k2.swoosh_r(x) - # return SwooshRFunction.apply(x) - - -class SwooshROnnx(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-R activation.""" - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return logaddexp_onnx(zero, x - 1.0) - 0.08 * x - 0.313261687 - - -# simple version of SwooshL that does not redefine the backprop, used in -# ActivationDropoutAndLinearFunction. -def SwooshLForward(x: Tensor): - x_offset = x - 4.0 - log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) - return log_sum - 0.08 * x - 0.035 - - -# simple version of SwooshR that does not redefine the backprop, used in -# ActivationDropoutAndLinearFunction. -def SwooshRForward(x: Tensor): - x_offset = x - 1.0 - log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) - return log_sum - 0.08 * x - 0.313261687 - - -class ActivationDropoutAndLinearFunction(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward( - ctx, - x: Tensor, - weight: Tensor, - bias: Optional[Tensor], - activation: str, - dropout_p: float, - dropout_shared_dim: Optional[int], - ): - if dropout_p != 0.0: - dropout_shape = list(x.shape) - if dropout_shared_dim is not None: - dropout_shape[dropout_shared_dim] = 1 - # else it won't be very memory efficient. - dropout_mask = (1.0 / (1.0 - dropout_p)) * ( - torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p - ) - else: - dropout_mask = None - - ctx.save_for_backward(x, weight, bias, dropout_mask) - - ctx.activation = activation - - forward_activation_dict = { - "SwooshL": k2.swoosh_l_forward, - "SwooshR": k2.swoosh_r_forward, - } - # it will raise a KeyError if this fails. This will be an error. We let it - # propagate to the user. - activation_func = forward_activation_dict[activation] - x = activation_func(x) - if dropout_mask is not None: - x = x * dropout_mask - x = torch.nn.functional.linear(x, weight, bias) - return x - - @staticmethod - @custom_bwd - def backward(ctx, ans_grad: Tensor): - saved = ctx.saved_tensors - (x, weight, bias, dropout_mask) = saved - - forward_and_deriv_activation_dict = { - "SwooshL": k2.swoosh_l_forward_and_deriv, - "SwooshR": k2.swoosh_r_forward_and_deriv, - } - # the following lines a KeyError if the activation is unrecognized. - # This will be an error. We let it propagate to the user. - func = forward_and_deriv_activation_dict[ctx.activation] - - y, func_deriv = func(x) - if dropout_mask is not None: - y = y * dropout_mask - # now compute derivative of y w.r.t. weight and bias.. - # y: (..., in_channels), ans_grad: (..., out_channels), - (out_channels, in_channels) = weight.shape - - in_channels = y.shape[-1] - g = ans_grad.reshape(-1, out_channels) - weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels)) - y_deriv = torch.matmul(ans_grad, weight) - bias_deriv = None if bias is None else g.sum(dim=0) - x_deriv = y_deriv * func_deriv - if dropout_mask is not None: - # order versus func_deriv does not matter - x_deriv = x_deriv * dropout_mask - - return x_deriv, weight_deriv, bias_deriv, None, None, None - - -class ActivationDropoutAndLinear(torch.nn.Module): - """ - This merges an activation function followed by dropout and then a nn.Linear module; - it does so in a memory efficient way so that it only stores the input to the whole - module. If activation == SwooshL and dropout_shared_dim != None, this will be - equivalent to: - nn.Sequential(SwooshL(), - Dropout3(dropout_p, shared_dim=dropout_shared_dim), - ScaledLinear(in_channels, out_channels, bias=bias, - initial_scale=initial_scale)) - If dropout_shared_dim is None, the dropout would be equivalent to - Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout - mask is smaller. - - Args: - in_channels: number of input channels, e.g. 256 - out_channels: number of output channels, e.g. 256 - bias: if true, have a bias - activation: the activation function, for now just support SwooshL. - dropout_p: the dropout probability or schedule (happens after nonlinearity). - dropout_shared_dim: the dimension, if any, across which the dropout mask is - shared (e.g. the time dimension). If None, this may be less memory - efficient if there are modules before this one that cache the input - for their backprop (e.g. Balancer or Whiten). - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - bias: bool = True, - activation: str = "SwooshL", - dropout_p: FloatLike = 0.0, - dropout_shared_dim: Optional[int] = -1, - initial_scale: float = 1.0, - ): - super().__init__() - # create a temporary module of nn.Linear that we'll steal the - # weights and bias from - l = ScaledLinear( - in_channels, out_channels, bias=bias, initial_scale=initial_scale - ) - - self.weight = l.weight - # register_parameter properly handles making it a parameter when l.bias - # is None. I think there is some reason for doing it this way rather - # than just setting it to None but I don't know what it is, maybe - # something to do with exporting the module.. - self.register_parameter("bias", l.bias) - - self.activation = activation - self.dropout_p = dropout_p - self.dropout_shared_dim = dropout_shared_dim - - def forward(self, x: Tensor): - if not self.training or torch.jit.is_scripting() or torch.jit.is_tracing(): - if self.activation == "SwooshL": - x = SwooshLForward(x) - elif self.activation == "SwooshR": - x = SwooshRForward(x) - else: - assert False, self.activation - return torch.nn.functional.linear(x, self.weight, self.bias) - - return ActivationDropoutAndLinearFunction.apply( - x, - self.weight, - self.bias, - self.activation, - float(self.dropout_p), - self.dropout_shared_dim, - ) + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Does a forward pass and returns Swoosh-R activation. + + Parameters + ---------- + x : torch.Tensor[torch.float32] + A float tensor of an arbitrary shape (*). The module input. + + Returns + ------- + torch.Tensor[torch.float32] + A float tensor of an arbitrary shape (*). A Swoosh-R activation output tensor + of the same shape as input x. + """ + logaddexp = torch.clamp(x - 1.0, min=0.0) + torch.log1p(torch.exp(-torch.abs(x - 1.0))) -def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: - if num_channels <= x.shape[-1]: - return x[..., :num_channels] - else: - shape = list(x.shape) - shape[-1] = num_channels - shape[-1] - zeros = torch.zeros(shape, dtype=x.dtype, device=x.device) - return torch.cat((x, zeros), dim=-1) + return logaddexp - 0.08 * x - 0.313261687 def _test_whiten(): diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index b2f769d3f6..3860f2f8c3 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -36,371 +36,259 @@ from torch import Tensor, nn -class ConvNeXt(nn.Module): +class ConvNeXt(torch.nn.Module): """ - Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf + The simplified ConvNeXt module interpretation based on https://arxiv.org/pdf/2206.14747.pdf. """ - def __init__( - self, - channels: int, - hidden_ratio: int = 3, - kernel_size: Tuple[int, int] = (7, 7), - layerdrop_rate: FloatLike = None, - ): - super().__init__() - self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) - hidden_channels = channels * hidden_ratio - if layerdrop_rate is None: - layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015)) - self.layerdrop_rate = layerdrop_rate - - self.depthwise_conv = nn.Conv2d( - in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=kernel_size, - padding=self.padding, - ) - - self.pointwise_conv1 = nn.Conv2d( - in_channels=channels, out_channels=hidden_channels, kernel_size=1 - ) - - self.hidden_balancer = Balancer( - hidden_channels, - channel_dim=1, - min_positive=0.3, - max_positive=1.0, - min_abs=0.75, - max_abs=5.0, - ) - - self.activation = SwooshL() - self.pointwise_conv2 = ScaledConv2d( - in_channels=hidden_channels, - out_channels=channels, - kernel_size=1, - initial_scale=0.01, - ) - - self.out_balancer = Balancer( - channels, - channel_dim=1, - min_positive=0.4, - max_positive=0.6, - min_abs=1.0, - max_abs=6.0, - ) - self.out_whiten = Whiten( - num_groups=1, - whitening_limit=5.0, - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: - return self.forward_internal(x) - layerdrop_rate = float(self.layerdrop_rate) - - if layerdrop_rate != 0.0: - batch_size = x.shape[0] - mask = ( - torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) - > layerdrop_rate - ) - else: - mask = None - # turns out this caching idea does not work with --world-size > 1 - # return caching_eval(self.forward_internal, x, mask) - return self.forward_internal(x, mask) - - def forward_internal( - self, x: Tensor, layer_skip_mask: Optional[Tensor] = None - ) -> Tensor: + def __init__(self, num_channels: int, device: torch.device) -> None: """ - x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) - - The returned value has the same shape as x. + ConvNeXt initialization. + + Parameters + ---------- + num_channels : int + The number of input and output channels for ConvNeXt module. + device : torch.device + The device used to store the layer weights. + Either torch.device("cpu") or torch.device("cuda"). """ - bypass = x - x = self.depthwise_conv(x) - x = self.pointwise_conv1(x) - x = self.hidden_balancer(x) - x = self.activation(x) - x = self.pointwise_conv2(x) - if layer_skip_mask is not None: - x = x * layer_skip_mask + super().__init__() - x = bypass + x - x = self.out_balancer(x) + self.padding = 3 + hidden_channels = num_channels * 3 - if x.requires_grad: - x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last - x = self.out_whiten(x) - x = x.transpose(1, 3) # (N, C, H, W) + self.depthwise_conv = torch.nn.Conv2d( + num_channels, + num_channels, + 7, + groups=num_channels, + padding=(0, self.padding), # time, freq + device=device, + ) - return x + self.activation = SwooshL() + self.pointwise_conv1 = torch.nn.Conv2d(num_channels, hidden_channels, 1, device=device) + self.pointwise_conv2 = torch.nn.Conv2d(hidden_channels, num_channels, 1, device=device) - def streaming_forward( - self, - x: Tensor, - cached_left_pad: Tensor, - ) -> Tuple[Tensor, Tensor]: + def forward(self, x: torch.Tensor) -> torch.Tensor: """ - Args: - x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) - cached_left_pad: (batch_size, num_channels, left_pad, num_freqs) - - Returns: - - The returned value has the same shape as x. - - Updated cached_left_pad. + Does a forward pass of the ConvNeXt module. + + Parameters + ---------- + x : torch.Tensor[torch.float32] + An input float tensor of shape (1, num_channels, num_input_frames, num_freqs). + + Returns + ------- + torch.Tensor[torch.float32] + A output float tensor of the same shape as input, + (1, num_channels, num_output_frames, num_freqs). """ - padding = self.padding - - # The length without right padding for depth-wise conv - T = x.size(2) - padding[0] - bypass = x[:, :, :T, :] + bypass = x[:, :, self.padding: x.size(2) - self.padding] - # Pad left side - assert cached_left_pad.size(2) == padding[0], ( - cached_left_pad.size(2), - padding[0], - ) - x = torch.cat([cached_left_pad, x], dim=2) - # Update cached left padding - cached_left_pad = x[:, :, T : padding[0] + T, :] - - # depthwise_conv - x = torch.nn.functional.conv2d( - x, - weight=self.depthwise_conv.weight, - bias=self.depthwise_conv.bias, - padding=(0, padding[1]), - groups=self.depthwise_conv.groups, - ) + x = self.depthwise_conv(x) x = self.pointwise_conv1(x) - x = self.hidden_balancer(x) x = self.activation(x) x = self.pointwise_conv2(x) x = bypass + x - return x, cached_left_pad - -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/2 length). + return x - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = (T-3)//2 - 2 == (T-7)//2 - It is based on - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa +class Conv2dSubsampling(torch.nn.Module): + """ + Convolutional 2D subsampling module. It performs the prior subsampling + (four times subsampling along the frequency axis and two times - along the time axis), + and low-level descriptor feature extraction from the log mel feature input before passing + it to zipformer encoder. """ def __init__( self, - in_channels: int, - out_channels: int, - layer1_channels: int = 8, - layer2_channels: int = 32, - layer3_channels: int = 128, - dropout: FloatLike = 0.1, + input_dim: int, + output_dim: int, + layer1_channels: int, + layer2_channels: int, + layer3_channels: int, + right_context: int, + device: torch.device, ) -> None: """ - Args: - in_channels: - Number of channels in. The input shape is (N, T, in_channels). - Caution: It requires: T >=7, in_channels >=7 - out_channels - Output dim. The output shape is (N, (T-3)//2, out_channels) - layer1_channels: - Number of channels in layer1 - layer1_channels: - Number of channels in layer2 - bottleneck: - bottleneck dimension for 1d squeeze-excite + Conv2dSubsampling initialization. + + Parameters + ---------- + input_dim : int + The number of input channels. Corresponds to the + number of features in the input feature tensor. + output_dim : int + The number of output channels. + layer1_channels : int + The number of output channels in the first Conv2d layer. + layer2_channels : int + The number of output channels in the second Conv2d layer. + layer3_channels : int + The number of output channels in the third Conv2d layer. + right_context: int + The look-ahead right context that is used to update the left cache. + device : torch.device + The device used to store the layer weights. Should be + either torch.device("cpu") or torch.device("cuda"). """ - assert in_channels >= 7 + super().__init__() - # The ScaleGrad module is there to prevent the gradients - # w.r.t. the weight or bias of the first Conv2d module in self.conv from - # exceeding the range of fp16 when using automatic mixed precision (amp) - # training. (The second one is necessary to stop its bias from getting - # a too-large gradient). + if input_dim < 7: + raise ValueError( + 'The input feature dimension of the Conv2dSubsampling layer, can not be less than ' + 'seven, otherwise the frequency subsampling will result with an empty output. ' + f'Expected input_dim to be at least 7 but got {input_dim}.', + ) + + self.right_context = right_context - self.conv = nn.Sequential( - nn.Conv2d( + # Assume batch size is 1 and the right padding is 10, + # see the forward method on why the right padding is 10. + self.right_pad = torch.full( + (1, 10, input_dim), ZERO_LOG_MEL, dtype=torch.float32, device=device, + ) + self.conv = torch.nn.Sequential( + torch.nn.Conv2d( in_channels=1, out_channels=layer1_channels, kernel_size=3, padding=(0, 1), # (time, freq) + device=device, ), - ScaleGrad(0.2), - Balancer(layer1_channels, channel_dim=1, max_abs=1.0), SwooshR(), - nn.Conv2d( - in_channels=layer1_channels, - out_channels=layer2_channels, - kernel_size=3, - stride=2, - padding=0, - ), - Balancer(layer2_channels, channel_dim=1, max_abs=4.0), + torch.nn.Conv2d(layer1_channels, layer2_channels, 3, stride=2, device=device), SwooshR(), - nn.Conv2d( - in_channels=layer2_channels, - out_channels=layer3_channels, - kernel_size=3, - stride=(1, 2), # (time, freq) - ), - Balancer(layer3_channels, channel_dim=1, max_abs=4.0), + torch.nn.Conv2d(layer2_channels, layer3_channels, 3, stride=(1, 2), device=device), SwooshR(), ) - # just one convnext layer - self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7)) - - # (in_channels-3)//4 - self.out_width = (((in_channels - 1) // 2) - 1) // 2 - self.layer3_channels = layer3_channels - - self.out = nn.Linear(self.out_width * layer3_channels, out_channels) - # use a larger than normal grad_scale on this whitening module; there is - # only one such module, so there is not a concern about adding together - # many copies of this extra gradient term. - self.out_whiten = Whiten( - num_groups=1, - whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0), - prob=(0.025, 0.25), - grad_scale=0.02, - ) + self.convnext = ConvNeXt(layer3_channels, device=device) - # max_log_eps=0.0 is to prevent both eps and the output of self.out from - # getting large, there is an unnecessary degree of freedom. - self.out_norm = BiasNorm(out_channels) - self.dropout = Dropout3(dropout, shared_dim=1) + out_width = (((input_dim - 1) // 2) - 1) // 2 + self.out = torch.nn.Linear(out_width * layer3_channels, output_dim, device=device) + self.out_norm = BiasNorm(output_dim, device=device) def forward( - self, x: torch.Tensor, x_lens: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - - Returns: - - a tensor of shape (N, (T-7)//2, odim) - - output lengths, of shape (batch_size,) + self, x: torch.Tensor, cached_left_pad: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: """ - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - # scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision) - # training, since the weights in the first convolution are otherwise the limiting factor for getting infinite - # gradients. - x = self.conv(x) - x = self.convnext(x) - - # Now x is of shape (N, odim, (T-7)//2, (idim-3)//4) - b, c, t, f = x.size() - - x = x.transpose(1, 2).reshape(b, t, c * f) - # now x: (N, (T-7)//2, out_width * layer3_channels)) - - x = self.out(x) - # Now x is of shape (N, (T-7)//2, odim) - x = self.out_whiten(x) - x = self.out_norm(x) - x = self.dropout(x) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - x_lens = (x_lens - 7) // 2 - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - x_lens = (x_lens - 7) // 2 - assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max()) - - return x, x_lens - - def streaming_forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - cached_left_pad: Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - - Returns: - - a tensor of shape (N, (T-7)//2, odim) - - output lengths, of shape (batch_size,) - - updated cache + Does a forward pass of the Conv2dSubsampling module. + + Parameters + ---------- + x : torch.Tensor[torch.float32] + An input float tensor of shape (1, num_frames, input_dim). An input feature tensor. + cached_left_pad : torch.Tensor[torch.float32] + A left cache float tensor of shape (1, 10, input_dim). Left cache is required + to preserve the "same" left padding to the output of the Conv2dSubsampling module. + See the get_init_states() documentation to understand why we need exactly ten frames + of left padding for the Conv2dSubsampling module. + + Returns + ------- + tuple[torch.Tensor[torch.float32], torch.Tensor[torch.float32]] + A tuple of two float tensors: + - The processing output of the Conv2dSubsampling module + of shape (1, subsampled_num_frames, output_dim). + - The udated left cache tensor of shape (1, 10, input_dim). """ - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - # T' = (T-7)//2 + x = torch.cat((cached_left_pad, x), dim=1) + new_cached_left_pad = x[ + :, + x.size(1) - self.right_context - cached_left_pad.size(1): + x.size(1) - self.right_context, + ] + + # Now when we concatenated the left cache with the input, we need to perform the right + # padding of the input in a way to preserve the "same" type of padding, so that the output + # of the module has the same duration as input (taking 2 times subsampling into account). + # There are two possible outcomes depending on whether the the number of input frames is + # even or odd, but both scenarios can be covered by 10 frames right padding. + + # x : right padding + # | | | | | | | | | | | |:| | | | | | | | | | input + # | | | | | | | | | | |:| | | | | | | | | first Conv2d output from self.conv + # | | | | | :| | | | second Conv2d output from self.conv + # | | | | :| | | third Conv2d output from self.conv + # | : Conv2d output from + # : self.convnext.depthwise_conv + # : + # x : right padding + # | | | | | | | | | | | | |:| | | | | | | | | | input + # | | | | | | | | | | | |:| | | | | | | | | first Conv2d output from self.conv + # | | | | | |: | | | | second Conv2d output from self.conv + # | | | | |: | | | third Conv2d output from self.conv + # | |: Conv2d output from + # : self.convnext.depthwise_conv + # : + + x = torch.cat((x, self.right_pad), dim=1) + + # (1, T, input_dim) -> (1, 1, T, input_dim) i.e., (N, C, H, W) + x = x.unsqueeze(1) x = self.conv(x) + x = self.convnext(x) - # T' = (T-7)//2-3 - x, cached_left_pad = self.convnext.streaming_forward( - x, cached_left_pad=cached_left_pad - ) - - # Now x is of shape (N, odim, T', ((idim-1)//2 - 1)//2) - b, c, t, f = x.size() - - x = x.transpose(1, 2).reshape(b, t, c * f) - # now x: (N, T', out_width * layer3_channels)) - + # Now x is of shape (1, output_dim, T', ((input_dim - 1) // 2 - 1) // 2) + b, c, t, f = x.size() # b is equal to 1 + x = x.permute(0, 2, 1, 3).reshape(b, t, c * f) + # Now x is of shape (T', output_dim * layer3_channels)) x = self.out(x) - # Now x is of shape (N, T', odim) + # Now x is of shape (T', output_dim) x = self.out_norm(x) - if torch.jit.is_scripting() or torch.jit.is_tracing(): - assert self.convnext.padding[0] == 3 - # The ConvNeXt module needs 3 frames of right padding after subsampling - x_lens = (x_lens - 7) // 2 - 3 - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - # The ConvNeXt module needs 3 frames of right padding after subsampling - assert self.convnext.padding[0] == 3 - x_lens = (x_lens - 7) // 2 - 3 + return x, new_cached_left_pad - assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max()) - - return x, x_lens, cached_left_pad +def get_init_states(input_dim: int, device: torch.device) -> torch.Tensor: + """ + Get initial states for Conv2dSubsampling module. The Conv2dSubsampling.conv consists of three + consecutive Conv2d layers with the kernel size 3 and no padding, also the middle Conv2d + has a stride 2, while the rest have the default stride 1. We want to pad the input from the + left side with cached_left_pad in the "same" way, so when we pass it through + the Conv2dSubsampling.conv and Conv2dSubsampling.convnext we end up with exactly zero padding + frames from the left. + + cached_left_pad : x + | | | | | | | | | |:| | | | | | | | | | | input + | | | | | | | | |:| | | | | | | | | | | first Conv2d output from Conv2dSubsampling.conv + | | | | :| | | | | | ... second Conv2d output from Conv2dSubsampling.conv + | | | :| | | | | | third Conv2d output from Conv2dSubsampling.conv + :| | | | | | Conv2d output from + : Conv2dSubsampling.convnext.depthwise_conv + + As we can see from the picture above, in order to preserve the "same" + padding from the left side we need + ((((pad - 1) - 1) // 2) - 1) - 3 = 0 --> pad = 10. + + Parameters + ---------- + input_dim : int + The number of input channels. + Corresponds to the number of features in the input of the Conv2dSubsampling module. + device : torch.device + The device used to store the left cache tensor. + Either torch.device("cpu") or torch.device("cuda"). + + Returns + ------- + torch.Tensor[torch.float32] + A left cache float tensor. The output shape is (1, 10, input_dim). + """ - @torch.jit.export - def get_init_states( - self, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), - ) -> Tensor: - """Get initial states for Conv2dSubsampling module. - It is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - """ - left_pad = self.convnext.padding[0] - freq = self.out_width - channels = self.layer3_channels - cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to( - device - ) + pad = 10 + cached_left_pad = torch.full( + (1, pad, input_dim), ZERO_LOG_MEL, dtype=torch.float32, device=device, + ) - return cached_embed_left_pad + return cached_left_pad diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 2a0ae01297..f6fa7ded7d 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -48,520 +48,723 @@ from torch import Tensor, nn -class Zipformer2(EncoderInterface): +class Zipformer2(torch.nn.Module): """ - Args: - - Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length - as downsampling_factor if they are single ints or one-element tuples. The length of - downsampling_factor defines the number of stacks. - - output_downsampling_factor (int): how much to downsample at the output. Note: - we also downsample by a factor of 2 in the Conv2dSubsampling encoder. - You should probably leave this at 2. - downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. - Note: this is in addition to the downsampling factor of 2 that is applied in - the frontend (self.encoder_embed). - encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per - encoder stack. - num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack - encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of - the encoder stacks for purposes of per-frame dropout (recommend 256 for - now). - query_head_dim (int or Tuple[int]): dimension of query and key per attention - head: per stack, if a tuple.. - pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per - attention head - value_head_dim (int or Tuple[int]): dimension of value in each attention head - num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. - Must be at least 4. - feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules - cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module - - pos_dim (int): the dimension of each positional-encoding vector prior to projection, - e.g. 128. - - dropout (float): dropout rate - warmup_batches (float): number of batches to warm up over; this controls - dropout of encoder layers. - causal (bool): if True, support chunkwise causal convolution. This should - not hurt WER as no modeling power is lost, but the convolution modules will be - slightly slower and use more memory. Enables use of the chunk_size and - left_context_chunks options in forward(), which simulates streaming - decoding. - chunk_size: (list of int): only set this to other than [-1] if causal; - the chunk size will be randomly chosen from this list. -1 means no chunking. - left_context_frames: (list of int): determines the number of left- - context chunks for causal training; will be rounded to a number of - chunks. Must not be less than cnn_module_kernel (after factoring in - rounding and downsampling); an error will be thrown if this is violated. + Zipformer2 encoder. """ + # pylint: disable=too-many-instance-attributes def __init__( self, - output_downsampling_factor: int = 2, - downsampling_factor: Tuple[int] = (2, 4), - encoder_dim: Union[int, Tuple[int]] = 384, - num_encoder_layers: Union[int, Tuple[int]] = 4, - encoder_unmasked_dim: Union[int, Tuple[int]] = 256, - query_head_dim: Union[int, Tuple[int]] = 24, - pos_head_dim: Union[int, Tuple[int]] = 4, - value_head_dim: Union[int, Tuple[int]] = 12, - num_heads: Union[int, Tuple[int]] = 8, - feedforward_dim: Union[int, Tuple[int]] = 1536, - cnn_module_kernel: Union[int, Tuple[int]] = 31, - pos_dim: int = 192, - dropout: FloatLike = None, # see code below for default - warmup_batches: float = 4000.0, - causal: bool = False, - chunk_size: Tuple[int] = [-1], - left_context_frames: Tuple[int] = [-1], + input_dim: int, + subsample_output_dim: int, + subsample_layer1_channels: int, + subsample_layer2_channels: int, + subsample_layer3_channels: int, + encoder_dims: list[int], + num_encoder_layers: list[int], + downsampling_factors: list[int], + num_heads: list[int], + feedforward_dims: list[int], + cnn_module_kernels: list[int], + query_head_dim: int, + pos_head_dim: int, + value_head_dim: int, + pos_dim: int, + pos_max_len: int, + output_dim: int, + use_ctc: bool, + left_context_frames: int, + right_context_frames: int, + device: torch.device, ) -> None: - super(Zipformer2, self).__init__() - - if dropout is None: - dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) - - def _to_tuple(x): - """Converts a single int or a 1-tuple of an int to a tuple with the same length - as downsampling_factor""" - if isinstance(x, int): - x = (x,) - if len(x) == 1: - x = x * len(downsampling_factor) - else: - assert len(x) == len(downsampling_factor) and isinstance(x[0], int) - return x + """ + Zipformer2 initialization. + + Parameters + ---------- + input_dim : int + The number of input features. + subsample_output_dim : int + The output dimension of the subsampling module represented by Conv2dSubsampling. + subsample_layer1_channels : int + The number of output channels in the first Conv2d layer of the + Conv2dSubsampling module. + subsample_layer2_channels : int + The number of output channels in the second Conv2d layer of the + Conv2dSubsampling module. + subsample_layer3_channels : int + The number of output channels in the third Conv2d layer of the + Conv2dSubsampling module. + encoder_dims : list[int] + A list of 5 integers, the embedding dimension of + Zipformer2EncoderLayer module in each Zipformer2Encoder stack. + num_encoder_layers : list[int] + A list of 5 integers, the number of Zipformer2EncoderLayer + modules in each Zipformer2Encoder stack. + downsampling_factors : list[int] + A list of 5 integers, the downsampling factor of each Zipformer2Encoder stack. + Note: this is in addition to the downsampling factor of 2 that is applied in the + Conv2dSubsampling module. + num_heads : list[int] + A list of 5 integers, the number of heads for attention weights and self-attention of + the Zipformer2EncoderLayer module in each Zipformer2Encoder stack. + feedforward_dims : list[int] + A list of 5 integers, the hidden dimension of the feedforward module of + the Zipformer2EncoderLayer module in each Zipformer2Encoder stack. + cnn_module_kernels : list[int] + A list of 5 integers, the kernel size of the convolution module of + the Zipformer2EncoderLayer module in each Zipformer2Encoder stack. + query_head_dim : int + The dimension of the query and key per attention head in attention weights of the + Zipformer2EncoderLayer module in each Zipformer2Encoder stack. + pos_head_dim : int + The dimension of the projected positional encoding per attention head in attention + weights of the Zipformer2EncoderLayer module in each Zipformer2Encoder stack. + value_head_dim : int + The dimension of the value per attention head in self-attention of + the Zipformer2EncoderLayer module in each Zipformer2Encoder stack. + pos_dim: int + The dimension of the relative positional embeddings in each Zipformer2Encoder stack. + pos_max_len : int + The maximum input duration of the relative positional embeddings in each + Zipformer2Encoder stack. Note: if the input duration of any positional embedding module + exceeds this number, then one might end up with a big degradation of inference speed. + output_dim : int + The output dimension after final output projection. + use_ctc : bool + If True, assuming that ctc head will loaded to the output encoder projection. + In this case torch.nn.functional. will be applied to the output at the very end. + left_context_frames : int + The left context number of frames after the initial subsampling with + Conv2dSubsampling module. + right_context_frames : int + The right (look-ahead) context number of frames. + device : torch.device + The device used to store the layer weights. Should be + either torch.device("cpu") or torch.device("cuda"). + """ + # pylint: disable=too-many-arguments,too-many-locals - self.output_downsampling_factor = output_downsampling_factor # int - self.downsampling_factor = downsampling_factor # tuple - self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple - self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple( - encoder_unmasked_dim - ) # tuple - num_encoder_layers = _to_tuple(num_encoder_layers) - self.num_encoder_layers = num_encoder_layers - self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) - self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) - pos_head_dim = _to_tuple(pos_head_dim) - self.num_heads = num_heads = _to_tuple(num_heads) - feedforward_dim = _to_tuple(feedforward_dim) - self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) - - self.causal = causal - self.chunk_size = chunk_size + super().__init__() + + if not ( + len(encoder_dims) + == len(num_encoder_layers) + == len(downsampling_factors) + == len(num_heads) + == len(feedforward_dims) + == len(cnn_module_kernels) + == 6 + ): + raise ValueError( + 'It is required that the length of encoder_dims, num_encoder_layers, ' + 'downsampling_factors, num_heads, feedforward_dims, and cnn_module_kernels is the ' + 'same and equal to 6, but got following list lengths:\n' + f'len(num_encoder_layers) == {len(num_encoder_layers)}\n' + f'len(downsampling_factors) == {len(downsampling_factors)}\n' + f'len(encoder_dims) == {len(encoder_dims)}\n' + f'len(num_heads) == {len(num_heads)}\n' + f'len(cnn_module_kernels) == {len(cnn_module_kernels)}\n' + f'len(feedforward_dims) == {len(feedforward_dims)}.', + ) + + self.encoder_dims = tuple(encoder_dims) + self.downsampling_factors = tuple(downsampling_factors) self.left_context_frames = left_context_frames + projection_dim = max(encoder_dims) + self.projection_dim = projection_dim + self.ctc = use_ctc - for u, d in zip(encoder_unmasked_dim, encoder_dim): - assert u <= d + self.subsampling = Conv2dSubsampling( + input_dim, + subsample_output_dim, + subsample_layer1_channels, + subsample_layer2_channels, + subsample_layer3_channels, + right_context_frames, + device, + ) - # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder encoders = [] + for i, num_layers in enumerate(num_encoder_layers): - num_encoders = len(downsampling_factor) - for i in range(num_encoders): encoder_layer = Zipformer2EncoderLayer( - embed_dim=encoder_dim[i], - pos_dim=pos_dim, - num_heads=num_heads[i], - query_head_dim=query_head_dim[i], - pos_head_dim=pos_head_dim[i], - value_head_dim=value_head_dim[i], - feedforward_dim=feedforward_dim[i], - dropout=dropout, - cnn_module_kernel=cnn_module_kernel[i], - causal=causal, + encoder_dims[i], + pos_dim, + num_heads[i], + query_head_dim, + pos_head_dim, + value_head_dim, + feedforward_dims[i], + cnn_module_kernels[i], + left_context_frames // downsampling_factors[i], + right_context_frames // 2 // downsampling_factors[i], + device, ) - # For the segment of the warmup period, we let the Conv2dSubsampling - # layer learn something. Then we start to warm up the other encoders. encoder = Zipformer2Encoder( encoder_layer, - num_encoder_layers[i], - pos_dim=pos_dim, - dropout=dropout, - warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), - final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), + num_layers, + encoder_dims[i], + pos_dim, + pos_max_len, + downsampling_factors[i], + device, ) - if downsampling_factor[i] != 1: - encoder = DownsampledZipformer2Encoder( - encoder, - dim=encoder_dim[i], - downsample=downsampling_factor[i], - dropout=dropout, - causal=causal, - ) - encoders.append(encoder) - self.encoders = nn.ModuleList(encoders) + self.encoder_1 = encoders[0] + self.encoder_2 = encoders[1] + self.encoder_3 = encoders[2] + self.encoder_4 = encoders[3] + self.encoder_5 = encoders[4] + self.encoder_6 = encoders[5] - self.downsample_output = SimpleDownsample( - max(encoder_dim), - downsample=output_downsampling_factor, - dropout=dropout, - causal=causal, - ) + self.downsample_output = SimpleDownsample(2, device) + self.projection_output = torch.nn.Linear(projection_dim, output_dim, device=device) - def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]: + def forward( + self, + x: torch.Tensor, + # We need to preserve this explicit arguments reference for the smooth + # TirchScript export with the following ONNX export. + left_cached_subsample_frames: torch.Tensor, + + left_cached_keys_encoder_1: torch.Tensor, + left_cached_nonlin_attentions_encoder_1: torch.Tensor, + left_cached_values_1_encoder_1: torch.Tensor, + left_cached_values_2_encoder_1: torch.Tensor, + left_cached_convolutions_1_encoder_1: torch.Tensor, + left_cached_convolutions_2_encoder_1: torch.Tensor, + + left_cached_keys_encoder_2: torch.Tensor, + left_cached_nonlin_attentions_encoder_2: torch.Tensor, + left_cached_values_1_encoder_2: torch.Tensor, + left_cached_values_2_encoder_2: torch.Tensor, + left_cached_convolutions_1_encoder_2: torch.Tensor, + left_cached_convolutions_2_encoder_2: torch.Tensor, + + left_cached_keys_encoder_3: torch.Tensor, + left_cached_nonlin_attentions_encoder_3: torch.Tensor, + left_cached_values_1_encoder_3: torch.Tensor, + left_cached_values_2_encoder_3: torch.Tensor, + left_cached_convolutions_1_encoder_3: torch.Tensor, + left_cached_convolutions_2_encoder_3: torch.Tensor, + + left_cached_keys_encoder_4: torch.Tensor, + left_cached_nonlin_attentions_encoder_4: torch.Tensor, + left_cached_values_1_encoder_4: torch.Tensor, + left_cached_values_2_encoder_4: torch.Tensor, + left_cached_convolutions_1_encoder_4: torch.Tensor, + left_cached_convolutions_2_encoder_4: torch.Tensor, + + left_cached_keys_encoder_5: torch.Tensor, + left_cached_nonlin_attentions_encoder_5: torch.Tensor, + left_cached_values_1_encoder_5: torch.Tensor, + left_cached_values_2_encoder_5: torch.Tensor, + left_cached_convolutions_1_encoder_5: torch.Tensor, + left_cached_convolutions_2_encoder_5: torch.Tensor, + + left_cached_keys_encoder_6: torch.Tensor, + left_cached_nonlin_attentions_encoder_6: torch.Tensor, + left_cached_values_1_encoder_6: torch.Tensor, + left_cached_values_2_encoder_6: torch.Tensor, + left_cached_convolutions_1_encoder_6: torch.Tensor, + left_cached_convolutions_2_encoder_6: torch.Tensor, + + processed_len: torch.Tensor, + ) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor, torch.Tensor, + ]: """ - In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of - randomized feature masks, one per encoder. - On e.g. 15% of frames, these masks will zero out all encoder dims larger than - some supplied number, e.g. >256, so in effect on those frames we are using - a smaller encoder dim. - - We generate the random masks at this level because we want the 2 masks to 'agree' - all the way up the encoder stack. This will mean that the 1st mask will have - mask values repeated self.zipformer_subsampling_factor times. - - Args: - x: the embeddings (needed for the shape and dtype and device), of shape - (1, batch_size, encoder_dims0) + Does a forward pass of the Zipformer2 module, which represents the whole acoustic encoder. + Returns a tuple with the output tensor, updated left cache feature tensor for subsampling + module, 36 left cache tensors for multiple attention and convolution modules within each of + 6 Zipformer2Encoder modules, and finally, the updated processed length single-element + tensor with the total number of processed frames after subsampling module. + + Parameters + ---------- + x : torch.Tensor[torch.float32] + The input float feature tensor of shape (1, num_frames, input_dim), + where the input_dim corresponds to the number of features. + left_cached_subsample_frames : torch.Tensor[torch.float32] + The subsampling module left cache tensor of shape (1, 10, input_dim). + left_cached_keys_encoder_1 : torch.Tensor[torch.float32] + The cached attention key tensor of the left context of each + Zipformer2EncoderLayer within the first Zipformer2Encoder. + The tensor is of shape (num_layers_1, 1, left_context_len_1, query_dim_1). + left_cached_nonlin_attentions_encoder_1 : torch.Tensor[torch.float32] + The left context cached attention tensor for the non-linear attention module of each + Zipformer2EncoderLayer within the first Zipformer2Encoder. + The tensor is of shape (num_layers_1, 1, left_context_len_1, head_dim_1). + left_cached_values_1_encoder_1 : torch.Tensor[torch.float32] + The cached left context tensor for the first self-attention module of each + Zipformer2EncoderLayer within the first Zipformer2Encoder. + The tensor is of shape (num_layers_1, 1, left_context_len_1, value_dim_1). + left_cached_values_2_encoder_1 : torch.Tensor[torch.float32] + The cached left context tensor for the second self-attention module of each + Zipformer2EncoderLayer within the first Zipformer2Encoder. + The tensor is of shape (num_layers_1, 1, left_context_len_1, value_dim_1). + left_cached_convolutions_1_encoder_1 : torch.Tensor[torch.float32] + The cached left context tensor for the first convolution module of each + Zipformer2EncoderLayer within the first Zipformer2Encoder. + The tensor is of shape (num_layers_1, 1, embed_dim_1, left_cache_len_1). + left_cached_convolutions_2_encoder_1 : torch.Tensor[torch.float32] + The cached left context tensor for the second convolution module of each + Zipformer2EncoderLayer within the first Zipformer2Encoder. + The tensor is of shape (num_layers_1, 1, embed_dim_1, left_cache_len_1). + . + . + . + left_cached_convolutions_2_encoder_6 : torch.Tensor[torch.float32] + The cached left context tensor for the second convolution module of each + Zipformer2EncoderLayer within the sixth Zipformer2Encoder. + The tensor is of shape (num_layers_6, 1, embed_dim_6, left_cache_len_6). + processed_len : torch.Tensor[torch.int32] + The total processed length after subsampling, single-element integer tensor + of shape (1,). + + Returns + ------- + tuple[ + torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32], + torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32], + torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32], + torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32], + torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32], + torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32], + torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32], + torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32], + torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32], + torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32], + torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32], + torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.float32], + torch.Tensor[torch.float32], torch.Tensor[torch.float32], torch.Tensor[torch.int32], + ] + A tuple of 38 float tensors and 1 integer tensor: + - The module output of shape (1, seq_len, output_dim). + - The updated subsampling module left cache tensor of shape (1, 10, input_dim). + - The updated cached attention key tensor of the left context of each + Zipformer2EncoderLayer within the first Zipformer2Encoder. + The tensor is of shape (num_layers_1, 1, left_context_len_1, query_dim_1). + - The updated left context cached attention tensor for the non-linear attention + module of each Zipformer2EncoderLayer within the first Zipformer2Encoder. + The tensor is of shape (num_layers_1, 1, left_context_len_1, head_dim_1). + - The updated cached left context tensor for the first self-attention module of each + Zipformer2EncoderLayer within the first Zipformer2Encoder. + The tensor is of shape (num_layers_1, 1, left_context_len_1, value_dim_1). + - The updated cached left context tensor for the second + self-attention module of each Zipformer2EncoderLayer within the first + Zipformer2Encoder. + The tensor is of shape (num_layers_1, 1, left_context_len_1, value_dim_1). + - The updated cached left context tensor for the first convolution module of each + Zipformer2EncoderLayer within the first Zipformer2Encoder. + The tensor is of shape (num_layers_1, 1, embed_dim_1, left_cache_len_1). + - The updated cached left context tensor for the second convolution module of each + Zipformer2EncoderLayer within the first Zipformer2Encoder. + The tensor is of shape (num_layers_1, 1, embed_dim_1, left_cache_len_1). + . + . + . + - The updated cached left context tensor for the second convolution module of each + Zipformer2EncoderLayer within the sixth Zipformer2Encoder. + The tensor is of shape (num_layers_6, 1, embed_dim_6, left_cache_len_6). + - The updated total processed length tensor after subsampling of shape (1,). """ - num_encoders = len(self.encoder_dim) - if not self.training: - return [1.0] * num_encoders - - (num_frames0, batch_size, _encoder_dims0) = x.shape - - assert self.encoder_dim[0] == _encoder_dims0, ( - self.encoder_dim[0], - _encoder_dims0, - ) - - feature_mask_dropout_prob = 0.125 - - # mask1 shape: (1, batch_size, 1) - mask1 = ( - torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob - ).to(x.dtype) - - # mask2 has additional sequences masked, about twice the number. - mask2 = torch.logical_and( - mask1, - ( - torch.rand(1, batch_size, 1, device=x.device) - > feature_mask_dropout_prob - ).to(x.dtype), - ) - - # dim: (1, batch_size, 2) - mask = torch.cat((mask1, mask2), dim=-1) + # pylint: disable=too-many-arguments,too-many-locals + + x, new_left_cached_subsample_frames = self.subsampling(x, left_cached_subsample_frames) + + batch_size, seq_len, _ = x.size() + src_key_padding_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=x.device) + + # processed_mask is used to mask out the initial self.states if left_context_frames == 6, + # then tensor will contain [5, 4, 3, 2, 1, 0] as if reversed + # torch.arange(left_context_frames). + processed_mask = torch.arange( + self.left_context_frames - 1, -1, -1, dtype=torch.int32, device=x.device, + ).expand(batch_size, self.left_context_frames) + + # (1, left_context_size) i.e. (batch_size, left_context_size) + processed_mask = processed_mask >= processed_len.expand(processed_mask.size()) + + # Update processed lengths + new_processed_len = processed_len + seq_len + + # (1, left_context_size + chunk_size) + src_key_padding_mask = torch.cat((processed_mask, src_key_padding_mask), dim=1) + + # If the last encoder 'x' has the largest dimension, then the 'output' will be just this + # last 'x' unchanged. Otherwise it will be concatenated from different pieces of 'x', + # taking each output channel dimension from the most recent x that has it present. + output = torch.empty( + batch_size, seq_len, self.projection_dim, dtype=torch.float32, device=x.device, + ) + + # We have a number of Zipformer2Encoder stacks fixed and equal to 6 for any Ziformer2 size + # including small, medium and large. For the sake of smoother model TorchScript export we + # engage sequential explicit forward call of each Zipformer2Encoder module instead of using + # torch.nn.ModuleList. + + # Encoder 1 + + ( + x, + new_left_cached_keys_encoder_1, + new_left_cached_nonlin_attentions_encoder_1, + new_left_cached_values_1_encoder_1, + new_left_cached_values_2_encoder_1, + new_left_cached_convolutions_1_encoder_1, + new_left_cached_convolutions_2_encoder_1, + ) = self.encoder_1( + x, + left_cached_keys_encoder_1, + left_cached_nonlin_attentions_encoder_1, + left_cached_values_1_encoder_1, + left_cached_values_2_encoder_1, + left_cached_convolutions_1_encoder_1, + left_cached_convolutions_2_encoder_1, + src_key_padding_mask[:, ::self.downsampling_factors[0]], + ) + output[:, :, :x.size(2)] = x + + # Encoder 2 + + pad = torch.zeros( + x.size(0), x.size(1), self.encoder_dims[1] - x.size(2), + dtype=torch.float32, + device=x.device, + ) + x = torch.cat((x, pad), dim=2) + + ( + x, + new_left_cached_keys_encoder_2, + new_left_cached_nonlin_attentions_encoder_2, + new_left_cached_values_1_encoder_2, + new_left_cached_values_2_encoder_2, + new_left_cached_convolutions_1_encoder_2, + new_left_cached_convolutions_2_encoder_2, + ) = self.encoder_2( + x, + left_cached_keys_encoder_2, + left_cached_nonlin_attentions_encoder_2, + left_cached_values_1_encoder_2, + left_cached_values_2_encoder_2, + left_cached_convolutions_1_encoder_2, + left_cached_convolutions_2_encoder_2, + src_key_padding_mask[:, ::self.downsampling_factors[1]], + ) + output[:, :, :x.size(2)] = x + + # Encoder 3 + + pad = torch.zeros( + x.size(0), x.size(1), self.encoder_dims[2] - x.size(2), + dtype=torch.float32, + device=x.device, + ) + x = torch.cat((x, pad), dim=2) + + ( + x, + new_left_cached_keys_encoder_3, + new_left_cached_nonlin_attentions_encoder_3, + new_left_cached_values_1_encoder_3, + new_left_cached_values_2_encoder_3, + new_left_cached_convolutions_1_encoder_3, + new_left_cached_convolutions_2_encoder_3, + ) = self.encoder_3( + x, + left_cached_keys_encoder_3, + left_cached_nonlin_attentions_encoder_3, + left_cached_values_1_encoder_3, + left_cached_values_2_encoder_3, + left_cached_convolutions_1_encoder_3, + left_cached_convolutions_2_encoder_3, + src_key_padding_mask[:, ::self.downsampling_factors[2]], + ) + output[:, :, :x.size(2)] = x + + # Encoder 4 + + pad = torch.zeros( + x.size(0), x.size(1), self.encoder_dims[3] - x.size(2), + dtype=torch.float32, + device=x.device, + ) + x = torch.cat((x, pad), dim=2) + + ( + x, + new_left_cached_keys_encoder_4, + new_left_cached_nonlin_attentions_encoder_4, + new_left_cached_values_1_encoder_4, + new_left_cached_values_2_encoder_4, + new_left_cached_convolutions_1_encoder_4, + new_left_cached_convolutions_2_encoder_4, + ) = self.encoder_4( + x, + left_cached_keys_encoder_4, + left_cached_nonlin_attentions_encoder_4, + left_cached_values_1_encoder_4, + left_cached_values_2_encoder_4, + left_cached_convolutions_1_encoder_4, + left_cached_convolutions_2_encoder_4, + src_key_padding_mask[:, ::self.downsampling_factors[3]], + ) + output[:, :, :x.size(2)] = x + + # Encoder 5 + + x = x[:, :, :self.encoder_dims[4]] + ( + x, + new_left_cached_keys_encoder_5, + new_left_cached_nonlin_attentions_encoder_5, + new_left_cached_values_1_encoder_5, + new_left_cached_values_2_encoder_5, + new_left_cached_convolutions_1_encoder_5, + new_left_cached_convolutions_2_encoder_5, + ) = self.encoder_5( + x, + left_cached_keys_encoder_5, + left_cached_nonlin_attentions_encoder_5, + left_cached_values_1_encoder_5, + left_cached_values_2_encoder_5, + left_cached_convolutions_1_encoder_5, + left_cached_convolutions_2_encoder_5, + src_key_padding_mask[:, ::self.downsampling_factors[4]], + ) + output[:, :, :x.size(2)] = x + + # Encoder 6 + + x = x[:, :, :self.encoder_dims[5]] + ( + x, + new_left_cached_keys_encoder_6, + new_left_cached_nonlin_attentions_encoder_6, + new_left_cached_values_1_encoder_6, + new_left_cached_values_2_encoder_6, + new_left_cached_convolutions_1_encoder_6, + new_left_cached_convolutions_2_encoder_6, + ) = self.encoder_6( + x, + left_cached_keys_encoder_6, + left_cached_nonlin_attentions_encoder_6, + left_cached_values_1_encoder_6, + left_cached_values_2_encoder_6, + left_cached_convolutions_1_encoder_6, + left_cached_convolutions_2_encoder_6, + src_key_padding_mask[:, ::self.downsampling_factors[5]], + ) + output[:, :, :x.size(2)] = x + + output = self.downsample_output(output) + output = self.projection_output(output) + if self.ctc: + output = torch.nn.functional.log_softmax(output, dim=2) - feature_masks = [] - for i in range(num_encoders): - channels = self.encoder_dim[i] - feature_mask = torch.ones( - 1, batch_size, channels, dtype=x.dtype, device=x.device - ) - u1 = self.encoder_unmasked_dim[i] - u2 = u1 + (channels - u1) // 2 - - feature_mask[:, :, u1:u2] *= mask[..., 0:1] - feature_mask[:, :, u2:] *= mask[..., 1:2] + return ( + output, + # Because of the reasons mentioned in previous comments, + # for the sake of easier TorchScript and ONNX export we + # preserve the explicit listing of each left cache tensor. + new_left_cached_subsample_frames, + + new_left_cached_keys_encoder_1, + new_left_cached_nonlin_attentions_encoder_1, + new_left_cached_values_1_encoder_1, + new_left_cached_values_2_encoder_1, + new_left_cached_convolutions_1_encoder_1, + new_left_cached_convolutions_2_encoder_1, + + new_left_cached_keys_encoder_2, + new_left_cached_nonlin_attentions_encoder_2, + new_left_cached_values_1_encoder_2, + new_left_cached_values_2_encoder_2, + new_left_cached_convolutions_1_encoder_2, + new_left_cached_convolutions_2_encoder_2, + + new_left_cached_keys_encoder_3, + new_left_cached_nonlin_attentions_encoder_3, + new_left_cached_values_1_encoder_3, + new_left_cached_values_2_encoder_3, + new_left_cached_convolutions_1_encoder_3, + new_left_cached_convolutions_2_encoder_3, + + new_left_cached_keys_encoder_4, + new_left_cached_nonlin_attentions_encoder_4, + new_left_cached_values_1_encoder_4, + new_left_cached_values_2_encoder_4, + new_left_cached_convolutions_1_encoder_4, + new_left_cached_convolutions_2_encoder_4, + + new_left_cached_keys_encoder_5, + new_left_cached_nonlin_attentions_encoder_5, + new_left_cached_values_1_encoder_5, + new_left_cached_values_2_encoder_5, + new_left_cached_convolutions_1_encoder_5, + new_left_cached_convolutions_2_encoder_5, + + new_left_cached_keys_encoder_6, + new_left_cached_nonlin_attentions_encoder_6, + new_left_cached_values_1_encoder_6, + new_left_cached_values_2_encoder_6, + new_left_cached_convolutions_1_encoder_6, + new_left_cached_convolutions_2_encoder_6, + + new_processed_len, + ) + + +def get_init_states( + input_dim: int, + num_encoder_layers: list[int], + downsample_left_pad_frames: list[int], + encoder_dims: list[int], + query_dims: list[int], + value_dims: list[int], + head_dims: list[int], + convolution_left_pad_frames: list[int], + device: torch.device, +) -> list[torch.Tensor]: + """ + Get initial states for the Zipformer2 encoder. The method generates a list of torch tensors, + where the first tensor corresponds to a subsampling module left cache. Next, for each + Zipformer2Encoder module we add six cache tensors that are essential for multi-head attention + and convolution modules. Finally, at the end we append a total processed frames tensor, + initialized with zero. + + Parameters + ---------- + input_dim : int + The number of input features. + num_encoder_layers : list[int] + The number of Zipformer2EncoderLayer modules for each Zipformer2Encoder stack. + downsample_left_pad_frames : list[int] + The multi-head attention left context cache frames after downsampling. + encoder_dims : list[int] + The embedding dimension for each Zipformer2Encoder stack. + query_dims : list[int] + The multi-head attention query dimension for each Zipformer2Encoder stack. + value_dims : list[int] + The multi-head attention value dimension for each Zipformer2Encoder stack. + head_dims : list[int] + The non-linear attention head dimension for each Zipformer2Encoder stack. + convolution_left_pad_frames : list[int] + The convolution modules left padding number of frames for each Zipformer2Encoder stack. + device : torch.device + The device used to store cache tensors. Should be + either torch.device("cpu") or torch.device("cuda"). + + Returns + ------- + list[torch.Tensor[torch.float32 | torch.int32]] + A list of left cache tensors. + - A subsampling module left cache tensor of shape (1, 10, input_dim) + - The first Zipformer2Encoder cached attention key tensor of the left context in each + Zipformer2EncoderLayer of the stack. + The tensor is of shape (num_layers_1, 1, left_context_len_1, query_dim_1). + - The first Zipformer2Encoder left context cached attention tensor for the non-linear + attention module in each Zipformer2EncoderLayer of the stack. + The tensor is of shape (num_layers_1, 1, left_context_len_1, head_dim_1). + - The first Zipformer2Encoder cached left context tensor for the first self-attention + module in each Zipformer2EncoderLayer of the stack. + The tensor is of shape (num_layers_1, 1, left_context_len_1, value_dim_1). + - The first Zipformer2Encoder cached left context tensor for the second self-attention + module in each Zipformer2EncoderLayer of the stack. + The tensor is of shape (num_layers_1, 1, left_context_len_1, value_dim_1). + - The first Zipformer2Encoder cached left context tensor for the first convolution module + in each Zipformer2EncoderLayer of the stack. + The tensor is of shape (num_layers_1, 1, encoder_dim_1, conv_left_pad_1). + - The first Zipformer2Encoder cached left context tensor for the second convolution module + in each Zipformer2EncoderLayer of the stack. + The tensor is of shape (num_layers_1, 1, encoder_dim_1, conv_left_pad_1). + . + . + . + - The sixth Zipformer2Encoder cached left context tensor for the second convolution module + in each Zipformer2EncoderLayer of the stack. + The tensor is of shape (num_layers_6, 1, encoder_dim_6, conv_left_pad_6). + - The processed length integer tensor initialized with a single zero element. + The tensor is of shape (1,). + """ + # pylint: disable=too-many-locals + + if not ( + len(num_encoder_layers) + == len(downsample_left_pad_frames) + == len(encoder_dims) + == len(query_dims) + == len(value_dims) + == len(head_dims) + == len(convolution_left_pad_frames) + ): + raise ValueError( + 'It is required that all encoder parameter lists have the same ' + 'length, but got following parameter list lengths:\n' + f'len(num_encoder_layers) == {len(num_encoder_layers)}\n' + f'len(downsample_left_pad_frames) == {len(downsample_left_pad_frames)}\n' + f'len(encoder_dims) == {len(encoder_dims)}\n' + f'len(query_dims) == {len(query_dims)}\n' + f'len(value_dims) == {len(value_dims)}\n' + f'len(nonlin_attn_head_dims) == {len(head_dims)}\n' + f'len(convolution_left_pad_frames) == {len(convolution_left_pad_frames)}.', + ) + + states = [subsampling_get_init_states(input_dim, device)] + for i, num_layers in enumerate(num_encoder_layers): + + encoder_dim = encoder_dims[i] + query_dim = query_dims[i] + value_dim = value_dims[i] + head_dim = head_dims[i] + left_context_len = downsample_left_pad_frames[i] + left_cache_len = convolution_left_pad_frames[i] + + # batch size is 1 + states += [ + torch.zeros( + num_layers, 1, left_context_len, query_dim, dtype=torch.float32, device=device, + ), + torch.zeros( + num_layers, 1, left_context_len, head_dim, dtype=torch.float32, device=device, + ), + torch.zeros( + num_layers, 1, left_context_len, value_dim, dtype=torch.float32, device=device, + ), + torch.zeros( + num_layers, 1, left_context_len, value_dim, dtype=torch.float32, device=device, + ), + torch.zeros( + num_layers, 1, encoder_dim, left_cache_len, dtype=torch.float32, device=device, + ), + torch.zeros( + num_layers, 1, encoder_dim, left_cache_len, dtype=torch.float32, device=device, + ), + ] - feature_masks.append(feature_mask) + states.append(torch.zeros(1, dtype=torch.int32, device=device)) - return feature_masks + return states - def get_chunk_info(self) -> Tuple[int, int]: - """ - Returns chunk_size and left_context_chunks. - """ - if not self.causal: - return -1, -1 - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - assert len(self.chunk_size) == 1, self.chunk_size - chunk_size = self.chunk_size[0] - else: - chunk_size = random.choice(self.chunk_size) - - if chunk_size == -1: - left_context_chunks = -1 - else: - if torch.jit.is_scripting() or torch.jit.is_tracing(): - assert len(self.left_context_frames) == 1, self.left_context_frames - left_context_frames = self.left_context_frames[0] - else: - left_context_frames = random.choice(self.left_context_frames) - # Note: in Python, -1 // n == -1 for n > 0 - left_context_chunks = left_context_frames // chunk_size - if left_context_chunks == 0: - left_context_chunks = 1 - - return chunk_size, left_context_chunks - def forward( - self, - x: Tensor, - x_lens: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - """ - Args: - x: - The input tensor. Its shape is (seq_len, batch_size, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - src_key_padding_mask: - The mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - Returns: - Return a tuple containing 2 tensors: - - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - """ - outputs = [] - if torch.jit.is_scripting() or torch.jit.is_tracing(): - feature_masks = [1.0] * len(self.encoder_dim) - else: - feature_masks = self.get_feature_masks(x) - - chunk_size, left_context_chunks = self.get_chunk_info() - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - # Not support exporting a model for simulating streaming decoding - attn_mask = None - else: - attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) - - for i, module in enumerate(self.encoders): - ds = self.downsampling_factor[i] - x = convert_num_channels(x, self.encoder_dim[i]) - - x = module( - x, - chunk_size=chunk_size, - feature_mask=feature_masks[i], - src_key_padding_mask=( - None - if src_key_padding_mask is None - else src_key_padding_mask[..., ::ds] - ), - attn_mask=attn_mask, - ) - outputs.append(x) - - # if the last output has the largest dimension, x will be unchanged, - # it will be the same as outputs[-1]. Otherwise it will be concatenated - # from different pieces of 'outputs', taking each dimension from the - # most recent output that has it present. - x = self._get_full_dim_output(outputs) - x = self.downsample_output(x) - # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2, self.output_downsampling_factor - if torch.jit.is_scripting() or torch.jit.is_tracing(): - lengths = (x_lens + 1) // 2 - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - lengths = (x_lens + 1) // 2 - - return x, lengths - - def _get_attn_mask( - self, x: Tensor, chunk_size: int, left_context_chunks: int - ) -> Optional[Tensor]: - """ - Return None if chunk_size == -1, else return attention mask of shape - (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True - means a masked position. - Args: - x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). - chunk_size: chunk size, must divide - """ - if chunk_size <= 0: - return None - assert all(chunk_size % d == 0 for d in self.downsampling_factor) - if left_context_chunks >= 0: - num_encoders = len(self.encoder_dim) - assert all( - chunk_size * left_context_chunks - >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] - for i in range(num_encoders) - ) - else: - left_context_chunks = 1000000 - - seq_len = x.shape[0] - - # t is frame index, shape (seq_len,) - t = torch.arange(seq_len, dtype=torch.int32, device=x.device) - # c is chunk index for each frame, shape (seq_len,) - if torch.jit.is_scripting() or torch.jit.is_tracing(): - c = t // chunk_size - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - c = t // chunk_size - src_c = c - tgt_c = c.unsqueeze(-1) - - attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks) - if __name__ == "__main__": - logging.info(f"attn_mask = {attn_mask}") - return attn_mask - - def _get_full_dim_output(self, outputs: List[Tensor]): - num_encoders = len(self.encoder_dim) - assert len(outputs) == num_encoders - output_dim = max(self.encoder_dim) - output_pieces = [outputs[-1]] - cur_dim = self.encoder_dim[-1] - for i in range(num_encoders - 2, -1, -1): - d = self.encoder_dim[i] - if d > cur_dim: - this_output = outputs[i] - output_pieces.append(this_output[..., cur_dim:d]) - cur_dim = d - assert cur_dim == output_dim - return torch.cat(output_pieces, dim=-1) - - def streaming_forward( - self, - x: Tensor, - x_lens: Tensor, - states: List[Tensor], - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Args: - x: - The input tensor. Its shape is (seq_len, batch_size, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - states: list of cached tensors of all encoder layers. For layer-i, - states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, - cached_conv1, cached_conv2). - src_key_padding_mask: - The mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - Returns: - Return a tuple containing 2 tensors: - - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - - updated states - """ - outputs = [] - new_states = [] - layer_offset = 0 - - for i, module in enumerate(self.encoders): - num_layers = module.num_layers - ds = self.downsampling_factor[i] - x = convert_num_channels(x, self.encoder_dim[i]) - - x, new_layer_states = module.streaming_forward( - x, - states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], - left_context_len=self.left_context_frames[0] // ds, - src_key_padding_mask=src_key_padding_mask[..., ::ds], - ) - layer_offset += num_layers - outputs.append(x) - new_states += new_layer_states - - # if the last output has the largest dimension, x will be unchanged, - # it will be the same as outputs[-1]. Otherwise it will be concatenated - # from different pieces of 'outputs', taking each dimension from the - # most recent output that has it present. - x = self._get_full_dim_output(outputs) - x = self.downsample_output(x) - # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2 - if torch.jit.is_scripting() or torch.jit.is_tracing(): - lengths = (x_lens + 1) // 2 - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - lengths = (x_lens + 1) // 2 - - return x, lengths, new_states - - @torch.jit.export - def get_init_states( - self, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), - ) -> List[Tensor]: - """Get initial states. - - A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - """ - states = [] - for i, module in enumerate(self.encoders): - num_layers = module.num_layers - embed_dim = self.encoder_dim[i] - ds = self.downsampling_factor[i] - num_heads = self.num_heads[i] - key_dim = self.query_head_dim[i] * num_heads - value_dim = self.value_head_dim[i] * num_heads - downsample_left = self.left_context_frames[0] // ds - nonlin_attn_head_dim = 3 * embed_dim // 4 - conv_left_pad = self.cnn_module_kernel[i] // 2 - for layer in range(num_layers): - cached_key = torch.zeros(downsample_left, batch_size, key_dim).to( - device - ) - cached_nonlin_attn = torch.zeros( - 1, batch_size, downsample_left, nonlin_attn_head_dim - ).to(device) - cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to( - device - ) - cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to( - device - ) - cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( - device - ) - cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( - device - ) - states += [ - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ] - - return states - - -def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: - return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x) - - -def _balancer_schedule(min_prob: float): - return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) - - -class Zipformer2EncoderLayer(nn.Module): +class Zipformer2EncoderLayer(torch.nn.Module): """ - Args: - embed_dim: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - feedforward_dim: the dimension of the feedforward network model (required). - dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module (default=31). - - Examples:: - >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) + Zipformer2EncoderLayer module, the basic block of Zipformer2Encoder encoder stack. """ + # pylint: disable=too-many-instance-attributes def __init__( self, @@ -572,390 +775,182 @@ def __init__( pos_head_dim: int, value_head_dim: int, feedforward_dim: int, - dropout: FloatLike = 0.1, - cnn_module_kernel: int = 31, - causal: bool = False, - attention_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 - ), - conv_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 - ), - const_attention_rate: FloatLike = ScheduledFloat( - (0.0, 0.25), (4000.0, 0.025), default=0 - ), - ff2_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) - ), - ff3_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) - ), - bypass_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.5), (4000.0, 0.02), default=0 - ), + cnn_module_kernel: int, + left_context_len: int, + right_context_len: int, + device: torch.device, ) -> None: - super(Zipformer2EncoderLayer, self).__init__() - self.embed_dim = embed_dim - - # self.bypass implements layer skipping as well as bypass; see its default values. - self.bypass = BypassModule( - embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0 - ) - # bypass_mid is bypass used in the middle of the layer. - self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) + """ + Zipformer2EncoderLayer initialization. + + Parameters + ---------- + embed_dim : int + The input and output embedding dimension. The number of channels is the same for input + and output of this module. + pos_dim : int + The dimension of the relative positional embedding. + num_heads : int + The number of heads for attention weights and self-attention. + query_head_dim : int + The dimension of the query and key per attention head in attention weights. + pos_head_dim: int + The dimension of the projected positional encoding + per attention head in attention weights. + value_head_dim : int + The dimension of the value per attention head in self-attention. + feedforward_dim : int + The hidden dimension of the feedforward modules. + cnn_module_kernel : int + The kernel size of the convolution modules. + left_context_len : int + The module left context number of subsampled frames. + right_context_len : int + The module right context number of subsampled frames. + Used to update attention and convolution left caches. + device : torch.device + The device used to store the layer weights. Should be + either torch.device("cpu") or torch.device("cuda"). + """ + # pylint: disable=too-many-arguments - # skip probability for dynamic modules (meaning: anything but feedforward). - self.attention_skip_rate = copy.deepcopy(attention_skip_rate) - # an additional skip probability that applies to ConvModule to stop it from - # contributing too much early on. - self.conv_skip_rate = copy.deepcopy(conv_skip_rate) + super().__init__() - # ff2_skip_rate is to prevent the ff2 module from having output that's too big - # compared to its residual. - self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) - self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate) + self.left_context_len = left_context_len - self.const_attention_rate = copy.deepcopy(const_attention_rate) + # self.bypass implements the whole layer skipping. + self.bypass = BypassModule(embed_dim, device) + # bypass_mid is bypass used in the middle of the layer. + self.bypass_mid = BypassModule(embed_dim, device) self.self_attn_weights = RelPositionMultiheadAttentionWeights( - embed_dim, - pos_dim=pos_dim, - num_heads=num_heads, - query_head_dim=query_head_dim, - pos_head_dim=pos_head_dim, - dropout=0.0, + embed_dim, pos_dim, num_heads, query_head_dim, pos_head_dim, right_context_len, device, ) - self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim) - - self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim) - - self.feed_forward1 = FeedforwardModule( - embed_dim, (feedforward_dim * 3) // 4, dropout + self.self_attn1 = SelfAttention( + embed_dim, num_heads, value_head_dim, right_context_len, device, ) - - self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) - - self.feed_forward3 = FeedforwardModule( - embed_dim, (feedforward_dim * 5) // 4, dropout + self.self_attn2 = SelfAttention( + embed_dim, num_heads, value_head_dim, right_context_len, device, ) self.nonlin_attention = NonlinAttention( - embed_dim, hidden_channels=3 * embed_dim // 4 + embed_dim, 3 * embed_dim // 4, right_context_len, device, ) + self.feed_forward1 = FeedforwardModule(embed_dim, (feedforward_dim * 3) // 4, device) + self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, device) + self.feed_forward3 = FeedforwardModule(embed_dim, (feedforward_dim * 5) // 4, device) + self.conv_module1 = ConvolutionModule( - embed_dim, cnn_module_kernel, causal=causal + embed_dim, cnn_module_kernel, right_context_len, device, ) - self.conv_module2 = ConvolutionModule( - embed_dim, cnn_module_kernel, causal=causal - ) - - # TODO: remove it - self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) - - self.norm = BiasNorm(embed_dim) - - self.balancer1 = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - min_abs=0.2, - max_abs=4.0, - ) - - # balancer for output of NonlinAttentionModule - self.balancer_na = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), - prob=0.05, # out of concern for memory usage - ) - - # balancer for output of feedforward2, prevent it from staying too - # small. give this a very small probability, even at the start of - # training, it's to fix a rare problem and it's OK to fix it slowly. - self.balancer_ff2 = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), - max_abs=2.0, - prob=0.05, - ) - - self.balancer_ff3 = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), - max_abs=4.0, - prob=0.05, + embed_dim, cnn_module_kernel, right_context_len, device, ) - self.whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(4.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - self.balancer2 = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - min_abs=0.1, - max_abs=4.0, - ) - - def get_sequence_dropout_mask( - self, x: Tensor, dropout_rate: float - ) -> Optional[Tensor]: - if ( - dropout_rate == 0.0 - or not self.training - or torch.jit.is_scripting() - or torch.jit.is_tracing() - ): - return None - batch_size = x.shape[1] - mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) - return mask - - def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor: - """ - Apply sequence-level dropout to x. - x shape: (seq_len, batch_size, embed_dim) - """ - dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate) - if dropout_mask is None: - return x - else: - return x * dropout_mask + self.norm = BiasNorm(embed_dim, device) def forward( self, - src: Tensor, - pos_emb: Tensor, - chunk_size: int = -1, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: + src: torch.Tensor, + pos_emb: torch.Tensor, + left_cached_key: torch.Tensor, + left_cached_nonlin_attn: torch.Tensor, + left_cached_val_1: torch.Tensor, + left_cached_val_2: torch.Tensor, + left_cached_conv_1: torch.Tensor, + left_cached_conv_2: torch.Tensor, + src_key_padding_mask: torch.Tensor, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: """ - Pass the input through the encoder layer. - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) - chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - - Returns: - A tensor which has the same shape as src + Does a forward pass of the Zipformer2EncoderLayer module. Returns an output tensor with the + same shape as input, and updated left caches for multiple attention and convolution + mudules. + + Parameters + ---------- + src : torch.Tensor[torch.float32] + The input float tensor of shape (1, seq_len, embed_dim). The module input. + pos_emb : torch.Tensor[torch.float32] + A positional embedding tensor + of shape (1, left_context_len + 2 * seq_len - 1, pos_dim). + left_cached_key : torch.Tensor[torch.float32] + A cached attention key tensor of the left context + of shape (1, left_context_len, query_dim). + left_cached_nonlin_attn : torch.Tensor[torch.float32] + A left context cached attention tensor for the non-linear attention module + of shape (1, left_context_len, head_dim). + left_cached_val_1 : torch.Tensor[torch.float32] + A cached left context tensor for the first self-attention module + of shape (1, left_context_len, value_dim). + left_cached_val_2 : torch.Tensor[torch.float32] + A cached left context for the second self-attention module + of shape (1, left_context_len, value_dim). + left_cached_conv_1 : torch.Tensor[torch.float32] + A cached left context tensor for the first convolution module + of shape (1, embed_dim, left_cache_len). + left_cached_conv_2 : torch.Tensor[torch.float32] + A cached left context tensor for the second convolution module + of shape (1, embed_dim, left_cache_len). + src_key_padding_mask : torch.Tensor[torch.bool] + A boolean tensor of shape (1, seq_len_2). Positions that are True in this mask will be + ignored as sources in the attention weighting and convolution modules. + + Returns + ------- + tuple[ + torch.Tensor[torch.float32], + torch.Tensor[torch.float32], + torch.Tensor[torch.float32], + torch.Tensor[torch.float32], + torch.Tensor[torch.float32], + torch.Tensor[torch.float32], + torch.Tensor[torch.float32], + ] + A tuple of seven float tensors: + - The module output of shape (1, seq_len, embed_dim). + A tensor with the same shape as input. + - The updated left context cached attention key tensor + of shape (1, left_context_len, query_dim). + - The updated left context cached attention tensor for the non-linear attention module + of shape (1, left_context_len, head_dim). + - The updated cached left context for the first self-attention module + of shape (1, left_context_len, value_dim). + - The updated cached left context for the second self-attention module + of shape (1, left_context_len, value_dim). + - The updated cached left context for the first convolution module + of shape (1, embed_dim, left_cache_len). + - The updated cached left context for the second convolution module + of shape (1, embed_dim, left_cache_len). """ - src_orig = src - - # dropout rate for non-feedforward submodules - if torch.jit.is_scripting() or torch.jit.is_tracing(): - attention_skip_rate = 0.0 - else: - attention_skip_rate = ( - float(self.attention_skip_rate) if self.training else 0.0 - ) - - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights = self.self_attn_weights( - src, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - ) - - src = src + self.feed_forward1(src) - - self_attn_dropout_mask = self.get_sequence_dropout_mask( - src, attention_skip_rate - ) - - selected_attn_weights = attn_weights[0:1] - if torch.jit.is_scripting() or torch.jit.is_tracing(): - pass - elif self.training and random.random() < float(self.const_attention_rate): - # Make attention weights constant. The intention is to - # encourage these modules to do something similar to an - # averaging-over-time operation. - # only need the mask, can just use the 1st one and expand later - selected_attn_weights = selected_attn_weights[0:1] - selected_attn_weights = (selected_attn_weights > 0.0).to( - selected_attn_weights.dtype - ) - selected_attn_weights = selected_attn_weights * ( - 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True) - ) - na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights)) - - src = src + ( - na if self_attn_dropout_mask is None else na * self_attn_dropout_mask - ) - - self_attn = self.self_attn1(src, attn_weights) - - src = src + ( - self_attn - if self_attn_dropout_mask is None - else self_attn * self_attn_dropout_mask - ) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - conv_skip_rate = 0.0 - else: - conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout( - self.conv_module1( - src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask - ), - conv_skip_rate, - ) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - ff2_skip_rate = 0.0 - else: - ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout( - self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate - ) - - # bypass in the middle of the layer. - src = self.bypass_mid(src_orig, src) - - self_attn = self.self_attn2(src, attn_weights) - - src = src + ( - self_attn - if self_attn_dropout_mask is None - else self_attn * self_attn_dropout_mask - ) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - conv_skip_rate = 0.0 - else: - conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout( - self.conv_module2( - src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask - ), - conv_skip_rate, - ) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - ff3_skip_rate = 0.0 - else: - ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout( - self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate - ) - - src = self.balancer1(src) - src = self.norm(src) - - src = self.bypass(src_orig, src) - - src = self.balancer2(src) - src = self.whiten(src) - - return src - - def streaming_forward( - self, - src: Tensor, - pos_emb: Tensor, - cached_key: Tensor, - cached_nonlin_attn: Tensor, - cached_val1: Tensor, - cached_val2: Tensor, - cached_conv1: Tensor, - cached_conv2: Tensor, - left_context_len: int, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - """Pass the input through the encoder layer in streaming forward mode. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or - (batch_size, left_context_len+2*seq_len-1, pos_emb_dim) - cached_key: cached attention key tensor of left context, - of shape (left_context_len, batch_size, key_dim) - cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape - (num_heads, batch_size, left_context_len, head_dim) - cached_val1: cached left context for the first attention module, - of shape (left_context_len, batch_size, value_dim) - cached_val2: cached left context for the second attention module, - of shape (left_context_len, batch_size, value_dim) - cached_conv1: cached left context for the first convolution module, - of shape (batch_size, channels, left_pad) - cached_conv2: cached left context for the second convolution module, - of shape (batch_size, channels, left_pad) - left_context_len: number of left context frames. - src_key_padding_mask: the mask for padding, of shape - (batch_size, left_context_len + seq_len); True means masked position. - May be None. - - Returns: - - x, with the same shape as src - - updated cached_key - - updated cached_nonlin_attn - - updated cached_val1 - - updated cached_val2 - - updated cached_conv1 - - updated cached_conv2 - """ src_orig = src - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights, cached_key = self.self_attn_weights.streaming_forward( - src, - pos_emb=pos_emb, - cached_key=cached_key, - left_context_len=left_context_len, - key_padding_mask=src_key_padding_mask, + # attn_weights: (1, num_heads, seq_len, seq_len_2) + attn_weights, left_cached_key = self.self_attn_weights( + src, pos_emb, left_cached_key, src_key_padding_mask, ) - src = src + self.feed_forward1(src) - na, cached_nonlin_attn = self.nonlin_attention.streaming_forward( - src, - attn_weights[0:1], - cached_x=cached_nonlin_attn, - left_context_len=left_context_len, + na, left_cached_nonlin_attn = self.nonlin_attention( + src, attn_weights[:, 0], left_cached_nonlin_attn, ) src = src + na - self_attn, cached_val1 = self.self_attn1.streaming_forward( - src, - attn_weights=attn_weights, - cached_val=cached_val1, - left_context_len=left_context_len, - ) + self_attn, left_cached_val_1 = self.self_attn1(src, attn_weights, left_cached_val_1) src = src + self_attn - src_conv, cached_conv1 = self.conv_module1.streaming_forward( - src, - cache=cached_conv1, - src_key_padding_mask=src_key_padding_mask[:, left_context_len:], + src_conv, left_cached_conv_1 = self.conv_module1( + src, left_cached_conv_1, src_key_padding_mask[:, self.left_context_len:], ) src = src + src_conv @@ -964,571 +959,536 @@ def streaming_forward( # bypass in the middle of the layer. src = self.bypass_mid(src_orig, src) - self_attn, cached_val2 = self.self_attn2.streaming_forward( - src, - attn_weights=attn_weights, - cached_val=cached_val2, - left_context_len=left_context_len, - ) + self_attn, left_cached_val_2 = self.self_attn2(src, attn_weights, left_cached_val_2) src = src + self_attn - src_conv, cached_conv2 = self.conv_module2.streaming_forward( - src, - cache=cached_conv2, - src_key_padding_mask=src_key_padding_mask[:, left_context_len:], + src_conv, left_cached_conv_2 = self.conv_module2( + src, left_cached_conv_2, src_key_padding_mask[:, self.left_context_len:], ) src = src + src_conv src = src + self.feed_forward3(src) src = self.norm(src) - src = self.bypass(src_orig, src) return ( src, - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, + left_cached_key, + left_cached_nonlin_attn, + left_cached_val_1, + left_cached_val_2, + left_cached_conv_1, + left_cached_conv_2, ) -class Zipformer2Encoder(nn.Module): - r"""Zipformer2Encoder is a stack of N encoder layers - - Args: - encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - pos_dim: the dimension for the relative positional encoding - - Examples:: - >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) - >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> out = zipformer_encoder(src) +class Zipformer2Encoder(torch.nn.Module): + """ + Zipformer2Encoder is a stack of Zipformer2EncoderLayer modules. """ def __init__( self, - encoder_layer: nn.Module, + encoder_layer: torch.nn.Module, num_layers: int, + embed_dim: int, pos_dim: int, - dropout: float, - warmup_begin: float, - warmup_end: float, - initial_layerdrop_rate: float = 0.5, - final_layerdrop_rate: float = 0.05, + pos_max_len: int, + downsample: int, + device: torch.device, ) -> None: + """ + Zipformer2Encoder initialization. + + Parameters + ---------- + encoder_layer : torch.nn.Module + An instance of the Zipformer2EncoderLayer class. + num_layers : int + The number of encoder Zipformer2EncoderLayer modules in the stack. + embed_dim : int + The input and output embedding dimension. The embedding dimension is the same for + input and output of this module. + pos_dim : int + The dimension for the relative positional embedding. + downsample : int + The downsampling factor of the module, the input will be downsampled in the beginning + and upsampled back at the end. + device : torch.device + The device used to store the layer weights. Should be + either torch.device("cpu") or torch.device("cuda"). + """ + super().__init__() + + self.num_layers = num_layers + self.downsample = SimpleDownsample(downsample, device) self.encoder_pos = CompactRelPositionalEncoding( - pos_dim, dropout_rate=0.15, length_factor=1.0 + pos_dim, pos_max_len, encoder_layer.left_context_len, device, ) - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] + self.layers = torch.nn.ModuleList( + [copy.deepcopy(encoder_layer) for _ in range(num_layers)], ) - self.num_layers = num_layers - - assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) - - delta = (1.0 / num_layers) * (warmup_end - warmup_begin) - cur_begin = warmup_begin # interpreted as a training batch index - for i in range(num_layers): - cur_end = cur_begin + delta - self.layers[i].bypass.skip_rate = ScheduledFloat( - (cur_begin, initial_layerdrop_rate), - (cur_end, final_layerdrop_rate), - default=0.0, - ) - cur_begin = cur_end + self.upsample = SimpleUpsample(downsample) + self.out_combiner = BypassModule(embed_dim, device) def forward( self, - src: Tensor, - chunk_size: int = -1, - feature_mask: Union[Tensor, float] = 1.0, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - - Returns: a Tensor with the same shape as src. + src: torch.Tensor, + left_cached_keys: torch.Tensor, + left_cached_nonlin_attentions: torch.Tensor, + left_cached_values_1: torch.Tensor, + left_cached_values_2: torch.Tensor, + left_cached_convolutions_1: torch.Tensor, + left_cached_convolutions_2: torch.Tensor, + src_key_padding_mask: torch.Tensor, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + """ + Does a forward pass of the Zipformer2Encoder module. Returns an output tensor with the same + shape as input, and updated left caches for multiple attention and convolution mudules. + + Parameters + ---------- + src : torch.Tensor[torch.float32] + The input float tensor of shape (1, seq_len, embed_dim). The module input. + left_cached_keys : torch.Tensor[torch.float32] + A cached attention key tensor of the left context for each Zipformer2EncoderLayer. + A tensor is of shape (num_layers, 1, left_context_len, query_dim). + left_cached_nonlin_attentions : torch.Tensor[torch.float32] + A left context cached attention tensor for the non-linear attention module of each + Zipformer2EncoderLayer. A tensor is + of shape (num_layers, 1, left_context_len, head_dim). + left_cached_values_1 : torch.Tensor[torch.float32] + A cached left context tensor for the first self-attention module of each + Zipformer2EncoderLayer. A tensor is + of shape (num_layers, 1, left_context_len, value_dim). + left_cached_values_2 : torch.Tensor[torch.float32] + A cached left context tensor for the second self-attention module of each + Zipformer2EncoderLayer. A tensor is + of shape (num_layers, 1, left_context_len, value_dim). + left_cached_convolutions_1 : torch.Tensor[torch.float32] + A cached left context tensor for the first convolution module of each + Zipformer2EncoderLayer. A tensor is + of shape (num_layers, 1, embed_dim, left_cache_len). + left_cached_convolutions_2 : torch.Tensor[torch.float32] + A cached left context tensor for the second convolution module of each + Zipformer2EncoderLayer. A tensor is + of shape (num_layers, 1, embed_dim, left_cache_len). + src_key_padding_mask : torch.Tensor[torch.bool] + A boolean tensor of shape (1, seq_len_2). Positions that are True in this mask will be + ignored as sources in the attention weighting and convolution modules. + + Returns + ------- + tuple[ + torch.Tensor[torch.float32], + torch.Tensor[torch.float32], + torch.Tensor[torch.float32], + torch.Tensor[torch.float32], + torch.Tensor[torch.float32], + torch.Tensor[torch.float32], + torch.Tensor[torch.float32], + ] + A tuple of seven float tensors: + - The module output of shape (1, seq_len, embed_dim). + A tensor with the same shape as input. + - The updated cached attention key tensor of the left context for each + Zipformer2EncoderLayer. A tensor is + of shape (num_layers, 1, left_context_len, query_dim). + - The updated left context cached attention tensor for the non-linear attention module + of each Zipformer2EncoderLayer. A tensor is + of shape (num_layers, 1, left_context_len, head_dim). + - The updated cached left context tensor for the first self-attention module of each + Zipformer2EncoderLayer. A tensor is + of shape (num_layers, 1, left_context_len, value_dim). + - The updated cached left context tensor for the second self-attention module of each + Zipformer2EncoderLayer. A tensor is + of shape (num_layers, 1, left_context_len, value_dim). + - The updated cached left context tensor for the first convolution module of each + Zipformer2EncoderLayer. A tensor is + of shape (num_layers, 1, embed_dim, left_cache_len). + - The updated cached left context tensor for the second convolution module of each + Zipformer2EncoderLayer. A tensor is + of shape (num_layers, 1, embed_dim, left_cache_len). """ + # pylint: disable=too-many-locals + + src_orig = src + src = self.downsample(src) pos_emb = self.encoder_pos(src) - output = src - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - output = output * feature_mask + new_left_cached_keys = torch.empty( + left_cached_keys.shape, dtype=torch.float32, device=left_cached_keys.device, + ) + new_left_cached_nonlin_attentions = torch.empty( + left_cached_nonlin_attentions.shape, + dtype=torch.float32, + device=left_cached_nonlin_attentions.device, + ) + new_left_cached_values_1 = torch.empty( + left_cached_values_1.shape, dtype=torch.float32, device=left_cached_values_1.device, + ) + new_left_cached_values_2 = torch.empty( + left_cached_values_2.shape, dtype=torch.float32, device=left_cached_values_2.device, + ) + new_left_cached_convolutions_1 = torch.empty( + left_cached_convolutions_1.shape, + dtype=torch.float32, + device=left_cached_convolutions_1.device, + ) + new_left_cached_convolutions_2 = torch.empty( + left_cached_convolutions_2.shape, + dtype=torch.float32, + device=left_cached_convolutions_2.device, + ) for i, mod in enumerate(self.layers): - output = mod( - output, + ( + src, + new_left_cached_keys[i], + new_left_cached_nonlin_attentions[i], + new_left_cached_values_1[i], + new_left_cached_values_2[i], + new_left_cached_convolutions_1[i], + new_left_cached_convolutions_2[i], + ) = mod( + src, pos_emb, - chunk_size=chunk_size, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, + left_cached_keys[i], + left_cached_nonlin_attentions[i], + left_cached_values_1[i], + left_cached_values_2[i], + left_cached_convolutions_1[i], + left_cached_convolutions_2[i], + src_key_padding_mask, ) - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - output = output * feature_mask + src = self.upsample(src) - return output + # Remove any extra frames that are not a multiple of downsample_factor + src = src[:, : src_orig.size(1)] + src = self.out_combiner(src_orig, src) - def streaming_forward( - self, - src: Tensor, - states: List[Tensor], - left_context_len: int, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, List[Tensor]]: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is - (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - left_context_len: Number of left context frames. - src_key_padding_mask: the mask for padding, of shape - (batch_size, left_context_len + seq_len); True means masked position. - May be None. - - Returns: - - output, a Tensor with the same shape as src. - - updated states - """ - pos_emb = self.encoder_pos(src, left_context_len) - output = src - - new_states = [] - for i, mod in enumerate(self.layers): - ( - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ) = states[i * 6 : (i + 1) * 6] - ( - output, - new_cached_key, - new_cached_nonlin_attn, - new_cached_val1, - new_cached_val2, - new_cached_conv1, - new_cached_conv2, - ) = mod.streaming_forward( - output, - pos_emb, - cached_key=cached_key, - cached_nonlin_attn=cached_nonlin_attn, - cached_val1=cached_val1, - cached_val2=cached_val2, - cached_conv1=cached_conv1, - cached_conv2=cached_conv2, - left_context_len=left_context_len, - src_key_padding_mask=src_key_padding_mask, - ) - new_states += [ - new_cached_key, - new_cached_nonlin_attn, - new_cached_val1, - new_cached_val2, - new_cached_conv1, - new_cached_conv2, - ] - - return output, new_states + return ( + src, + new_left_cached_keys, + new_left_cached_nonlin_attentions, + new_left_cached_values_1, + new_left_cached_values_2, + new_left_cached_convolutions_1, + new_left_cached_convolutions_2, + ) -class BypassModule(nn.Module): +class BypassModule(torch.nn.Module): """ - An nn.Module that implements a learnable bypass scale, and also randomized per-sequence - layer-skipping. The bypass is limited during early stages of training to be close to - "straight-through", i.e. to not do the bypass operation much initially, in order to - force all the modules to learn something. + A bypass module that implements a learnable bypass scale for each input channel. """ - def __init__( - self, - embed_dim: int, - skip_rate: FloatLike = 0.0, - straight_through_rate: FloatLike = 0.0, - scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), - scale_max: FloatLike = 1.0, - ): - super().__init__() - self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) - self.skip_rate = copy.deepcopy(skip_rate) - self.straight_through_rate = copy.deepcopy(straight_through_rate) - self.scale_min = copy.deepcopy(scale_min) - self.scale_max = copy.deepcopy(scale_max) - - def _get_bypass_scale(self, batch_size: int): - # returns bypass-scale of shape (num_channels,), - # or (batch_size, num_channels,). This is actually the - # scale on the non-residual term, so 0 corresponds to bypassing - # this module. - if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: - return self.bypass_scale - else: - ans = limit_param_value( - self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max) - ) - skip_rate = float(self.skip_rate) - if skip_rate != 0.0: - mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate - ans = ans * mask - # now ans is of shape (batch_size, num_channels), and is zero for sequences - # on which we have randomly chosen to do layer-skipping. - straight_through_rate = float(self.straight_through_rate) - if straight_through_rate != 0.0: - mask = ( - torch.rand((batch_size, 1), device=ans.device) - < straight_through_rate - ) - ans = torch.maximum(ans, mask.to(ans.dtype)) - return ans - - def forward(self, src_orig: Tensor, src: Tensor): - """ - Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) - Returns: something with the same shape as src and src_orig + def __init__(self, num_channels: int, device: torch.device) -> None: """ - bypass_scale = self._get_bypass_scale(src.shape[1]) - return src_orig + (src - src_orig) * bypass_scale - - -class DownsampledZipformer2Encoder(nn.Module): - r""" - DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate, - after convolutional downsampling, and then upsampled again at the output, and combined - with the origin input, so that the output has the same shape as the input. - """ - - def __init__( - self, - encoder: nn.Module, - dim: int, - downsample: int, - dropout: FloatLike, - causal: bool, - ): - super(DownsampledZipformer2Encoder, self).__init__() - self.downsample_factor = downsample - self.downsample = SimpleDownsample(dim, downsample, dropout, causal) - self.num_layers = encoder.num_layers - self.encoder = encoder - self.upsample = SimpleUpsample(dim, downsample) - self.out_combiner = BypassModule(dim, straight_through_rate=0) - - def forward( - self, - src: Tensor, - chunk_size: int = -1, - feature_mask: Union[Tensor, float] = 1.0, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Downsample, go through encoder, upsample. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - - Returns: a Tensor with the same shape as src. + BypassModule initialization. + + Parameters + ---------- + num_channels : int + The number of input channels, corresponds to the number of learnable bypass scales. + device : torch.device + The device used to store the layer weights. Should be + either torch.device("cpu") or torch.device("cuda"). """ - src_orig = src - src = self.downsample(src) - ds = self.downsample_factor - if attn_mask is not None: - attn_mask = attn_mask[::ds, ::ds] - src = self.encoder( - src, - chunk_size=chunk_size // ds, - feature_mask=feature_mask, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, + super().__init__() + self.bypass_scale = torch.nn.Parameter( + torch.ones(num_channels, dtype=torch.float32, device=device), ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[: src_orig.shape[0]] - - return self.out_combiner(src_orig, src) - def streaming_forward( - self, - src: Tensor, - states: List[Tensor], - left_context_len: int, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, List[Tensor]]: - r"""Downsample, go through encoder, upsample, in streaming forward mode. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is - (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - left_context_len: Number of left context frames. - src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len+seq_len); - True means masked position. May be None. - - Returns: - - output, a Tensor with the same shape as src. - - updated states + def forward(self, x_early: torch.Tensor, x_later: torch.Tensor) -> torch.Tensor: + """ + Does a forward pass of the BypassModule module. + + Parameters + ---------- + x_early : torch.Tensor[torch.float32] + The input float tensor of shape (1, seq_len, num_channels). + The module input that will be propagated with (1 - self.bypass_scale) weight. + x_later : torch.Tensor[torch.float32] + An input float tensor of shape (1, seq_len, num_channels). + The module input that will be propagated with self.bypass_scale weight. + + Returns + ------- + torch.Tensor[torch.float32] + A float tensor of shape (1, seq_len, num_channels). The shape is the same for x_early + and x_later. The output of the module is x_early bypassed and added to x_later. """ - src_orig = src - src = self.downsample(src) - - src, new_states = self.encoder.streaming_forward( - src, - states=states, - left_context_len=left_context_len, - src_key_padding_mask=src_key_padding_mask, - ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[: src_orig.shape[0]] - return self.out_combiner(src_orig, src), new_states + # It's just a slightly more efficient implementation of + # (1.0 - self.bypass_scale) * x_early + self.bypass_scale * x_later + return x_early + (x_later - x_early) * self.bypass_scale class SimpleDownsample(torch.nn.Module): """ - Does downsampling with attention, by weighted sum, and a projection.. + A downsample layer, does downsampling by weighted sum aggregation. """ - def __init__( - self, channels: int, downsample: int, dropout: FloatLike, causal: bool - ): - super(SimpleDownsample, self).__init__() - - self.causal = causal - self.bias = nn.Parameter(torch.zeros(downsample)) - - self.name = None # will be set from training code - self.dropout = copy.deepcopy(dropout) - - self.downsample = downsample - - def forward(self, src: Tensor) -> Tensor: + def __init__(self, downsample: int, device: torch.device) -> None: """ - x: (seq_len, batch_size, in_channels) - Returns a tensor of shape - ( (seq_len+downsample-1)//downsample, batch_size, channels) + SimpleDownsample initialization. + + Parameters + ---------- + downsample : int + The module downsampling factor. + device : torch.device + The device used to store the layer weights. + Either torch.device("cpu") or torch.device("cuda"). """ - (seq_len, batch_size, in_channels) = src.shape - ds = self.downsample - d_seq_len = (seq_len + ds - 1) // ds - # Pad to an exact multiple of self.downsample - # right-pad src, repeating the last element. - pad = d_seq_len * ds - seq_len + super().__init__() + self.weights = torch.nn.Parameter( + torch.zeros(downsample, 1, dtype=torch.float32, device=device), + ) - if self.causal and torch.jit.is_tracing(): - assert ( - pad == 0 - ), f"pad should be zero for exporting streaming models. Given {pad}" + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Does a forward pass of the SimpleDownsample module. + + Parameters + ---------- + x : torch.Tensor[torch.float32] + The input float tensor of shape (1, seq_len, num_channels). + The module input that will be downsampled. + + Returns + ------- + torch.Tensor[torch.float32] + A float tensor of shape + (1, (seq_len + downsample - 1) // downsample, num_channels). + The downsampled output of the module. + """ - # If we are exporting a streaming model, then we skip the if statement - if not self.causal or not torch.jit.is_tracing(): - src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) - src = torch.cat((src, src_extra), dim=0) + downsample = self.weights.size(0) + if downsample == 1: + return x - assert src.shape[0] == d_seq_len * ds, (src.shape, d_seq_len, ds) + batch_size, seq_len, in_channels = x.size() # batch_size is 1 + downsampled_seq_len = (seq_len + downsample - 1) // downsample - src = src.reshape(d_seq_len, ds, batch_size, in_channels) + # Pad to an exact multiple of downsample. Right-pad x, repeating the last element. + pad_frames = downsampled_seq_len * downsample - seq_len + if pad_frames > 0: + pad = x[:, seq_len - 1:, :].expand(batch_size, pad_frames, in_channels) + x = torch.cat((x, pad), dim=1) - weights = self.bias.softmax(dim=0) - # weights: (downsample, 1, 1) - weights = weights.unsqueeze(-1).unsqueeze(-1) + # (1, seq_len, in_channels) -> (1, seq_len // downsample, downsample, in_channels) + x = x.reshape(batch_size, downsampled_seq_len, downsample, in_channels) - # ans1 is the first `in_channels` channels of the output - ans = (src * weights).sum(dim=1) + x = torch.sum(x * self.weights, dim=2) - return ans + return x class SimpleUpsample(torch.nn.Module): """ - A very simple form of upsampling that mostly just repeats the input, but - also adds a position-specific bias. + An upsample layer, does upsampling by repeating the input frames. """ - def __init__(self, num_channels: int, upsample: int): - super(SimpleUpsample, self).__init__() + def __init__(self, upsample: int) -> None: + """ + SimpleUpsample initialization. + + Parameters + ---------- + upsample : int + The module upsampling factor. + """ + + super().__init__() self.upsample = upsample - def forward(self, src: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """ - x: (seq_len, batch_size, num_channels) - Returns a tensor of shape - ( (seq_len*upsample), batch_size, num_channels) + Does a forward pass of the SimpleUpsample module. + + Parameters + ---------- + x : torch.Tensor[torch.float32] + The input float tensor of shape (1, seq_len, num_channels). + The module input that will be upsampled. + + Returns + ------- + torch.Tensor[torch.float32] + A float tensor of shape (1, seq_len * upsample, num_channels). + The upsampled output of the module. """ - upsample = self.upsample - (seq_len, batch_size, num_channels) = src.shape - src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) - src = src.reshape(seq_len * upsample, batch_size, num_channels) - return src + + if self.upsample == 1: + return x + + x = torch.repeat_interleave(x, self.upsample, dim=1) + + return x class CompactRelPositionalEncoding(torch.nn.Module): """ - Relative positional encoding module. This version is "compact" meaning it is able to encode - the important information about the relative position in a relatively small number of dimensions. - The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) - make very little difference to the embedding. Such differences were potentially important - when encoding absolute position, but not important when encoding relative position because there - is now no need to compare two large offsets with each other. - - Our embedding works by projecting the interval [-infinity,infinity] to a finite interval - using the atan() function, before doing the Fourier transform of that fixed interval. The - atan() function would compress the "long tails" too small, - making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic - function to compress large offsets to a smaller range before applying atan(). - Scalings are chosen in such a way that the embedding can clearly distinguish individual offsets as long - as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim) - - - Args: - embed_dim: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length: just a heuristic for initialization. - length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives - less weight to small differences of offset near the origin. + Relative positional encoding module. This version is "compact" meaning it is able to encode the + important information about the relative positions in a relatively small number of dimensions. + The goal is to make it so that small differences between large relative offsets + (e.g. 1000 vs. 1001) make very little difference to the embedding. Such differences were + potentially important when encoding absolute position, but not important when encoding relative + position because there is now no need to compare two large offsets with each other. + + This implementation works by projecting the interval [-infinity, infinity] to a finite interval + using the torch.atan() function before doing the fourier transform of that fixed interval. + The torch.atan() function would compress the "long tails" too small, making it hard to + distinguish between different magnitudes of large offsets. To mitigate this a logarithmic + function is used to compress large offsets to a smaller range before applying torch.atan(). + Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets + as long as they are quite close to the origin, e.g. abs(offset) <= sqrt(embedding_dim). """ def __init__( - self, - embed_dim: int, - dropout_rate: FloatLike, - max_len: int = 1000, - length_factor: float = 1.0, + self, embed_dim: int, max_length: int, left_context_len: int, device: torch.device, ) -> None: - """Construct a CompactRelPositionalEncoding object.""" - super(CompactRelPositionalEncoding, self).__init__() + """ + CompactRelPositionalEncoding initialization. + + Parameters + ---------- + embed_dim : int + The positional embedding dimension. + max_length : int + The maximum length of the input that this module will be able to handle after + initialization without positional embeddings expansion. In case of longer input the + positional embeddings will be re-computed to adjust bigger length. + left_context_len : int + Length of cached left context. + device : torch.device + The device used to store the layer positional embeddings. + Should be either torch.device("cpu") or torch.device("cuda"). + """ + + super().__init__() + + if embed_dim % 2 != 0: + raise ValueError( + 'Embedding dimension for CompactRelPositionalEncoding ' + f'should be an even number, but got {embed_dim}.', + ) + self.embed_dim = embed_dim - assert embed_dim % 2 == 0, embed_dim - self.dropout = Dropout2(dropout_rate) - self.pe = None - assert length_factor >= 1.0, length_factor - self.length_factor = length_factor - self.extend_pe(torch.tensor(0.0).expand(max_len)) - - def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: - """Reset the positional encodings.""" - T = x.size(0) + left_context_len - - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(0) >= T * 2 - 1: - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - - # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] - x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) - - freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) - - # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution - # for small time offsets but less resolution for large time offsets. - compression_length = self.embed_dim**0.5 - # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; - # but it does so more slowly than T for large absolute values of T. - # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which - # is important. - x_compressed = ( - compression_length - * x.sign() - * ((x.abs() + compression_length).log() - math.log(compression_length)) - ) + self.left_context_len = left_context_len + self.pos_emb = self.create_pos_emb(max_length, device) - # if self.length_factor == 1.0, then length_scale is chosen so that the - # FFT can exactly separate points close to the origin (T == 0). So this - # part of the formulation is not really heuristic. - # But empirically, for ASR at least, length_factor > 1.0 seems to work better. - length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) + def create_pos_emb(self, max_length: int, device: torch.device) -> torch.Tensor: + """ + Creates a relative positional embeddings based on the maximum length. + This method is used to create positional embeddings with a + sufficiently long temporal axes during module initialization. + We want it to be big enough to avoid getting input x that is longer + than self.pos_emb during inference. On the other hand, we want + to initialize it with the smallest maximum length possible to consume + less memory. + + Parameters + ---------- + max_length : int + The maximum length of the input that can be handeled by this layer. Increasing this + will let to process bigger input (speaking of temporal dimension), but will also + increase the memory consumption. + device : torch.device + The device used to store the positional embeddings. + Should be either torch.device("cpu") or torch.device("cuda"). + + Returns + ------- + torch.Tensor[torch.float32] + A float tensor of shape (2 * max_length - 1, embed_dim). + Relative positional embeddings. + """ - # note for machine implementations: if atan is not available, we can use: - # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) - # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) - x_atan = (x_compressed / length_scale).atan() # results between -pi and pi + # if max_length == 4, the x would contain [-3, -2, -1, 0, 1, 2, 3] + x = torch.arange(-max_length + 1, max_length, dtype=torch.float32, device=device) - cosines = (x_atan * freqs).cos() - sines = (x_atan * freqs).sin() + # Compression length is an arbitrary heuristic, if it is larger we have more resolution for + # small time offsets but less resolution for large time offsets. + compression_length = self.embed_dim**0.5 + + # Compressing x within the next line of code, similarly to uncompressed x, it goes from + # -infinity to infinity as the sequence length goes from -infinity to infinity, but it does + # so more slowly than sequence length for the large absolute values of sequence length. + # The formula is chosen so that d(x_compressed) / dx is equal to 1 around x == 0, + # which is important. + x = compression_length * torch.sign(x) * torch.log(torch.abs(x) / compression_length + 1.0) - pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) - pe[:, 0::2] = cosines - pe[:, 1::2] = sines - pe[:, -1] = 1.0 # for bias. + # results between -pi and pi + x = torch.atan(2.0 * torch.pi * x / self.embed_dim) - self.pe = pe.to(dtype=x.dtype) + freqs = torch.arange(1, self.embed_dim // 2 + 1, dtype=torch.float32, device=device) + x = x.unsqueeze(1) * freqs - def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: - """Create positional encoding. + pos_emb = torch.zeros(x.size(0), self.embed_dim, dtype=torch.float32, device=device) + pos_emb[:, 0::2] = torch.cos(x) + pos_emb[:, 1::2] = torch.sin(x) + pos_emb[:, self.embed_dim - 1] = 1.0 # for bias. - Args: - x (Tensor): Input tensor (time, batch, `*`). - left_context_len: (int): Length of cached left context. + return pos_emb - Returns: - positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). + def forward(self, x: torch.Tensor) -> torch.Tensor: """ - self.extend_pe(x, left_context_len) - x_size_left = x.size(0) + left_context_len - # length of positive side: x.size(0) + left_context_len - # length of negative side: x.size(0) - pos_emb = self.pe[ - self.pe.size(0) // 2 - - x_size_left - + 1 : self.pe.size(0) // 2 # noqa E203 - + x.size(0), - :, - ] - pos_emb = pos_emb.unsqueeze(0) - return self.dropout(pos_emb) + Does a forward pass of the CompactRelPositionalEncoding module. + Returns a relative positional embeddings based on the input x temporal dimension. + + Parameters + ---------- + x : torch.Tensor[torch.float32] + An input float tensor of shape (1, seq_len, embed_dim). + The module input. It's shape will be used to construct relative positional embeddings. + + Returns + ------- + torch.Tensor[torch.float32] + A float tensor of shape (1, self.left_context_len + 2 * seq_len - 1, embed_dim). + Relative positional embeddings. + """ + + if self.pos_emb.size(0) < 2 * (x.size(1) + self.left_context_len) - 1: + self.pos_emb = self.create_pos_emb(x.size(1) + self.left_context_len, x.device) + # Length of negative side: x.size(1) + self.left_context_len. + # Length of positive side: x.size(1). + pos_emb = self.pos_emb[ + self.pos_emb.size(0) // 2 - x.size(1) - self.left_context_len + 1: + self.pos_emb.size(0) // 2 + x.size(1) + ].unsqueeze(0).repeat(x.size(0), 1, 1) -class RelPositionMultiheadAttentionWeights(nn.Module): - r"""Module that computes multi-head attention weights with relative position encoding. - Various other modules consume the resulting attention weights: see, for example, the - SimpleAttention module which allows you to compute conventional attention. + # (1, left_context_len + 2 * seq_len - 1, embed_dim), + # i. e. (batch_size, pos_len, embed_dim). + return pos_emb - This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", - we have to write up the differences. +class RelPositionMultiheadAttentionWeights(torch.nn.Module): + """ + Module that computes multi-head attention weights with relative position encoding. + Various other modules consume the resulting attention weights: see, for example, + the SelfAttention module which allows you to compute conventional self-attention. - Args: - embed_dim: number of channels at the input to this module, e.g. 256 - pos_dim: dimension of the positional encoding vectors, e.g. 128. - num_heads: number of heads to compute weights for, e.g. 8 - query_head_dim: dimension of the query (and key), per head. e.g. 24. - pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. - dropout: dropout probability for attn_output_weights. Default: 0.0. - pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on - any given call to forward(), in training time. + This is a quite heavily modified from: + "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context". """ def __init__( @@ -1538,362 +1498,160 @@ def __init__( num_heads: int, query_head_dim: int, pos_head_dim: int, - dropout: float = 0.0, - pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), + right_context: int, + device: torch.device, ) -> None: + """ + RelPositionMultiheadAttentionWeights initialization. + + Parameters + ---------- + embed_dim : int + The embedding dimension. The number of channels at the input to this module. + pos_dim : int + A dimension of the positional embeddings. + num_heads : int + The number of attention heads to compute weights. + query_head_dim : int + The dimension of the query and key per head. + pos_head_dim : int + The dimension of the projected positional encoding per head. + right_context : int + The module look ahead future context, used to update left + cached attention key correctly. + device : torch.device + The device used to store the layer positional embeddings. Should be + either torch.device("cpu") or torch.device("cuda"). + """ + super().__init__() + self.embed_dim = embed_dim self.num_heads = num_heads self.query_head_dim = query_head_dim self.pos_head_dim = pos_head_dim - self.dropout = dropout - self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) - self.name = None # will be overwritten in training code; for diagnostics. - - key_head_dim = query_head_dim - in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads - - # the initial_scale is supposed to take over the "scaling" factor of - # head_dim ** -0.5 that has been used in previous forms of attention, - # dividing it between the query and key. Note: this module is intended - # to be used with the ScaledAdam optimizer; with most other optimizers, - # it would be necessary to apply the scaling factor in the forward function. - self.in_proj = ScaledLinear( - embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25 - ) + self.right_context = right_context - self.whiten_keys = Whiten( - num_groups=num_heads, - whitening_limit=_whitening_schedule(3.0), - prob=(0.025, 0.25), - grad_scale=0.025, - ) + in_proj_dim = (2 * query_head_dim + pos_head_dim) * num_heads + self.in_proj = torch.nn.Linear(embed_dim, in_proj_dim, device=device) - # add a balancer for the keys that runs with very small probability, and - # tries to enforce that all dimensions have mean around zero. The - # weights produced by this module are invariant to adding a constant to - # the keys, so the derivative of the bias is mathematically zero; but - # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero - # bias because the small numerical roundoff tends to have a non-random - # sign. This module is intended to prevent that. Use a very small - # probability; that should be sufficient to fix the problem. - self.balance_keys = Balancer( - key_head_dim * num_heads, - channel_dim=-1, - min_positive=0.4, - max_positive=0.6, - min_abs=0.0, - max_abs=100.0, - prob=0.025, + # Linear transformation for positional encoding. + self.linear_pos = torch.nn.Linear( + pos_dim, num_heads * pos_head_dim, bias=False, device=device, ) - # linear transformation for positional encoding. - self.linear_pos = ScaledLinear( - pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 - ) - - # the following are for diagnostics only, see --print-diagnostics option - self.copy_pos_query = Identity() - self.copy_query = Identity() - def forward( - self, - x: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - ) -> Tensor: - r""" - Args: - x: input of shape (seq_len, batch_size, embed_dim) - pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) - key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that - are True in this mask will be ignored as sources in the attention weighting. - attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), - interpreted as ([batch_size,] tgt_seq_len, src_seq_len) - saying which positions are allowed to attend to which other positions. - Returns: - a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) - interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + self, x: torch.Tensor, + pos_emb: torch.Tensor, + left_cached_key: torch.Tensor, + key_padding_mask: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: """ - x = self.in_proj(x) - query_head_dim = self.query_head_dim - pos_head_dim = self.pos_head_dim - num_heads = self.num_heads - - seq_len, batch_size, _ = x.shape - - query_dim = query_head_dim * num_heads - - # self-attention - q = x[..., 0:query_dim] - k = x[..., query_dim : 2 * query_dim] - # p is the position-encoding query - p = x[..., 2 * query_dim :] - assert p.shape[-1] == num_heads * pos_head_dim, ( - p.shape[-1], - num_heads, - pos_head_dim, - ) - - q = self.copy_query(q) # for diagnostics only, does nothing. - k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. - p = self.copy_pos_query(p) # for diagnostics only, does nothing. - - q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) - p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) - k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) - - # time1 refers to target, time2 refers to source. - q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) - p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) - k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) - - attn_scores = torch.matmul(q, k) - - use_pos_scores = False - if torch.jit.is_scripting() or torch.jit.is_tracing(): - # We can't put random.random() in the same line - use_pos_scores = True - elif not self.training or random.random() >= float(self.pos_emb_skip_rate): - use_pos_scores = True - - if use_pos_scores: - pos_emb = self.linear_pos(pos_emb) - seq_len2 = 2 * seq_len - 1 - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( - 2, 0, 3, 1 - ) - # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) - - # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_scores = torch.matmul(p, pos_emb) - # the following .as_strided() expression converts the last axis of pos_scores from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - if torch.jit.is_tracing(): - (num_heads, batch_size, time1, n) = pos_scores.shape - rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(seq_len) - rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) - indexes = rows + cols - pos_scores = pos_scores.reshape(-1, n) - pos_scores = torch.gather(pos_scores, dim=1, index=indexes) - pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len) - else: - pos_scores = pos_scores.as_strided( - (num_heads, batch_size, seq_len, seq_len), - ( - pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2) - pos_scores.stride(3), - pos_scores.stride(3), - ), - storage_offset=pos_scores.stride(3) * (seq_len - 1), - ) - - attn_scores = attn_scores + pos_scores - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - pass - elif self.training and random.random() < 0.1: - # This is a harder way of limiting the attention scores to not be - # too large. It incurs a penalty if any of them has an absolute - # value greater than 50.0. this should be outside the normal range - # of the attention scores. We use this mechanism instead of, say, - # something added to the loss function involving the entropy, - # because once the entropy gets very small gradients through the - # softmax can become very small, and we'd get zero derivatives. The - # choices of 1.0e-04 as the scale on the penalty makes this - # mechanism vulnerable to the absolute scale of the loss function, - # but we view this as a failsafe to avoid "implausible" parameter - # values rather than a regularization method that should be active - # under normal circumstances. - attn_scores = penalize_abs_values_gt( - attn_scores, limit=25.0, penalty=1.0e-04, name=self.name - ) - - assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) - - if attn_mask is not None: - assert attn_mask.dtype == torch.bool - # use -1000 to avoid nan's where attn_mask and key_padding_mask make - # all scores zero. It's important that this be large enough that exp(-1000) - # is exactly zero, for reasons related to const_attention_rate, it - # compares the final weights with zero. - attn_scores = attn_scores.masked_fill(attn_mask, -1000) - - if key_padding_mask is not None: - assert key_padding_mask.shape == ( - batch_size, - seq_len, - ), key_padding_mask.shape - attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), - -1000, - ) - - # We use our own version of softmax, defined in scaling.py, which should - # save a little of the memory used in backprop by, if we are in - # automatic mixed precision mode (amp / autocast), by only storing the - # half-precision output for backprop purposes. - attn_weights = softmax(attn_scores, dim=-1) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - pass - elif random.random() < 0.001 and not self.training: - self._print_attn_entropy(attn_weights) - - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) - - return attn_weights - - def streaming_forward( - self, - x: Tensor, - pos_emb: Tensor, - cached_key: Tensor, - left_context_len: int, - key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor]: - r""" - Args: - x: input of shape (seq_len, batch_size, embed_dim) - pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim) - cached_key: cached attention key tensor of left context, - of shape (left_context_len, batch_size, key_dim) - left_context_len: number of left context frames. - key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that - are True in this mask will be ignored as sources in the attention weighting. - - Returns: - - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2), - interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). - - updated cached attention key tensor of left context. + Does a forward pass of the RelPositionMultiheadAttentionWeights module. + Returns attention weights and updated cached attention key tensor of the left context. + + Parameters + ---------- + x : torch.Tensor[torch.float32] + The input float tensor of shape (1, seq_len, embed_dim). The module input. + pos_emb : torch.Tensor[torch.float32] + A positional embedding tensor + of shape (1, left_context_len + 2 * seq_len - 1, pos_dim). + left_cached_key : torch.Tensor[torch.float32] + A cached attention key tensor of the left context + of shape (1, left_context_len, key_dim). + key_padding_mask : torch.Tensor[torch.bool] + A boolean tensor of shape (1, seq_len_2). Positions that are True in this mask will be + ignored as sources in the attention weighting. + + Returns + ------- + tuple[torch.Tensor[torch.float32], torch.Tensor[torch.float32]] + A tuple of two float tensors: + - attention weights, of shape (1, hum_heads, seq_len, seq_len_2) + interpreted as (1, hum_heads, tgt_seq_len, src_seq_len). + - updated cached attention key tensor of the left context + of shape (1, left_context_len, key_dim). """ - x = self.in_proj(x) - query_head_dim = self.query_head_dim - pos_head_dim = self.pos_head_dim - num_heads = self.num_heads + # pylint: disable=too-many-locals - seq_len, batch_size, _ = x.shape + batch_size = x.size(0) # batch size is 1 + seq_len = x.size(1) + x = self.in_proj(x) - query_dim = query_head_dim * num_heads + query_dim = self.query_head_dim * self.num_heads - # self-attention - q = x[..., 0:query_dim] - k = x[..., query_dim : 2 * query_dim] - # p is the position-encoding query - p = x[..., 2 * query_dim :] - assert p.shape[-1] == num_heads * pos_head_dim + # Self-attention. + q = x[:, :, :query_dim] + k = x[:, :, query_dim: 2 * query_dim] + # p is the position-encoding query. + p = x[:, :, 2 * query_dim:] - # Pad cached left contexts - assert cached_key.shape[0] == left_context_len, ( - cached_key.shape[0], - left_context_len, - ) - k = torch.cat([cached_key, k], dim=0) + # Pad key with cached left context. + k = torch.cat((left_cached_key, k), dim=1) # Update cached left contexts - cached_key = k[-left_context_len:, ...] - - # The length of key - k_len = k.shape[0] + seq_len_2 = k.size(1) # left_context_len + seq_len + left_cached_key = k[ + :, + seq_len_2 - self.right_context - left_cached_key.size(1): + seq_len_2 - self.right_context, + ] - q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) - p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) - k = k.reshape(k_len, batch_size, num_heads, query_head_dim) + q = q.reshape(batch_size, seq_len, self.num_heads, self.query_head_dim) + p = p.reshape(batch_size, seq_len, self.num_heads, self.pos_head_dim) + k = k.reshape(batch_size, seq_len_2, self.num_heads, self.query_head_dim) - # time1 refers to target, time2 refers to source. - q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) - p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) - k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + # seq_len refers to target, seq_len_2 refers to source. + q = q.permute(0, 2, 1, 3) # (1, hum_heads, seq_len, query_head_dim) + p = p.permute(0, 2, 1, 3) # (1, hum_heads, seq_len, pos_head_dim) + k = k.permute(0, 2, 3, 1) # (1, hum_heads, key_head_dim, seq_len_2) - attn_scores = torch.matmul(q, k) + attn_scores = torch.matmul(q, k) # (1, hum_heads, seq_len, seq_len_2) + pos_len = pos_emb.size(1) # left_context_len + 2 * seq_len - 1 + # (1, pos_len, num_heads * pos_head_dim) pos_emb = self.linear_pos(pos_emb) - seq_len2 = 2 * seq_len - 1 + left_context_len - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( - 2, 0, 3, 1 - ) - # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + pos_emb = pos_emb.reshape( + batch_size, pos_len, self.num_heads, self.pos_head_dim, + ).permute(0, 2, 3, 1) # (1, hum_heads, pos_head_dim, pos_len) - # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) - # [where seq_len2 represents relative position.] + # (1, hum_heads, seq_len, pos_head_dim) x (1, hum_heads, pos_head_dim, pos_len) -> + # -> (1, hum_heads, seq_len, pos_len) where pos_len represents relative position. pos_scores = torch.matmul(p, pos_emb) - if torch.jit.is_tracing(): - (num_heads, batch_size, time1, n) = pos_scores.shape - rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(k_len) - rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) - indexes = rows + cols - pos_scores = pos_scores.reshape(-1, n) - pos_scores = torch.gather(pos_scores, dim=1, index=indexes) - pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len) - # the following .as_strided() expression converts the last axis of pos_scores from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - else: - pos_scores = pos_scores.as_strided( - (num_heads, batch_size, seq_len, k_len), - ( - pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2) - pos_scores.stride(3), - pos_scores.stride(3), - ), - storage_offset=pos_scores.stride(3) * (seq_len - 1), - ) - - attn_scores = attn_scores + pos_scores - - assert attn_scores.shape == ( - num_heads, - batch_size, - seq_len, - k_len, - ), attn_scores.shape - - if key_padding_mask is not None: - assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape - attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), - -1000, - ) + # Now we need to perform the relative shift of the pos_scores, to do that we need to add + # a column of zeros to the left side of the last dimension and perform the relative shift. + pos_scores_pad = torch.zeros( + pos_scores.size(0), pos_scores.size(1), pos_scores.size(2), 1, + dtype=torch.float32, + device=pos_scores.device, + ) + # (1, hum_heads, seq_len, pos_len + 1) + pos_scores = torch.cat((pos_scores_pad, pos_scores), dim=3) + pos_scores = pos_scores.reshape( + batch_size, self.num_heads, pos_len + 1, seq_len, + ) # (1, hum_heads, pos_len + 1, seq_len) + # Now drop the extra row that had been added over padding and reshape. + pos_scores = pos_scores[:, :, 1:].reshape( + batch_size, self.num_heads, seq_len, pos_len, + ) # (1, hum_heads, seq_len, pos_len) - attn_weights = attn_scores.softmax(dim=-1) + # (1, hum_heads, seq_len, seq_len_2) + attn_scores = attn_scores + pos_scores[:, :, :, : attn_scores.size(3)] - return attn_weights, cached_key + # (1, seq_len_2) -> (1, 1, 1, seq_len_2) to make it broadcastable to attn_scores shape. + key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) - def _print_attn_entropy(self, attn_weights: Tensor): - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + attn_scores = torch.masked_fill(attn_scores, key_padding_mask, -1000.0) + attn_weights = torch.softmax(attn_scores, dim=3) - with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): - attn_weights = attn_weights.to(torch.float32) - attn_weights_entropy = ( - -((attn_weights + 1.0e-20).log() * attn_weights) - .sum(dim=-1) - .mean(dim=(1, 2)) - ) - logging.info( - f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" - ) + return attn_weights, left_cached_key -class SelfAttention(nn.Module): +class SelfAttention(torch.nn.Module): """ - The simplest possible attention module. This one works with already-computed attention - weights, e.g. as computed by RelPositionMultiheadAttentionWeights. - - Args: - embed_dim: the input and output embedding dimension - num_heads: the number of attention heads - value_head_dim: the value dimension per head + The simplest possible attention module. This one works with pre-computed attention weights, + e.g. as computed by RelPositionMultiheadAttentionWeights. """ def __init__( @@ -1901,528 +1659,314 @@ def __init__( embed_dim: int, num_heads: int, value_head_dim: int, + right_context: int, + device: torch.device, ) -> None: - super().__init__() - self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True) - - self.out_proj = ScaledLinear( - num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 - ) - - self.whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(7.5, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - def forward( - self, - x: Tensor, - attn_weights: Tensor, - ) -> Tensor: """ - Args: - x: input tensor, of shape (seq_len, batch_size, embed_dim) - attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), - with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect - attn_weights.sum(dim=-1) == 1. - Returns: - a tensor with the same shape as x. + SelfAttention initialization. + + Parameters + ---------- + embed_dim : int + The input and output embedding dimension. The number of channels is the same for input + and output of this module. + num_heads : int + The number of attention heads. + value_head_dim : int + The dimension of the value per head. + right_context : int + The module look ahead future context, used to update left cached + attention value correctly. + device : torch.device + The device used to store the layer positional embeddings. + Either torch.device("cpu") or torch.device("cuda"). """ - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) - - x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, value_head_dim) - value_head_dim = x.shape[-1] - # todo: see whether there is benefit in overriding matmul - x = torch.matmul(attn_weights, x) - # v: (num_heads, batch_size, seq_len, value_head_dim) - - x = ( - x.permute(2, 1, 0, 3) - .contiguous() - .view(seq_len, batch_size, num_heads * value_head_dim) - ) - - # returned value is of shape (seq_len, batch_size, embed_dim), like the input. - x = self.out_proj(x) - x = self.whiten(x) + super().__init__() - return x + self.in_proj = torch.nn.Linear(embed_dim, num_heads * value_head_dim, device=device) + self.out_proj = torch.nn.Linear(num_heads * value_head_dim, embed_dim, device=device) + self.right_context = right_context - def streaming_forward( - self, - x: Tensor, - attn_weights: Tensor, - cached_val: Tensor, - left_context_len: int, - ) -> Tuple[Tensor, Tensor]: + def forward( + self, x: torch.Tensor, attn_weights: torch.Tensor, left_cached_val: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: """ - Args: - x: input tensor, of shape (seq_len, batch_size, embed_dim) - attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), - with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect - attn_weights.sum(dim=-1) == 1. - cached_val: cached attention value tensor of left context, - of shape (left_context_len, batch_size, value_dim) - left_context_len: number of left context frames. - - Returns: - - attention weighted output, a tensor with the same shape as x. - - updated cached attention value tensor of left context. + Does a forward pass of the SelfAttention module. Returns attention weighted input tensor + and updated cached attention value tensor of the left context. + + Parameters + ---------- + x : torch.Tensor[torch.float32] + The input float tensor of shape (1, seq_len, embed_dim). The module input. + attn_weights : torch.Tensor[torch.float32] + The tensor of shape (1, num_heads, seq_len, seq_len_2), with (seq_len, seq_len_2) + being interpreted as (tgt_seq_len, src_seq_len). Expect attn_weights.sum(dim=3) == 1.0. + left_cached_val : torch.Tensor[torch.float32] + The cached attention value tensor of the left context + of shape (1, left_context_len, value_dim). + + Returns + ------- + tuple[torch.Tensor[torch.float32], torch.Tensor[torch.float32]] + A tuple of two float tensors: + - attention weighted output of shape (1, seq_len, embed_dim). + A tensor with the same shape as input x. + - updated cached attention value tensor of the left context + of shape (1, left_context_len, value_dim). """ - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - seq_len2 = seq_len + left_context_len - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2) - x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + batch_size = x.size(0) # batch size is 1 + num_heads = attn_weights.size(1) - # Pad cached left contexts - assert cached_val.shape[0] == left_context_len, ( - cached_val.shape[0], - left_context_len, - ) - x = torch.cat([cached_val, x], dim=0) + x = self.in_proj(x) # (1, seq_len, num_heads * value_head_dim) + + x = torch.cat((left_cached_val, x), dim=1) # Update cached left contexts - cached_val = x[-left_context_len:, ...] + left_cached_val = x[ + :, + x.size(1) - self.right_context - left_cached_val.size(1): + x.size(1) - self.right_context, + ] - x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, value_head_dim) - value_head_dim = x.shape[-1] + x = x.reshape(batch_size, x.size(1), num_heads, x.size(2) // num_heads).permute(0, 2, 1, 3) - # todo: see whether there is benefit in overriding matmul + # (1, num_heads, seq_len, seq_len_2) x (1, num_heads, seq_len_2, value_head_dim) -> + # -> (1, num_heads, seq_len, value_head_dim) x = torch.matmul(attn_weights, x) - # v: (num_heads, batch_size, seq_len, value_head_dim) - x = ( - x.permute(2, 1, 0, 3) - .contiguous() - .view(seq_len, batch_size, num_heads * value_head_dim) - ) + # (1, num_heads, seq_len, value_head_dim) -> (1, seq_len, num_heads, value_head_dim) + x = x.permute(0, 2, 1, 3) + x = x.reshape(batch_size, x.size(1), num_heads * x.size(3)) - # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + # returned value is of shape (1, seq_len, embed_dim), like the input. x = self.out_proj(x) - return x, cached_val + return x, left_cached_val -class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer2 model.""" +class FeedforwardModule(torch.nn.Module): + """ + Feedforward module in Zipformer2 encoder. + """ - def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): - super(FeedforwardModule, self).__init__() - self.in_proj = nn.Linear(embed_dim, feedforward_dim) + def __init__(self, embed_dim: int, feedforward_dim: int, device: torch.device) -> None: + """ + FeedforwardModule initialization. + + Parameters + ---------- + embed_dim : int + The input and output embedding dimension. The number of channels is the same for input + and output of this module. + feedforward_dim : int + The module hidden dimension. + device : torch.device + The device used to store the layer weights. should be + either torch.device("cpu") or torch.device("cuda"). + """ - self.hidden_balancer = Balancer( - feedforward_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=1.0, - min_abs=0.75, - max_abs=5.0, - ) + super().__init__() - # shared_dim=0 means we share the dropout mask along the time axis - self.out_proj = ActivationDropoutAndLinear( - feedforward_dim, - embed_dim, - activation="SwooshL", - dropout_p=dropout, - dropout_shared_dim=0, - bias=True, - initial_scale=0.1, - ) + self.in_proj = torch.nn.Linear(embed_dim, feedforward_dim, device=device) + self.activation = SwooshL() + self.out_proj = torch.nn.Linear(feedforward_dim, embed_dim, device=device) - self.out_whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01, - ) + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Does a forward pass of the FeedforwardModule module. + + Parameters + ---------- + x : torch.Tensor[torch.float32] + A float tensor of shape (1, seq_len, embed_dim). The module input. + + Returns + ------- + torch.Tensor[torch.float32] + A float tensor of shape (1, seq_len, embed_dim). + The module output has the same shape as input. + """ - def forward(self, x: Tensor): x = self.in_proj(x) - x = self.hidden_balancer(x) - # out_proj contains SwooshL activation, then dropout, then linear. + x = self.activation(x) x = self.out_proj(x) - x = self.out_whiten(x) - return x + return x -class NonlinAttention(nn.Module): - """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed - from the attention module) in place of actual convolution. We also took out the second nonlinearity, the - one after the attention mechanism. - Args: - channels (int): The number of channels of conv layers. +class NonlinAttention(torch.nn.Module): + """ + This is like the ConvolutionModule, but refactored so that we use multiplication by attention + weights (borrowed from the RelPositionMultiheadAttentionWeights module) instead of actual + convolution. We also took out the second nonlinearity, the one after the attention mechanism. """ def __init__( - self, - channels: int, - hidden_channels: int, + self, embed_dim: int, att_dim: int, right_context: int, device: torch.device, ) -> None: - super().__init__() - - self.hidden_channels = hidden_channels - - self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) - - # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0, - # because we noticed that well-trained instances of this module have abs-value before the sigmoid - # starting from about 3, and poorly-trained instances of the module have smaller abs values - # before the sigmoid. - self.balancer = Balancer( - hidden_channels, - channel_dim=-1, - min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), - max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), - min_abs=0.5, - max_abs=5.0, - ) - self.tanh = nn.Tanh() - - self.identity1 = Identity() # for diagnostics. - self.identity2 = Identity() # for diagnostics. - self.identity3 = Identity() # for diagnostics. - - self.out_proj = ScaledLinear( - hidden_channels, channels, bias=True, initial_scale=0.05 - ) + """ + NonlinAttention initialization. + + Parameters + ---------- + embed_dim : int + The input and output embedding dimension. The number of channels is the same for input + and output of this module. + att_dim : int + The attention output dimension of this module. + right_context : int + The module look ahead future context, used to update left cache + correctly. + device : torch.device + The device used to store the positional embeddings. + Should be either torch.device("cpu") or torch.device("cuda"). + """ - self.whiten1 = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(5.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) + super().__init__() - self.whiten2 = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(5.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) + self.in_proj = torch.nn.Linear(embed_dim, att_dim * 3, device=device) + self.out_proj = torch.nn.Linear(att_dim, embed_dim, device=device) + self.right_context = right_context def forward( - self, - x: Tensor, - attn_weights: Tensor, - ) -> Tensor: - """. - Args: - x: a Tensor of shape (seq_len, batch_size, num_channels) - attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) - Returns: - a Tensor with the same shape as x + self, x: torch.Tensor, attn_weights: torch.Tensor, left_cached_x: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: """ - x = self.in_proj(x) - - (seq_len, batch_size, _) = x.shape - hidden_channels = self.hidden_channels - - s, x, y = x.chunk(3, dim=2) - - # s will go through tanh. - - s = self.balancer(s) - s = self.tanh(s) - - s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) - x = self.whiten1(x) - x = x * s - x = self.identity1(x) # diagnostics only, it's the identity. - - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) - - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = torch.matmul(attn_weights, x) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) - - y = self.identity2(y) - x = x * y - x = self.identity3(x) - - x = self.out_proj(x) - x = self.whiten2(x) - return x - - def streaming_forward( - self, - x: Tensor, - attn_weights: Tensor, - cached_x: Tensor, - left_context_len: int, - ) -> Tuple[Tensor, Tensor]: - """. - Args: - x: a Tensor of shape (seq_len, batch_size, num_channels) - attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) - cached_x: left context, a Tensor of shape - (num_heads, batch_size, left_context_len, head_dim) - left_context_len: number of left context frames. - Returns: - - a Tensor with the same shape as x - - updated left context with same shape as cached_x + Does a forward pass of the NonlinAttention module. Returns attention weighted input tensor + and updated attention input tensor cache of the left context. + + Parameters + ---------- + x : torch.Tensor[torch.float32] + An input float tensor of shape (1, seq_len, embed_dim). + attn_weights : torch.Tensor[torch.float32] + A tensor of shape (1, seq_len, seq_len_2), that corresponds to a single attention head + with (seq_len, seq_len_2) being interpreted as (tgt_seq_len, src_seq_len). + Expected attn_weights.sum(dim=2) == 1.0. + Note: the first dimension here corresponds to a batch size. + left_cached_x : torch.Tensor[torch.float32] + A cached attention tensor of the left context of shape (1, left_context_len, att_dim). + + Returns + ------- + tuple[torch.Tensor[torch.float32], torch.Tensor[torch.float32]] + A tuple of two float tensors: + - attention weighted output of shape (1, seq_len, embed_dim). + A tensor with the same shape as input x. + - updated cached attention tensor of the left context + of shape (1, left_context_len, att_dim). """ - x = self.in_proj(x) - (seq_len, batch_size, _) = x.shape - hidden_channels = self.hidden_channels + x = self.in_proj(x) s, x, y = x.chunk(3, dim=2) - # s will go through tanh. - s = self.tanh(s) + x = x * torch.tanh(s) - s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) - x = x * s - - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == ( - num_heads, - batch_size, - seq_len, - left_context_len + seq_len, - ) - - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, head_dim) - - # Pad cached tensor - assert cached_x.shape[2] == left_context_len, ( - cached_x.shape[2], - left_context_len, - ) - x_pad = torch.cat([cached_x, x], dim=2) + x = torch.cat((left_cached_x, x), dim=1) # Update cached tensor - cached_x = x_pad[:, :, -left_context_len:, :] - - x = torch.matmul(attn_weights, x_pad) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + left_cached_x = x[ + :, + x.size(1) - self.right_context - left_cached_x.size(1): + x.size(1) - self.right_context, + ] + # (1, seq_len, seq_len_2) x (1, seq_len_2, att_dim) -> (1, seq_len, att_dim) + x = torch.matmul(attn_weights, x) x = x * y x = self.out_proj(x) - return x, cached_x - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Zipformer2 model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py + return x, left_cached_x - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). +class ConvolutionModule(torch.nn.Module): + """ + ConvolutionModule in Zipformer2 encoder. """ def __init__( - self, - channels: int, - kernel_size: int, - causal: bool, + self, embed_dim: int, kernel_size: int, right_context: int, device: torch.device, ) -> None: - """Construct a ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 - - bottleneck_dim = channels - self.causal = causal - - self.in_proj = nn.Linear( - channels, - 2 * bottleneck_dim, - ) - # the gradients on in_proj are a little noisy, likely to do with the - # sigmoid in glu. - - # after in_proj we put x through a gated linear unit (nn.functional.glu). - # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, - # but sometimes, for some reason, for layer 0 the rms ends up being very large, - # between 50 and 100 for different channels. This will cause very peaky and - # sparse derivatives for the sigmoid gating function, which will tend to make - # the loss function not learn effectively. (for most layers the average absolute values - # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, - # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different - # layers, which likely breaks down as 0.5 for the "linear" half and - # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we - # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, - # it will be in a better position to start learning something, i.e. to latch onto - # the correct range. - self.balancer1 = Balancer( - bottleneck_dim, - channel_dim=-1, - min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), - max_positive=1.0, - min_abs=1.5, - max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0), - ) - - self.activation1 = Identity() # for diagnostics - - self.sigmoid = nn.Sigmoid() - - self.activation2 = Identity() # for diagnostics + """ + ConvolutionModule initialization. + + Parameters + ---------- + embed_dim : int + The input and output embedding dimension, also the number of channels of convolution + modules. The embedding dmension is the same for input and output of this module. + kernel_size : int + The kernel size of the depthwise convolution module. + right_context : int + The module look ahead future context, used to update + causal depthwise convolution left cache correctly. + device : torch.device + The device used to store the layer weights. Should be + either torch.device("cpu") or torch.device("cuda"). + """ - assert kernel_size % 2 == 1 + super().__init__() - self.depthwise_conv = ( - ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size) - if causal - else nn.Conv1d( - in_channels=bottleneck_dim, - out_channels=bottleneck_dim, - groups=bottleneck_dim, - kernel_size=kernel_size, - padding=kernel_size // 2, + if kernel_size % 2 == 0: + raise ValueError( + 'ConvolutionModule kernerl size should be ' + f'an odd number but got {kernel_size} instead.', ) - ) - - self.balancer2 = Balancer( - bottleneck_dim, - channel_dim=1, - min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), - max_positive=1.0, - min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), - max_abs=10.0, - ) - self.whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01, + self.in_proj = torch.nn.Linear(embed_dim, 2 * embed_dim, device=device) + self.depthwise_conv = ChunkCausalDepthwiseConv1d( + embed_dim, kernel_size, right_context, device, ) - self.out_proj = ActivationDropoutAndLinear( - bottleneck_dim, - channels, - activation="SwooshR", - dropout_p=0.0, - initial_scale=0.05, - ) + self.activation = SwooshR() + self.out_proj = torch.nn.Linear(embed_dim, embed_dim, device=device) def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - chunk_size: int = -1, - ) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains True in masked positions. - - Returns: - Tensor: Output tensor (#time, batch, channels). - + self, x: torch.Tensor, left_cache: torch.Tensor, src_key_padding_mask: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: """ - - x = self.in_proj(x) # (time, batch, 2*channels) - - x, s = x.chunk(2, dim=2) - s = self.balancer1(s) - s = self.sigmoid(s) - x = self.activation1(x) # identity. - x = x * s - x = self.activation2(x) # identity - - # (time, batch, channels) - - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - - if ( - not torch.jit.is_scripting() - and not torch.jit.is_tracing() - and chunk_size >= 0 - ): - # Not support exporting a model for simulated streaming decoding - assert ( - self.causal - ), "Must initialize model with causal=True if you use chunk_size" - x = self.depthwise_conv(x, chunk_size=chunk_size) - else: - x = self.depthwise_conv(x) - - x = self.balancer2(x) - x = x.permute(2, 0, 1) # (time, batch, channels) - - x = self.whiten(x) # (time, batch, channels) - x = self.out_proj(x) # (time, batch, channels) - - return x - - def streaming_forward( - self, - x: Tensor, - cache: Tensor, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor]: - """Compute convolution module in streaming forward mode. - - Args: - x: Input tensor (#time, batch, channels). - cache: cached left context for depthwise_conv of shape - (#batch, channels, left_pad) - src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains True in masked positions. - - Returns: - - Output tensor (#time, batch, channels). - - Updated cache (#batch, channels, left_pad) + Does a forward pass of the ConvolutionModule module. Returns processed tensor of the same + shape as input and updated cached convolution tensor of the left context. + + Parameters + ---------- + x : torch.Tensor[torch.float32] + The input float tensor of shape (1, seq_len, embed_dim). The module input. + left_cache : torch.Tensor[torch.float32] + A cached convolution tensor of the left context + of shape (1, embed_dim, left_cache_len). + src_key_padding_mask : torch.Tensor[torch.bool] + The mask for the source keys of shape (1, seq_len), + contains True in masked positions that will be ignored. + + Returns + ------- + tuple[torch.Tensor[torch.float32], torch.Tensor[torch.float32]] + A tuple of two float tensors: + - module output of shape (1, seq_len, embed_dim). + A tensor with the same shape as input x. + - updated cached convolution tensor of the left context + of shape (1, embed_dim, left_cache_len). """ - x = self.in_proj(x) # (time, batch, 2*channels) + x = self.in_proj(x) # (1, seq_len, 2 * embed_dim) x, s = x.chunk(2, dim=2) - s = self.sigmoid(s) - x = x * s - # (time, batch, channels) + x = x * torch.sigmoid(s) # (1, seq_len, embed_dim) - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). + x = torch.masked_fill(x, src_key_padding_mask.unsqueeze(2), 0.0) - if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + # exchange the temporal dimension and the feature dimension for depthwise convolution. + x = x.permute(0, 2, 1) # (1, embed_dim, seq_len). + x, left_cache = self.depthwise_conv(x, left_cache) + x = x.permute(0, 2, 1) # (1, seq_len, embed_dim) - x, cache = self.depthwise_conv.streaming_forward(x, cache=cache) - - x = x.permute(2, 0, 1) # (time, batch, channels) - - x = self.out_proj(x) # (time, batch, channels) - - return x, cache - - -class ScalarMultiply(nn.Module): - def __init__(self, scale: float): - super().__init__() - self.scale = scale + x = self.activation(x) + x = self.out_proj(x) # (1, seq_len, embed_dim) - def forward(self, x): - return x * self.scale + return x, left_cache def _test_zipformer_main(causal: bool = False):