DEV Community

sangjun_park
sangjun_park

Posted on

Thinking About Multi-head attention and why we need it

Transformer Model

Before explaining the concept of Multi-Head Attention, we first need to know about the Transformer model, which is superior to Multi-Head Attention.

Researchers at Google developed the transformer model and it is proposed in a 2017 paper "Attention is All You Need". The Text is converted to a numerical representation called tokens, and each token is converted into a vector by looking up from a word embedding table.

In NLP, the word embedding is a representation of a word. Typically, the representation is a real-valued vector that encodes the meaning of the word in such a way that the words that are closer in the vector space are expected to be similar in meaning.

The goal of word embedding is to capture the semantic meaning of words.

Before transformer models, there are several kinds of NLP models such as RNNs, LSTMs, or GRUs. However, these models had several limitations.

  1. Sequential Processing: RNNs process tokens in sequence, which makes them slow and difficult to parallelize.

  2. Long-Range Dependencies: RNNs struggle to capture dependencies between distant tokens due to the vanishing gradient problem.

  3. Fixed-Length Contexts: RNNs generally rely on fixed-length context windows, limiting their ability to handle varying context lengths effectively.

The transformer model addresses these limitations using a self-attention mechanism and fully parallelizable architecture.

Self-Attention mechanism

concept

  • Attention: In the context of neural networks, it refers to the ability to focus on specific parts of the input sequence when producing each element of the output. It allows the model to weigh the importance of the different words (or tokens) in a sequence relative to each other. In the context of sequence-to-sequence models, attention typically involves focusing on different parts of the encoder's output when generating each word in the decoder's output.

  • Self-Attention: specific type of attention where each word in the sequence attends to every other word, including itself. This helps the model understand the relationships between words in the context of the entire sequence.

The key difference between attention and self-attention lies in the scope of the elements. For example, In the sentence "The cat sat on the mat", self-attention allows the word "cat" to focus on "sat" and "mat" to understand its context better, rather than considering only adjacent words.

For each word in input sequences, the self-attention mechanism computes three vectors

  • Query(Q): Represents the word we are currently focusing on. In the Self-Attention mechanism, we make a Query vector based on each word of the input sequence.
  • Key(K): Represents the words we are comparing against. It compares to the Query vector so that we can calculate how important this word(K) is compared to the target(Q).
  • Value(V): Contains the actual information of the words. It is calculated by the weight of Attention and used as a final result.

These vectors are not directly generated from the input word embeddings. instead, they are created learnable weight matrices.

We multiply between input embedding X and the weight matrix (W_Q, W_K, W_V) to make the vector of Q, K, and V.

Now we need to Calculate Attention Scores. Q, K, and V vectors are generated from the input embeddings using the respective weight matrices W_Q, W_K, and W_V.

The attention score for each pair of words is calculated by taking the dot product of the Query vector for one word with the Key vector of another word.

Image description

  • Qi: the Query vector for the i-th word.
  • Kj: the Key vector for the j-th word.
  • dot product: measures the similarity between the Query and Key vectors, determining how much one word is related to another

To prevent the dot products from becoming too large and destabilizing the softmax function, the scores are scaled by the square root of the dimension of the Key vectors (dk).

Image description

Finally, we'll apply the softmax function. The scaled attention scores are then passed through a softmax function. The softmax function converts these scores into a probability distribution, where the scores sum to 1. Softmax essentially normalizes the scores, making it easier to interpret them as attention weights.

Image description

Generate Output

after calculating attention weights, the next step is to generate the output for the self-attention mechanism.

Image description

for each value of j from 1 to n, the summation notation means to calculate the expression of Attention Weights and jth of V(Value).

Multi-Head Attention Extension

In Multi-Head Attention, each head performs its own self-attention calculation independently, and then the output of all heads is concatenated and then linearly transformed to produce the final output.

This is the formula for Multi-Head attention extension.

Let's Imagine that there are h heads, each head produces an output, which we can denote as:

  • head 1, head 2, ..., head h

and each of these outputs (from head 1 to head h) is a vector of the same dimension, let's say d-k. The idea of concatenation is to take these h vectors and combine them into a single, longer vector. The reason why we are using d-k is because we divide d-model by h.

for example, let's suppose there are d-model, which is the dimensionality of the input vectors to the model. For example, if each word in the input sequence is represented by a 512-dimensional vector, then d-model = 512.

h is the number of attention heads. For example, if you have 8 heads, then h = 8. We need to figure out the dimensions of the head. and it could be calculated by 512 / 8 = 64. and we called this d-k.

Each head independently processes the input sequence through the Linear Projections and Attention Mechanism.

Image description

This image shows the best example of explaining the concatenation and linearly-transformed.

After we apply the linear transformation, it is ready for subsequent layers in the Transformer model.

Why do we need Multi-Head Attention?

The primary reasons for using Multi-Head Attention are related to the ability to capture complex patterns, improve model performance, and provide richer, more context-aware representations of the input data.

References

Top comments (0)