This notebook contains a Keras / Tensorflow implementation of the VQ-VAE model, which was introduced in Neural Discrete Representation Learning (van den Oord et al, NeurIPS 2017). This is a generative model based on Variational Auto Encoders (VAE) which aims to make the latent space discrete using Vector Quantization (VQ) techniques. This implementation trains a VQ-VAE based on simple convolutional blocks (no auto-regressive decoder), and a PixelCNN categorical prior as described in the paper. The current code was tested on MNIST. This project is also hosted as a Kaggle notebook.
Pros (+)
- Simple method and training objective
- “Proper” Discrete latent space. This is a promising property to model data that is inherently discrete, e.g. text.
Cons (-)
- Loses the “easy latent sampling” property from VAEs. Two-stage training required to learn a fitting categorical prior.
- The training objective does not correspond to a bound on the log-likelihood anymore.
Vector-Quantized Latent Space
The first step is to build the main VQ-VAE model. It consists of a standard encoder-decoder architecture with convolutional blocks. The main novelty lies in the intermediate Vector Quantizer layer (VQ) that takes care of building a discrete latent space.
More specifically, the encoder, \(f\), is a fully-convolutional neural network that maps input images to latent codes of size \((w, h, d)\), where \(d\) is the dimension of the latent space, and \(w \times h\) the size of the final feature map. The output of the encoder is then mapped to the closest entry in a discrete codebook of \(K\) latent codes, \(\mathcal E = \{e_0 \dots e_{K-1} \}\) where \(\forall i, e_i \in \mathbb{R}^d\).
\[\begin{align} &\textbf{input }x \tag{W x H x C}\\ z_e &= f(x) \tag{w x h x d}\\ z_q^{i, j} &= \arg\min_{e \in \mathcal E} \| z_e^{i, j} - e \|^2 \end{align}\]The Vector Quantization process is implemented as the following Keras layer:
class VectorQuantizer(K.layers.Layer):
def __init__(self, k, **kwargs):
super(VectorQuantizer, self).__init__(**kwargs)
self.k = k
def build(self, input_shape):
self.d = int(input_shape[-1])
rand_init = K.initializers.VarianceScaling(distribution="uniform")
self.codebook = self.add_weight(shape=(self.k, self.d), initializer=rand_init, trainable=True)
def call(self, inputs):
# Map z_e of shape (b, w,, h, d) to indices in the codebook
lookup_ = tf.reshape(self.codebook, shape=(1, 1, 1, self.k, self.d))
z_e = tf.expand_dims(inputs, -2)
dist = tf.norm(z_e - lookup_, axis=-1)
k_index = tf.argmin(dist, axis=-1)
return k_index
def sample(self, k_index):
# Map indices array of shape (b, w, h) to actual codebook z_q
lookup_ = tf.reshape(self.codebook, shape=(1, 1, 1, self.k, self.d))
k_index_one_hot = tf.one_hot(k_index, self.k)
z_q = lookup_ * k_index_one_hot[..., None]
z_q = tf.reduce_sum(z_q, axis=-2)
return z_q
The decoder, \(g\), then takes the quantized codes \(z_q\) as inputs and generates the output image. Here we consider a simple architecture with transposed convolution blocks, mirroring the encoder architecture:
def encoder_pass(inputs, d, num_layers=[16, 32]):
x = inputs
for i, filters in enumerate(num_layers):
x = K.layers.Conv2D(filters=filters, kernel_size=3, padding='SAME', activation='relu',
strides=(2, 2), name="conv{}".format(i + 1))(x)
z_e = K.layers.Conv2D(filters=d, kernel_size=3, padding='SAME', activation=None,
strides=(1, 1), name='z_e')(x)
return z_e
def decoder_pass(inputs, num_layers=[32, 16]):
y = inputs
for i, filters in enumerate(num_layers):
y = K.layers.Conv2DTranspose(filters=filters, kernel_size=4, strides=(2, 2), padding="SAME",
activation='relu', name="convT{}".format(i + 1))(y)
decoded = K.layers.Conv2DTranspose(filters=1, kernel_size=3, strides=(1, 1),
padding="SAME", activation='sigmoid', name='output')(y)
return decoded
Once these three building blocks are done, we can build the full VQ-VAE. One subtility is how we can estimate gradient through the Vector Quantizer: In fact, the transition from \(z_e\) to \(z_q\) does not allow to backpropagate gradient due to the argmin function. Instead, the authors propose to use a straight-through estimator, that directly copies the gradient received by \(z_q\) to \(z_e\).
def build_vqvae(k, d, input_shape=(28, 28, 1), num_layers=[16, 32]):
global SIZE
## Encoder
encoder_inputs = K.layers.Input(shape=input_shape, name='encoder_inputs')
z_e = encoder_pass(encoder_inputs, d, num_layers=num_layers)
SIZE = int(z_e.get_shape()[1])
## Vector Quantization
vector_quantizer = VectorQuantizer(k, name="vector_quantizer")
codebook_indices = vector_quantizer(z_e)
encoder = K.Model(inputs=encoder_inputs, outputs=codebook_indices, name='encoder')
## Decoder
decoder_inputs = K.layers.Input(shape=(SIZE, SIZE, d), name='decoder_inputs')
decoded = decoder_pass(decoder_inputs, num_layers=num_layers[::-1])
decoder = K.Model(inputs=decoder_inputs, outputs=decoded, name='decoder')
## VQVAE Model (training)
sampling_layer = K.layers.Lambda(lambda x: vector_quantizer.sample(x), name="sample_from_codebook")
z_q = sampling_layer(codebook_indices)
codes = tf.stack([z_e, z_q], axis=-1)
codes = K.layers.Lambda(lambda x: x, name='latent_codes')(codes)
straight_through = K.layers.Lambda(lambda x : x[1] + tf.stop_gradient(x[0] - x[1]), name="straight_through_estimator")
straight_through_zq = straight_through([z_q, z_e])
reconstructed = decoder(straight_through_zq)
vq_vae = K.Model(inputs=encoder_inputs, outputs=[reconstructed, codes], name='vq-vae')
## VQVAE model (inference)
codebook_indices = K.layers.Input(shape=(SIZE, SIZE), name='discrete_codes', dtype=tf.int32)
z_q = sampling_layer(codebook_indices)
generated = decoder(z_q)
vq_vae_sampler = K.Model(inputs=codebook_indices, outputs=generated, name='vq-vae-sampler')
## Transition from codebook indices to model (for training the prior later)
indices = K.layers.Input(shape=(SIZE, SIZE), name='codes_sampler_inputs', dtype='int32')
z_q = sampling_layer(indices)
codes_sampler = K.Model(inputs=indices, outputs=z_q, name="codes_sampler")
## Getter to easily access the codebook for vizualisation
indices = K.layers.Input(shape=(), dtype='int32')
vector_model = K.Model(inputs=indices, outputs=vector_quantizer.sample(indices[:, None, None]), name='get_codebook')
def get_vq_vae_codebook():
codebook = vector_model.predict(np.arange(k))
codebook = np.reshape(codebook, (k, d))
return codebook
return vq_vae, vq_vae_sampler, encoder, decoder, codes_sampler, get_vq_vae_codebook
vq_vae, vq_vae_sampler, encoder, decoder, codes_sampler, get_vq_vae_codebook = build_vqvae(
NUM_LATENT_K, NUM_LATENT_D, input_shape=INPUT_SHAPE, num_layers=VQVAE_LAYERS)
vq_vae.summary()
Training the model
All is left now is to train the model: The training objective contains the reconstruction loss (here, we use mean squared error), the KL divergence term on the latent codebook (ignored because it is constant as we assume a uniform prior during training), and two vector quantization losses which guarantee that (i) the outputs of the encoder stay close to the codebook entries they are matched to and (ii) that the codebook does not grow too much relatively to the space of the encoder outputs.
\[\begin{align} \mathcal L_{\text{VQ-VAE}}(x) = - \mathbb{E}_{z \sim f(x)}{p(x | z)} + \| z_e - \bar{z_q}\|^2 + \|\bar{z_e} - z_q\|^2 \end{align}\]where \(\bar{\cdot}\) denotes the stop gradient operation: i.e., during forward pass, this corresponds to the identity, but during backpropagation no gradients are flowing through this operation.
def mse_loss(ground_truth, predictions):
mse_loss = tf.reduce_mean((ground_truth - predictions)**2, name="mse_loss")
return mse_loss
def latent_loss(dummy_ground_truth, outputs):
global BETA
del dummy_ground_truth
z_e, z_q = tf.split(outputs, 2, axis=-1)
vq_loss = tf.reduce_mean((tf.stop_gradient(z_e) - z_q)**2)
commit_loss = tf.reduce_mean((z_e - tf.stop_gradient(z_q))**2)
latent_loss = tf.identity(vq_loss + BETA * commit_loss, name="latent_loss")
return latent_loss
We can now train the model on the MNIST dataset:
Figure 1: Training the VQ-VAE on the MNIST dataset
Once training is done, we can also visualize some of the results, such as reconstructions on the test set and the learned codebook entries (projected to 2D with TSNE). In particular, we observe that reconstructions are close to perfect, which indicates the model is able to learn a meaningful codebook, as well as how to map images to this discrete space.
Figure 2: Reconstructions on the test set. The image reads row-wise such that every pair contains the original image (left) and its reconstruction (right) with the mean squared error distance to the original.
Learning a prior over the latent space
We have now learned an encoder-decoder architecture and a discrete latent codebook powerful enough to encode and reconstruct our dataset. However, the uniform prior assumption during training is not sufficient for generating good samples. In fact, due to the fully-convolutional architecture, each image is encoded with SIZE
x SIZE
latent vectors from the codebook (for instance, SIZE = 7
for our current model).
However, the codes for our dataset have no guarantee to lie uniformly on that space, as we assumed during training, but rather have some specific structure that follow a certain non-uniform categorical prior. This can be seen easily by generating images from code feature maps sampled uniformly from the total latent space of size SIZE
x SIZE
x \(K\).
Figure 3: Generating samples from the uniform latent prior assumed during training
PixelCNN
To solve the problem and sample likely codes from the latent space, the authors propose to learn a powerful categorical prior over the latent codes from the training images using a PixelCNN. PixelCNN is a fully probabilistic autoregressive generative model that generates images (or here, feature maps) pixel by pixel, conditioned on the previously generated pixels. The main drawback of such models is that the sampling process is rather slow. However, since here we are only generating small SIZE
x SIZE
maps, the overhead is not too bad.
Here we consider the architecture proposed in Conditional Image Generation with PixelCNN Decoders (van den Oord et al, NeurIPS 2017) which uses gated and masked convolutions to model the fact that pixels only depend from the previously generated context. We implement the base building block of the architecture as the following Keras pipeline:
# References:
# https://github.com/anantzoid/Conditional-PixelCNN-decoder/blob/master/layers.py
# https://github.com/ritheshkumar95/pytorch-vqvae
def gate(inputs):
"""Gated activations"""
x, y = tf.split(inputs, 2, axis=-1)
return Kb.tanh(x) * Kb.sigmoid(y)
class MaskedConv2D(K.layers.Layer):
"""Masked convolution"""
def __init__(self, kernel_size, out_dim, direction, mode, **kwargs):
self.direction = direction # Horizontal or vertical
self.mode = mode # Mask type "a" or "b"
self.kernel_size = kernel_size
self.out_dim = out_dim
super(MaskedConv2D, self).__init__(**kwargs)
def build(self, input_shape):
filter_mid_y = self.kernel_size[0] // 2
filter_mid_x = self.kernel_size[1] // 2
in_dim = int(input_shape[-1])
w_shape = [self.kernel_size[0], self.kernel_size[1], in_dim, self.out_dim]
mask_filter = np.ones(w_shape, dtype=np.float32)
# Build the mask
if self.direction == "h":
mask_filter[filter_mid_y + 1:, :, :, :] = 0.
mask_filter[filter_mid_y, filter_mid_x + 1:, :, :] = 0.
elif self.direction == "v":
if self.mode == 'a':
mask_filter[filter_mid_y:, :, :, :] = 0.
elif self.mode == 'b':
mask_filter[filter_mid_y+1:, :, :, :] = 0.0
if self.mode == 'a':
mask_filter[filter_mid_y, filter_mid_x, :, :] = 0.0
# Create convolution layer parameters with masked kernel
self.W = mask_filter * self.add_weight("W_{}".format(self.direction), w_shape, trainable=True)
self.b = self.add_weight("v_b", [self.out_dim,], trainable=True)
def call(self, inputs):
return K.backend.conv2d(inputs, self.W, strides=(1, 1)) + self.b
def gated_masked_conv2d(v_stack_in, h_stack_in, out_dim, kernel, mask='b', residual=True, i=0):
"""Basic Gated-PixelCNN block.
This is an improvement over PixelRNN to avoid "blind spots", i.e. pixels missingt from the
field of view. It works by having two parallel stacks, for the vertical and horizontal direction,
each being masked to only see the appropriate context pixels.
"""
kernel_size = (kernel // 2 + 1, kernel)
padding = (kernel // 2, kernel // 2)
v_stack = K.layers.ZeroPadding2D(padding=padding, name="v_pad_{}".format(i))(v_stack_in)
v_stack = MaskedConv2D(kernel_size, out_dim * 2, "v", mask, name="v_masked_conv_{}".format(i))(v_stack)
v_stack = v_stack[:, :int(v_stack_in.get_shape()[-3]), :, :]
v_stack_out = K.layers.Lambda(lambda inputs: gate(inputs), name="v_gate_{}".format(i))(v_stack)
kernel_size = (1, kernel // 2 + 1)
padding = (0, kernel // 2)
h_stack = K.layers.ZeroPadding2D(padding=padding, name="h_pad_{}".format(i))(h_stack_in)
h_stack = MaskedConv2D(kernel_size, out_dim * 2, "h", mask, name="h_masked_conv_{}".format(i))(h_stack)
h_stack = h_stack[:, :, :int(h_stack_in.get_shape()[-2]), :]
h_stack_1 = K.layers.Conv2D(filters=out_dim * 2, kernel_size=1, strides=(1, 1), name="v_to_h_{}".format(i))(v_stack)
h_stack_out = K.layers.Lambda(lambda inputs: gate(inputs), name="h_gate_{}".format(i))(h_stack + h_stack_1)
h_stack_out = K.layers.Conv2D(filters=out_dim, kernel_size=1, strides=(1, 1), name="res_conv_{}".format(i))(h_stack_out)
if residual:
h_stack_out += h_stack_in
return v_stack_out, h_stack_out
Training the prior
In order to train the prior, we’re going to encode every training image to obtain their discrete representation as indices in the latent codebook. This is also a good opportunity to visualize the discrete representations learned by the encoder. Here we can notice some interesting features, such as: all the black/background pixels are mapped to the same codeword, and the same for dense/white pixels (e.g., the ones at the center of a number)
Figure 4: Visualizing categorical codes learned by the VQ-VAE encoder
Equipped with the Masked gated convolutions and the training set, we can finally build our PixelCNN architecture and train the prior. The full model simply consists in a concatenation of masked and gated convolutions, followed by two fully-connected layers to output the final prediction. Here the training objective is a multi-class classification one, as the prior should output a map of codebook indices.
Figure 5: Training the VQ-VAE PixelCNN prior
Once again, we can check the model ability to reconstruct discrete latent codes obtained from the test set:
Figure 6: Training the VQ-VAE PixelCNN prior
More importantly, let’s have a look at images generated by sampling from the prior. As expected, they look much better than sampling from a uniform distribution, which means that (i) discrete codes for our image distribution lie in a specific subset of the latent space and (ii) the PixelCNN was able to properly model a prior probability distribution on that space
Figure 7: Generating samples from the learned PixelCNN prior