Recurrent neural networks strike back

Recurrent neural networks (RNN), as well as its successors such as LSTM and GRU, were once the main tools for working with sequential data. However, in recent years they have been almost completely replaced by transformers (the rise Attention is all you need), which have come to dominate fields from natural language processing to computer vision. IN articleWere RNNs All We Needed“The team of authors Lio Feng, Frederick Tang, Mohamed Osama Ahmed, Yoshua Bengio and Hossein Hajimirsadegh reconsider the potential of RNN, adapting it for parallel computing. Let's take a closer look at what they have achieved success in.

Why stick with RNN?

To begin with, it is worth noting that recurrent networks have an important advantage: their memory requirements are linear with respect to the length of the sequence during the training phase and remain constant during inference. In contrast, transformers have quadratic memory complexity during training and linear complexity during inference, which is especially noticeable on large data sequences. This makes RNNs more resource efficient when solving problems with long sequences.

However, the main disadvantage of classical RNNs was the lack of parallelization of training. Algorithm backpropagation through time (BPTT) is executed sequentially, which makes learning on long sequences very slow. It is this limitation that has given an advantage to transformers, which can train in parallel and, despite higher requirements for computing resources, have significantly accelerated the learning process.

What is the key change?

In recent years, several attempts have emerged to address this limitation of RNNs, such as LRU, Griffin, RWKV, Mamba, and others. They are united by the use of an algorithm parallel prefix scan.

By eliminating hidden state dependencies on input, forgetting, and updating, we enable LSTMs and GRUs to no longer need to be trained via BPTT. They can be trained efficiently using the above algorithm. Based on this approach, the authors simplified the LSTM and GRU architectures by removing restrictions on the range of their output values ​​(for example, getting rid of the tanh activation function) and ensuring that the output signals were time-independent in scale. These changes led to the creation of “light” versions (minLSTM And minGRU), which use significantly fewer parameters compared to traditional options and can be trained in parallel.

minLSTM and minGRU architectures

First, let's write estimates of the number of parameters in previous models. If d_h size of the hidden state, then p_{LSTM} = O(4d_h(d_x+d_h))A p_{GRU} = O(3d_h(d_x+d_h)). Recall that the improvement occurred due to the fact that GRU used two types of gates (gated) versus three for the LSTM, and updated the hidden layer directly when the LSTM had two states – cells and hidden.

Rice. 1. minGRU diagram (page 4)

Rice. 1. minGRU diagram (page 4)

They eliminated the addiction update gate and hidden state from the previous value (z_t), completely excluded reset gate and removed the nonlinear activation function tanh, because already freed from dependence on the hidden layer. So they achieved the use of only O(2d_hd_x)parameters.

Rice. 2. minLSTM circuit (page 5)

Rice. 2. minLSTM circuit (page 5)

In the case of LSTM, the changes affected the dependency input And forget gate from the previous hidden state (h_{t-1}). At the bottom, the two gates are normalized and the LSTM cell state scale becomes independent of time. By ensuring that the scale of the hidden state is independent of time, we also eliminate output gatewhich scales the hidden state. Without output gate the normalized hidden state is equal to the cell state, making the presence of both the hidden state and the cell state redundant. Thus, we exclude the cell state as well.

As a result, the “light” version of LSTM consumes fewer parameters (O(3d_hd_x)) compared to the original architecture.

Briefly about comparison with other models

The minified versions of LSTM and GRU show impressive results: on a sequence of 512 elements, they are 235 and 175 times faster than the original LSTM and GRU, respectively. However, it is worth noting that this increase in speed is achieved at the cost of increasing memory requirements: minGRU requires 88% more memory than classic GRU (for comparison, Mamba uses 56% more than GRU).

The minLSTM and minGRU models demonstrate competitive results on several tasks. For example, they coped with the task Selective Copy (the authors took it from the Mamba article), while other Mamba configurations, such as S4 and H3, were only partially up to the task.

Rice. 3. Results of the LM task (page 9)

Rice. 3. Results of the LM task (page 9)

When tested on a language modeling task using nanoGPT (character-level text modeling on the works of Shakespeare) minLSTM and minGRU also showed excellent results, reaching the minimum value of the loss function faster than transformers. MinGRU and minLSTM reach the optimum in 575 and 625 training steps, respectively, while the transformer requires about 2000+ steps. Mamba works a little worse, but learned quite quickly – in only 400 steps.

Thank you for your attention. Perhaps we are in for a local RNN renaissance.

Similar Posts

Leave a Reply

Your email address will not be published. Required fields are marked *