layout | title | tags | mathjax | ||
---|---|---|---|---|---|
post |
A round-up of linear transformers |
|
true |
Transformers are ubiquitous in deep learning today. First proposed in the famous "Attention is all you need" paper by Vaswani et al. for the task for neural machine translation, they soon gained popularity in NLP, and formed the backbone for strong pre-trained language models like BERT and GPT. Since then, they have been adopted extensively in speech tasks (see my [other post]({% post_url 2020-01-08-transformer-asr %}) on the challenges of using transformers in ASR), and more recently in computer vision, with the introduction of the ViT model.
The workhorse of the transformer architecture is the multi-head self-attention (MHSA) layer. Here, "self-attention" is a way of routing information in a sequence using the same sequence as the guiding mechanism (hence the "self"), and when this process is repeated several times, i.e., for many "heads", it is called MHSA. I will not go into details about the transformer in this post --- it has already been covered in much visual detail by Jay Alammar and annotated with code by Sasha Rush. Yannic Kilcher has also covered the paper in his Papers Explained series. If you are not familiar with the model or with self-attention in general, I would suggest that you check out those links before reading further.
Self-attention is simply a method to transform an input sequence using signals from the same sequence. Suppose we have an input sequence
Here, the softmax
operation is applied per-row of the
There has been a long line of research on making transformers "efficient" --- too long, in fact, to be covered in one blog post. This paper provides a great review of these methods. In this post, I will focus on methods which make the self-attention mechanism linear, i.e., they reduce the complexity from
- Methods based on low-rank approximation
- Methods based on local-global attention
- Methods using softmax as a kernel
In the remaining part of this post, I will discuss papers falling under each of these categories. Please note that this is my attempt at understanding how these different "efficient transformers" relate to each other. I may be wrong about some methods --- in which case, please feel free to correct me in the comments. A special shout-out to Yannic Kilcher's YT videos which helped me understand the details in several of the papers mentioned below.
In the case of multi-head self-attention, the embedding dimensionality
Linformer was perhaps the first paper which used the above observation to linearize self-attention. Suppose we represent softmax
operation on the low-rank
The authors then used the Johnson–Lindenstrauss lemma to claim that there exists a low-rank matrix
If we use a random projection matrix to project a set of points onto a lower dimension, the pairwise distances are approximately preserved.
To put theory into practice, the authors projected
The only remaining question is how to choose an appropriate
This paper also uses the low-rank observation, but instead of using JL projections, it uses the Nystrom approximation for approximating
where softmax
operation (elements of
The workaround suggested in the paper is to perform the approximation inside the softmax
, and then apply the softmax
operation on the approximated matrix. In summary, suppose
The
The second class of methods "sparsifies" attention computation by restricting how many tokens in the sequence each token attends to. Often, such a selection is made using knowledge of the task at hand, i.e., these methods inject some inductive biases into the attention modeling.
The idea behind longformer can most easily be understood from the following figure taken from the paper:
Figure (a) shows the self-attention pattern in the standard transformer. If we restrict each item to only attend to a window of size [CLS]
token is used for global attention in classification tasks, while for QA, all the question tokens receive global attention.
An important detail in the Longformer paper is the implementation of such an attention pattern. The authors provide a custom CUDA kernel to implement such "banded" matrix multiplication, since it cannot be naturally implemented using existing functions in PyTorch or Tensorflow. Their implementation is available here.
The core idea of BigBird is very similar to the Longformer, and is shown in the figure below, taken from the paper:
Similar to the longformer, BigBird uses a windowed attention and a selective global attention. Additionally, it also uses a "random attention", where each token in the sequence attends to a few randomly selected tokens (in addition to the global tokens and those in its window). More importantly, the authors show that this attention pattern has the same expressivity as standard full self-attention, both theoretically and empirically. In particular, they show 2 main things:
- Sparse attention patterns with some "global" tokens are universal approximators, similar to full attention. For this, they use the idea of a "star graph" (as opposed to a complete graph formed by full attention). The idea is that any information routing can be done through the center node, which is the global token.
- Sparse attention is also Turing complete. However, in some cases, it may require a polynomial number of layers where each layer is linear. This kind of defeats the purpose of linear self-attention.
Overall, the random tokens is what makes BigBird different from Longformer, but it seems these random tokens are not really required for the theoretical guarantees. Moreover, they didn't use the random tokens in their BigBird-ETC experiments either (see Table 8 in the Appendix). One neat trick in the paper is the use of matrix rolling for efficient attention computation, explained in detail in Appendix D.
This paper combines a short-term attention and a long-range attention. Their short-term attention is simply the sliding window attention pattern that we have seen previously in Longformer and BigBird. The long-range attention is similar to the low-rank projection idea that was used in Linformer, but with a small change. In Linformer, the key and value matrices
Both the categories we have seen previously used some prior inductive biases about the model. The low-rank approximation methods relied on the empirical observation that the self-attention matrix is approximately low rank, while the local-global attention was based on the idea that only a few tokens (often defined based on the task) need to attend globally to all tokens. In contrast, kernel-based approximations do not usually involve any such priors, and as a result, are more mathematically robust. To understand this category of linear transformers, let us take another look at self-attention.
In the above equation, the
where $ \text{sim}(Q_i, K_j) = \frac{\text{exp}(Q_iK_j)}{\sqrt{d}}.$ Here sim
is just a similarity function between query
Since
Using the above decomposition, we can rewrite $ V_i^{\prime}$ as
Now, we can take
The expressions in the parentheses can be computed once and used for all softmax
kernel is infinite-dimensioal, and so we cannot compute it exactly! The papers in this section use the above idea and try to approximate
In this paper, the authors (somewhat arbitrarily) selected
The other important part of the paper (which gives it the name "Transformers are RNNs") shows an equivalance between autoreressive linear transformers and RNNs. In general, for problems requiring autoregressive computation, a causal masking function is usually employed to compute attention. It can then be shown through an appropriate manipulation of the linear self-attention equation, that the model simply updates an internal states and passes it forward, which should make it equivalent to an RNN.
In any case, this paper was perhaps the first to make the kernel interpretation for softmax attention, and paved the way to more rigorous approximation using random features, which we will see in the next 2 papers. The authors also provide a fast implementation (using gradient computation in CUDA) and an autoreressive generation demo which works on the browser, demonstrating the capabilities of their model.
Performers use something called fast attention via positive orthogonal random features, abbreviated as FAVOR+, a method which (the authors claim) can be used for any general-purpose scalable kernel approximation. FAVOR+ is based on the idea of random Fourier features first made popular in this award-winning paper from Rahimi and Recht. The authors propose that any kernel function can be approximated using the following mapping function:
Here,
There is a caveat in the above approximation. While the method provides a good approximation on average, the variance is quite high, especially when the actual value is close to 0. This is because the softmax kernel is always positive, while the above approximation uses sinusoidal functions which may be positive or negative. Since the self-attention matrix is usually sparse in practice, using the above approximation results in a very high variance empirically. To solve this problem, the authors suggest a slightly different approximation, using exp
function with ReLU
and get better results. Finally, they show that if we choose the
Using the above mapping function,
This paper was published concurrently with the Performer paper at ICLR 2021, and proposes the same idea of approximating the softmax kernel using random features. A similar extension of Rahimi and Recht's work is used to compute a mapping
Additionally, this paper also proposes a method for learning with recency bias, since softmax does not explicitly model distance or locality (hence the importance of positional encodings in transformers). In the "transformers are RNNs" paper, we saw how autoregressive transformers can be shown to be equivalent to RNNs. Inspired from this equivalence, the authors in this paper further add a learned gating mechanism in the computation which biases the model to rely more on recent tokens.
Here is a tabular summary of all the papers covered in this post:
Method Name | Concept used | Approximation method | Implementations |
---|---|---|---|
Linformer | Low-rank approximation | JL projection matrices | Original (Fairseq), PyTorch |
Nystromformer | Low-rank approximation | Nystrom approximation | Original, PyTorch |
Longformer | Local-global attention | Window + Task-specific global attention | Original, HuggingFace |
Big Bird | Local-global attention | Window + global + block-wise random | Original (Tensorflow), HuggingFace |
Long-short transformer | Local-global attention | Window + dynamic JL projection matrix | PyTorch |
Fast transformer | Softmax kernel | Original, Reimplementation | |
Performer | Softmax kernel | FAVOR+ | Original (TF), PyTorch, HuggingFace |
Random Features Attention | Softmax kernel | Random features + gating |
I hope this summary would be useful to keep track of all the research happening in the field of efficient transformers.