/usr/local/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py:295: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:84.)
cpu = _conversion_method_template(device=torch.device("cpu"))
Attention
As you know from this week’s video lectures, attention is one of the cornerstones of the Transformer architecture and the current generation of large language models. At the same time, for many it is one of the most challenging technical concepts in this course. This became apparent in the quiz results, so I dedicated parts of Monday’s on-campus session to discuss it in detail.
Example from the quiz
The quiz for lecture 2.3 featured the following question:
Consider following values for the example in slide 6 ff.:
\(s = [0.4685, 0.9785]\), \(h_1 = [0.5539, 0.7239]\), \(h_2 = [0.4111, 0.3878]\), \(h_3 = [0.2376, 0.1264]\)
Assuming that the attention score is computed using the dot product, what is \(v\)?
Let us code up this example in PyTorch to get the answer:
# decoder hidden state (s)
quiz_s = torch.tensor([0.4685, 0.9785])
# encoder hidden states (h1, h2, h3)
quiz_h1 = torch.tensor([0.5539, 0.7239])
quiz_h2 = torch.tensor([0.4111, 0.3878])
quiz_h3 = torch.tensor([0.2376, 0.1264])First, we compute the scores using the dot product:
quiz_score1 = (quiz_s * quiz_h1).sum(-1)
quiz_score2 = (quiz_s * quiz_h2).sum(-1)
quiz_score3 = (quiz_s * quiz_h3).sum(-1)
quiz_score1, quiz_score2, quiz_score3(tensor(0.9678), tensor(0.5721), tensor(0.2350))
Next, we normalise the scores using the softmax function:
quiz_alphas = F.softmax(torch.tensor([quiz_score1, quiz_score2, quiz_score3]), dim=-1)
quiz_alphastensor([0.4643, 0.3126, 0.2231])
Finally, we compute the alpha-weighted sum:
quiz_v = quiz_alphas[0] * quiz_h1 + quiz_alphas[1] * quiz_h2 + quiz_alphas[2] * quiz_h3
quiz_vtensor([0.4387, 0.4855])
Thus, the correct answer to the quiz question was \(v = [0.4387, 0.4855]\).
Dictionary example
One way to understand the more general characterisation of attention as a query–key–value mechanism is to try using this mechanism to code up a dictionary data structure.
Simulating dictionaries
A standard dictionary is a mapping from keys to values. To retrieve a value from the dictionary, we query it with a key. Here is an example where we map integer keys to tensor-like lists of integers:
# Set up a dictionary with keys and values
dict_example = {0: [2, 3], 1: [5, 7], 2: [3, 2]}
# Query the dictionary
dict_example[1][5, 7]
To simulate this in PyTorch, we set up keys and queries as one-hot vectors. The values are tensor versions of the original lists.
# Set up the keys as one-hot vectors
dict_k1 = torch.tensor([1., 0., 0.])
dict_k2 = torch.tensor([0., 1., 0.])
dict_k3 = torch.tensor([0., 0., 1.])
# Set up the values (tensor versions of the original lists)
dict_v1 = torch.tensor([2., 3.])
dict_v2 = torch.tensor([5., 7.])
dict_v3 = torch.tensor([3., 2.])
# Set up the query (one-hot vector)
dict_q = torch.tensor([0., 1., 0.])The idea now is to retrieve the value \([5, 7]\) from the dictionary using attention. Let us see how this works.
Using attention to retrieve from a dictionary
First, we compute the scores using the dot product:
dict_score1 = (dict_q * dict_k1).sum(-1)
dict_score2 = (dict_q * dict_k2).sum(-1)
dict_score3 = (dict_q * dict_k3).sum(-1)
dict_score1, dict_score2, dict_score3(tensor(0.), tensor(1.), tensor(0.))
Next, we put these scores into a vector \(\alpha\):
dict_alphas = torch.tensor([dict_score1, dict_score2, dict_score3])
dict_alphastensor([0., 1., 0.])
In the actual attention mechanism, we would have to send this vector through the softmax function in order to get normalised scores, i.e., scores that sum up to one. We return to this issue in a minute. For now, let us take the last step of the attention mechanism and compute the \(\alpha\)-weighted sum:
dict_v = dict_alphas[0] * dict_v1 + dict_alphas[1] * dict_v2 + dict_alphas[2] * dict_v3
dict_vtensor([5., 7.])
And, hey presto, there is the value we wanted to query!
Putting the softmax back in
Now, back to the softmax. Using it directly does not really work:
dict_alphas = F.softmax(torch.tensor([dict_score1, dict_score2, dict_score3]), dim=-1)
dict_v = dict_alphas[0] * dict_v1 + dict_alphas[1] * dict_v2 + dict_alphas[2] * dict_v3
dict_vtensor([3.9403, 5.0925])
The problem is that the softmax does not give us a one-hot vector as in our initial attempt, but a smoothed-out distribution of weights. As a consequence, the retrieved value is a blend of all three values in the dictionary.
However, we can simulate the desired behaviour by applying a low temperature to the softmax:
# Softmax with temperature 0.01
dict_alphas = F.softmax(torch.tensor([dict_score1, dict_score2, dict_score3]) / 0.01, dim=-1)
dict_v = dict_alphas[0] * dict_v1 + dict_alphas[1] * dict_v2 + dict_alphas[2] * dict_v3
dict_vtensor([5., 7.])
Summary
So, what we have seen here is that a standard dictionary data structure can be simulated by a restricted form of the attention mechanism in which the queries and keys are one-hot vectors. In such a mechanism, the score between queries and keys is either \(1\) (if query and key are identical) or \(0\) (if they are not).
In general attention, where the queries and keys are not necessarily one-hot vectors, the (normalised) score can be anything in between \(1\) and \(0\). As a consequence, the value retrieved with a query will be a blend of all values in the dictionary, weighted with their corresponding scores.
Attention function
We can package the attention computations from the previous sections in the following function:
def attention(Q, K, V, temp=0.01):
# shape of Q: [batch_size, num_of_queries, query_dim]
# shape of K: [batch_size, num_of_keys, query_dim]
# shape of V: [batch_size, num_of_keys, value_dim]
# Compute the attention scores
scores = Q @ K.transpose(-1, -2)
# shape of scores: [batch_size, num_of_queries, num_of_keys]
# Normalise the attention scores
alphas = F.softmax(scores / temp, dim=-1)
# shape of scores: [batch_size, num_of_queries, num_of_keys]
# The output is the alpha-weighted sum of the values
result = alphas @ V
# shape of result: [batch_size, query_len, out_dim]
return resultExample from the quiz
Let’s verify that this function actually does what it is supposed to do. We start with the quiz example. Because the attention function requires a batch dimension, we have to unsqueeze all tensors at dimension 0. The attention function also supports a whole sequence of queries per batch item, so our previous query needs an additional unqueeze:
Q_quiz = quiz_s.unsqueeze(0).unsqueeze(0)
K_quiz = torch.vstack([quiz_h1, quiz_h2, quiz_h3]).unsqueeze(0)
V_quiz = torch.vstack([quiz_h1, quiz_h2, quiz_h3]).unsqueeze(0)
Q_quiz, K_quiz, V_quiz
attention(Q_quiz, K_quiz, V_quiz)tensor([[[0.5539, 0.7239]]])
This seems to work.
Dictionary example
Now for the dictionary example:
Q_dict = dict_q.unsqueeze(0).unsqueeze(0)
K_dict = torch.vstack([dict_k1, dict_k2, dict_k3]).unsqueeze(0)
V_dict = torch.vstack([dict_v1, dict_v2, dict_v3]).unsqueeze(0)
Q_dict, K_dict, V_dict
attention(Q_dict, K_dict, V_dict)tensor([[[5., 7.]]])
Looks good!