The world's first large dataset for contextual reinforcement learning

two articlesboth accepted to ICML 2024, and also a JAX environment with a lot of meta-learning tasks. We will definitely tell you about them a little later (subscribe!), but in this article I would like to touch on our recent preprint. In it, we presented and made open-source a huge (by RL standards) and so far the only dataset for in-context RL. Collecting trajectories for 40k tasks and 130B transitions required 50,000 GPU hours. We did this work together with colleagues from the T-Bank AI Research lab.

The dataset is already usable, so let's tell you and hope for future acceptance of the article! Well, let's start a little from afar, I'll tell you what in-context learning is, how it appeared in RL and why we needed our own dataset.

Contextual and meta-learning

In fact, when I said that the field is new, I was being a little bit cunning! Of course, the idea of ​​learning meta-algorithms, i.e. algorithms that are learning algorithms in themselves, appeared relatively long ago. Unlike multi-task learning, we want to get a model that not only works well on all training tasks simultaneously, but is also capable of maximally effectively learning new ones in minimal time. Meta-learning also exists in the more traditional field of supervised learning and reinforcement learning. For those interested: you can learn more about the brief history of meta-learning and the main methods in blogpost by Lilian Weng from Open AI or in recent big review.

Schematic difference between multi-task and meta-learning. Source: meta-world.github.io

Schematic difference between multi-task and meta-learning. Source: meta-world.github.io

But by new, I mean the way meta-models are derived is becoming increasingly popular lately. I think anyone who has played with large language models like ChatGPT will be familiar with it, even if you don't know the formal name.

An example of contextual learning. First, examples of solving the problem are given in the prompt, and then the model copes on its own.

An example of contextual learning. First, examples of solving the problem are given in the prompt, and then the model copes on its own.

You've probably heard, or maybe even tried it yourself, that ChatGPT can be taught a new task by simply feeding it a context prompt with a number of examples and solutions for it. This will work for a wide range of tasks. And this is surprising, because no one specifically taught the model to solve your problem, it figured out how to respond to new data based on the context and examples given to it. This phenomenon is called in-context learning (ICL) and, in essence, is an example of meta-learning too!

The phenomenon of contextual learning was on full display for the first time demonstrated with the release of GPT-3after which for some time the scientific community thought that this property was exclusively a property of huge language models based on the transformer architecture. However, it was later discovered that it is capable of completely little transformersNot transformers at alland, in general, this phenomenon is not unique to language, but also occurs in other domains, for example in image generative models. For a long time, the main limitation was the size of the context due to the quadratic nature of the attention mechanism and, as a result, the high cost of long prompts. With the advent of million-scale contexts, it became clear that contextual learning is often much easier and more efficient more familiar additional training for a specific task.

Contextual reinforcement learning

Conceptually, contextual learning fits very well with RL (see contextual Markov decision process). Considering how sometimes anecdotally difficult it is possible to train an agent from scratch even on several tasks using reinforcement learning, the possibility of obtaining an agent capable of fast and effective adaptation to new tasks is extremely attractive. After all, isn't this our ultimate goal – to obtain an analogue of the base (foundation) model for RL?

The potential of this approach was demonstrated by DeepMind's Ada, came out last year. It turned out that with the proper scope of training, complexity and diversity of environments, you can get an agent that can adapt to new non-trivial tasks at a speed equal to a human, or even faster. I highly recommend watching the video with the presentation of the article!

Despite its success, Ada was trained using a more or less traditional RL algorithm, thus inheriting all the learning instabilities and scaling difficulties. Could it be simpler?

In fact, yes. After all, if you think about it, in reinforcement learning we have trajectories of the agent's interaction with the environment – essentially discrete sequences. Yes, they consist of three modalities (state, action, reward), but globally they are not much different from language from the transformer's point of view. You can feed them into the transformer and predict the next action.

However, the agent will not learn to maximize the reward this way, only to imitate the policy in the dataset. You can train, stipulating the final total reward in the current trajectory, and during testing, predict the highest possible one. This will result in a Decision Transformer (DT), probably the most popular method based on a transformer in RL.

Visualization of the DT approach. Note that the context starts with the future total reward. Source: Lili Chen et al. / arXiv, 2021

Visualization of the DT approach. Note that the context starts with the future total reward. Source: Lili Chen et al. / arXiv, 2021

But DT is poorly suited for contextual learning, and in fact, it has not been observed to have such abilities. First, the spread of the total reward may vary significantly from task to task, or we may not know it in advance in general. And second, we do not explicitly teach the transformer to maximize the reward.

Let's imagine that we have complete training histories of some basic RL algorithms. The reward in them obviously grows as training progresses. We will train a transformer on this, feeding the trajectories as is, that is, without changing the order and without breaking the reward ordering. What do you think will happen? The following will happen.

With a small context size, we get a regular behavioral cloning (BC), since the transformer will simply try to predict the policy action in the dataset based on a short history. However, if we try to expand the context enough to include the history of improvement of the base algorithm, for example, four or more episodes, then suddenly the transformer after training will acquire in-context learning properties! Namely: when testing on a new task, it will start maximizing the reward based on its own context and solve previously unseen problems (but from a distribution similar to the training one). Moreover, it will start solving better than an expert in the data.

Why does this happen? The transformer has no choice. In order to minimize the loss function, it will have to “distill” the base algorithm that collected the data. The key here is that it is the improvement algorithm, and not some specific average policy. Thus, given the right data, we have an extremely simple and stable method for training a reinforcement learning meta-algorithm that inherits all the discoveries for efficient training of transformers from the NLP field. The first to suggest training in this way was article from DeepMind, calling it Algorithm Distillation (AD), which effectively started the in-context RL field.

Visualization of the AD approach. Unlike DT, the context is fed with states, actions, and raw rewards in a natural order. Source: Michael Laskin et al. / arXiv, 2022

Visualization of the AD approach. Unlike DT, the context is fed with states, actions, and raw rewards in a natural order. Source: Michael Laskin et al. / arXiv, 2022

Curious: further theoretical analysis showedthat a transformer trained in this way Maybe implement entire optimal RL algorithms within themselves. Moreover, the algorithms are diverse, from simple Thompson Sampling, Lin-UCB, to many variations of TD-learning, which underlies modern DQN-like algorithms. Whether they actually implement them, and which ones specifically, is an open question. Theoretical analysis only proves that they are capable of this. If you are interested in the topic mechanical interpretabilitythen you can try your luck and try to decode the algorithms learned by the transformer. For example, it was recently discovered that transformers implicitly learn hidden models of the world And plan actions. So the topic is extremely fascinating!

The problem of contextual learning in RL

As soon as we got interested in in-context RL, we quickly realized that it lacked a complex and interesting benchmark, as well as large datasets. All existing works tested their methods on extremely simple environments and were forced to rebuild datasets from scratch each time, since there was no standardization. For us, as researchers, this quickly became frustrating, since in each article the environments and datasets were similar in appearance, but there were always subtle differences (which the authors of the articles forgot to mention), which ultimately affected the results. We decided that we should fix this, and it seems that we have succeeded!

Before spending time and resources on collecting a new dataset, we, of course, looked at all the existing ones, maybe they could be adapted to our needs? It turned out that of all the datasets we knew (especially large ones), none fit the requirements of in-context RL: they do not contain complete training histories of the underlying algorithms and do not have a high diversity of tasks. If we want to study contextual learning from all sides, for example scaling laws in terms of the number and complexity of unique tasks, and not just tokens as in NLP, we need a truly diverse and large dataset containing thousands of tasks.

Comparison of our dataset with existing ones. None of them allow studying contextual learning.

Comparison of our dataset with existing ones. None of them allow studying contextual learning.

It was decided that we needed to build it ourselves. But where can we get an environment with a lot of unique and non-trivial tasks, fast enough to handle training thousands of RL agents from scratch? By a lucky chance, we recently released such an environment ourselves! We called it XLand‑MiniGrid, because it is inspired by and takes the best from XLand from DeepMind and from the well-known in the community minimalistic MiniGrid. It is a grid‑world with a sparse reward, some set of goals and rules that can dynamically change the environment. Tasks can be generated procedurally in the form of binary trees.

Visualization of the environment in XLand-MiniGrid. A short example of one rule transforming two specific objects into a new one is shown.

Visualization of the environment in XLand-MiniGrid. A short example of one rule transforming two specific objects into a new one is shown.

An example of one task in XLand-MiniGrid. At the root of the tree is the goal that needs to be achieved. Any other node sets the rules by which new objects needed for the goal can be obtained from existing ones.

An example of one task in XLand-MiniGrid. At the root of the tree is the goal that needs to be achieved. Any other node sets the rules by which new objects needed for the goal can be obtained from existing ones.

Imagine that every time you log into Minecraft, the crafting tree is randomized and you are asked to obtain some complex item. You will have the exciting task of discovering this crafting tree yourself through experimentation and a good memory. Moreover, having solved one such task, you will not be able to generalize to a new one (after all, the new crafting tree is hidden from you), you will have to start all over again, that is, be adaptive – learn the process of learning and effectively exploring the environment.

Because the framework is written from scratch in JAX, it supports GPUs and TPUs, meaning we can train on tens of thousands of parallel frameworks simultaneously, completely avoiding the CPU bottleneck inherent in most RL training pipelines.

Collecting the dataset

To collect the dataset, we chose medium-1m and trivial-1m — task sets from those pre-generated in XLand-MiniGrid. One simplest one for a smaller dataset, to iterate faster, and one more complex one for the main dataset. We pre-selected a total of 40k tasks of varying complexity from these sets. As a base algorithm, we took the most well-known and stable PPO, adding GRU as a memory module. On each task, we trained on multiple environments and GPUs in parallel. During training, we collected all logs from the first 32 environments so as not to inflate the dataset size too much.

Unfortunately, the first attempt at collection failed. It turned out that most of the tasks are too difficult for the basic PPO+RNN, even if trained on billions of transitions. The reward is too sparse and the chances of solving it by chance are minimal. Various ways to improve the agent's exploration of the environment were tried, but without much success.

In the end, we came to the collection in two phases. Initially, the rules and the goal are hidden from the agent. We cheated a little and made them open during the first phase and fed them to the agent as an additional input (goal-conditioning), training on many different tasks at once. After training, such an agent could solve new tasks in zero-shot mode quite well based on the task information fed to the input. Then, for a specific task, we further trained the agent, completely masking all information about the rules and goals. In essence, after pre-training and masking, we received a policy averaged over all tasks, which significantly improved the agent's ability to explore the environment and accelerated convergence on complex tasks.

The reward during training is averaged over a set of tasks. As can be seen, after pre-training the agent achieves a much better result on average.

The reward during training is averaged over a set of tasks. As can be seen, after pre-training the agent achieves a much better result on average.

The data collection took two weeks in total. It took 50,000 GPU hours to collect trajectories for 40k tasks and 130B transitions. After collection, we filtered out a number of tasks if the agent was unable to solve them or if they failed during collection for one reason or another. The data is stored in hdf5 format with compression enabled. We played around with compression settings and cache size a bit, compressing the dataset by 15 times and slowing down sampling by only half. The final statistics can be seen in the table below:

Further details and technical details of the collection can be found in our article

Further details and technical details of the collection can be found in our article

Dataset validation

Finally, after collection, we had to make sure that the resulting dataset met all the requirements we needed. We looked at two things.

First, we plotted the training histories, finding that the trajectories were naturally ordered by increasing total reward per episode, as they should be.

Second, we reproduced and trained Algorithm Distillation (AD) and Decision-Pretrained Transformer (DPT), two of the most popular methods for in-context RL, to demonstrate that the agent actually exhibits contextual learning capabilities when trained on our data. We found something interesting.

It turned out that DPT cannot work out of the box with environments where it is not executed Markov propertywhich was only briefly mentioned in the original paper. Regarding AD: after training, it was able to solve simple problems and showed signs of contextual learning on slightly more complex ones, but it was far from ideal. Perhaps it needs more training, or maybe new, more effective ideas are needed?

Visualization of the AD test reward. On simple tasks, contextual learning quickly improves the reward (around 50 episodes). However, on more complex tasks, it is either not yet observed or quickly reaches a plateau far from the optimal reward.

Visualization of the AD test reward. On simple tasks, contextual learning quickly improves the reward (around 50 episodes). However, on more complex tasks, it is either not yet observed or quickly reaches a plateau far from the optimal reward.

Summary

As you can see, there is still a lot of work to do! At the moment, the dataset is unique in its niche, both in size and in its properties. We plan to test many more hypotheses based on it and, I am sure, significantly improve the state-of-the-art in the in-context RL area. We hope that it will be useful to you, because it can be used for many other areas too, for example, offline RL.

In any case, I hope that, beyond the dataset itself, this informal and brief introduction to this new field has taught you something new and interesting. Stay tuned and subscribe to the channels AIRI And our team. Be sure to put stars dataset And Wednesday on GitHub. If you are interested in the topic or you are generally interested in RL and scientific work, want to try yourself in this, then we are always happy to talk to you. Write to me or to the group leader To Vlad Kurenkov.

Similar Posts

Leave a Reply

Your email address will not be published. Required fields are marked *