ball-blog
  • Welcome, I'm ball
  • Machine Learning - Basic
    • Entropy
    • Cross Entropy
    • KL-Divergence
    • Monte-Carlo Method
    • Variational Auto Encoder
    • SVM
    • Adam Optimizer
    • Batch Normalization
    • Tokenizer
    • Rotary Positional Encoding
    • Vector Quantized VAE(VQ-VAE)
    • DALL-E
    • Diffusion Model
    • Memory Layers at Scale
    • Chain-of-Thought
  • Einsum
  • Linear Algebra
    • Linear Transformation
    • Determinant
    • Eigen-Value Decomposition(EVD)
    • Singular-Value Decomposition(SVD)
  • AI Accelerator
    • CachedAttention
    • SGLang
    • CacheBlend
  • Reinforcement Learning
    • Markov
  • Policy-Improvement Algorithm
  • Machine Learning - Transformer
    • Attention is All you need
    • Why do we need a mask in Transformer
    • Linear Transformer
    • kr2en Translator using Tranformer
    • Segment Anything
    • MNIST, CIFAR10 Classifier using ViT
    • Finetuning PaliGemma using LoRA
    • LoRA: Low-Rank Adaptation
  • EGTR: Extracting Graph from Transformer for SGG
  • Machine Learning - Mamba
    • Function Space(Hilbert Space)
    • HIPPO Framework
    • Linear State Space Layer
    • S4(Structures Space for Sequence Model)
    • Parallel Scan Algorithm
    • Mamba Model
  • Computer System
    • Memory Ordering: Release/Acquire 1
    • Memory Ordering: Release/Acquire 2
    • BUDAlloc
    • Lock-free Hash Table
    • Address Sanitizer
  • App development
    • Bluetooth connection in linux
    • I use Bun!
    • Using Tanstack-query in Frontend
    • Deploying AI Service using EC2
  • Problem Solving
    • Bipartite Graph
    • Shortest Path Problem in Graph
    • Diameter of a Tree
  • Scribbles
Powered by GitBook
On this page
  • What is an Encoder, Decoder?
  • Limitation of previous Auto Encoder
  • Variational Auto Encoder
  • How VAE Work?
  • Critical problem in VAE
  • Implementation details

Was this helpful?

Edit on GitHub
  1. Machine Learning - Basic

Variational Auto Encoder

Last updated 3 months ago

Was this helpful?

What is an Encoder, Decoder?

Before Machine Learning algorithms process data, all non-matrix datas are converted to low-dimensional vector(called Latent vector). For example, for text, they are converted to embedding spaces. This is crutial step because if we just use one-hot vector, then the vector dimension would be very large, and this will make the computational cost explode.

Encoder is an algorithm that converts real-data(Text, Image) into low-dimension matrix. Decoder is an algorithm that converts low-dimension matrix into real-data. In deep learning, we often use neural network such as CNN for encoder/decoder.

Limitation of previous Auto Encoder

Auto encoder outputs latent vector, called zzz, for input xxx. Decoder will be trained to convert zzz into xxx again.

However, auto encoder outputs same latent vector if the input is same. which means that, decoder won't be able to produce various generative outputs for same input. Also, decoder won't be trained to produce valid outputs if the latent vector slightly change.

This limitation brings the use of VAE.

Variational Auto Encoder

Variational Auto Encoder(VAE) produce a distribution of latent vector for input.

The term distribution of latent vector is hard to understand. Let's say latent vector is 3-dimensional. Latent vector can be expressed as z=(p,q,r)z = (p, q, r)z=(p,q,r). Distribution of latent vector means that, each of p,q,rp, q, rp,q,r are distribution and not deterministic values. p∼N(0,1)p \sim N(0, 1)p∼N(0,1), q∼N(1,2)q \sim N(1, 2)q∼N(1,2), r∼N(−1,1)r \sim N(-1, 1)r∼N(−1,1)

However, neural networks(CNN, RNN) used in encoder don't have an ability to produce a distribution for same input. Instead, we structure the encoder to produce the mean μ\muμ and variant . And we set the latent vector as Gaussian distribution z∼N(μ,σ)z \sim N(\mu, \sigma)z∼N(μ,σ)

How VAE Work?

Let's think of specific case of VAE.

Defining a problem

We want to build a VAE as follows: input XXX: picture taken by a camera output ZZZ: the angle of the camera, focus of the lense, type of figure that is taken

Important fact is that we can observe XXX, but not ZZZ. We want to know the distribution of ZZZ when XXX is given(= p(Z∣X)p(Z|X)p(Z∣X))

Applying Bayes theorem

Due to Bayes theorem, we can come up with following equation.

p(Z∣X)=p(X∣Z)⋅p(Z)p(X)p(Z|X) = \frac{p(X|Z) \cdot p(Z)}{p(X)}p(Z∣X)=p(X)p(X∣Z)⋅p(Z)​

Can we just calculate it? No, we cannot.

Tractable terms and Intractable term

The term p(Z)p(Z)p(Z) is the distribution of latent vector. Also called as prior distribution for latent vector. This is tractable because we usually approximate ZZZ into tractable distribution such as Gaussian.

The term p(X∣Z)p(X|Z)p(X∣Z) is the distribution of decoder output when latent vector ZZZ is given. This is tractable because this is a simple calculation of decoder neural network.

The term p(X)p(X)p(X), is however intractable. As the decoder neural network gets complicated(which makes complex p(X∣Z)p(X|Z)p(X∣Z)), it is impossible to do integral for all latent space.

p(X)=∫p(X∣Z)⋅p(Z)dZp(X) = \int p(X|Z) \cdot p(Z) dZp(X)=∫p(X∣Z)⋅p(Z)dZ

Bayes theorem is not enough.

New way to calculate p(Z∣X)p(Z|X)p(Z∣X)?

Approximating p(Z∣X)≈q(Z∣X)p(Z|X) \approx q(Z|X)p(Z∣X)≈q(Z∣X) is the method used in VAE. qqq is well-known distribution such as Gaussian.

We are going to use KL-divergence to find best matching q(Z∣X)q(Z|X)q(Z∣X) close to p(Z∣X)p(Z|X)p(Z∣X).

Please look at KL-Divergence if you are not familiar with KL-divergence.

Approximating q(Z∣X)q(Z|X)q(Z∣X) that minimize KL-Divergence

Finding best-fitting q(Z∣X)q(Z|X)q(Z∣X) is as follows:

q∗(Z∣X)=argminq(Z∣X) KL(q(Z∣X),p(Z∣X))q^*(Z|X) = \underset{q(Z|X)}{argmin} \ KL(q(Z|X), p(Z|X))q∗(Z∣X)=q(Z∣X)argmin​ KL(q(Z∣X),p(Z∣X))

Let's unwrap the KL-divergence to take a deeper look.

KL(q(Z∣X),p(Z∣X))=∫latent spaceq(Z∣X)⋅log(q(Z∣X)p(Z∣X)) dZ =∫latent spaceq(Z∣X)⋅log(q(Z∣X)⋅p(X)p(X,Z))dZ =∫latent spaceq(Z∣X)⋅log(q(Z∣X)p(X,Z))dZ+∫latent spaceq(Z∣X)⋅log(p(X))dZ =EZ∼q(Z∣X)[log(q(Z∣X)p(X,Z))]+EZ∼q(Z∣X)[log(p(X))] ...(1) =−L(q)+log(p(X)) ...(2) =−L(q)+evidenceKL(q(Z|X), p(Z|X)) = \int_{latent \ space} q(Z|X) \cdot log(\frac{q(Z|X)}{p(Z|X)}) \ dZ \\~\\ = \int_{latent \ space} q(Z|X) \cdot log(\frac{q(Z|X) \cdot p(X)}{p(X, Z)}) dZ \\~\\ = \int_{latent \ space} q(Z|X) \cdot log(\frac{q(Z|X)}{p(X, Z)})dZ + \int_{latent \ space} q(Z|X) \cdot log(p(X)) dZ \\~\\= E_{Z \sim q(Z|X)}[log(\frac{q(Z|X)}{p(X,Z)})] + E_{Z \sim q(Z|X)}[log(p(X))] \ ...(1) \\~\\ = -L(q) + log(p(X)) \ ...(2) \\~\\ = -L(q) + evidenceKL(q(Z∣X),p(Z∣X))=∫latent space​q(Z∣X)⋅log(p(Z∣X)q(Z∣X)​) dZ =∫latent space​q(Z∣X)⋅log(p(X,Z)q(Z∣X)⋅p(X)​)dZ =∫latent space​q(Z∣X)⋅log(p(X,Z)q(Z∣X)​)dZ+∫latent space​q(Z∣X)⋅log(p(X))dZ =EZ∼q(Z∣X)​[log(p(X,Z)q(Z∣X)​)]+EZ∼q(Z∣X)​[log(p(X))] ...(1) =−L(q)+log(p(X)) ...(2) =−L(q)+evidence

The left term in (1) can be expressed as −L(q)-L(q)−L(q) because p(X,Z)p(X, Z)p(X,Z) is known.( p(X∣Z),p(Z)p(X|Z), p(Z)p(X∣Z),p(Z) are tractable!)

The right term in (1) does not depend on ZZZ. As a result, it can be represented as right term in (2). This is also referred as evidence.

Analyzing more deeper

Since probability is between 0~1, evidence is always smaller or equal to 0.

Also, KL-divergence is always bigger or equal to 0.

This results L(q)≤evidenceL(q) \le evidenceL(q)≤evidence, which means that L(q)L(q)L(q) is the evidence lower bound, a.k.a ELBO.

KLD=−ELBO+evidenceKLD = -ELBO + evidenceKLD=−ELBO+evidence
ELBO=L(q)=EZ∼q(Z∣X)[log(p(X,Z)q(Z∣X))]ELBO = L(q) = E_{Z \sim q(Z|X)}[log(\frac{p(X,Z)}{q(Z|X)})]ELBO=L(q)=EZ∼q(Z∣X)​[log(q(Z∣X)p(X,Z)​)]

Applying it to Variational Inference

q∗(Z∣X)=argminq(Z∣X) KL(q(Z∣X),p(Z∣X)) =argmaxq(Z∣X) L(q)q^*(Z|X) = \underset{q(Z|X)}{argmin} \ KL(q(Z|X), p(Z|X)) \\~\\ = \underset{q(Z|X)}{argmax} \ L(q)q∗(Z∣X)=q(Z∣X)argmin​ KL(q(Z∣X),p(Z∣X)) =q(Z∣X)argmax​ L(q)

If we find the optimal q∗(Z∣X)q^*(Z|X)q∗(Z∣X), then this means that we can encode the input XXX into latent distribution q∗(Z∣X)q^*(Z|X)q∗(Z∣X), which is an approximation of p(Z∣X)p(Z|X)p(Z∣X).

The encoder learns to output mean and standard deviation of the input image. For outputs of encoder μ and σ\mu \ and \ \sigmaμ and σ, it makes latent space based on Gaussian distribution with parameter μ and σ\mu \ and \ \sigmaμ and σ.

Critical problem in VAE

There is a critical problem in VAE: To get latent vector, we sample from the distribution q∗(Z∣X)q^*(Z|X)q∗(Z∣X). However, sampling from distribution is not differentiable, because it is just picking a random sample from the distribution.

Reparameterization Trick

Reparameterization is a method that express the randomness using one more parameter.

Using reparameterization, we can move the random sampling step to unimportant node ϵ\epsilonϵ.

Implementation details

You can train and run VAE model to generate various MNIST picture in .

https://github.com/jinho-choi123/VAE-pytorch/tree/main
Figure of VAE
Figure of Autoencoder.
Figure of p(Z∣X)p(Z|X)p(Z∣X) and q(Z∣X)q(Z|X)q(Z∣X).
Figure of VAE.
Figure of Reparameterization Trick.
https://paperswithcode.com/method/autoencoder
https://lilianweng.github.io/posts/2018-08-12-vae/
https://medium.com/@rushikesh.shende/autoencoders-variational-autoencoders-vae-and-%CE%B2-vae-ceba9998773d
https://www.researchgate.net/figure/The-VAE-reparameterization-trick_fig2_381041249