ball-blog
  • Welcome, I'm ball
  • Machine Learning - Basic
    • Entropy
    • Cross Entropy
    • KL-Divergence
    • Monte-Carlo Method
    • Variational Auto Encoder
    • SVM
    • Adam Optimizer
    • Batch Normalization
    • Tokenizer
    • Rotary Positional Encoding
    • Vector Quantized VAE(VQ-VAE)
    • DALL-E
    • Diffusion Model
    • Memory Layers at Scale
    • Chain-of-Thought
  • Einsum
  • Linear Algebra
    • Linear Transformation
    • Determinant
    • Eigen-Value Decomposition(EVD)
    • Singular-Value Decomposition(SVD)
  • AI Accelerator
    • CachedAttention
    • SGLang
    • CacheBlend
  • Reinforcement Learning
    • Markov
  • Policy-Improvement Algorithm
  • Machine Learning - Transformer
    • Attention is All you need
    • Why do we need a mask in Transformer
    • Linear Transformer
    • kr2en Translator using Tranformer
    • Segment Anything
    • MNIST, CIFAR10 Classifier using ViT
    • Finetuning PaliGemma using LoRA
    • LoRA: Low-Rank Adaptation
  • EGTR: Extracting Graph from Transformer for SGG
  • Machine Learning - Mamba
    • Function Space(Hilbert Space)
    • HIPPO Framework
    • Linear State Space Layer
    • S4(Structures Space for Sequence Model)
    • Parallel Scan Algorithm
    • Mamba Model
  • Computer System
    • Memory Ordering: Release/Acquire 1
    • Memory Ordering: Release/Acquire 2
    • BUDAlloc
    • Lock-free Hash Table
    • Address Sanitizer
  • App development
    • Bluetooth connection in linux
    • I use Bun!
    • Using Tanstack-query in Frontend
    • Deploying AI Service using EC2
  • Problem Solving
    • Bipartite Graph
    • Shortest Path Problem in Graph
    • Diameter of a Tree
  • Scribbles
Powered by GitBook
On this page
  • Summary
  • What is an Optimizer?
  • Why we need algorithm for Optimizer?
  • Momentum Algorithm
  • RMSprop Algorithm
  • Adam
  • Bias correction in Adam
  • References

Was this helpful?

Edit on GitHub
  1. Machine Learning - Basic

Adam Optimizer

Last updated 3 months ago

Was this helpful?

Summary

In most training in Deep Learning, we use Adam Optimizer. In this post, I would like to discuss what is an optimizer, and reason we use Adam optimizer.

If you want more detailed information about Adam optimizer, please look at the following paper:

What is an Optimizer?

In Deep learning, we optimize the model to make lower loss. In other words, we call this procedure Training.

To optimize(minimizing the loss) the model, we use 'Gradient Descent method'. This is done by calculating each parameter's gradient and subtract it from the parameter value.

θt+1=θt−λ∂L(X,θt,αt,βt,...)∂θt\theta_{t+1} = \theta_t - \lambda \frac {\partial L(X, \theta_t, \alpha_t, \beta_t, ...)}{\partial \theta_t}θt+1​=θt​−λ∂θt​∂L(X,θt​,αt​,βt​,...)​

Why we need algorithm for Optimizer?

Following GIF compares various algorithms in the local minima situation.

SGD cannot escape the local minima. Momentum algorithm escapes at final seconds. Other algorithms could easily escape the local minima.

As you can see, good algorithm can bring good training speed, and good model performance.

We are going to focus on Adam algorithm, but first things first! We are going to look at Momentum Algorithm and RMSProp Algorithm.

Momentum Algorithm

Momentum algorithm is an algorithm that use previous moment value to calculate current moment.

mt=α⋅mt−1+(1−α)⋅gt  θt=θt−1−λmtm_t = \alpha \cdot m_{t-1} + (1-\alpha) \cdot g_t \\~~\\ \theta_t = \theta_{t-1} - \lambda m_tmt​=α⋅mt−1​+(1−α)⋅gt​  θt​=θt−1​−λmt​

mtm_tmt​: moment value at timestamp t α\alphaα: moment weight constant gtg_tgt​: gradient calculated at timestamp t θt\theta_tθt​: model parameter at timestamp t λ\lambdaλ: learning rate constant

RMSprop Algorithm

The key of RMSprop is following:

If the gradient is big, the parameter convergence is reached early. On the other hand, if the gradient is small, then the parameter convergence is delayed.

So in RMSprop, it scales the learning rate by the size of gradient.

To calculate the size of gradient, it just square the gradient value!

rt=βrt−1+(1−β)(▽θtL)2 θt=θt−1−αrt+ϵ⊙▽θtLr_t = \beta r_{t-1} + (1-\beta)(\triangledown_{\theta_t} L)^2 \\~\\ \theta_t = \theta_{t-1} - \frac {\alpha}{\sqrt{r_t} + \epsilon} \odot \triangledown_{\theta_t} Lrt​=βrt−1​+(1−β)(▽θt​​L)2 θt​=θt−1​−rt​​+ϵα​⊙▽θt​​L

Adam

Adam algorithm is mixture of Momentum and RMSprop.

You can see there is a bias-correction step during the algorithm. Let's take a look at the reason

Bias correction in Adam

Let's think mt,vt,gtm_t, v_t, g_tmt​,vt​,gt​ as a probability distribution. For example, g1,g2,...,gtg_1, g_2, ..., g_tg1​,g2​,...,gt​ are sampled from distribution GGG.

for i=1...t, mi∼M, vi∼V, gi∼Gfor \ i=1...t,\\ \ m_i \sim M, \ v_i \sim V, \ g_i\sim Gfor i=1...t, mi​∼M, vi​∼V, gi​∼G

As the timestamp t increase, the estimate value of mt,vtm_t, v_tmt​,vt​ converge.

E[mt]→E[gt] and E[vt]→E[gt2]E[m_t] \rightarrow E[g_t] \ and \ E[v_t] \rightarrow E[g_t^2]E[mt​]→E[gt​] and E[vt​]→E[gt2​]

However, there is a problem. When timestamp t is small, E[mt],E[vt]E[m_t], E[v_t]E[mt​],E[vt​] is greatly biased.

mt=(1−β1)⋅∑k=1tβ1t−kgk E[mt]=E[gt]⋅(1−β1t)m_t = (1-\beta_1) \cdot \sum_{k=1}^t \beta_1^{t-k}g_k \\~\\ E[m_t] = E[g_t] \cdot(1-\beta_1^t)mt​=(1−β1​)⋅k=1∑t​β1t−k​gk​ E[mt​]=E[gt​]⋅(1−β1t​)
vt=(1−β2)⋅∑k=1tβ2t−kgk2 E[vt]=E[gt2]⋅(1−β2t)v_t = (1-\beta_2) \cdot \sum_{k=1}^t \beta_2^{t-k}g_k^2 \\~\\ E[v_t]=E[g_t^2] \cdot (1-\beta_2^t)vt​=(1−β2​)⋅k=1∑t​β2t−k​gk2​ E[vt​]=E[gt2​]⋅(1−β2t​)

So in Adam divides mt,vtm_t, v_tmt​,vt​ by 1−β1t and 1−β2t1-\beta_1^t \ and \ 1-\beta_2^t1−β1t​ and 1−β2t​ to correct the bias.

References

[1]

[2]

[3]

[4]

https://www.deepchecks.com/glossary/rmsprop/#:~:text=RMSprop%20(Root%20Mean%20Squared%20Propagation,and%20other%20Machine%20Learning%20techniques.
https://optimization.cbe.cornell.edu/index.php?title=File:1_-_2dKCQHh_-_Long_Valley.gif
https://velog.io/@cha-suyeon/DL-%EC%B5%9C%EC%A0%81%ED%99%94-%EC%95%8C%EA%B3%A0%EB%A6%AC%EC%A6%98-RMSProp-Adam#adam
https://arxiv.org/abs/1412.6980
Optimizer Comparison. GIF from
Pseudo code for Adam. From
https://optimization.cbe.cornell.edu/index.php?title=File:1_-_2dKCQHh_-_Long_Valley.gif
https://arxiv.org/abs/1412.6980
LogoAdam: A Method for Stochastic OptimizationarXiv.org