how the transformer architecture changed

More than 7 years have passed since the release of the original article about the transformer, and this architecture has turned the entire DL upside down: having started with NLP, it is now used everywhere, including image generation. But is this the same architecture or not? In this article I wanted to give a brief overview of the main changes that are used in the current versions of the Mistral, Llama and the like.

Positional Embeddings (PE)

  • Basic approach — to the vector of each token at the input we add an absolute position vector, it can be trainable, it can be some kind of function of the position

  • Relative P.E. — we will be at the attention stage when we count <q_i*k_j> add difference embedding there i-j. The advantage of this approach is that it is easy to generalize to sequences of a new length that were not used in training

  • RoPE – the most trendy approach now. At the attention stage we will rotate the vectors q And k depending on the position of the token. Conditionally, if the position t then let's turn the corner t*alpha. What’s cool is that the position is encoded by rotation, less calculations than with Relative PE, while relative information is preserved: if we add text before a pair of words, but the number of words between them does not change, we will additionally rotate both vectors by the same angle, and the angle between them will be preserved, which means the scalar product will not change (what is important to us in attention)! In fact, it’s a little more complicated: we won’t rotate the entire embedding, but we’ll split it into many small vectors of 2 coordinates, and rotate each one separately (see picture). Accelerates learning, improves metrics, a beautiful idea – what more do you need?

Activation Function

Let me remind you a little about the architecture: in the transformer block after attention we have a linear layer, aka linear-activation-linear. Initially there was good old ReLU. There now SwiGLU. In general, GLU-like layers are about controlling the strength of the outgoing signal. Conditionally: glu(x) = f1(x)*f2(x)Where f1(x) there will be a signal f2(x) signal strength, and the result is their element-by-element multiplication. Further f1 this is a regular linear layer, and f2 may be different, in this case the function silu. Why this particular function and not another is unclear, but again it improves the metrics.

Attention

The main part of the transformer. There are several used updates here, all in one way or another about speeding up/decrease in memory:

  • Grouped Query Attention: in a regular multi-head, in each head we have our own vectors for the token q, k And v. Here we divide the heads into groups, and inside each group there are vectors k And v the token will have the same ones. What is the point – less calculations, with not a very large loss of quality

  • Flash Attention: the point here is that the bottleneck is in memory access, not in calculations, and you can change the approach to a less efficient one in terms of calculations, but more efficient in terms of memory, due to which you can get an increase in operating speed – that is, this is exactly what how to construct calculations, the essence and result does not change, but the speedup turns out to be decent

  • Sliding Window Attention: during attention, the token will not refer to all previous tokens, but only to W the last one. If we have k layers, then on k-that layer element i will be able to obtain information from the latter W*k tokens. Again, the goal is to save on memory so that you can work with very long sequences

  • KV-cache: here we are talking about saving money during inference. Actually, we generate text recursively, that is, for each new token we run the model from the very beginning. If we run it stupidly for “Shla”, “Shla Masha”, “Shla Masha by”, etc., then we will be forced to calculate each time for all text tokens q, k, v in attention. But in general, we don’t need previous vectors to predict the next token qand also vectors k And v We have already calculated all tokens except the last one and they will not change in any way. So the idea is to keep the vectors in memory k And v during generation (in the case of a sliding window, not even all), at each step calculating only one vector q, k and v for the last token. It also greatly speeds up the inference process, as it allows you to get rid of redundant calculations

Normalization

Basic approach

x = norm(x + attention(x))
x = norm(x + linear(x))

Current approach

x = x + attention(norm(x))
x = x + linear(norm(x))

Why – simply better convergence. Another change: we used to use layer norm: subtract the average, divide by the standard deviation, then multiply by the trained statistics and add another one. Authors of the article RMSNorm such: well, in general, it’s not necessary to subtract the average, let’s just divide by something, and then multiply by the trainable statistics: it turned out that there were fewer calculations + trainable parameters, but the quality did not deteriorate. So now everyone is using it.

Experts

Well, the FFN layer was not left out. People thought: what if there were heads there too, like in attention, but in a different way?)
Let us have not one such FFN layer, but n. Each layer is an “expert”. But at the same time, each token will not pass through all experts, but through k. But how to choose which k experts a specific token will go through? Let's say we have a sequence of tokens of length M and dimension D. We use the most obtuse classification: multiply the sequence of tokens MxD by the matrix DxN, for each token we get n numbers, from them we select the k largest ones – the indices to which they correspond and there will be expert indices for this token. Then we apply softmax to these k numbers and get weights for the experts. The final pipeline is like this:

  • we classify each token, get expert indices and their weights for it

  • we pass each token through k experts (each expert is the same SwiGLU layer as we discussed)

  • for each token we add the results of k experts with weights

What’s cool: it’s easy to parallel + you can increase the number of experts n, but not change k – as a result, the total number of model parameters (sparse parameter count) will increase, you can stuff more information there, but the complexity of the calculations will not change (because the active parameter count will remain) — we will still use k layers for each token. In the end, it really works – the quality is increasing, everything is super. A fairly easy to understand article on the topic from Mistral – Here. By the way, they use n=8 and k=2.

Sources

In addition to the articles referenced in the text, I can recommend:

Similar Posts

Leave a Reply

Your email address will not be published. Required fields are marked *