ball-blog
  • Welcome, I'm ball
  • Machine Learning - Basic
    • Entropy
    • Cross Entropy
    • KL-Divergence
    • Monte-Carlo Method
    • Variational Auto Encoder
    • SVM
    • Adam Optimizer
    • Batch Normalization
    • Tokenizer
    • Rotary Positional Encoding
    • Vector Quantized VAE(VQ-VAE)
    • DALL-E
    • Diffusion Model
    • Memory Layers at Scale
    • Chain-of-Thought
  • Einsum
  • Linear Algebra
    • Linear Transformation
    • Determinant
    • Eigen-Value Decomposition(EVD)
    • Singular-Value Decomposition(SVD)
  • AI Accelerator
    • CachedAttention
    • SGLang
    • CacheBlend
  • Reinforcement Learning
    • Markov
  • Policy-Improvement Algorithm
  • Machine Learning - Transformer
    • Attention is All you need
    • Why do we need a mask in Transformer
    • Linear Transformer
    • kr2en Translator using Tranformer
    • Segment Anything
    • MNIST, CIFAR10 Classifier using ViT
    • Finetuning PaliGemma using LoRA
    • LoRA: Low-Rank Adaptation
  • EGTR: Extracting Graph from Transformer for SGG
  • Machine Learning - Mamba
    • Function Space(Hilbert Space)
    • HIPPO Framework
    • Linear State Space Layer
    • S4(Structures Space for Sequence Model)
    • Parallel Scan Algorithm
    • Mamba Model
  • Computer System
    • Memory Ordering: Release/Acquire 1
    • Memory Ordering: Release/Acquire 2
    • BUDAlloc
    • Lock-free Hash Table
    • Address Sanitizer
  • App development
    • Bluetooth connection in linux
    • I use Bun!
    • Using Tanstack-query in Frontend
    • Deploying AI Service using EC2
  • Problem Solving
    • Bipartite Graph
    • Shortest Path Problem in Graph
    • Diameter of a Tree
  • Scribbles
Powered by GitBook
On this page
  • Purpose
  • Core
  • Training VQ-VAE
  • Adding PixelCNN into VQ-VAE for image processing
  • Implementation details

Was this helpful?

Edit on GitHub
  1. Machine Learning - Basic

Vector Quantized VAE(VQ-VAE)

Last updated 3 months ago

Was this helpful?

Purpose

Many previous VAE generates continuous latent space.

This paper represents a powerful generative model that has discrete latent space.

Core

  1. Encoder gets image, and outputs the vectors zθ(x)z_\theta (x)zθ​(x).

  2. Each of the vectors are then compared to the codebook(a.k.a embedding space). And find the nearest embedding vector, then be converted into it. This is the latent matrix q(z∣x)q(z|x)q(z∣x).

  3. Finally, decoder gets the latent vector and convert it to an image.

Training VQ-VAE

There are three trainable parts: encoder parameter, codebook, and decoder parameter.

The objective of training is to minimize the following loss function: sgsgsg stands for the stopgradient operator, which has zero partial derivative.

L=log p(x∣zq(x))+∣∣sg[ze(x)]−e∣∣22+β∣∣ze(x)−sg[e]∣∣22L = log \ p(x|z_q(x)) + ||sg[z_e(x)] - e||_2^2 + \beta||z_e(x)-sg[e]||_2^2L=log p(x∣zq​(x))+∣∣sg[ze​(x)]−e∣∣22​+β∣∣ze​(x)−sg[e]∣∣22​

The first term is the reconstruction loss, which is used to optimize decoder and encoder.

The second term is used to train the codebook, making it be similar to the output of the encoder.

The last term is to train the encoder to be similar to the embeddings.

You may think that second and last term is making codebook and encoder to be similar to be each other, which is weird. Yes, but there might be a case where encoder's parameter is quickly changing, and codebook parameter's training speed cannot follow it.

We add the last term to make sure the encoder and codebook move together.

Why don't we have a KLD term in the objective loss function? This is because we set the prior of latent space zzz as uniform distribution, p(z)=1Kp(z) = \frac{1}{K}p(z)=K1​.

Then the KLD between the prior and posterior categorical distribution can be expressed as follows: KL(q(z∣x),p(x)) =∫log q(z∣x)⋅log(q(z∣x)p(x)) =log q(z∣x)⋅log(q(z∣x)p(x)) =1⋅log(1K)=−log(K)KL(q(z|x), p(x)) \\~\\ = \int log \ q(z|x) \cdot log(\frac{q(z|x)}{p(x)}) \\~\\ = log \ q(z|x) \cdot log(\frac{q(z|x)}{p(x)}) \\~\\ = 1 \cdot log(\frac{1}{K}) = -log(K)KL(q(z∣x),p(x)) =∫log q(z∣x)⋅log(p(x)q(z∣x)​) =log q(z∣x)⋅log(p(x)q(z∣x)​) =1⋅log(K1​)=−log(K)

Because of this reason, we ignore the KLD in the loss objective.

Adding PixelCNN into VQ-VAE for image processing

After training VQ-VAE, we have a sequence of latent vectors. PixelCNN learns the relationship between all the latent vectors, in other words, PixelCNN learns the prior of latent space. This is important to make realistic samples.

PixelCNN is autoregression latent vector generator, which learns how to make latent vectors that generates realistic image.

Implementation details

posterior categorical distribution q(z∣x)q(z|x)q(z∣x) is one-hot function as given.

You can train and run VQ-VAE model for CIFAR-10 in .

https://github.com/jinho-choi123/VQVAE-pytorch
3MB
Neural Discrete Representation Learning.pdf
pdf
Paper of VQ-VAE.
Figure of VQ-VAE
https://arxiv.org/abs/1711.00937