Visual Transformers (ViT)

Since its appearance in 2017 in the publication Attention is All You Need [1] transformers have become a dominant approach in natural language processing (NLP). In 2021 in the article An Image is Worth 16×16 Words [2] Transformers have been successfully adapted for computer vision tasks. Since then, many transformer-based architectures have been proposed for computer vision.

In this article we will look at the Vision Transformer (ViT) in the form in which it was presented in the article [2]. It includes the open source ViT code as well as conceptual explanations of the components. The ViT implementation discussed in the article was made using the PyTorch package.

What are visual transformers?

As stated in the publication Attention is All You Need,Transformers are a machine learning architecture that uses ,self-attention mechanism as a core component for learning. Transformers quickly became a cutting-edge method for solving serial data processing problems, such as translating from one language to another.

In progress An Image is Worth 16×16 Words transformer proposed in the work Attention is All You Need [1]was successfully modified to solve image classification problems, leading to the creation of Vision Transformer (ViT). ViT is based on the same self-attention mechanism as in the transformer from [1]. However, unlike the original Transformers architecture for NLP, which includes an encoder and a decoder, ViT only uses an encoder. The output of the encoder is passed to the output layer, which is responsible for the final prediction.

The disadvantage of ViT, as shown in [2],is its dependence on large data sets to achieve optimal ,performance. The best models are pre-trained on the proprietary JFT-300M dataset. Models pretrained on the smaller, open-source ImageNet-21k dataset perform on par with state-of-the-art convolutional models such as ResNet.

Tokens-to-Token ViT Model: Training Vision Transformers from Scratch on ImageNet [3] attempts to eliminate the need for pre-training by proposing a method to convert the input image into tokens. You can read more about this method Here. In this article we will look at the ViT architecture and its implementation proposed in the work [2].

Meet the model

This article follows the structure of the model described in the article An Image is Worth 16×16 Words [2]. However, the code described in the article is not publicly available. Code from a later post Tokens-to-Token ViT [3] available on GitHub. The Tokens-to-Token ViT (T2T-ViT) model adds a Tokens-to-Token (T2T) module to the standard ViT framework. The code in this article is based on the ViT components from the repository Tokens-to-Token ViT [3] on GitHub. Modifications made to this implementation include, but are not limited to: support for disproportionate images and the removal of Dropout layers.

Below is a diagram of the ViT model.

ViT model diagram (image by author)

ViT model diagram (image by author)

Image tokenization

The first step of ViT is to generate tokens from the input image. Transformers work with a sequence of tokens; in natural language processing (NLP), each token typically represents a word. In the case of computer vision, tokenization involves breaking up an image into fixed-sized patches.

ViT converts an image into tokens such that each token represents a local region—or patch—of the image. They describe converting an image of height H, width W, and channels C into N markers with marker size P:

N = \frac{HW}{P^2}

Each token has a length of P²∗C, where P² is the number of pixels in the patch and C is the number of channels.

Let's look at an example of patch tokenization using pixel art as an example “Sunset in the Mountains” Luis Zuno (@ansimuz) [4]. The original work has been cropped and converted to a single channel image. This means that each pixel is represented by a number between zero and one. Single-channel images are usually displayed in grayscale, but we will display them in purple for ease of viewing.

Note that patch tokenization is missing from the code provided in [3]; All code in this section was written by the author.

mountains = np.load(os.path.join(figure_path, 'mountains.npy'))

H = mountains.shape[0]
W = mountains.shape[1]
print('Mountain at Dusk is H =', H, 'and W =', W, 'pixels.')
print('\n')

fig = plt.figure(figsize=(10,6))
plt.imshow(mountains, cmap='Purples_r')
plt.xticks(np.arange(-0.5, W+1, 10), labels=np.arange(0, W+1, 10))
plt.yticks(np.arange(-0.5, H+1, 10), labels=np.arange(0, H+1, 10))
plt.clim([0,1])
cbar_ax = fig.add_axes([0.95, .11, 0.05, 0.77])
plt.clim([0, 1])
plt.colorbar(cax=cbar_ax);
#plt.savefig(os.path.join(figure_path, 'mountains.png'))
Изображение «Закат в горах» имеет разрешение H = 60 и W = 100 пикселей.
Code output (image by author)

Code output (image by author)

This image has H=60 and W=100. We'll set P=20 since it divides H and W without a remainder.


P = 20
N = int((H*W)/(P**2))
print('There will be', N, 'patches, each', P, 'by', str(P)+'.')
print('\n')

fig = plt.figure(figsize=(10,6))
plt.imshow(mountains, cmap='Purples_r')
plt.hlines(np.arange(P, H, P)-0.5, -0.5, W-0.5, color="w")
plt.vlines(np.arange(P, W, P)-0.5, -0.5, H-0.5, color="w")
plt.xticks(np.arange(-0.5, W+1, 10), labels=np.arange(0, W+1, 10))
plt.yticks(np.arange(-0.5, H+1, 10), labels=np.arange(0, H+1, 10))
x_text = np.tile(np.arange(9.5, W, P), 3)
y_text = np.repeat(np.arange(9.5, H, P), 5)
for i in range(1, N+1):
    plt.text(x_text[i-1], y_text[i-1], str(i), color="w", fontsize="xx-large", ha="center")
plt.text(x_text[2], y_text[2], str(3), color="k", fontsize="xx-large", ha="center");
#plt.savefig(os.path.join(figure_path, 'mountain_patches.png'), bbox_inches="tight"
Всего будет 15 патчей, каждый размером 20 на 20.
Code output (image by author)

Code output (image by author)

By converting these patches into one-dimensional vectors, we see the resulting tokens. Let's look at patch number 12 as an example since it contains four different shades.

print('Each patch will make a token of length', str(P**2)+'.')
print('\n')

patch12 = mountains[40:60, 20:40]
token12 = patch12.reshape(1, P**2)

fig = plt.figure(figsize=(10,1))
plt.imshow(token12, aspect=10, cmap='Purples_r')
plt.clim([0,1])
plt.xticks(np.arange(-0.5, 401, 50), labels=np.arange(0, 401, 50))
plt.yticks([]);
#plt.savefig(os.path.join(figure_path, 'mountain_token12.png'), bbox_inches="tight")
Каждый патч будет создавать токен длиной 400.
Code output (image by author)

Code output (image by author)

Once tokens are extracted from an image, a linear mapping is typically applied to them, which changes their dimensionality. This mapping is implemented using a trainable linear layer. The new length of tokens, depending on the context, is called latent dimension, channel dimension, or token length. After this transformation, the tokens no longer contain visual information that can be matched to the patches in the original image.

Now that the concept of patch tokenization is clear, we can look at its implementation in code.

class Patch_Tokenization(nn.Module):
    def __init__(self,
                img_size: tuple[int, int, int]=(1, 1, 60, 100),
                patch_size: int=50,
                token_len: int=768):

        """ Patch Tokenization Module
            Args:
                img_size (tuple[int, int, int]): size of input (channels, height, width)
                patch_size (int): the side length of a square patch
                token_len (int): desired length of an output token
        """
        super().__init__()

        ## Defining Parameters
        self.img_size = img_size
        C, H, W = self.img_size
        self.patch_size = patch_size
        self.token_len = token_len
        assert H % self.patch_size == 0, 'Height of image must be evenly divisible by patch size.'
        assert W % self.patch_size == 0, 'Width of image must be evenly divisible by patch size.'
        self.num_tokens = (H / self.patch_size) * (W / self.patch_size)

        ## Defining Layers
        self.split = nn.Unfold(kernel_size=self.patch_size, stride=self.patch_size, padding=0)
        self.project = nn.Linear((self.patch_size**2)*C, token_len)

    def forward(self, x):
        x = self.split(x).transpose(1,0)
        x = self.project(x)
        return x

Notice the two statements assertwhich check that the image dimensions are divisible by the patch size without remainder. The actual patching is done using a layer torch.nn.Unfold.

We'll run this code example using a cropped, single-channel version of the image “Sunset in the Mountains”. We should see values ​​for the number of tokens and initial token size as described above. We will use token_len=768 as projected length, which corresponds to the size of the base version ViT.

The first line in the code block below converts the image “Sunset in the Mountains” from a NumPy array to a Torch tensor. We also need to apply unsqueeze to the tensor to create the channel dimension and the packet dimension. As above, we have one channel. Since we have one image, the batch size is (batch size) is equal to 1.

x = torch.from_numpy(mountains).unsqueeze(0).unsqueeze(0).to(torch.float32)
token_len = 768
print('Input dimensions are\n\tbatchsize:', x.shape[0], '\n\tnumber of input channels:', x.shape[1], '\n\timage size:', (x.shape[2], x.shape[3]))

# Define the Module
patch_tokens = Patch_Tokenization(img_size=(x.shape[1], x.shape[2], x.shape[3]),
                                    patch_size = P,
                                    token_len = token_len)
Входные размеры: 
  размер батча: 1 
  количество входных каналов: 1
  размер изображения: (60, 100)

Now let's divide the image into tokens.

x = patch_tokens.project(x)
print('After projection, dimensions are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\ttoken length:', x.shape[2])
Размеры после токенизации патчей: 
  размер батча: 1 
  количество токенов: 15 
  длина токена: 400

As we saw in the example, there is N=15 tokens, each of which has a length of 400. Finally, we project the tokens to a length of token_len

x = patch_tokens.project(x)
print('After projection, dimensions are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\ttoken length:', x.shape[2])
Размеры после проекции: 
  размер батча: 1 
  количество токенов: 15 
  длина токена: 768

Now that we have the tokens, we are ready to move on to working with ViT.

Token processing

We will refer to the next two steps of ViT, prior to encoding, as “token processing.” The token processing component, which is responsible for preparing data for encoding blocks, is shown in the ViT diagram below.

Token processing components in a ViT diagram (image by author)

Token processing components in a ViT diagram (image by author)

The first step is to add an empty token, called a Prediction Token, to the image tokens. This token will be used at the output of the encoding blocks to create the prediction. It is initially empty—equivalent to zero—so that it can receive information from other image tokens.

We will start with 175 tokens. Each token has a length of 768, which corresponds to the ViT base case. We use batch size 13 because it is a simple number and will not be confused with other parameters.

# Define an Input
num_tokens = 175
token_len = 768
batch = 13
x = torch.rand(batch, num_tokens, token_len)
print('Input dimensions are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\ttoken length:', x.shape[2])

# Append a Prediction Token
pred_token = torch.zeros(1, 1, token_len).expand(batch, -1, -1)
print('Prediction Token dimensions are\n\tbatchsize:', pred_token.shape[0], '\n\tnumber of tokens:', pred_token.shape[1], '\n\ttoken length:', pred_token.shape[2])

x = torch.cat((pred_token, x), dim=1)
print('Dimensions with Prediction Token are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\ttoken length:', x.shape[2])
Входные размеры:
   размер батча: 13 
   количество токенов: 175 
   длина токена: 768
Размеры токена предсказания:
   размер батча: 13 
   количество токенов: 1 
   длина токена: 768
Размеры с токеном предсказания:
   размер батча: 13 
   количество токенов: 176 
   длина токена: 768

Now we add positional encoding for our tokens. Positional encoding allows the transformer to understand the order of image tokens. Note that positional encoding is added to tokens using addition rather than concatenation, which preserves the original dimension of the tokens. The implementation details and various positional encoding options are a separate complex topic for another time.

def get_sinusoid_encoding(num_tokens, token_len):
    """ Make Sinusoid Encoding Table

        Args:
            num_tokens (int): number of tokens
            token_len (int): length of a token
            
        Returns:
            (torch.FloatTensor) sinusoidal position encoding table
    """

    def get_position_angle_vec(i):
        return [i / np.power(10000, 2 * (j // 2) / token_len) for j in range(token_len)]

    sinusoid_table = np.array([get_position_angle_vec(i) for i in range(num_tokens)])
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) 

    return torch.FloatTensor(sinusoid_table).unsqueeze(0)

PE = get_sinusoid_encoding(num_tokens+1, token_len)
print('Position embedding dimensions are\n\tnumber of tokens:', PE.shape[1], '\n\ttoken length:', PE.shape[2])

x = x + PE
print('Dimensions with Position Embedding are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\ttoken length:', x.shape[2])
Размеры позиционного эмбеддинга:
   количество токенов: 176 
   длина токена: 768
Размеры с позиционным эмбеддингом:
   размер батча: 13 
   количество токенов: 176 
   длина токена: 768

Now our tokens are ready to be passed to the encoding blocks, where the main phase of model training begins.

Coding block

Coding units are the components where the model is actually trained by processing image tokens using the attention engine and neural networks. The number of coding blocks is a user-specified hyperparameter. The encoding block diagram is shown below.

Coding block (image by author)

Coding block (image by author)

The code for the encoding block is given below.

class Encoding(nn.Module):

    def __init__(self,
       dim: int,
       num_heads: int=1,
       hidden_chan_mul: float=4.,
       qkv_bias: bool=False,
       qk_scale: NoneFloat=None,
       act_layer=nn.GELU, 
       norm_layer=nn.LayerNorm):
        
        """ Encoding Block

            Args:
                dim (int): size of a single token
                num_heads(int): number of attention heads in MSA
                hidden_chan_mul (float): multiplier to determine the number of hidden channels (features) in the NeuralNet component
                qkv_bias (bool): determines if the qkv layer learns an addative bias
                qk_scale (NoneFloat): value to scale the queries and keys by; 
                                    if None, queries and keys are scaled by ``head_dim ** -0.5``
                act_layer(nn.modules.activation): torch neural network layer class to use as activation
                norm_layer(nn.modules.normalization): torch neural network layer class to use as normalization
        """

        super().__init__()

        ## Define Layers
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim=dim,
                            chan=dim,
                            num_heads=num_heads,
                            qkv_bias=qkv_bias,
                            qk_scale=qk_scale)
        self.norm2 = norm_layer(dim)
        self.neuralnet = NeuralNet(in_chan=dim,
                                hidden_chan=int(dim*hidden_chan_mul),
                                out_chan=dim,
                                act_layer=act_layer)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.neuralnet(self.norm2(x))
        return x

The num_heads, qkv_bias and qk_scale parameters specify key aspects of the attention module (Attention). We will leave a detailed consideration of the attention mechanism for visual transformers to another time.

The hidden_chan_mul parameters define the size of the hidden layers of the neural network, and act_layer specifies the activation function, which can be selected from the module torch.nn.modules.activation. We will look at the neural network module in more detail later in the article.

The norm_layer specifies the normalization type and can be selected from any layer in torch.nn.modules.normalization.

Now we'll look at each blue block in the diagram and its accompanying code. We will use 176 tokens with a length of 768. The packet size will be 13 as this is a prime number and will not cause confusion with other parameters. We will use 4 attention heads because this allows us to split the token length (768) into equal parts for each “head” in the attention engine. The dimension of each attention head is calculated automatically and is not directly displayed in the encoding block

# Define an Input
num_tokens = 176
token_len = 768
batch = 13
heads = 4
x = torch.rand(batch, num_tokens, token_len)
print('Input dimensions are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\ttoken length:', x.shape[2])

# Define the Module
E = Encoding(dim=token_len, num_heads=heads, hidden_chan_mul=1.5, qkv_bias=False, qk_scale=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm)
E.eval();
Входные размеры:
   размер батча: 13
   количество токенов: 176
   длина токена: 768

Now we will go through the normalization layer and the attention module. The attention module in the encoding block is designed not to change the token length. This is achieved by using linear projections after the attention mechanism. After the attention module, we implement our first split connection.

y = E.norm1(x)
print('After norm, dimensions are\n\tbatchsize:', y.shape[0], '\n\tnumber of tokens:', y.shape[1], '\n\ttoken size:', y.shape[2])
y = E.attn(y)
print('After attention, dimensions are\n\tbatchsize:', y.shape[0], '\n\tnumber of tokens:', y.shape[1], '\n\ttoken size:', y.shape[2])
y = y + x
print('After split connection, dimensions are\n\tbatchsize:', y.shape[0], '\n\tnumber of tokens:', y.shape[1], '\n\ttoken size:', y.shape[2])
После нормализации размеры:
   размер батча: 13
   количество токенов: 176
   размер токена: 768
После слоя внимания размеры:
   размер батча: 13
   количество токенов: 176
   размер токена: 768
После разделения соединения размеры:
   размер батча: 13
   количество токенов: 176
   размер токена: 768

Now we go through another normalization layer and then a neural network module. We end with a second branch.

z = E.norm2(y)
print('After norm, dimensions are\n\tbatchsize:', z.shape[0], '\n\tnumber of tokens:', z.shape[1], '\n\ttoken size:', z.shape[2])
z = E.neuralnet(z)
print('After neural net, dimensions are\n\tbatchsize:', z.shape[0], '\n\tnumber of tokens:', z.shape[1], '\n\ttoken size:', z.shape[2])
z = z + y
print('After split connection, dimensions are\n\tbatchsize:', z.shape[0], '\n\tnumber of tokens:', z.shape[1], '\n\ttoken size:', z.shape[2])
После нормализации размеры:
   размер батча: 13
   количество токенов: 176
   размер токена: 768
После нейронной сети размеры:
   размер батча: 13
   количество токенов: 176
   размер токена: 768
После разделения соединения размеры:
   размер батча: 13
   количество токенов: 176
   размер токена: 768

This completes the processing in one coding block. Since the length and shape of tokens remain unchanged after processing, the model can easily pass tokens through multiple coding blocks, as determined by the depth hyperparameter.

Neural network module

The Neural Network (NN) module is a subcomponent of the encoding block. The NN module is very simple and consists of a fully connected layer, an activation layer, followed by another fully connected layer. The activation layer can be any layer from torch.nn.modules.activationwhich is transmitted to the input of the module. The NN module can be configured to maintain the same shape at the input and output, although the internal layers can change the dimensionality of the data. We won't go into detail about this code since neural networks are a common topic in machine learning and are not the focus of this article. However, below is the code for the NN module.

class NeuralNet(nn.Module):
    def __init__(self,
       in_chan: int,
       hidden_chan: NoneFloat=None,
       out_chan: NoneFloat=None,
       act_layer = nn.GELU):
        """ Neural Network Module

            Args:
                in_chan (int): number of channels (features) at input
                hidden_chan (NoneFloat): number of channels (features) in the hidden layer;
                                        if None, number of channels in hidden layer is the same as the number of input channels
                out_chan (NoneFloat): number of channels (features) at output;
                                        if None, number of output channels is same as the number of input channels
                act_layer(nn.modules.activation): torch neural network layer class to use as activation
        """

        super().__init__()

        ## Define Number of Channels
        hidden_chan = hidden_chan or in_chan
        out_chan = out_chan or in_chan

        ## Define Layers
        self.fc1 = nn.Linear(in_chan, hidden_chan)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_chan, out_chan)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x

Processing Predictions

After passing through the coding blocks, the last step for the model is to make a prediction. The “prediction processing” component of the ViT diagram is shown below.

Prediction processing components in a ViT diagram (image by author)

Prediction processing components in a ViT diagram (image by author)

We will look at each step of this process. We will continue with 176 tokens of length 768. We will use a batch size of 1 to illustrate how a single prediction is made. A batch size greater than 1 would mean that the model makes multiple predictions at once, computing them in parallel.

# Define an Input
num_tokens = 176
token_len = 768
batch = 1
x = torch.rand(batch, num_tokens, token_len)
print('Input dimensions are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\ttoken length:', x.shape[2])
Входные размеры:
   размер батча: 1
   количество токенов: 176
   длина токена: 768

First, all tokens go through a normalization layer.

norm = nn.LayerNorm(token_len)
x = norm(x)
print('After norm, dimensions are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\ttoken size:', x.shape[2])
После нормализации размеры:
   размер батча: 1
   количество токенов: 1001
   размер токена: 768

We will then separate the prediction token from the rest of the tokens. Throughout all coding blocks, the prediction token accumulated information from other tokens and became non-zero. We will only use this prediction token to make the final prediction.

pred_token = x[:, 0]
print('Length of prediction token:', pred_token.shape[-1])
Длина токена предсказания: 768

Finally, the prediction token is passed through the head to make a prediction. The head is usually a type of neural network, and its structure varies depending on the model. In the article An Image is Worth 16×16 Words [2] MLP (multilayer perceptron) is used multilayer perceptron) with one hidden layer during pre-training and a line layer during final fine-tuning of the model. In the Tokens-to-Token ViT model [3] one line layer is used as the “head”. This example uses one line layer.

Note that the shape of the head output depends on the learning task. For classification, this is usually a vector of length equal to the number of classes, using the encoding one-hot. For a regression problem, this can be any integer number of predicted parameters. In this example, the output value has size 1, representing the single numeric value predicted for the regression problem.

head = nn.Linear(token_len, 1)
pred = head(pred_token)
print('Length of prediction:', (pred.shape[0], pred.shape[1]))
print('Prediction:', float(pred))
Длина предсказания: (1, 1)
Предсказание: -0.5474240779876709

And that's all! The model made a prediction!

Full code

To create a complete ViT module we use the patch tokenization module defined above and the ViT Backbone. ViT Backbone is defined below and contains token processing components, encoding blocks, and prediction processing components.

class ViT_Backbone(nn.Module):
    def __init__(self,
                preds: int=1,
                token_len: int=768,
                num_heads: int=1,
                Encoding_hidden_chan_mul: float=4.,
                depth: int=12,
                qkv_bias=False,
                qk_scale=None,
                act_layer=nn.GELU,
                norm_layer=nn.LayerNorm):

        """ VisTransformer Backbone
            Args:
                preds (int): number of predictions to output
                token_len (int): length of a token
                num_heads(int): number of attention heads in MSA
                Encoding_hidden_chan_mul (float): multiplier to determine the number of hidden channels (features) in the NeuralNet component of the Encoding Module
                depth (int): number of encoding blocks in the model
                qkv_bias (bool): determines if the qkv layer learns an addative bias
                qk_scale (NoneFloat): value to scale the queries and keys by; 
                 if None, queries and keys are scaled by ``head_dim ** -0.5``
                act_layer(nn.modules.activation): torch neural network layer class to use as activation
                norm_layer(nn.modules.normalization): torch neural network layer class to use as normalization
        """

        super().__init__()

        ## Defining Parameters
        self.num_heads = num_heads
        self.Encoding_hidden_chan_mul = Encoding_hidden_chan_mul
        self.depth = depth

        ## Defining Token Processing Components
        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.token_len))
        self.pos_embed = nn.Parameter(data=get_sinusoid_encoding(num_tokens=self.num_tokens+1, token_len=self.token_len), requires_grad=False)

        ## Defining Encoding blocks
        self.blocks = nn.ModuleList([Encoding(dim = self.token_len, 
                                               num_heads = self.num_heads,
                                               hidden_chan_mul = self.Encoding_hidden_chan_mul,
                                               qkv_bias = qkv_bias,
                                               qk_scale = qk_scale,
                                               act_layer = act_layer,
                                               norm_layer = norm_layer)
             for i in range(self.depth)])

        ## Defining Prediction Processing
        self.norm = norm_layer(self.token_len)
        self.head = nn.Linear(self.token_len, preds)

        ## Make the class token sampled from a truncated normal distrobution 
        timm.layers.trunc_normal_(self.cls_token, std=.02)

    def forward(self, x):
        ## Assumes x is already tokenized

        ## Get Batch Size
        B = x.shape[0]
        ## Concatenate Class Token
        x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
        ## Add Positional Embedding
        x = x + self.pos_embed
        ## Run Through Encoding Blocks
        for blk in self.blocks:
            x = blk(x)
        ## Take Norm
        x = self.norm(x)
        ## Make Prediction on Class Token
        x = self.head(x[:, 0])
        return x

From module ViT Backbone we can define a complete ViT model.

class ViT_Model(nn.Module):
 def __init__(self,
    img_size: tuple[int, int, int]=(1, 400, 100),
    patch_size: int=50,
    token_len: int=768,
    preds: int=1,
    num_heads: int=1,
    Encoding_hidden_chan_mul: float=4.,
    depth: int=12,
    qkv_bias=False,
    qk_scale=None,
    act_layer=nn.GELU,
    norm_layer=nn.LayerNorm):

  """ VisTransformer Model

   Args:
    img_size (tuple[int, int, int]): size of input (channels, height, width)
    patch_size (int): the side length of a square patch
    token_len (int): desired length of an output token
    preds (int): number of predictions to output
    num_heads(int): number of attention heads in MSA
    Encoding_hidden_chan_mul (float): multiplier to determine the number of hidden channels (features) in the NeuralNet component of the Encoding Module
    depth (int): number of encoding blocks in the model
    qkv_bias (bool): determines if the qkv layer learns an addative bias
    qk_scale (NoneFloat): value to scale the queries and keys by; 
         if None, queries and keys are scaled by ``head_dim ** -0.5``
    act_layer(nn.modules.activation): torch neural network layer class to use as activation
    norm_layer(nn.modules.normalization): torch neural network layer class to use as normalization
  """
  super().__init__()

  ## Defining Parameters
  self.img_size = img_size
  C, H, W = self.img_size
  self.patch_size = patch_size
  self.token_len = token_len
  self.num_heads = num_heads
  self.Encoding_hidden_chan_mul = Encoding_hidden_chan_mul
  self.depth = depth

  ## Defining Patch Embedding Module
  self.patch_tokens = Patch_Tokenization(img_size,
           patch_size,
           token_len)

  ## Defining ViT Backbone
  self.backbone = ViT_Backbone(preds,
         self.token_len,
         self.num_heads,
         self.Encoding_hidden_chan_mul,
         self.depth,
         qkv_bias,
         qk_scale,
         act_layer,
         norm_layer)
  ## Initialize the Weights
  self.apply(self._init_weights)

 def _init_weights(self, m):
  """ Initialize the weights of the linear layers & the layernorms
  """
  ## For Linear Layers
  if isinstance(m, nn.Linear):
   ## Weights are initialized from a truncated normal distrobution
   timm.layers.trunc_normal_(m.weight, std=.02)
   if isinstance(m, nn.Linear) and m.bias is not None:
    ## If bias is present, bias is initialized at zero
    nn.init.constant_(m.bias, 0)
  ## For Layernorm Layers
  elif isinstance(m, nn.LayerNorm):
   ## Weights are initialized at one
   nn.init.constant_(m.weight, 1.0)
   ## Bias is initialized at zero
   nn.init.constant_(m.bias, 0)
   
 @torch.jit.ignore ##Tell pytorch to not compile as TorchScript
 def no_weight_decay(self):
  """ Used in Optimizer to ignore weight decay in the class token
  """
  return {'cls_token'}

 def forward(self, x):
  x = self.patch_tokens(x)
  x = self.backbone(x)
  return x

The parameters img_size, patch_size and token_len define the image size, patch size and token length, respectively, in the patch tokenization module.

Options num_heads, Encoding_hidden_channel_mul, qkv_bias, qk_scale and act_layer define modules of encoding blocks. The activation layer (act_layer) can be any layer from torch.nn.modules.activation. Parameter depth determines the number of coding blocks in the model.

Parameter norm_layer specifies normalization both within and outside of coding block modules. It can be selected from any layer in torch.nn.modules.normalization.

Method _init_weights taken from code T2T-ViT [3]. This method can be removed so that all trainable weights and biases are initialized randomly. In the current implementation, the weights of the linear layers are initialized to a truncated normal distribution; Line layer offsets are initialized to zero; the weights of the normalization layers are initialized to one; offsets of normalization layers are set to zero.

Conclusion

Now that you have a deep understanding of ViT mechanics, you can start training your models! Below is a list of resources where you can download the code for ViT models. Some of them allow you to make more modifications to the model than others.

  • GitHub repository for this series of articles

  • GitHub repository for the article An Image is Worth 16×16 Words – contains pre-trained models and fine-tuning code, but does not contain model definitions

  • ViT, as implemented in PyTorch Image Models (timm) timm.create_model('vit_base_patch16_224', pretrained=True)

  • Plastic bag vit-pytorch by Phil Wang


In conclusion, a reminder about the upcoming open lessons on computer vision and machine learning:

Similar Posts

Leave a Reply

Your email address will not be published. Required fields are marked *