Vector Quantized VAE(VQ-VAE)
Last updated
Was this helpful?
Last updated
Was this helpful?
Many previous VAE generates continuous latent space.
This paper represents a powerful generative model that has discrete latent space.
Encoder gets image, and outputs the vectors .
Each of the vectors are then compared to the codebook(a.k.a embedding space). And find the nearest embedding vector, then be converted into it. This is the latent matrix .
Finally, decoder gets the latent vector and convert it to an image.
There are three trainable parts: encoder parameter, codebook, and decoder parameter.
The objective of training is to minimize the following loss function: stands for the stopgradient operator, which has zero partial derivative.
The first term is the reconstruction loss, which is used to optimize decoder and encoder.
The second term is used to train the codebook, making it be similar to the output of the encoder.
The last term is to train the encoder to be similar to the embeddings.
You may think that second and last term is making codebook and encoder to be similar to be each other, which is weird. Yes, but there might be a case where encoder's parameter is quickly changing, and codebook parameter's training speed cannot follow it.
We add the last term to make sure the encoder and codebook move together.
Why don't we have a KLD term in the objective loss function? This is because we set the prior of latent space as uniform distribution, .
Then the KLD between the prior and posterior categorical distribution can be expressed as follows:
Because of this reason, we ignore the KLD in the loss objective.
After training VQ-VAE, we have a sequence of latent vectors. PixelCNN learns the relationship between all the latent vectors, in other words, PixelCNN learns the prior of latent space. This is important to make realistic samples.
PixelCNN is autoregression latent vector generator, which learns how to make latent vectors that generates realistic image.
posterior categorical distribution is one-hot function as given.
You can train and run VQ-VAE model for CIFAR-10 in .