What happens if you combine a transformer and a decision tree

Researchers have made significant progress in the speed of convergence, accuracy, and interpretability of visual transformer solutions. For details, we invite under cat. We share material from the Google Research blog for the launch of the flagship data science course.


In visual recognition, visual transformer (ViT) and its variations are gaining significant attention due to its excellent performance in a variety of mainstream visual applications such as classification images, detection facilities and recognition on video.

The main idea of ​​ViT is to use the power of self-attention layers to explore the global relationships between image fragments. However, as the image size increases, the number of connections between these parts increases quadratically.

Such a design inefficient in terms of data, although the original ViT can also efficiently learn visual representations and does it better than a convolutional neural network with hundreds of millions of images in the training set. Such data requirements are not always practical, and transformers still lag behind convolutional neural networks when they have less data.

Many researchers are trying to find suitable architecture changes that can efficiently learn visual representations by adding convolutional layers and developing hierarchical structures with local self-attention.

The principle of hierarchical structure is one of the main ideas in machine vision models, where lower layers learn more local structures of high-dimensional pixel space, and layers above learn more abstract and high-level low-dimensional feature spaces.

To achieve such a hierarchy, existing ViT-based methods focus on the architecture of various modifications within layers of self-attention and often require significant redesign of the architecture. Moreover, these approaches do not have an interpretable design, so the inner workings of the trained models are difficult to explain.

To solve these problems, we submitted rethinking existing, structure-driven designs and a new, orthogonal approach – a hierarchically grouped transformer. This approach greatly simplifies the design.

The main idea of ​​the work is to divide trait training and components of abstract features (pooling): grouped identical transformer layers encode visual information about parts of images separately from each other, and then aggregate the processed information.

This process is repeated in a hierarchy, resulting in a pyramidal structure. The resulting architecture is able to compete with imagenet and surpasses it in data efficiency benchmarks.

We have shown that such a design can significantly improve data efficiency by speeding up convergence and providing interpretability benefits. Moreover, we have created GradCAT, a new technique for interpreting the decision-making process of trained models.

architecture diagram

The general architecture is implemented simply by adding a few lines of Python to the original ViT. original architecture divides the input image into smaller rectangular parts, projects the pixels on each such part onto a vector with a certain resolution, and then feeds sequences of all these vectors to a transformer architecture with many identical layers grouped.

While each ViT layer processes the image as a whole, using the new method, the grouped transformer layers are used to process only an area (block) of the image containing several adjacent image sections in space. This step is independent for each block; feature learning takes place at the same step.

After calculations, the layer causes block aggregation, combining the latter into spatially contiguous ones. After block aggregation, features corresponding to four adjacent blocks are fed to another transformer module with grouped layers, which processes the union of all four blocks.

This design naturally builds a pyramidal hierarchical network structure, where the bottom layers focus on local features such as textures, and the top layers focus on global features such as the shape of an object. Due to the aggregation of blocks, the dimension is reduced.

Visualization of image processing by a network. Having received the input image, the network first divides it into blocks, each of which contains 4 parts of the image. These parts in each block are linearly projected as vectors and processed in grouped identical layers.

Then the processed block aggregation layer combines information from each block and reduces its spatial size by 4 times. At the top level, the number of blocks is reduced to one, the classification is carried out after the output of this block.

Interpretability

This architecture has an independent information processing mechanism at each node without overlaps. Design reminds decision tree structurewhich shows one-of-a-kind interpretation possibilities: each node of the tree contains independent information about the image block, which is received by the parent node.

To understand the importance of each feature, you can track the flow of information through the nodes. The hierarchical structure of the network preserves the spatial structure of images on all layers of the network, which leads to an efficient study of spatial feature maps from the point of view of interpretation.

We presented a method for interpreting the trained model on test images – GradCAT (gradient-based class-aware tree-traversal – class-aware gradient-based tree traversal). In the hierarchical structure, GradCAT tracks the importance of features of each block (tree node) from top to bottom.

The main idea is to find the most valuable traversal from a node in the top layer to the child nodes of the bottom layers that contribute the most to the classification results. Because each node processes information from a specific region of the image, this traversal can be easily mapped into image space for interpretation (as shown by the lines and dots in the image below).

Here is an example of the top 4 predictions and their associated interpretation results in the input image on the left (with four animals). As shown below, GradCAT highlights the path to the solution next to the hierarchical structure and provides visual cues in local image areas.

Take the image on the left with 4 objects. The figure visualizes the interpretable results for the four predicted classes. The traversal finds the decision path along the tree and the corresponding part of the image that has the most influence on the prediction (this part is shown as a dotted line in the image).

Moreover, the following figures are a visualization of the results on the ImageNet validation dataset. They show how this approach allows some intuitive observations to be made.

The example with the lighter in the upper left corner is definitely interesting, because the class of observed truth – lighter / matchbox – is determined by the lower right box.

However, the most prominent visual cues (with the highest node value) are actually the red light in the upper left corner, which conceptually shares visual cues with a laser pointer. This can be seen from the red lines, which indicate the parts of the images with the greatest impact on the prediction.

Thus, despite errors in the hints, the output prediction is correct. Also, the four child nodes of the wooden spoon image below have similar feature importance (look at the number of nodes rendered; the more, the more important). This is because the texture of a wooden table is similar to that of a spoon.

Visualization of GradCAT results. Images from the validation set imagenet

Unlike the original ViT, our hierarchical structure preserves the spatial relationships of learned representations. The output of the upper layers are low-resolution feature maps from input images, allowing the model to easily perform an attention-based interpretation by applying CAM (class attention map) to the learned representations. class attention map; [в публикации по ссылке вы увидите class activation map — карта активации класса]).

Visualization of CAM results on the ImageNet validation set. The warmer the colors, the more attention

Benefits in Convergence

With this design, feature learning occurs independently, only in local areas, and is abstracted inside the aggregation function. It, as well as a simple implementation, is mostly sufficient for other types of visual recognition tasks outside of classification. The approach also greatly accelerates convergence, significantly reducing training time and achieving maximum accuracy.

We tested these achievements in two ways:

  1. We compared the accuracy of the ViT structure on the ImageNet set with different numbers of training epochs. The results on the left side of the figure below show a convergence that is much faster than in the original ViT. This is about a 20% improvement at 30 training epochs.

  2. Changed the architecture to complete the tasks unconditional image generation, since ViT-based model training due to problems with the speed of convergence is a challenge. Creating such a generator is straightforward: we have transposed the new architecture. The input image became a vector representation, and the output became a full image with RGB channels.

We have replaced block aggregation with a deaggregation component supported by Pixel shuffling [перестановка пикселей — операция, используемая в моделях сверхразрешения для реализации эффективных субпиксельных свёрток с определённым шагом].

Surprisingly, we found that our generator is easy to train and demonstrates both convergence acceleration and the metric FID (Fréchet distance by Interception v3) is better than SAGAN. FID measures how similar images are to the real thing.

Left: Accuracy of the standard ViT architecture on ImageNet with different total training epochs. Right: FID in image generation (lower is better) with a single training of a thousand epochs

In both problems, our method shows the best convergence rate.

Conclusion

In this paper, we have demonstrated a simple idea of ​​separating feature learning and feature extraction in a hierarchically grouped design, improving interpretability through a class-aware gradient-based tree traversal method. This architecture accelerates convergence not only in classification problems.

The proposed idea is centered around the aggregation function and is thus orthogonal to the advanced design of the self-attention architecture [не влияет на этот дизайн]. We hope that this research will serve as an encouragement for architectural designers to explore even more interpretable and data-efficient visual recognition models, such as adaptation this work to generate high resolution images. We have also released source part of the work that is related to classification.

To continue learning Python and neural networks to learn how to solve business problems, you can take our courses:

Choose another in-demand profession.

Abbreviated catalog of courses and professions

Data Science and Machine Learning

Python, web development

Mobile development

Java and C#

From basics to depth

As well as

Similar Posts

Leave a Reply

Your email address will not be published. Required fields are marked *