GigaChat training with context of hundreds of thousands of tokens

Remember the phrase “640 kilobytes of memory is enough for everyone”? The demands of humanity are always growing, and the industry needs to keep up.

This is the case with language models. Until recently, we were all amazed at what they were capable of. And now this is not enough for us: “okay, maybe the model in the dialogue takes into account what I said hundreds of replicas ago?”

In the spring, at our I'ML conference, Evgeny Kosarev (SberDevices) talked about how they approached increasing the context when working on GigaChat. And now we are publishing a text transcript of his report. We also attach links to his video: YouTube, VK Video.

Report plan

GigaChat is a popular language model in Russian. Used in dozens of Sber products – SaluteJazz (formerly SberJazz), SaluteBot, smart speakers, TVs, voice assistant. Outside of Sber, GigaChat is also very useful; thousands of clients use the GigaChat API.

Before preparing the report, I asked GigaChat how to talk about a long language context at a conference. And AI suggested a report plan that I liked:

  • benefits of long context;

  • technologies that help to implement it;

  • development prospects.

LLaMa

Ten years ago the following definition was sufficient. Language model — some algorithm that can intelligently continue the text. This is how the simple T9 system could be described.

As time went on, our understanding of qualitative response began to change dramatically. Now we want language models to solve math problems, write code, have cool conversations, and much, much more. And at the beginning of 2024 we came to the LLaMa-3 model.

At the time of writing, this was the strongest language model in the public domain across a variety of benchmarks: MMLU and GPQA (benchmarks covering 57 areas of general knowledge), HumanEval (assessing the ability to write code), and GSM-8K and MATH (testing the ability to write code). solve mathematical problems at school level and more difficult).

The progress of the LLaMa family of models is impressive. In about a year, the result on difficult benchmarks improved twofold, and on some benchmarks even tenfold. Everything seems to be fine, but something is missing.

Context in LLM

I’ll tell you what context is and how language models understand it. The context of the language model is processed using tokens – units of text. Previously, these were letters or whole words, now it is something more complex. The text is divided into substrings with numbers that the language model operates on.

Nowadays, most language models have a context of 8 thousand tokens. If translated into understandable numbers, this is four pages of A4 text. It turns out that we can operate with such a volume quite well and solve important problems. Is this a lot or a little?

Now that language models have become good, they are increasingly being used for tasks that require more. A business may ask a question related to a document of more than 20 pages (this could be an article or a legal document). For example, a document contains a legal act; what does it regulate? It's written somewhere in the text. Let me remind you that our model processes four A4 pages. What to do with the rest?

Solutions

Modification of the attention block. Now almost all language models have an attention block. Its task is to analyze the connection of tokens with each other. Through this block, the model understands the context. One option is to simplify this block and come up with heuristics.

The CoLT5 heuristic works like this. Let’s say 100 thousand tokens come to us, and the context is 8 thousand tokens. We can't process the entire context, but we can separate tokens into important and unimportant. In the context window, process all important tokens and look at the connections between them, and send unimportant tokens to a lighter branch.

In the light branch, it is proposed to count connections not between all tokens, but between adjacent three. Attention becomes quite cheap, and in two passes of Attention into the heavy branch through the light branch, we can process the entire context.

A fairly obvious flaw is that there is an entity that determines the importance of the token. But defining this in free text and building such a router is a rather difficult task. Many mistakes occur in determining importance.

Secondly, the light branch of attention still processes context poorly and loses a lot of information.

Sliding Window Attention. This modification of attention has been around for many years, and the last time it was popularized by the team that made the Mistral model. Before the advent of LLaMa-3, this language model was considered the strongest with 7 billion parameters. The authors declared the following.

First, let's remember that we are training not only a large language model, but also a deep neural network. Let's take advantage of the fact that we have many layers and transfer information sequentially between layers:

If we have a token numbered 20,000 and a context window at 4000, then it looks at the last 4000 tokens. The relationships of tokens numbered from 16,000 to 20,000 are analyzed. Moving up the context, we get the connection of the first token with the last. Everything seems to be fine. We have full attention in the window and, using the language model architecture, we have obtained an aggregation of information from the lower layers. The disadvantage is the loss of information through layers, but it is not as obvious compared to what was in COLT5. How do you know if there is a flaw? To do this, let's introduce a benchmark.

PassKey. He examines a needle in a haystack. There are many long context benchmarks, but this one is one of the simplest and most understandable. Everyone uses it when talking about long context.

Let's say our model understands a context of hundreds of thousands of tokens. Let's take some fact and hide it at a certain depth of context. For example, let’s take a context depth of 8000 and somewhere at position 4000 we’ll place a fact like date of birth or name—something that can be obtained upon request. We will cover all other context outside this depth with random noise or long texts like “War and Peace.” An elementary task from the point of view of determining quality.

We have a large context, a fact in the middle, the beginning or the end of this context, and our task is to get it. If the model copes, then it receives 1, if it fails, then 0. Binary response format. For each context length, we average all indicators and get an answer in what percentage of cases the model can get this fact. This makes it clear whether it processes the context well or not.

For Mistral 7B Sliding Window Attention the results are sad.

In a basic context window of 4000 tokens, it manages to get results 100% of the time. And in the context of 8000 tokens – already in half the cases. The authors trained the model with a context of 32,000 tokens. It was declared that this was a context that she understood well. But there information is lost in 80% of cases. The model does not understand long context. I assume that it was precisely because of this that the authors retrained the model. And the result was Mistral 7B v.02 with an honest attention mechanism, that is, without any optimizations. And the authors honestly receive 100% quality in the entire context.

Why do I think the authors retrained the model for the sake of this? Because we measured different versions of Mistral, and according to the main quality benchmarks they are almost the same. Only the understanding of the context is different.

RAGor retrieval augmented generation. Another almost good solution. We add a search engine to our language model. He can search for information both on the Internet and in documents. If we have a large document, the search engine says that the request found information in the fifth paragraph. Usually it is small. We can add this information to the context of the language model. She will give us a beautiful correct answer.

Advantages:

Flaws:

In general, training language models is expensive, so it is not trained every week or every month. Therefore, we train the model, then it stores knowledge that becomes outdated over time. So that they do not become outdated, we add RAG.


All of the above approaches have drawbacks. It is not possible to create a large context so that the model understands all the facts and can operate with a large amount of data.

Therefore, you need to honestly increase the context window of the language model to 130,000 tokens or more. Why doesn't everyone do this? The problem is that it costs too much:

  1. Tensor activations do not fit on the GPU.

  2. Slowdown of calculations.

Let's try to optimize the calculations.

Computation Optimization

First, let's define what we are going to optimize. We have neural networks that consist of layers. Very approximately the work of the inner layer looks like this.

The matrix M1 comes to the input of the layer – [B, T, D]. It has dimension. B is the batch size, T is the number of tokens in the context, D is the internal representation. The latter depends on the size of the model: the larger the model, the larger the D parameter. There is also some layer weight [D,H]. A matrix product (activation) occurs – what is stored on the video card and what may not fit on it. As we can see, for small matrices M1 and M2, quite large activation can result. Something needs to be done to be able to train large language models with large context.

To perform optimizations, you need to select an architecture. In the report I will talk about the architectures of open language models such as LLaMa-1,2,3. These are decoder architectures, and they consist of two important blocks.

The first is the attention block, the second is the MLP block. Please note that these are not the only blocks that are optimized. With distributed learning, the embedding layer is always optimized, as well as the distributed calculation of the loss function. In the embedding layer, everything is trivial – if we have many video cards, then we simply store the weights of this layer separately on video cards. After the layer is assembled to pass the forward, there is half the layer and half the layer on the second video card that communicate with each other, and the result is obtained.

When calculating the loss function, everything is simple – we connect the library, and the calculation takes place. There are no variations yet.

More interesting is what to do with the MLP attention mechanism. We need it to take into account the relationship of tokens and transform the knowledge of the model. The MLP blocks themselves are quite large. If we calculate how long it takes to calculate attention mechanisms and how long MLP takes, we get the following.

The complexity of computing attention mechanisms is quadratic both in the length of the context and in the internal representations, one might say, in the size of the model. The memory of the attention mechanism is always quadratic. If we want to double the context, then our memory quadruples. For MLP everything is a little better, the dependence there is linear. Let's move on to the first optimizations.

Flash Attention

This optimization is familiar to anyone involved in MLP. Large language models are usually trained in Python, the PyTorch framework. It is worth saying that PyTorch is just an interface. When we ask it to multiply matrices, it sends a command to the video card, and tensors are uploaded to it.

An operation is performed on the GPU, the result of the execution is sent back to the processor, and that’s how it all works. If we write many operations in a row, then the cost of starting the operation is high. Therefore, it is proposed to write an efficient computation of the attention mechanism on video cards.

Firstly, as we see on the left, we can make not many operation calls in PyTorch, but just one, but a kernel written in CUDA will process faster than many different calls.

Secondly, we still want to efficiently calculate attention mechanisms on a video card, so we cannot ignore the design of the video card and the design of our system in general. The structure of the system looks like a pyramid. That is, the video card has fast memory (orange in the image), green memory is larger and slightly slower. And the slowest memory of the three is processor memory.

So we can calculate the attention block as shown on the right in the image above. For the attention block, we need to calculate the pairwise relationships of all tokens in the context. If the context size is T, then this is a T2 computation. The author argues that we do not need to store the entire matrix in memory and completely materialize it. We can count it in blocks and send it for calculations, and collect it in other places. Our memory is used more efficiently. In the model setup, I took the 7 billion model, and using flash-attention, our calculation speed increases by 50% and we lose 7 GB of memory on activation. It's pretty good.

GQA

The point is this. We use a multi-head attention mechanism, that is, the query key value matrix is ​​divided into smaller submatrices. In the standard implementation we use the image on the left.

Each query head has a corresponding key and value head. Next, using such calculations, we determine their interaction and calculate the attention matrix. One paper argues that we do not need to make unique key and value heads for each query matrix. You can group them and reduce the key matrices and value matrices without much loss in quality. This optimization gives us a reduction in memory by a gigabyte and a speed increase of ten percent.

Attention is also very important for the inference of language models. In addition to training a language model, you also need to use it. Key value cache is some optimization of recalculation so that the inference of the language model is fast. The smaller it is, the more effective the inference. That is, this optimization also affects the speed of the language model.

So, we were able to speed up attention. What's next? Two things come to mind here: tensor parallel and sequence parallel.

Tensor Parallel

Matrix product occurs on video cards. Let's assume that the product of matrices does not fit on one video card. Let's use two video cards or, in general, n video cards.

A matrix has two important dimensions: the internal representation dimension and the context dimension. Thus, we can share these matrices between two video cards. Video card No. 0 contains and operates the first half of the matrix, video card No. 1 operates the second half of the matrix. Using communication between video cards, you can make all activations smaller by a factor of n and the stored memory becomes smaller by a factor of n. Also, the number of video cards you have increases by n times. This allows you to train larger models with more context, but on more maps.

What's the result? Tensor parallel is partitioning by model dimension. It reduces memory by 50% and increases speed by 90%. Why not 100%? To produce such things, you need to make communications between video cards. At the forward and backward stages, video cards communicate and exchange calculation results. Thus, for effective implementation, the speedup should be approximately 90%.

Sequence parallel

I'll show you a scary picture again.

Sequence parallel has similar performance. The memory is half as much, the speed is again 90%. As a result, in both approaches we divide the matrix in half. What is better, why and how to use it?

Tensor vs Sequence Parallel

There was an article that was already two or three years old at the time of the report. It calculated how much memory is spent on the first and second approaches. The authors arrived at the following formulas:

You can notice that there are common terms and there are completely different ones, which reduce the memory by a factor of n. Therefore, you need to choose one of two approaches depending on the situation. If we recalculate all the same things for the LLaMa architecture, we get the following.

Tensor vs Sequence Parallel

Tensor vs Sequence Parallel

Everything has been calculated for B = 1, and can be recalculated for larger quantities. The attention mechanism is beneficial when the size of the context is 8 times larger than the internal representation. In seven billion LLaMa models, this is 4000 tokens. That is, it is more profitable for us to use sequence parallel in the context of 32,000 tokens. For MLP we can also calculate this. Excess by 16 times.

If the context is 64,000 tokens, then we are guaranteed to be more profitable using sequence parallel. Sometimes you can combine them.

In summary, since we want to learn extremely large contexts, we will discuss how sequence parallel can be implemented.

SP implementation – all2all

Lazy implementation suggested in deepspeed. It doesn't change the code at all. You have code for training the language model. You only need to do two things. First, before entering the language model, split the tokens. Why is this necessary?

Language models, namely decoder layers, can architecturally handle arbitrary context.

The model itself can accommodate a context of any size from the point of view of matrices, from the point of view of calculations. Thus, to train the model, without changing the code, we can simply send the first half of the context to one part of the model, and its second half to the other part. The only thing we change is that all2all communication is used before and after the attention block.

Briefly, it does the following: if the matrix is ​​divided according to the sequence dimension, then after all2all communication it completely assembles the sequence and parses the internal sequence of the model. Thus, before all2all communications, the model operates in sequence parallel mode, and inside the model it operates in tensor parallel mode. This implementation is the simplest, why not try it?

Ring Attention

This implementation is more complicated. It allows training in extreme contexts and can be combined with the previous implementation. Let's see what happens.

Ring attention inherits the idea from flash-attention (and maybe vice versa), that there is no need to implement the entire attention matrix internally. Let me remind you that the attention matrix is ​​quadratic with respect to the length of the context. We can sequentially evaluate attention blocks in a loop. This is what is proposed to be done here.

There is a query matrix, a key value matrix. They are painted in four colors, and in four cycles we can, by passing the query key value between video cards, calculate the complete attention block. This is quite good, because while the attention block is being considered, the video cards are communicating. There is an overlap of communications and computation. If we remember all2all communications, they call for a simple model. The models calculated the MLP block or previous blocks in parallel, stopped, and talked. We counted attention, stopped again, talked again and continued counting. During communication, the cluster is idle and the video cards do nothing. This is bad.

There is no such thing in ring attention. While the video cards are communicating, calculations are being made, the key value is transferred to the matrix to the next block. What can you notice here?

Below are attention masks – black and gray squares. You will notice that there are attention masks that are completely square, completely gray, completely black and triangular. The black mask is a simple video card. Our attention is causal, that is, the second token depends on the first, the third on the first and second, the fourth on the first three, but the fourth does not depend on the fifth. That is, the token does not see what happens in the future. When a black mask appears on the video card, the entire result of its calculations is thrown away and the model calculates in vain. The authors of striped ring attention propose to fix this.

Striped Ring Attention

A different cutting into query and key value in the matrix is ​​proposed.

The blocks turn out to be slightly different, which make the matrices triangular. These stripes allow us to always count either more or slightly less than half of the attention mechanisms. The load is distributed more optimally between video cards, and we already have some acceleration. But there is a nuance that the next acceleration can handle.

ZigZag Ring Attention

If striped ring attention is the sequence that is written in the first line in the image below, then they would communicate and calculate attention like this.

ZigZag groups two or in general n tokens with each other, that is, sequentially. This grouping leads us to one of three blocks of attention. Either it is painted half horizontally or vertically, or it is upper triangular or lower triangular (depending on how you look at it). That is, the video card costs either 50% or less.

It seems that this implementation is better than the previous ones. Let's check it on benchmarks.

Comparison of implementations

Here is a table comparing benchmarks on different video cards.

ZigZag ring attention wins over everyone in terms of speed. You can see that the increase in ZigZag relative to the standard ring attention is quite strong, about one and a half times.

As a result, we have all2all and ring attention. Which is more profitable? To measure the effect, I took 64 video cards and used the following setup.

This is TP=1, SP=4. And I measured the speed of tokens processed by one GPU, and the maximum amount of memory required by one video card. What do we have?

Ring attention for all models is optimal in terms of memory. All2all eats memory. In fact, it was not designed to be memory efficient. On small 7 billion models, all2all is more profitable in terms of speed, but not by much. Most likely, in the case of ring attention, there is no complete overlap of communications and calculations, and perhaps it is possible to optimize the setup and get the difference between ring attention and all2all. But if you don’t bother, then all2all will be better for small models. It is also easier to implement.

If we take large models worth 30 billion and make a context of 100,000 tokens on them, then all2all will not run in principle – it will not fit on the video card. And for a smaller context, it is almost twice as slow and consumes one and a half times more memory.

As a result, we have technology. Now let's move on to learning GigaChat.

GigaChat training

Let's summarize:

  • We have formulas for counting activations.

  • We know when tensor parallel is more profitable than sequence parallel.

  • We can combine both modes.

  • Effective ring attention is needed for extreme contexts.

Our results turned out like this. Our GigaChat Pro with 29 billion parameters is trained on the context of 128,000 tokens. GigaChat Light at 7 billion is trained on the context of 1 million tokens.

To make it clearer, I went to the Internet, calculated approximately how many letters there are in a Russian word, how many words fit on an A4 page, and the results are as follows. GigaChat Pro can handle 64 pages at a time, and GigaChat Light can handle 512 A4 pages. As a result, the context is gigantic, you can even slip books into it. Why is this even necessary?

Prospects

The prospects are as follows.

  • Code, video and audio require a lot of context.

  • Solving business problems with additional information.

  • Personalization of language models.

  • Multi-agent interaction (that is, several smart devices).

That's all! You can ask questions about the report in Telegram: evgenijkkk.

Similar Posts

Leave a Reply

Your email address will not be published. Required fields are marked *