Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modified zipformer #1774

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 136 additions & 89 deletions egs/librispeech/ASR/zipformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,115 +20,162 @@
from scaling import Balancer


class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper:

RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419

It removes the recurrent connection from the decoder, i.e., the prediction
network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
class Decoder(torch.nn.Module):
"""
This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419

TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
It removes the recurrent connection from the decoder, i.e., the prediction network.
Different from the above paper, it adds an extra Conv1d right after the embedding layer.
"""

def __init__(
self,
vocab_size: int,
decoder_dim: int,
blank_id: int,
context_size: int,
):
self, vocab_size: int, decoder_dim: int, context_size: int, device: torch.device,
) -> None:
"""
Args:
vocab_size:
Number of tokens of the modeling unit including blank.
decoder_dim:
Dimension of the input embedding, and of the decoder output.
blank_id:
The ID of the blank symbol.
context_size:
Number of previous words to use to predict the next word.
Decoder initialization.

Parameters
----------
vocab_size : int
A number of tokens or modeling units, includes blank.
decoder_dim : int
A dimension of the decoder embeddings, and the decoder output.
context_size : int
A number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
device : torch.device
The device used to store the layer weights. Should be
either torch.device("cpu") or torch.device("cuda").
"""
super().__init__()

self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=decoder_dim,
)
# the balancers are to avoid any drift in the magnitude of the
# embeddings, which would interact badly with parameter averaging.
self.balancer = Balancer(
decoder_dim,
channel_dim=-1,
min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)
super().__init__()

self.blank_id = blank_id
self.embedding = torch.nn.Embedding(vocab_size, decoder_dim)

assert context_size >= 1, context_size
if context_size < 1:
raise ValueError(
'RNN-T decoder context size should be an integer greater '
f'or equal than 1, but got {context_size}.',
)
self.context_size = context_size
self.vocab_size = vocab_size

if context_size > 1:
self.conv = nn.Conv1d(
in_channels=decoder_dim,
out_channels=decoder_dim,
kernel_size=context_size,
padding=0,
groups=decoder_dim // 4, # group size == 4
bias=False,
)
self.balancer2 = Balancer(
decoder_dim,
channel_dim=-1,
min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)
else:
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
# when inference with torch.jit.script and context_size == 1
self.conv = nn.Identity()
self.balancer2 = nn.Identity()
self.conv = torch.nn.Conv1d(
decoder_dim,
decoder_dim,
context_size,
groups=decoder_dim // 4,
bias=False,
device=device,
)

def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
need_pad:
True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference.
Returns:
Return a tensor of shape (N, U, decoder_dim).
Does a forward pass of the stateless Decoder module. Returns an output decoder tensor.

Parameters
----------
y : torch.Tensor[torch.int32]
The input integer tensor of shape (N, context_size).
The module input that corresponds to the last context_size decoded token indexes.

Returns
-------
torch.Tensor[torch.float32]
An output float tensor of shape (N, 1, decoder_dim).
"""
y = y.to(torch.int64)
# this stuff about clamp() is a temporary fix for a mismatch
# at utterance start, we use negative ids in beam_search.py
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)

embedding_out = self.balancer(embedding_out)
# this stuff about clamp() is a fix for a mismatch at utterance start,
# we use negative ids in RNN-T decoding.
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(2)

if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = F.relu(embedding_out)
embedding_out = self.balancer2(embedding_out)
embedding_out = torch.nn.functional.relu(embedding_out)

return embedding_out


class DecoderModule(torch.nn.Module):
"""
A helper module to combine decoder, decoder projection, and joiner inference together.
"""

def __init__(
self,
vocab_size: int,
decoder_dim: int,
joiner_dim: int,
context_size: int,
beam: int,
device: torch.device,
) -> None:
"""
DecoderModule initialization.

Parameters
----------
vocab_size:
A number of tokens or modeling units, includes blank.
decoder_dim : int
A dimension of the decoder embeddings, and the decoder output.
joiner_dim : int
Input joiner dimension.
context_size : int
A number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
beam : int
A decoder beam.
device : torch.device
The device used to store the layer weights. Should be
either torch.device("cpu") or torch.device("cuda").
"""

super().__init__()

self.decoder = Decoder(vocab_size, decoder_dim, context_size, device)
self.decoder_proj = torch.nn.Linear(decoder_dim, joiner_dim, device=device)
self.joiner = Joiner(joiner_dim, vocab_size, device)

self.vocab_size = vocab_size
self.beam = beam

def forward(
self, decoder_input: torch.Tensor, encoder_out: torch.Tensor, hyps_log_prob: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Does a forward pass of the stateless Decoder module. Returns an output decoder tensor.

Parameters
----------
decoder_input : torch.Tensor[torch.int32]
The input integer tensor of shape (num_hyps, context_size).
The module input that corresponds to the last context_size decoded token indexes.
encoder_out : torch.Tensor[torch.float32]
An output tensor from the encoder after projection of shape (num_hyps, joiner_dim).
hyps_log_prob : torch.Tensor[torch.float32]
Hypothesis probabilities in a logarithmic scale of shape (num_hyps, 1).

Returns
-------
torch.Tensor[torch.float32]
A float output tensor of logit token probabilities of shape (num_hyps, vocab_size).
"""

decoder_out = self.decoder(decoder_input)
decoder_out = self.decoder_proj(decoder_out)

logits = self.joiner(encoder_out, decoder_out[:, 0, :])

tokens_log_prob = torch.log_softmax(logits, dim=1)
log_probs = (tokens_log_prob + hyps_log_prob).reshape(-1)

hyps_topk_log_prob, topk_indexes = log_probs.topk(self.beam)
topk_hyp_indexes = torch.floor_divide(topk_indexes, self.vocab_size).to(torch.int32)
topk_token_indexes = torch.remainder(topk_indexes, self.vocab_size).to(torch.int32)
tokens_topk_prob = torch.exp(tokens_log_prob.reshape(-1)[topk_indexes])

return hyps_topk_log_prob, tokens_topk_prob, topk_hyp_indexes, topk_token_indexes
73 changes: 33 additions & 40 deletions egs/librispeech/ASR/zipformer/joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,49 +19,42 @@
from scaling import ScaledLinear


class Joiner(nn.Module):
def __init__(
self,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
):
super().__init__()

self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
self.output_linear = nn.Linear(joiner_dim, vocab_size)
class Joiner(torch.nn.Module):

def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
project_input: bool = True,
) -> torch.Tensor:
def __init__(self, joiner_dim: int, vocab_size: int, device: torch.device) -> None:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, s_range, C).
decoder_out:
Output from the decoder. Its shape is (N, T, s_range, C).
project_input:
If true, apply input projections encoder_proj and decoder_proj.
If this is false, it is the user's responsibility to do this
manually.
Returns:
Return a tensor of shape (N, T, s_range, C).
Joiner initialization.

Parameters
----------
joiner_dim : int
Input joiner dimension.
vocab_size : int
Output joiner dimension, the vocabulary size, the number of BPEs of the model.
device : torch.device
The device used to store the layer weights. Should be
either torch.device("cpu") or torch.device("cuda").
"""
assert encoder_out.ndim == decoder_out.ndim, (
encoder_out.shape,
decoder_out.shape,
)

if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
else:
logit = encoder_out + decoder_out
super().__init__()

self.output_linear = torch.nn.Linear(joiner_dim, vocab_size, device=device)

logit = self.output_linear(torch.tanh(logit))
def forward(self, encoder_out: torch.Tensor, decoder_out: torch.Tensor) -> torch.Tensor:
"""
Does a forward pass of the Joiner module. Returns an output tensor after a simple joining.

Parameters
----------
encoder_out : torch.Tensor[torch.float32]
An output tensor from the encoder after projection of shape (N, joiner_dim).
decoder_out : torch.Tensor[torch.float32]
An output tensor from the decoder after projection of shape (N, joiner_dim).

Returns
-------
torch.Tensor[torch.float32]
A float output tensor of log token probabilities of shape (N, vocab_size).
"""

return logit
return self.output_linear(torch.tanh(encoder_out + decoder_out))
Loading
Loading