As Transformers, and especially their attention mechanism, are a regular part of my work, I've often wondered how such an innovative architecture was created.
If you're looking for an implementation, I highly recommend checking out fast attention [https://github.com/Dao-AILab/flash-attention]. It's my go-to, and far better than anything we could whip up here using just PyTorch or TensorFlow.
In this article, I want to focus on the thought process that might have led to the attention mechanism. Here's what we'll be exploring:
- Quick recap of the attention formulation
- Core idea and a step-by-step construction of the weighting mechanism
- Technical details
- Some possible extensions
I hope this walkthrough sparks some ideas for your own modifications!
Original Formulation
Transformers are the go-to architectures for sequences. The paper "Attention is All You Need" from 2017 became a milestone in neural network architectures. It might seem like a genius stroke of discovery, but tracing the history of ideas can give us insight into how to arrive at such architectures ourselves.
This isn't a typical Transformer tutorial; I assume you have some familiarity with them. Instead, we'll follow the thought chain that likely led A. Vaswani et. al. to their idea, and see how the final attention mechanism emerges:
Discoveries usually happen in the "adjacent possible" – the space of ideas made reachable by prior discoveries. Often, a single minor step creates a tipping point for impressive results.
The idea of attention in ML existed before the Transformer. For example, "Effective Approaches to Attention-based Neural Machine Translation" (published a few months earlier) used attention mechanisms to reweight hidden layers in sequence models. Additionally, Squeeze-and-Excitation networks, which reweighted selected CNN hidden layers, were also considered attention-based architectures.
The core idea was to boost features more relevant to the output. Mathematically, this re-weighting can be expressed as:
where $i$ indexes layers $L_i$ , and $w_i$ are the attention weights. It's also common to normalize the weights:
Application to Sequences
The Transformer authors were dealing with text embedded sequences $t_1,...,t_n$ (think of each token as a vector in $\mathbb{R}^d$ , which represents some embedded information).
Before Transformers, machine learning models often had fixed input and output sizes. To process full sequences, we either fed tokens one by one i.e.,
$NN(t_i)=t_i'$
) or used recursive architectures like RNNs or LSTMs, where a hidden representation tries to encode the prior sequence
These architectures were notoriously hard to train, and performance wasn't optimal (although similar ideas re-emerged recently in architectures like RetNet, which claims to be a potential Transformer successor).
The authors sought a way to create an expressive sequence-to-sequence architecture that could handle arbitrary-length input tokens. The goal was for each output token to have knowledge of every prior token in the input sequence:
A straightforward idea is to reweight tokens, the essence of attention:
However, this limits us to linear combinations of previous tokens. Even with non-linearities like ReLU, we restrict ourselves to linear half-planes, creating a bottleneck.
A simple fix is to introduce a linear layer $V:\mathbb{R}^d \to \mathbb{R}^d$ for more flexibility:
The key question now is how to learn the weights $w_{ij}$ . A naive approach — outputting them from a fixed layer $W:\mathbb{R}^{n∗d} \to \mathbb{R}^{n^2}$ — fails for two reasons:
- Scalability: The weight count of this layer is $d*n^3+biases$ , which doesn't scale well
- Input Length Dependence: We'd reintroduce the dependence on $n$ , our original limitation in previous models
An alternative might be pairwise weighting, but this would require the layer to be invariant to input order, a difficult property to achieve with networks.
The solution (which might take some time to arrive at) is to split the weighting terms by indices:
These are the query and key terms. This solves our sequence length problem – now, our two layers $Q,K:\mathbb{R}^d \to \mathbb{R}$ only need to output a single value each.
Normalizing the weights with a softmax gives us a special case of the attention layer:
where $w_{ij}'=Q(t_i)∗K(t_j)$ are normalized over index $i$ using the softmax function: $w_{ij}=softmax_i(w_{1j}',...,w_{nj}')$ .
Softmax is common in ML for making outputs resemble probability distributions.
Technical Details
We essentially learn three things for each token:
- Value: Embeds the token to a new linear space
- Query: Learns how relevant a token is
- Key: Learns how relevant a token is for others
Mapping to a single dimension often creates performance bottlenecks. Therefore, making queries and keys multidimensional makes sense. The simple multiplication now becomes a dot product:
This extension provides an intuitive way to measure alignment between vectors. If two vectors align, their dot product is maximized compared to any other vector with the same norm. Essentially, the query now offers multiple directions for relevance and allows 'choosing' the contribution in each dimension. The key selects the direction in which a token wants to aggregate information for the next layer of tokens.
Now we can assemble the original formulation using queries, keys, and values written as matrices:
The dot product is calculated in one step as $QK^T$ .
You might notice the missing square root term. This was likely an experimental discovery. A useful rule of thumb is:
Neural networks perform best for inputs and outputs with a value range between -1 and 1.
The authors realized that if keys and values have a standard normal distribution, their sum has a variance of $d_k$ .
Normalizing by the square root of the variance returns the distribution to a standard normal one:
Possible Extensions
While the dot product is a natural similarity measure, if your data has spatial information, you might consider rotary attention (as introduced in Act3: [https://arxiv.org/abs/2306.17817] ), which uses spatial distances within a 3D point cloud.
Even non-standard metrics like Manhattan or Mahalanobis distances could be used for attention relationships.
Hopefully, this inspires you to explore your own extensions! Remember, complexity should always have a reason.
Let me know in the comments if you'd like a deeper dive into these topics :)
Top comments (0)