ball-blog
  • Welcome, I'm ball
  • Machine Learning - Basic
    • Entropy
    • Cross Entropy
    • KL-Divergence
    • Monte-Carlo Method
    • Variational Auto Encoder
    • SVM
    • Adam Optimizer
    • Batch Normalization
    • Tokenizer
    • Rotary Positional Encoding
    • Vector Quantized VAE(VQ-VAE)
    • DALL-E
    • Diffusion Model
    • Memory Layers at Scale
    • Chain-of-Thought
  • Einsum
  • Linear Algebra
    • Linear Transformation
    • Determinant
    • Eigen-Value Decomposition(EVD)
    • Singular-Value Decomposition(SVD)
  • AI Accelerator
    • CachedAttention
    • SGLang
    • CacheBlend
  • Reinforcement Learning
    • Markov
  • Policy-Improvement Algorithm
  • Machine Learning - Transformer
    • Attention is All you need
    • Why do we need a mask in Transformer
    • Linear Transformer
    • kr2en Translator using Tranformer
    • Segment Anything
    • MNIST, CIFAR10 Classifier using ViT
    • Finetuning PaliGemma using LoRA
    • LoRA: Low-Rank Adaptation
  • EGTR: Extracting Graph from Transformer for SGG
  • Machine Learning - Mamba
    • Function Space(Hilbert Space)
    • HIPPO Framework
    • Linear State Space Layer
    • S4(Structures Space for Sequence Model)
    • Parallel Scan Algorithm
    • Mamba Model
  • Computer System
    • Memory Ordering: Release/Acquire 1
    • Memory Ordering: Release/Acquire 2
    • BUDAlloc
    • Lock-free Hash Table
    • Address Sanitizer
  • App development
    • Bluetooth connection in linux
    • I use Bun!
    • Using Tanstack-query in Frontend
    • Deploying AI Service using EC2
  • Problem Solving
    • Bipartite Graph
    • Shortest Path Problem in Graph
    • Diameter of a Tree
  • Scribbles
Powered by GitBook
On this page
  • Summary
  • Implementation
  • Learnable Positional Encoding
  • Is the model smaller than Original Transformer?
  • Make a Checklist
  • References

Was this helpful?

Edit on GitHub
  1. Machine Learning - Transformer

MNIST, CIFAR10 Classifier using ViT

Last updated 3 months ago

Was this helpful?

Summary

I made a ViT(Visual Transformer) Classifier for MNIST, CIFAR10 Dataset.

Visual Transformer is an attempt to apply Transformer Model to Computer Vision Task. Before ViT, CV tasks were mainly done by CNN. However, ViT suggests that Transformer can be applied to CV, and achieves SOTA and bypass CNN performance.

ViT uses only Encoder part of the Transformer, and it is very simple architecture. I'm pretty sure that if you understood Transformer model, then ViT is piece of cake!

If you want to take a deeper look at ViT, please look at the following paper:

Implementation

Please check the github code to see detailed implementation.

Learnable Positional Encoding

In NLP, we use various static positional encoding methods such as sinusoidal encoding. However, Visual Transformer use learnable positional encoding, which sets the positional encoding as a learnable parameter.

In ViT, it converts a 2D patches into 1D sequence because of the Transformer structure. This conversion should be dealt in the ViT, which learnable positional encoding takes care.

Is the model smaller than Original Transformer?

Well, first I thought since ViT only use Encoder part, it should be smaller than the original Transformer. But at the training phase, I found out that this was totally wrong.

In the Food101 dataset, we preprocessed the image size into (224, 224) and used patch size of (16, 16). The sequence length of the ViT is (224÷16)2=196(224 \div 16)^2 = 196(224÷16)2=196. In the original paper, it uses patch size of (4, 4). For this reason, ViT requires much larger VRAM for training.

Make a Checklist

For this project, I was stuck for couple days. I was trying to train the model, but it wasn't efficiently learning. After few days, I found out that I was making a single-layer encoder instead of n_layers encoder. I found this bug when I visualized the summary of the model using torchinfo.summary function.

INFO:transformer_log:=========================================================================================================
Layer (type:depth-idx)                                  Output Shape              Param #
=========================================================================================================
ViT                                                     [128, 101]                --
├─Encoder: 1-1                                          [128, 197, 768]           --
│    └─TransformerEmbedding: 2-1                        [128, 197, 768]           768
│    │    └─PatchEmbedding: 3-1                         [128, 196, 768]           590,592
│    │    └─PositionalEmbedding: 3-2                    [1, 197, 768]             151,296
│    │    └─Dropout: 3-3                                [128, 197, 768]           --
│    └─ModuleList: 2-2                                  --                        --
│    │    └─EncoderBlock: 3-4                           [128, 197, 768]           3,151,616
│    │    └─EncoderBlock: 3-5                           [128, 197, 768]           3,151,616
│    │    └─EncoderBlock: 3-6                           [128, 197, 768]           3,151,616
│    │    └─EncoderBlock: 3-7                           [128, 197, 768]           3,151,616
│    │    └─EncoderBlock: 3-8                           [128, 197, 768]           3,151,616
│    │    └─EncoderBlock: 3-9                           [128, 197, 768]           3,151,616
│    │    └─EncoderBlock: 3-10                          [128, 197, 768]           3,151,616
│    │    └─EncoderBlock: 3-11                          [128, 197, 768]           3,151,616
│    │    └─EncoderBlock: 3-12                          [128, 197, 768]           3,151,616
│    │    └─EncoderBlock: 3-13                          [128, 197, 768]           3,151,616
│    │    └─EncoderBlock: 3-14                          [128, 197, 768]           3,151,616
│    │    └─EncoderBlock: 3-15                          [128, 197, 768]           3,151,616
│    └─LayerNorm: 2-3                                   [128, 197, 768]           1,536
├─MLPHead: 1-2                                          [128, 101]                --
│    └─Linear: 2-4                                      [128, 1024]               787,456
│    └─GELU: 2-5                                        [128, 1024]               --
│    └─Linear: 2-6                                      [128, 101]                103,525
│    └─Dropout: 2-7                                     [128, 101]                --
=========================================================================================================
Total params: 39,454,565
Trainable params: 39,454,565
Non-trainable params: 0
Total mult-adds (G): 19.77
=========================================================================================================
Input size (MB): 77.07
Forward/backward pass size (MB): 14564.72
Params size (MB): 157.82
Estimated Total Size (MB): 14799.61
=========================================================================================================
INFO:transformer_log:model parameter #: 39454565

There were some other dumb mistakes I made during this project. So I decided to make a checklist for myself when implementing a ML model.

Checklist learned from ViT project

  1. Always print out the model summary(using torchinfo.summary) after implementing the model

  2. Apply scheduler.step() for each epoch.(not per steps)

  3. Log every configuration parameters in the log file

  4. Put the data close to the machine. In Colab, do not directly import data from googledrive. Download it in the runtime("/content" directory)

  5. Print test accuracy(correctness percentage) for each epoch.

  6. Use small parameters first. Don't make the model large before you know that your model is perfect. It takes too much time to train.

References

[1]

[2]

[3]

[4]

[5]

https://medium.com/@brianpulfer/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0c
https://kimbg.tistory.com/31
https://csm-kr.tistory.com/54
https://github.com/lucidrains/vit-pytorch
https://github.com/kamrulhasanrony/Vision-Transformer-based-Food-Classification/tree/master
LogoAn Image is Worth 16x16 Words: Transformers for Image Recognition at ScalearXiv.org
https://github.com/jinho-choi123/ViT-Visual-Transformer-