MNIST, CIFAR10 Classifier using ViT
Last updated
Was this helpful?
Last updated
Was this helpful?
I made a ViT(Visual Transformer) Classifier for MNIST, CIFAR10 Dataset.
Visual Transformer is an attempt to apply Transformer Model to Computer Vision Task. Before ViT, CV tasks were mainly done by CNN. However, ViT suggests that Transformer can be applied to CV, and achieves SOTA and bypass CNN performance.
ViT uses only Encoder part of the Transformer, and it is very simple architecture. I'm pretty sure that if you understood Transformer model, then ViT is piece of cake!
If you want to take a deeper look at ViT, please look at the following paper:
Please check the github code to see detailed implementation.
In NLP, we use various static positional encoding methods such as sinusoidal encoding. However, Visual Transformer use learnable positional encoding, which sets the positional encoding as a learnable parameter.
In ViT, it converts a 2D patches into 1D sequence because of the Transformer structure. This conversion should be dealt in the ViT, which learnable positional encoding takes care.
Well, first I thought since ViT only use Encoder part, it should be smaller than the original Transformer. But at the training phase, I found out that this was totally wrong.
In the Food101 dataset, we preprocessed the image size into (224, 224) and used patch size of (16, 16). The sequence length of the ViT is . In the original paper, it uses patch size of (4, 4). For this reason, ViT requires much larger VRAM for training.
For this project, I was stuck for couple days. I was trying to train the model, but it wasn't efficiently learning. After few days, I found out that I was making a single-layer encoder instead of n_layers encoder. I found this bug when I visualized the summary of the model using torchinfo.summary function.
There were some other dumb mistakes I made during this project. So I decided to make a checklist for myself when implementing a ML model.
Checklist learned from ViT project
Always print out the model summary(using torchinfo.summary) after implementing the model
Apply scheduler.step() for each epoch.(not per steps)
Log every configuration parameters in the log file
Put the data close to the machine. In Colab, do not directly import data from googledrive. Download it in the runtime("/content" directory)
Print test accuracy(correctness percentage) for each epoch.
Use small parameters first. Don't make the model large before you know that your model is perfect. It takes too much time to train.
[1]
[2]
[3]
[4]
[5]