Large Language Models such as models in the GPT, BERT, and Llama families have shown impressive performance on several language generating tasks, ranging from chat completion to machine translation and text summarization, to name a few. One reason for their success lies in the design of their architecture which is based on the Transformer architecture.
In what follows, we will go through the basic components of a transformer to fully understand their functionality and purpose.
Image by Vaswani, Ashish, et al. "Attention is all you need." (2017) and Philip LippeTo this end we will use the following example. Imagine that we have the following sentence
Write a poem about a man fishing on a river bank.
and that we would like to give it as input to the transformer. Our goal is to understand what will happen to this sentence as it goes through a transformer.
First, our sentence will be broken down into tokens (that could be words or parts of words) by a tokenizer, and may look like
Write
, a
, poem
, about
, a
, man
, fishing
, on
, a
, river
, bank
Next, assuming that we have a vocabulary of
Each of the token IDs will be then mapped to an embedding vector of dimension
The learned embeddings above do not capture any information about the position of each token in the sentence. However, the position of each token or word in a sentence is important as it does affect the meaning of our sentence. For example, it is different to say
Now come here!
and to say
Come here now!
In the first case, we put emphasis on the action (come here), while in the second we put emphasis on when we want the action to happen (now). All this information would be lost if we had only knowledge about the representation of each token.
In order to keep this information, we would like to add somehow in our embeddings, information about the position of the tokens. This can be done by a function that we would ideally like to be
- Bounded — we do not want too large values during training
- Periodic — patterns that repeat are easier to learn
- Predictable — the model should be able to understand the position of each of the embeddings in sequences that it has not seen before.
The positional encoding function
$$
PE(k,i) =
\begin{cases}
\sin\left(k / 10000^{\frac{2i}{d_{\text{model}}}}\right) & \text{if } i \text{ is even} \
\cos\left(k / 10000^{\frac{2i}{d_{\text{model}}}}\right) & \text{if } \text{ otherwise},
\end{cases}
$$
where
Once we add the positional encoding information to our input embeddings, our sentence will have been through the following stages
Image by Bradney SmithOur sentence is now ready for the next component of the transformer, which we call the self-attention mechanism and is basically the heart of the transformer.
The purpose of the self-attention mechanism is to capture contextual information between the tokens in our input sentence. It is not only the representation and the position of a token or word that they give it a meaning, but also the context, i.e., the other words in the sentence. For example, the word bank
would have a very different meaning if our sentence was given in a different context such as
Write a poem about a man applying for a loan in a bank
In the above sentence, the word loan
gives us a very different context, giving bank
the meaning of a financial institution, whereas in our sentence the words fishing
and river
, give bank
the meaning of the land alongside the river. In both sentences, though, words like a
do not provide any context information.
What we would like to have is a mechanism with which it would become apparent that in our sentence the word river
gives more important context information on the word bank
compared to the word a
gives; one could say that river
is more relevant to the word bank
than a
is. The way the self-attention mechanism achieves this is by computing a measure of similarity between the token embeddings; tokens that are more relevant between each other, have a higher measure of similarity.
One could use several measures of similarity between embedding vectors, such as distance, e.g., Euclidean distance, or angle distance, e.g., cosine similarity. Both of the above examples are expensive to compute, while the latter also completely looses information about the magnitude of the embeddings and considers parallel embeddings as the same embedding.
The self-attention mechanism uses the dot product as a similarity metric, as this metric keeps information both about magnitude and angle and also is very easy to compute. Back to our example, to find the similarity of the word bank
with each of the words in our input sentence one would do
where
One problem with the dot product though, could be that it can become arbitrarily large, which is something we would like to avoid during training. Therefore, we would prefer to scale it down, i.e., we would prefer a scaled dot-product. One way to scale it down is to normalize it by dividing the dot-product with
$$
S_{\text{bank}} = \frac{X_{10} \cdot X^{T}}{\sqrt{d_{\text{model}}}},
$$
where each element bank
and the word i
in the input sentence.
The self-attention mechanism uses the scores of the vector bank
. To construct such weights, we use the softmax function
Using the attention weights, we can transform the input embedding for the token bank
to an embedding that has the context information we discuss above, by computing the weighted sum
We summarize the entire process to generate the new embedding for the word bank
as follows
In our example above, we could view the process of finding the context information for the word bank
from the tokens in the input sentence, as a process of finding information for given a query in a database. By looking for context for the word bank
we essentially ''query'' our database, that is the input sentence, to find context information. For this reason, one could refer to the bank embedding
When we compute the similarity score between our query and each input token, we essentially go through all ''attributes'' or keys in our ''database'', i.e., the input sentence to look for context. Therefore, in the dot product computation, we can refer to the
Finally, once we compute the attention weights using the query and key, we can use them to weight the input embeddings, which we could view abstractly as a way to ''select'' the actual entries or values in our database. As a result, in the weighted sum computation, we could refer to the
Overall, we could rewrite the new transformer embedding for the word bank
as
$$
y_{\text{bank}} = \sigma\left(\frac{Q_{\text{bank}} \cdot K^{T}}{\sqrt{d_{\text{model}}}}\right) \cdot V
$$
The self-attention mechanism creates a new embedding, as the one we saw above for the token bank
, for every token in the input sentence. Therefore, we consider every possible token as the query, to compute its new embedding as above. Therefore, the we can write the self-attention mechanism for all the input tokens using the input embeddings
The simple weighted sum above does not include any trainable parameters. Without trainable parameters, the performance of the model may still be good, but by allowing the model to learn more intricate patterns and hidden features from the training data, we observe much stronger model performance.
One way to introduce trainable parameters is to introduce a different weight matrix each time we use the input embeddings. Overall, we use the self-attention input embeddings three times to compute the new embeddings; once as the query, once as the key and once as the value. As a result we can introduce a matrix
The number of columns in each weight matrix is an architectural choice by the user. Choosing a number less than
The self-attention mechanism we describe above uses a single triplet of weight matrices, namely
The advantage in doing so is that, in this way, we can capture more rich context information. Each weight matrix triplet can specialize in extracting context information, each time using different relations between words. For example, in our sentence,
Write a poem about a man fishing on a river bank.
one query, key, value triplet may specialize in context between words that are close to each other such as
Write
and poem
, while another one may specialize in context between words that are further apart, such as Write
and fishing
.
We refer to this extension of the self-attention mechanism as the multi-head attention mechanism, where we consider as head each application of the attention mechanism using a single query, key, value matrix triplet. Assuming that we have
where the new embedding
The self-attention mechanism we describe above, our goal is to capture context information for each token in the input sentence from each token in the input sentence; hence the ''self'' term. However, the transformer generates the output in an autoregressive manner, meaning that it generates each new output token depending on the input and the previously generated output tokens. As a result, the new output token is influenced by context information not only in the input tokens but also in the previously generated output tokens.
For this reason we need a mechanism that captures the context information from both the input and previously generated output. The cross-attention mechanism achieves this by applying the attention mechanism using the transformer embeddings given by self-attention on both the input tokens and the previously generated output tokens. More precisely, the cross-attention mechanism uses as query values the transformer embeddings of the previously generated outputs and as key and value the transformer embeddings of the encoder, that is essentially the transformer embeddings of the input tokens passed through a feed forward layer.
Image by mage by Vaswani, Ashish, et al. "Attention is all you need." (2017) and Eleni StraitouriIn the picture above, note that for the transformer embeddings of the previously generated output sequence, we use masked mutli-head attention. What is this and why we need it?
The mechanism of masked attention (that is extended to masked multi-head attention) simply adds a matrix
The attention mask ensures that during the self-attention mechanism, we can extract context only from tokens before the query and not after the query. Since, we are interested in the partial output token sequence, without the mask we would be able to 'look into the future' and look for context in tokens that will be generated at a later time than the query. However, allowing access to information from 'future' tokens would lead to data leakage and incorrect training.
[1] J. Alammar, The Illustrated Transformer (2018). GitHub
[2] Vaswani, Ashish, et al. "Attention is all you need." (2017)
[3] Bradney Smith Self-Attention Explained with Code (2024), Medium
[4] Philip Lippe Transformers and Multi-Head Attention, (2022), University of Amsterdam