Building Vision Transformers with PyTorch

Sept 2, 2024

Hey there! Ever wondered how those super smart AI models understand images? Like, how can a computer look at a picture of a cat and be like, "Yo, that's a cat!"? Well, my friend, today we're building one from scratch using PyTorch!

You might've heard of Transformers in the context of Natural Language Processing (NLP). They're the tech behind those chatbots and language models that spit out some pretty impressive text. But here's the twist: what if we used that same technology on images? Enter Vision Transformers.

Here is the link to the "An Image is worth 16*16 words" Vision Transformer research paper.

Instead of reading words like in NLP, we're gonna slice up images into tiny patches, treat them like words, and feed them into a Transformer. By the end, we'll have a model that can look at an image and tell us what's in it. Cool, right?

The Core Structure: Breaking It Down

Our Vision Transformer is like a layered cake, with each layer doing something special. Let's go through each layer and see how it all comes together:

Vision Transformer structure
Figure 1: From the ViT Paper showcasing the different inputs, outputs, layers and blocks.

1. Patch + Position Embedding: The Foundation

First up, we need to break our image into patches. Think of each patch as a small piece of the puzzle. But the Transformer needs to know the order of these pieces, so we add some position info to each patch—just like numbering puzzle pieces. This helps the model understand the layout of the image.

Patch and Position Embedding

2. Linear Projection: Turning Patches into Embeddings

Next, we convert these patches into embeddings. What's an embedding, you ask? It's a fancy way of turning raw image data into a format that the model can learn from. Imagine the patches going through a cool filter that makes them smarter!

Linear Projection

3. Norm: Keeping Things Regular

Layer Normalization, or "LayerNorm", is like that friend who keeps you grounded. It regularizes the model, making sure it doesn't get too carried away with any one piece of information. This helps prevent overfitting, where the model gets too good at the training data but flops in the real world.

4. Multi-Head Attention: The Gossip Layer

Now we get to the juicy part—Multi-Head Self-Attention (MSA). This layer is like a bunch of friends in a group chat. They're all talking to each other, sharing gossip (information), and making sure everyone's in the loop. Each head in this layer looks at the image from a different angle, ensuring that no detail is missed.

5. MLP: The Brain Power

After all that chatting, we need to do some serious thinking. Enter the Multilayer Perceptron (MLP), which is just a fancy way of saying "a bunch of neurons working together." This layer processes the information, learns patterns, and makes decisions. It's the brain of our Transformer.

6. Transformer Encoder: Stack 'Em Up

The Transformer Encoder is where things get stacked—literally. It's a combination of all the layers above, repeated multiple times. Think of it like stacking multiple burgers on top of each other to make a giant, delicious burger tower. Each layer adds more flavor (or in our case, more learning).

Transformer Encoder

7. MLP Head: The Final Answer

Finally, after all that hard work, we need to get the answer. The MLP Head is the final layer that takes all the learned features and says, "Okay, based on everything I've seen, this image is a dog!" Or a cat. Or whatever's in the image.

Let's Code This Bad Boy

Here's how you can implement a Vision Transformer from scratch using PyTorch:


import torch 
import torch.nn as nn

class PatchEmbedding(nn.Module):
    """Turns a 2D input image into a 1D sequence learnable embedding vector.

    Args:
        in_channels (int): Number of color channels for the input images. Defaults to 3.
        patch_size (int): Size of patches to convert input image into. Defaults to 16.
        embedding_dim (int): Size of embedding to turn image into. Defaults to 768.
    """
    # 2. Initialize the class with appropriate variables
    def __init__(self, 
                    in_channels:int=3,
                    patch_size:int=16,
                    embedding_dim:int=768):
        super().__init__()

        self.patch_size = patch_size

        # 3. Create a layer to turn an image into patches
        self.patcher = nn.Conv2d(in_channels=in_channels,
                                    out_channels=embedding_dim,
                                    kernel_size=patch_size,
                                    stride=patch_size,
                                    padding=0)

        # 4. Create a layer to flatten the patch feature maps into a single dimension
        self.flatten = nn.Flatten(start_dim=2, # only flatten the feature map dimensions into a single vector
                                    end_dim=3)

    # 5. Define the forward method 
    def forward(self, x):
        # Create assertion to check that inputs are the correct shape
        image_resolution = x.shape[-1]
        assert image_resolution % self.patch_size == 0, f"Input image size must be divisble by patch size, image shape: {image_resolution}, patch size: {self.patch_size}"

        # Perform the forward pass
        x_patched = self.patcher(x)
        x_flattened = self.flatten(x_patched) 
        # 6. Make sure the output shape has the right order 
        return x_flattened.permute(0, 2, 1) #permuting to change the order of dimension





#main function
class ViT(nn.Module):
    def __init__(self,
                img_size = 224,
                in_channels=3, #table 3 
                patch_size = 16, 
                embedding_dim = 768,
                dropout= 0.1,
                mlp_size = 3072,
                num_transformers_layers = 12,
                num_heads = 12,
                num_classes=1000):
        super().__init__()

        assert img_size % patch_size == 0 , "Image size should be divisible by patch size"

        #create patch embedding
        self.patch_embedding = PatchEmbedding(in_channels=in_channels,
                                                patch_size=patch_size,
                                                embedding_dim=embedding_dim)

        #create class token 
        self.class_token = nn.Parameter(torch.randn(1,1,embedding_dim),
                                        requires_grad=True)


        #create positional embedding
        num_patches = (img_size * img_size) // patch_size ** 2 # N= HW/p*3
        self.positional_embedding = nn.Parameter(torch.randn(1, num_patches + 1 ,embedding_dim))


        #create patch + positional embedding dropout 
        self.embedding_dropout = nn.Dropout(p=dropout)


        #create stacked Transformer Encoder layes
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer= nn.TransformerEncoderLayer(d_model=embedding_dim,
                                                                                                    activation='gelu',
                                                                                                    batch_first=True,
                                                                                                    norm_first=True,
                                                                                                    nhead=num_heads,
                                                                                                    dim_feedforward=mlp_size),
                                                                                                    num_layers=num_transformers_layers)


        # create MLP heads
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(normalized_shape=embedding_dim),
            nn.Linear(in_features=embedding_dim,
                        out_features=num_classes)
        )

    def forward(self, x):

        #batch size
        batch_size = x.shape[0]

        #patch embedding
        x =  self.patch_embedding(x)


        # Create class token embedding and expand it to match the batch size (equation 1)
        class_token = self.class_token.expand(batch_size, -1, -1) # "-1" means to infer the dimension 

        # Concat class embedding and patch embedding (equation 1)
        x = torch.cat((class_token, x), dim=1)

        # Add positional embedding to patch embedding with class token
        x = self.positional_embedding + x

        # droput on patch + positional embedding
        x = self.embedding_dropout(x)

        # Pass embedding through Transformer Encoder stack 
        x = self.transformer_encoder(x)

        # Pass 0th index of x through MLP head
        x = self.mlp_head(x[:,0])

        return x
            

Wrapping Up

Vision Transformers are powerful tools, capable of understanding images in ways that were previously unimaginable. By splitting images into patches and treating them like words in a sentence, we can leverage the power of Transformers in a whole new way.
You can find the full implementation on GitHub.

References

  1. Daniel Brouke learnpytorch.io.
  2. Pytorch documentation.