# decoder hidden state (s)
= torch.tensor([0.4685, 0.9785])
quiz_s
# encoder hidden states (h1, h2, h3)
= torch.tensor([0.5539, 0.7239])
quiz_h1 = torch.tensor([0.4111, 0.3878])
quiz_h2 = torch.tensor([0.2376, 0.1264]) quiz_h3
Attention
As you know from this week’s video lectures, attention is one of the cornerstones of the Transformer architecture and the current generation of large language models. At the same time, for many it is one of the most challenging technical concepts in this course. This became apparent in the quiz results, so I dedicated parts of Monday’s on-campus session to discuss it in detail.
Example from the quiz
Quiz Q3.3 featured the following question:
Consider following values for the example in slide 6 ff.:
\(s = [0.4685, 0.9785]\), \(h_1 = [0.5539, 0.7239]\), \(h_2 = [0.4111, 0.3878]\), \(h_3 = [0.2376, 0.1264]\)
Assuming that the attention score is computed using the dot product, what is \(v\)?
Let us code up this example in PyTorch to get the answer:
First, we compute the scores using the dot product:
= (quiz_s * quiz_h1).sum(-1)
quiz_score1 = (quiz_s * quiz_h2).sum(-1)
quiz_score2 = (quiz_s * quiz_h3).sum(-1)
quiz_score3
quiz_score1, quiz_score2, quiz_score3
(tensor(0.9678), tensor(0.5721), tensor(0.2350))
Next, we normalise the scores using the softmax function:
= F.softmax(torch.tensor([quiz_score1, quiz_score2, quiz_score3]), dim=-1)
quiz_alphas
quiz_alphas
tensor([0.4643, 0.3126, 0.2231])
Finally, we compute the alpha-weighted sum:
= quiz_alphas[0] * quiz_h1 + quiz_alphas[1] * quiz_h2 + quiz_alphas[2] * quiz_h3
quiz_v
quiz_v
tensor([0.4387, 0.4855])
Thus, the correct answer to the quiz question was \(v = [0.4387, 0.4855]\).
Dictionary example
One way to understand the more general characterisation of attention as a query–key–value mechanism is to try using this mechanism to code up a dictionary data structure.
Simulating dictionaries
A standard dictionary is a mapping from keys to values. To retrieve a value from the dictionary, we query it with a key. Here is an example where we map integer keys to tensor-like lists of integers:
# Set up a dictionary with keys and values
= {0: [2, 3], 1: [5, 7], 2: [3, 2]}
dict_example
# Query the dictionary
1] dict_example[
[5, 7]
To simulate this in PyTorch, we set up keys and queries as one-hot vectors. The values are tensor versions of the original lists.
# Set up the keys as one-hot vectors
= torch.tensor([1., 0., 0.])
dict_k1 = torch.tensor([0., 1., 0.])
dict_k2 = torch.tensor([0., 0., 1.])
dict_k3
# Set up the values (tensor versions of the original lists)
= torch.tensor([2., 3.])
dict_v1 = torch.tensor([5., 7.])
dict_v2 = torch.tensor([3., 2.])
dict_v3
# Set up the query (one-hot vector)
= torch.tensor([0., 1., 0.]) dict_q
The idea now is to retrieve the value \([5, 7]\) from the dictionary using attention. Let us see how this works.
Using attention to retrieve from a dictionary
First, we compute the scores using the dot product:
= (dict_q * dict_k1).sum(-1)
dict_score1 = (dict_q * dict_k2).sum(-1)
dict_score2 = (dict_q * dict_k3).sum(-1)
dict_score3
dict_score1, dict_score2, dict_score3
(tensor(0.), tensor(1.), tensor(0.))
Next, we put these scores into a vector \(\alpha\):
= torch.tensor([dict_score1, dict_score2, dict_score3])
dict_alphas
dict_alphas
tensor([0., 1., 0.])
In the actual attention mechanism, we would have to send this vector through the softmax function in order to get normalised scores, i.e., scores that sum up to one. We return to this issue in a minute. For now, let us take the last step of the attention mechanism and compute the \(\alpha\)-weighted sum:
= dict_alphas[0] * dict_v1 + dict_alphas[1] * dict_v2 + dict_alphas[2] * dict_v3
dict_v
dict_v
tensor([5., 7.])
And, hey presto, there is the value we wanted to query!
Putting the softmax back in
Now, back to the softmax. Using it directly does not really work:
= F.softmax(torch.tensor([dict_score1, dict_score2, dict_score3]), dim=-1)
dict_alphas
= dict_alphas[0] * dict_v1 + dict_alphas[1] * dict_v2 + dict_alphas[2] * dict_v3
dict_v
dict_v
tensor([3.9403, 5.0925])
The problem is that the softmax does not give us a one-hot vector as in our initial attempt, but a smoothed-out distribution of weights. As a consequence, the retrieved value is a blend of all three values in the dictionary.
However, we can simulate the desired behaviour by applying a low temperature to the softmax:
# Softmax with temperature 0.01
= F.softmax(torch.tensor([dict_score1, dict_score2, dict_score3]) / 0.01, dim=-1)
dict_alphas
= dict_alphas[0] * dict_v1 + dict_alphas[1] * dict_v2 + dict_alphas[2] * dict_v3
dict_v
dict_v
tensor([5., 7.])
Summary
So, what we have seen here is that a standard dictionary data structure can be simulated by a restricted form of the attention mechanism in which the queries and keys are one-hot vectors. In such a mechanism, the score between queries and keys is either \(1\) (if query and key are identical) or \(0\) (if they are not).
In general attention, where the queries and keys are not necessarily one-hot vectors, the (normalised) score can be anything in between \(1\) and \(0\). As a consequence, the value retrieved with a query will be a blend of all values in the dictionary, weighted with their corresponding scores.
Attention function
We can package the attention computations from the previous sections in the following function:
def attention(Q, K, V, temp=0.01):
# shape of Q: [batch_size, num_of_queries, query_dim]
# shape of K: [batch_size, num_of_keys, query_dim]
# shape of V: [batch_size, num_of_keys, value_dim]
# Compute the attention scores
= Q @ K.transpose(-1, -2)
scores # shape of scores: [batch_size, num_of_queries, num_of_keys]
# Normalise the attention scores
= F.softmax(scores / temp, dim=-1)
alphas # shape of scores: [batch_size, num_of_queries, num_of_keys]
# The output is the alpha-weighted sum of the values
= alphas @ V
result # shape of result: [batch_size, query_len, out_dim]
return result
Example from the quiz
Let’s verify that this function actually does what it is supposed to do. We start with the quiz example. Because the attention function requires a batch dimension, we have to unsqueeze all tensors at dimension 0. The attention function also supports a whole sequence of queries per batch item, so our previous query needs an additional unqueeze:
= quiz_s.unsqueeze(0).unsqueeze(0)
Q_quiz = torch.vstack([quiz_h1, quiz_h2, quiz_h3]).unsqueeze(0)
K_quiz = torch.vstack([quiz_h1, quiz_h2, quiz_h3]).unsqueeze(0)
V_quiz
Q_quiz, K_quiz, V_quiz
attention(Q_quiz, K_quiz, V_quiz)
tensor([[[0.5539, 0.7239]]])
This seems to work.
Dictionary example
Now for the dictionary example:
= dict_q.unsqueeze(0).unsqueeze(0)
Q_dict = torch.vstack([dict_k1, dict_k2, dict_k3]).unsqueeze(0)
K_dict = torch.vstack([dict_v1, dict_v2, dict_v3]).unsqueeze(0)
V_dict
Q_dict, K_dict, V_dict
attention(Q_dict, K_dict, V_dict)
tensor([[[5., 7.]]])
Looks good!