Metamorph VAE

by multi-dim interpolation of two embedings in latent space of Variational Autoencoder

Input: image1.png, image2.png.

Output: multi frames of metamorphosis image1 to image2

Steps:

  1. train VAE on Tiny ImageNet

  2. compute the embeding vectors of each image

  3. make an interpolation of two vectors in latent space

  4. decode each vector λ"dog" + (1-λ)"bird" into image

You can launch colab notebook to replicate my results!

Below, I present the content step by step starting from the simple:

  1. MNIST autoencoder

  2. CIFAR variational autoencoder (VAE)

  3. Metamorph CV

Simple things first: MNIST autoencoder

Let's start with the simplest things: train autoencoder on 28x28 MNIST images.

For architecture, let's chose few linear layers:

Here what we get before training (noise):

Input -> Autoencoder -> Decode

Here what we get after training:

fdgfd

MNIST reconstruction: input vs decoded images

Next, let's check how VAE stores MNIST classes inside its latent space:

Latent space of MNIST classes

The V-autoencoder learned so called Distributional hypothesis:

  • MNIST classes are grouped into clusters inside latent space.

We will use this hypothesis to metamorph from one class to another.

Variational autoencoder(VAE) on CIFAR

Now we train VAE on CIFAR:

CIFAR reconstruction : input vs decoded images

VAE design:

Pseudo-architecture:

  • input RGB (3x28x28)

  • Conv2d(c=64) + ReLU

  • Conv2d(c=128) + ReLU

  • Conv2d(c=256) + ReLU

  • Flatten + Linear

  • Reparametrization ~N(mu, sigma)

  • Linear + Unflatten

  • ConvTranspose2d(c=256) + ReLU

  • ConvTranspose2d(c=128) + ReLU

  • ConvTranspose2d(c=64) + Sigmoid

  • output RGB (3x28x28)

Model was trained from scratch on CIFAR.

Results:

1.Passable accuracy

87% human accuracy (my own classification of 100 decoded samples)

2.White background

I was also surprised how VAE preserves the white background...

...the feature that wasn't explicetely required.

Metamorph

Finally, we train VAE on high quality tiny-imagenet dataset.


Using PCA techniques, we can reduce the latent vectors to 3D dimension and plot their scatterplot.


VAE was able to learn Distributional hypothesis again

VAE latent space and some decoded vectors (dataset=TinyImageNet)

Multi-dim interpolation

The final step is to use interpoaltion of input's embedings to create a path λ"dog" + (1-λ)"bird" along which we decode the intermediate vectors into images

What is multi-dim interpolation and why we use it?

As we choose normal distribution for our VAE, our data clusters are mainly represented by multi-dim bubble (for example 3D bubbles).

In high dimensions, Gaussian distributions are practically indistinguishable from uniform distributions on the unit sphere.

Thus, instead of straight line interpolation in multi-dim, we rather prefer to pass through the sphere surface of each class.

Multi-dim interpolation inside the latent space of pre-trained VAE