# VQ-VAE

This notebook contains a Keras / Tensorflow implementation of the **VQ-VAE** model, which was introduced in Neural Discrete Representation Learning (van den Oord et al, NeurIPS 2017). This is a generative model based on Variational Auto Encoders (**VAE**) which aims to make the latent space discrete using Vector Quantization (**VQ**) techniques. This implementation trains a VQ-VAE based on simple convolutional blocks (no auto-regressive decoder), and a PixelCNN categorical prior as described in the paper. The current code was tested on MNIST. This project is also hosted as a Kaggle notebook.

**Pros (+):**

- Simple method and training objective
- “Proper” Discrete latent space. This is a promising property to model data that is inherently discrete, e.g. text.

**Cons (-):**

- Loses the “easy latent sampling” property from VAEs. Two-stage training required to learn a fitting categorical prior.
- The training objective does not correspond to a bound on the log-likelihood anymore.

### Vector-Quantized Latent Space

The first step is to build the main VQ-VAE model. It consists of a standard encoder-decoder architecture with convolutional blocks. The main novelty lies in the intermediate *Vector Quantizer* layer (**VQ**) that takes care of building a *discrete* latent space.

More specifically, the encoder, , is a *fully-convolutional neural network* that maps input images to latent codes of size , where is the dimension of the latent space, and the size of the final feature map. The output of the encoder is then mapped to the closest entry in a discrete *codebook* of latent codes, where .

The Vector Quantization process is implemented as the following Keras layer:

```
class VectorQuantizer(K.layers.Layer):
def __init__(self, k, **kwargs):
super(VectorQuantizer, self).__init__(**kwargs)
self.k = k
def build(self, input_shape):
self.d = int(input_shape[-1])
rand_init = K.initializers.VarianceScaling(distribution="uniform")
self.codebook = self.add_weight(shape=(self.k, self.d), initializer=rand_init, trainable=True)
def call(self, inputs):
# Map z_e of shape (b, w,, h, d) to indices in the codebook
lookup_ = tf.reshape(self.codebook, shape=(1, 1, 1, self.k, self.d))
z_e = tf.expand_dims(inputs, -2)
dist = tf.norm(z_e - lookup_, axis=-1)
k_index = tf.argmin(dist, axis=-1)
return k_index
def sample(self, k_index):
# Map indices array of shape (b, w, h) to actual codebook z_q
lookup_ = tf.reshape(self.codebook, shape=(1, 1, 1, self.k, self.d))
k_index_one_hot = tf.one_hot(k_index, self.k)
z_q = lookup_ * k_index_one_hot[..., None]
z_q = tf.reduce_sum(z_q, axis=-2)
return z_q
```

The decoder, , then takes the quantized codes as inputs and generates the output image. Here we consider a simple architecture with transposed convolution blocks, mirroring the encoder architecture:

```
def encoder_pass(inputs, d, num_layers=[16, 32]):
x = inputs
for i, filters in enumerate(num_layers):
x = K.layers.Conv2D(filters=filters, kernel_size=3, padding='SAME', activation='relu',
strides=(2, 2), name="conv{}".format(i + 1))(x)
z_e = K.layers.Conv2D(filters=d, kernel_size=3, padding='SAME', activation=None,
strides=(1, 1), name='z_e')(x)
return z_e
def decoder_pass(inputs, num_layers=[32, 16]):
y = inputs
for i, filters in enumerate(num_layers):
y = K.layers.Conv2DTranspose(filters=filters, kernel_size=4, strides=(2, 2), padding="SAME",
activation='relu', name="convT{}".format(i + 1))(y)
decoded = K.layers.Conv2DTranspose(filters=1, kernel_size=3, strides=(1, 1),
padding="SAME", activation='sigmoid', name='output')(y)
return decoded
```

Once these three building blocks are done, we can build the full *VQ-VAE*. One subtility is how we can estimate gradient through the Vector Quantizer: In fact, the transition from to does not allow to backpropagate gradient due to the argmin function. Instead, the authors propose to use a *straight-through estimator*, that directly copies the gradient received by to .

```
def build_vqvae(k, d, input_shape=(28, 28, 1), num_layers=[16, 32]):
global SIZE
## Encoder
encoder_inputs = K.layers.Input(shape=input_shape, name='encoder_inputs')
z_e = encoder_pass(encoder_inputs, d, num_layers=num_layers)
SIZE = int(z_e.get_shape()[1])
## Vector Quantization
vector_quantizer = VectorQuantizer(k, name="vector_quantizer")
codebook_indices = vector_quantizer(z_e)
encoder = K.Model(inputs=encoder_inputs, outputs=codebook_indices, name='encoder')
## Decoder
decoder_inputs = K.layers.Input(shape=(SIZE, SIZE, d), name='decoder_inputs')
decoded = decoder_pass(decoder_inputs, num_layers=num_layers[::-1])
decoder = K.Model(inputs=decoder_inputs, outputs=decoded, name='decoder')
## VQVAE Model (training)
sampling_layer = K.layers.Lambda(lambda x: vector_quantizer.sample(x), name="sample_from_codebook")
z_q = sampling_layer(codebook_indices)
codes = tf.stack([z_e, z_q], axis=-1)
codes = K.layers.Lambda(lambda x: x, name='latent_codes')(codes)
straight_through = K.layers.Lambda(lambda x : x[1] + tf.stop_gradient(x[0] - x[1]), name="straight_through_estimator")
straight_through_zq = straight_through([z_q, z_e])
reconstructed = decoder(straight_through_zq)
vq_vae = K.Model(inputs=encoder_inputs, outputs=[reconstructed, codes], name='vq-vae')
## VQVAE model (inference)
codebook_indices = K.layers.Input(shape=(SIZE, SIZE), name='discrete_codes', dtype=tf.int32)
z_q = sampling_layer(codebook_indices)
generated = decoder(z_q)
vq_vae_sampler = K.Model(inputs=codebook_indices, outputs=generated, name='vq-vae-sampler')
## Transition from codebook indices to model (for training the prior later)
indices = K.layers.Input(shape=(SIZE, SIZE), name='codes_sampler_inputs', dtype='int32')
z_q = sampling_layer(indices)
codes_sampler = K.Model(inputs=indices, outputs=z_q, name="codes_sampler")
## Getter to easily access the codebook for vizualisation
indices = K.layers.Input(shape=(), dtype='int32')
vector_model = K.Model(inputs=indices, outputs=vector_quantizer.sample(indices[:, None, None]), name='get_codebook')
def get_vq_vae_codebook():
codebook = vector_model.predict(np.arange(k))
codebook = np.reshape(codebook, (k, d))
return codebook
return vq_vae, vq_vae_sampler, encoder, decoder, codes_sampler, get_vq_vae_codebook
vq_vae, vq_vae_sampler, encoder, decoder, codes_sampler, get_vq_vae_codebook = build_vqvae(
NUM_LATENT_K, NUM_LATENT_D, input_shape=INPUT_SHAPE, num_layers=VQVAE_LAYERS)
vq_vae.summary()
```

### Training the model

All is left now is to train the model: The training objective contains the *reconstruction loss* (here, we use mean squared error), the KL divergence term on the latent codebook (ignored because it is constant as we assume a uniform prior during training), and two *vector quantization losses* which guarantee that **(i)** the outputs of the encoder stay close to the codebook entries they are matched to and **(ii)** that the codebook does not grow too much relatively to the space of the encoder outputs.

where denotes the stop gradient operation: i.e., during forward pass, this corresponds to the identity, but during backpropagation no gradients are flowing through this operation.

```
def mse_loss(ground_truth, predictions):
mse_loss = tf.reduce_mean((ground_truth - predictions)**2, name="mse_loss")
return mse_loss
def latent_loss(dummy_ground_truth, outputs):
global BETA
del dummy_ground_truth
z_e, z_q = tf.split(outputs, 2, axis=-1)
vq_loss = tf.reduce_mean((tf.stop_gradient(z_e) - z_q)**2)
commit_loss = tf.reduce_mean((z_e - tf.stop_gradient(z_q))**2)
latent_loss = tf.identity(vq_loss + BETA * commit_loss, name="latent_loss")
return latent_loss
```

We can now train the model on the MNIST dataset:

**Figure 1**: Training the VQ-VAE on the MNIST dataset

Once training is done, we can also visualize some of the results, such as reconstructions on the test set and the learned codebook entries (projected to 2D with *TSNE*). In particular, we observe that reconstructions are close to perfect, which indicates the model is able to learn a meaningful codebook, as well as how to map images to this discrete space.

**Figure 2**: Reconstructions on the test set. The image reads row-wise such that every pair contains the original image (left) and its reconstruction (right) with the mean squared error distance to the original.

### Learning a prior over the latent space

We have now learned an encoder-decoder architecture and a discrete latent codebook powerful enough to encode and reconstruct our dataset. However, the *uniform prior* assumption during training is not sufficient for generating good samples. In fact, due to the fully-convolutional architecture, each image is encoded with `SIZE`

x `SIZE`

latent vectors from the codebook (for instance, `SIZE = 7`

for our current model).

However, the codes for our dataset have no guarantee to lie uniformly on that space, as we assumed during training, but rather have some specific structure that follow a certain *non-uniform categorical prior*. This can be seen easily by generating images from code feature maps sampled uniformly from the total latent space of size `SIZE`

x `SIZE`

x .

**Figure 3**: Generating samples from the uniform latent prior assumed during training

#### PixelCNN

To solve the problem and sample likely codes from the latent space, the authors propose to learn a *powerful categorical prior* over the latent codes from the training images using a **PixelCNN**. PixelCNN is a fully probabilistic autoregressive generative model that generates images (or here, feature maps) pixel by pixel, conditioned on the previously generated pixels. The main drawback of such models is that the sampling process is rather slow.
However, since here we are only generating small `SIZE`

x `SIZE`

maps, the overhead is not too bad.

Here we consider the architecture proposed in Conditional Image Generation with PixelCNN Decoders (van den Oord et al, NeurIPS 2017) which uses *gated* and *masked* convolutions to model the fact that pixels only depend from the previously generated context. We implement the base building block of the architecture as the following Keras pipeline:

```
# References:
# https://github.com/anantzoid/Conditional-PixelCNN-decoder/blob/master/layers.py
# https://github.com/ritheshkumar95/pytorch-vqvae
def gate(inputs):
"""Gated activations"""
x, y = tf.split(inputs, 2, axis=-1)
return Kb.tanh(x) * Kb.sigmoid(y)
class MaskedConv2D(K.layers.Layer):
"""Masked convolution"""
def __init__(self, kernel_size, out_dim, direction, mode, **kwargs):
self.direction = direction # Horizontal or vertical
self.mode = mode # Mask type "a" or "b"
self.kernel_size = kernel_size
self.out_dim = out_dim
super(MaskedConv2D, self).__init__(**kwargs)
def build(self, input_shape):
filter_mid_y = self.kernel_size[0] // 2
filter_mid_x = self.kernel_size[1] // 2
in_dim = int(input_shape[-1])
w_shape = [self.kernel_size[0], self.kernel_size[1], in_dim, self.out_dim]
mask_filter = np.ones(w_shape, dtype=np.float32)
# Build the mask
if self.direction == "h":
mask_filter[filter_mid_y + 1:, :, :, :] = 0.
mask_filter[filter_mid_y, filter_mid_x + 1:, :, :] = 0.
elif self.direction == "v":
if self.mode == 'a':
mask_filter[filter_mid_y:, :, :, :] = 0.
elif self.mode == 'b':
mask_filter[filter_mid_y+1:, :, :, :] = 0.0
if self.mode == 'a':
mask_filter[filter_mid_y, filter_mid_x, :, :] = 0.0
# Create convolution layer parameters with masked kernel
self.W = mask_filter * self.add_weight("W_{}".format(self.direction), w_shape, trainable=True)
self.b = self.add_weight("v_b", [self.out_dim,], trainable=True)
def call(self, inputs):
return K.backend.conv2d(inputs, self.W, strides=(1, 1)) + self.b
def gated_masked_conv2d(v_stack_in, h_stack_in, out_dim, kernel, mask='b', residual=True, i=0):
"""Basic Gated-PixelCNN block.
This is an improvement over PixelRNN to avoid "blind spots", i.e. pixels missingt from the
field of view. It works by having two parallel stacks, for the vertical and horizontal direction,
each being masked to only see the appropriate context pixels.
"""
kernel_size = (kernel // 2 + 1, kernel)
padding = (kernel // 2, kernel // 2)
v_stack = K.layers.ZeroPadding2D(padding=padding, name="v_pad_{}".format(i))(v_stack_in)
v_stack = MaskedConv2D(kernel_size, out_dim * 2, "v", mask, name="v_masked_conv_{}".format(i))(v_stack)
v_stack = v_stack[:, :int(v_stack_in.get_shape()[-3]), :, :]
v_stack_out = K.layers.Lambda(lambda inputs: gate(inputs), name="v_gate_{}".format(i))(v_stack)
kernel_size = (1, kernel // 2 + 1)
padding = (0, kernel // 2)
h_stack = K.layers.ZeroPadding2D(padding=padding, name="h_pad_{}".format(i))(h_stack_in)
h_stack = MaskedConv2D(kernel_size, out_dim * 2, "h", mask, name="h_masked_conv_{}".format(i))(h_stack)
h_stack = h_stack[:, :, :int(h_stack_in.get_shape()[-2]), :]
h_stack_1 = K.layers.Conv2D(filters=out_dim * 2, kernel_size=1, strides=(1, 1), name="v_to_h_{}".format(i))(v_stack)
h_stack_out = K.layers.Lambda(lambda inputs: gate(inputs), name="h_gate_{}".format(i))(h_stack + h_stack_1)
h_stack_out = K.layers.Conv2D(filters=out_dim, kernel_size=1, strides=(1, 1), name="res_conv_{}".format(i))(h_stack_out)
if residual:
h_stack_out += h_stack_in
return v_stack_out, h_stack_out
```

#### Training the prior

In order to train the prior, we’re going to encode every training image to obtain their discrete representation as indices in the latent codebook. This is also a good opportunity to *visualize the discrete representations learned by the encoder*. Here we can notice some interesting features, such as: all the black/background pixels are mapped to the same codeword, and the same for dense/white pixels (e.g., the ones at the center of a number)

**Figure 4**: Visualizing categorical codes learned by the VQ-VAE encoder

Equipped with the Masked gated convolutions and the training set, we can finally build our PixelCNN architecture and train the prior. The full model simply consists in a concatenation of masked and gated convolutions, followed by two fully-connected layers to output the final prediction. Here the training objective is a multi-class classification one, as the prior should output a map of codebook indices.

**Figure 5**: Training the VQ-VAE PixelCNN prior

Once again, we can check the model ability to *reconstruct* discrete latent codes obtained from the test set:

**Figure 6**: Training the VQ-VAE PixelCNN prior

More importantly, let’s have a look at images *generated* by sampling from the prior. As expected, they look much better than sampling from a uniform distribution, which means that **(i)** discrete codes for our image distribution lie in a specific subset of the latent space and **(ii)** the PixelCNN was able to properly model a prior probability distribution on that space

**Figure 7**: Generating samples from the learned PixelCNN prior