Attention

If you could use some additional explanation of attention, this post is here to help.
Author

Marco Kuhlmann

Published

February 8, 2024

As you know from this week’s video lectures, attention is one of the cornerstones of the Transformer architecture and the current generation of large language models. At the same time, for many it is one of the most challenging technical concepts in this course. This became apparent in the quiz results, so I dedicated parts of Monday’s on-campus session to discuss it in detail.

Example from the quiz

Quiz Q3.3 featured the following question:

Consider following values for the example in slide 6 ff.:

\(s = [0.4685, 0.9785]\), \(h_1 = [0.5539, 0.7239]\), \(h_2 = [0.4111, 0.3878]\), \(h_3 = [0.2376, 0.1264]\)

Assuming that the attention score is computed using the dot product, what is \(v\)?

Let us code up this example in PyTorch to get the answer:

# decoder hidden state (s)
quiz_s = torch.tensor([0.4685, 0.9785])

# encoder hidden states (h1, h2, h3)
quiz_h1 = torch.tensor([0.5539, 0.7239])
quiz_h2 = torch.tensor([0.4111, 0.3878])
quiz_h3 = torch.tensor([0.2376, 0.1264])

First, we compute the scores using the dot product:

quiz_score1 = (quiz_s * quiz_h1).sum(-1)
quiz_score2 = (quiz_s * quiz_h2).sum(-1)
quiz_score3 = (quiz_s * quiz_h3).sum(-1)

quiz_score1, quiz_score2, quiz_score3
(tensor(0.9678), tensor(0.5721), tensor(0.2350))

Next, we normalise the scores using the softmax function:

quiz_alphas = F.softmax(torch.tensor([quiz_score1, quiz_score2, quiz_score3]), dim=-1)

quiz_alphas
tensor([0.4643, 0.3126, 0.2231])

Finally, we compute the alpha-weighted sum:

quiz_v = quiz_alphas[0] * quiz_h1 + quiz_alphas[1] * quiz_h2 + quiz_alphas[2] * quiz_h3

quiz_v
tensor([0.4387, 0.4855])

Thus, the correct answer to the quiz question was \(v = [0.4387, 0.4855]\).

Dictionary example

One way to understand the more general characterisation of attention as a query–key–value mechanism is to try using this mechanism to code up a dictionary data structure.

Simulating dictionaries

A standard dictionary is a mapping from keys to values. To retrieve a value from the dictionary, we query it with a key. Here is an example where we map integer keys to tensor-like lists of integers:

# Set up a dictionary with keys and values
dict_example = {0: [2, 3], 1: [5, 7], 2: [3, 2]}

# Query the dictionary
dict_example[1]
[5, 7]

To simulate this in PyTorch, we set up keys and queries as one-hot vectors. The values are tensor versions of the original lists.

# Set up the keys as one-hot vectors
dict_k1 = torch.tensor([1., 0., 0.])
dict_k2 = torch.tensor([0., 1., 0.])
dict_k3 = torch.tensor([0., 0., 1.])

# Set up the values (tensor versions of the original lists)
dict_v1 = torch.tensor([2., 3.])
dict_v2 = torch.tensor([5., 7.])
dict_v3 = torch.tensor([3., 2.])

# Set up the query (one-hot vector)
dict_q = torch.tensor([0., 1., 0.])

The idea now is to retrieve the value \([5, 7]\) from the dictionary using attention. Let us see how this works.

Using attention to retrieve from a dictionary

First, we compute the scores using the dot product:

dict_score1 = (dict_q * dict_k1).sum(-1)
dict_score2 = (dict_q * dict_k2).sum(-1)
dict_score3 = (dict_q * dict_k3).sum(-1)

dict_score1, dict_score2, dict_score3
(tensor(0.), tensor(1.), tensor(0.))

Next, we put these scores into a vector \(\alpha\):

dict_alphas = torch.tensor([dict_score1, dict_score2, dict_score3])

dict_alphas
tensor([0., 1., 0.])

In the actual attention mechanism, we would have to send this vector through the softmax function in order to get normalised scores, i.e., scores that sum up to one. We return to this issue in a minute. For now, let us take the last step of the attention mechanism and compute the \(\alpha\)-weighted sum:

dict_v = dict_alphas[0] * dict_v1 + dict_alphas[1] * dict_v2 + dict_alphas[2] * dict_v3

dict_v
tensor([5., 7.])

And, hey presto, there is the value we wanted to query!

Putting the softmax back in

Now, back to the softmax. Using it directly does not really work:

dict_alphas = F.softmax(torch.tensor([dict_score1, dict_score2, dict_score3]), dim=-1)

dict_v = dict_alphas[0] * dict_v1 + dict_alphas[1] * dict_v2 + dict_alphas[2] * dict_v3

dict_v
tensor([3.9403, 5.0925])

The problem is that the softmax does not give us a one-hot vector as in our initial attempt, but a smoothed-out distribution of weights. As a consequence, the retrieved value is a blend of all three values in the dictionary.

However, we can simulate the desired behaviour by applying a low temperature to the softmax:

# Softmax with temperature 0.01
dict_alphas = F.softmax(torch.tensor([dict_score1, dict_score2, dict_score3]) / 0.01, dim=-1)

dict_v = dict_alphas[0] * dict_v1 + dict_alphas[1] * dict_v2 + dict_alphas[2] * dict_v3

dict_v
tensor([5., 7.])

Summary

So, what we have seen here is that a standard dictionary data structure can be simulated by a restricted form of the attention mechanism in which the queries and keys are one-hot vectors. In such a mechanism, the score between queries and keys is either \(1\) (if query and key are identical) or \(0\) (if they are not).

In general attention, where the queries and keys are not necessarily one-hot vectors, the (normalised) score can be anything in between \(1\) and \(0\). As a consequence, the value retrieved with a query will be a blend of all values in the dictionary, weighted with their corresponding scores.

Attention function

We can package the attention computations from the previous sections in the following function:

def attention(Q, K, V, temp=0.01):
    # shape of Q: [batch_size, num_of_queries, query_dim]
    # shape of K: [batch_size, num_of_keys, query_dim]
    # shape of V: [batch_size, num_of_keys, value_dim]

    # Compute the attention scores
    scores = Q @ K.transpose(-1, -2)
    # shape of scores: [batch_size, num_of_queries, num_of_keys]

    # Normalise the attention scores
    alphas = F.softmax(scores / temp, dim=-1)
    # shape of scores: [batch_size, num_of_queries, num_of_keys]

    # The output is the alpha-weighted sum of the values
    result = alphas @ V
    # shape of result: [batch_size, query_len, out_dim]

    return result

Example from the quiz

Let’s verify that this function actually does what it is supposed to do. We start with the quiz example. Because the attention function requires a batch dimension, we have to unsqueeze all tensors at dimension 0. The attention function also supports a whole sequence of queries per batch item, so our previous query needs an additional unqueeze:

Q_quiz = quiz_s.unsqueeze(0).unsqueeze(0)
K_quiz = torch.vstack([quiz_h1, quiz_h2, quiz_h3]).unsqueeze(0)
V_quiz = torch.vstack([quiz_h1, quiz_h2, quiz_h3]).unsqueeze(0)

Q_quiz, K_quiz, V_quiz

attention(Q_quiz, K_quiz, V_quiz)
tensor([[[0.5539, 0.7239]]])

This seems to work.

Dictionary example

Now for the dictionary example:

Q_dict = dict_q.unsqueeze(0).unsqueeze(0)
K_dict = torch.vstack([dict_k1, dict_k2, dict_k3]).unsqueeze(0)
V_dict = torch.vstack([dict_v1, dict_v2, dict_v3]).unsqueeze(0)

Q_dict, K_dict, V_dict

attention(Q_dict, K_dict, V_dict)
tensor([[[5., 7.]]])

Looks good!