Convolutional VAE in Flux
In this post, we’ll take a look at variational autoencoders and demonstrate an implementation using the FashionMNIST dataset and Flux. This content follows on from my previous post where I introduce Flux and show how it compares with Tensorflow and PyTorch.
Formulating the Variational Autoencoder
Before looking at the implementation, I’ll present a short overview of autoencoders and the differentiating features of a variational autoencoder (VAE). For a more complete description of the background and the loss function derivation, have a look at the original VAE paper: Auto-Encoding Variational Bayes.
Autoencoders
An autoencoder is a type of neural network made up of two principal components, an encoder and a decoder. The role of the encoder is to extract learnt features from the input data, \(x\), and represent them in a constrained latent space, \(z\). Ideally, this latent space, sometimes called the bottleneck layer, is a representation of the compressed underlying characteristics of the data. The decoder then generates a reconstruction of the original image, \(\hat{x}\), which aims to closely resemble the input data. If the encoder and decoders are modelled using neural networks, we can train the autoencoder to minimise the reconstruction loss between \(x\) and \(\hat{x}\).
Making them Variational
The key difference between the vanilla autoencoder, and a VAE, is in the treatment of the latent space, \(z\). For VAEs we model the latent space as a probability distribution, \(q(z|x)\), which approximates some prior, \(p(z)\). Typically this prior is the Gaussian \(\mathcal{N} (0, 1)\). We train the encoder to learn the mean and standard deviation of \(q(z | x)\), which we then use to generate samples to feed into the decoder network. Since we still want to train our VAE using backpropagation and gradient descent, we need a mechanism for removing the sampling operation from the backprogation path whilst still obtaining samples of \(z\). To this end, we apply the reparameterisation trick and perform our sampling via \(z \sim \mu + \sigma \odot \epsilon\) where \(\epsilon \sim p(z)\).
As before, we define our loss function such that we minimise the reconstruction loss between the original \(x\) and the reconstruction \(\hat{x}\). An additional KL loss term is included which penalises the model when \(q(z | x)\) deviates from \(p(z)\). This loss function ends up being equivalent to the Evidence Lower Bound (ELBO), which during optimisation we intent to maximise. If we parameterise our encoder and decoder with the parameters \(\phi\) and \(\theta\) respectively, the objective can be written as follows.
Building the VAE in Flux
For the remainder of this post, we move away from the theory and step through an example implementation of a Convolutional VAE using Flux. The code snippets to follow are taken from my Github repository, so head over there if you want to simply jump to the complete scripts. Let’s get started with learning how to leverage dataloaders to easily import the FashionMNIST dataset.
Loading the FashionMNIST dataset
FashionMNIST is an incredibly popular benchmarking dataset, made up of low-resolution greyscale images of clothes and accessories, which operates as an easy drop-in replacement for the simpler original MNIST dataset. Each image is one of 10 possible item types, which are used as the 10 labelled classes. The full dataset, established by Zalando, is made up of a training set of 60 000 images and a test set of 10 000 images, where all images are 28 by 28 in pixel dimensions. For our demonstration, we zero-pad each image to 32 by 32 pixels so that we can apply a similar model architecture as documented in β-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework and Understanding disentangling in β-VAE.
We can define a function to do just that, and return a Flux.Data.DataLoader
object which will handle the batching and shuffling of our data.
Defining the model
Next, we have a method which defines a fairly typical Convolutional VAE architecture. The encoder is defined by chaining 3 convolutional layers, with a kernel width of 4 and 32 filters. The output from these layers is then flattened before being pushed to two dense layers with 256 neurons each. This portion of the encoder which we call encoder_features
has two separate fully connected layers branching off it to provide us with the networks which generate the mean (\(\mu\)) and logarithmic variance (\(log\sigma^2\)) vectors. We use the log variance, rather than the variance, so that we can leave the encoder_logvar
network unconstrained and not worry about how forcing the network to only produce positive values might affect the optimisation process.
The decoder is defined to look like the transpose of the encoder where we expect the input from the latent space, \(z\), to have the same dimensionality of the \(\mu\) vector produced by the encoder. Something important to note in the decoder is that we have defined a custom layer Reshape
rather than using the operation x -> reshape(x, (4, 4, 32, :))
. This custom layer is able to be saved and loaded using the BSON
package while the built-in reshape
operation caused problems when I tried to forward pass a model loaded from disk.
The training loop
Flux facilitates custom training loops, which are great for allowing custom progress tracking and metric logging code. The train()
function below takes in the three Chain
components which make up our VAE, the dataloader
described above, as well as some key training parameters. These include a weight decay regularisation parameter (\(\lambda\)), a hyperparameter which controls the relative importance of disentangling factors of variation (\(\beta\)), amongst others. For each batch, we calculate the loss as defined in vae_loss()
and generate a pullback from which to calculate the gradients.
Calculating the loss
Before we can train our model, we need to define the loss function vae_loss()
. The method takes in our mean and logvar encoders, the decoder, the batch of images to train on, \(x\), as well as the \(\beta\) and \(\lambda\) hyperparameters. First, \(x\) is fed through the encoder to generate our mean and log variance vectors. We then sample from \(q(z|x)\) using the reparameterisation trick, where we obtain the standard deviation through log manipulation, to obtain \(z\). The reconstructed image is then generated by pushing \(z\) through the decoder. The ELBO is calculated by subtracting the reverse KL divergence from the negative reconstriction loss. Finally, the function returns the sum of the negative ELBO and an \(L_{2}\) weight decay regularisation term. As mentioned above, we actually want to maximise the ELBO, but in the context of a code implementation, it is more intuitive to minimise the negative ELBO.
Evaluate a trained model
That is the main modelling done! For demonstration purposes, I trained the model for 10 epochs, using Flux’s Adam optimiser with a learning rate of 0.0001, and saved it to disk. Before we can have a look at some images, let’s define a test data loader (which is very similar to the training data loader) and a function to save our images to disk.
Additionally, we define a function to pass images through the VAE to reconstruct images from the unseen test set. A key thing to note is that we apply the sigmoid
activation to the reconstructed images so that they are normalised appropriately.
Now there is nothing left to do than load the trained VAE from disk, and set up a loop where we reconstruct test set images to compare with the originals.
Show me some images!
The two side-by-side images below demonstrate the reconstruction ability of the VAE on unseen data after training for 10 epochs. The set of images on the left are taken from the test set, while the images on the right are generated from the model. We see that the model has learnt a good enough latent representation to reconstruct the original samples to a reasonable degree of accuracy. That being said, the reconstructed images are certainly blurrier than the corresponding original images. This is a common problem in VAEs, which due to the reverse KL term in their objective exhibit zero-forcing properties and therefore suffer from over-dispersion.
Other VAEs
It is worth mentioning that there have been numerous variations on the VAE architecture. Some interesting examples include β-VAE, \(JS^{G_α}\)-VAEs, and NVAE. Furthermore, if you are specifically interested in disentangling in VAE, take a look at this work I was involved in where we investigated and contrasted a number of disentangling VAE architectures.
To cite this post:
@article{kastanos20fluxvae,
title = "Convolutional VAE in Flux",
author = "Alexandros Kastanos",
journal = "alecokas.github.io",
year = "2020",
url = "http://alecokas.github.io/julia/flux/vae/2020/07/22/convolutional-vae-in-flux.html"
}
References
[1] Innes, Mike. “Flux: Elegant machine learning with Julia.” Journal of Open Source Software 3.25 (2018): 602.
[2] Kingma, Diederik P., and Max Welling. “Auto-encoding variational bayes.” arXiv preprint arXiv:1312.6114 (2013).
[3] Burgess, Christopher P., et al. “Understanding disentangling in $\beta$-VAE.” arXiv preprint arXiv:1804.03599 (2018).
[4] Higgins, Irina, et al. “beta-vae: Learning basic visual concepts with a constrained variational framework.” (2016).
[5] Zhang, Mingtian, et al. “Variational f-divergence minimization.” arXiv preprint arXiv:1907.11891 (2019).
[6] Deasy, Jacob, Nikola Simidjievski, and Pietro Liò. “Constraining Variational Inference with Geometric Jensen-Shannon Divergence.” arXiv preprint arXiv:2006.10599 (2020).
[7] Vahdat, Arash, and Jan Kautz. “NVAE: A Deep Hierarchical Variational Autoencoder.” arXiv preprint arXiv:2007.03898 (2020).