In [1]:
import torch
import matplotlib.pyplot as plt

from tqdm import tqdm
from torch import nn, optim
from torch.utils.data import DataLoader
from IPython.display import clear_output
from torchvision import datasets, transforms
In [2]:
cifar = datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=transforms.ToTensor()
)

cifar = torch.utils.data.Subset(cifar, range(20))

dataloader = DataLoader(
    cifar,
    batch_size=1,
    shuffle=True,
)
Files already downloaded and verified
In [3]:
class VariationalAutoEncoder(nn.Module):
    def __init__(self, x_dim, hidden_dim, latent_dim=20):
        super().__init__()
        
        self.x_dim = x_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        
        self.encoder = nn.Sequential(
            nn.Linear(x_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
        )
        
        self.mu = nn.Linear(hidden_dim, latent_dim)
        self.sigma = nn.Linear(hidden_dim, latent_dim)
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, x_dim),
            nn.Sigmoid(),
        )
        
    def reparameterize(self, mu, sigma):
        epsilon = torch.randn_like(sigma)
        return mu + sigma * epsilon
        
    def forward(self, x):
        x = self.encoder(x)
        
        mu = self.mu(x)
        sigma = self.sigma(x)
        
        z = self.reparameterize(mu, torch.exp(sigma) * 0.5)  # sigma has to be positive values
        
        reconstructed = self.decoder(z)
        
        return mu, sigma, reconstructed
In [4]:
class LossFunction(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x, mu, sigma, reconstructed):
        """L(θ,φ;x) = -D_KL(q_φ(z|x)||p(z)) + E_q_φ(z|x)[log p_θ(x|z)]"""
        
        kl_divergence_loss = -0.5 * torch.sum(1 + sigma - mu ** 2 - sigma.exp())
        reconstruction_loss = nn.functional.binary_cross_entropy(reconstructed, x, reduction="sum")
        
        return kl_divergence_loss + reconstruction_loss
In [5]:
epochs = 500
device = "cuda" if torch.cuda.is_available() else "cpu"

vae = VariationalAutoEncoder(x_dim=32*32*3, hidden_dim=32*32, latent_dim=320).to(device)
optimizer = optim.Adam(vae.parameters(), lr=1e-4)
loss_function = LossFunction()

for epoch in range(1, epochs + 1):
    clear_output()
    
    print(f"{epoch}/{epochs}")
    for batch, labels in tqdm(dataloader):
        batch = batch.to(device)
        x = batch.reshape(batch.shape[0], -1)
        
        optimizer.zero_grad()
        
        mu, sigma, reconstructed = vae(x)
    
        loss = loss_function(x, mu, sigma, reconstructed)
        loss.backward()
        optimizer.step()
    
500/500
100%|██████████| 20/20 [00:00<00:00, 111.03it/s]
In [6]:
def visualize(vae, dataloader, device, num_images=10):
    vae.eval()
    
    _, axes = plt.subplots(2, 10, figsize=(2 * 10, 4))
    
    for i, (batch, _) in enumerate(dataloader):
        if i >= num_images:
            break
        
        original = batch.cpu()
        
        with torch.no_grad():
            x = batch.reshape(batch.shape[0], -1).to(device)
            _, _, reconstructed = vae(x)
            reconstructed = reconstructed.reshape(-1, 3, 32, 32).cpu()
        
        axes[0, i].imshow(original[0].permute(1, 2, 0).clamp(0, 1))
        axes[0, i].axis("off")
            
        axes[1, i].imshow(reconstructed[0].permute(1, 2, 0).clamp(0, 1))
        axes[1, i].axis("off")
    
    plt.tight_layout()
    plt.show()
    
    vae.train()
    
visualize(vae, dataloader, device)
Variational Autoencoder — plot 1
In [7]:
def generate_samples(vae, n_samples, device):
    vae.eval()
    
    with torch.no_grad():
        z = torch.randn(n_samples, vae.latent_dim).to(device)
        generated = vae.decoder(z)
        generated = generated.reshape(-1, 3, 32, 32)
    
    _, axes = plt.subplots(1, n_samples, figsize=(2 * n_samples, 2))
    
    for i in range(n_samples):
        axes[i].imshow(generated[i].cpu().permute(1, 2, 0).clamp(0, 1))
        axes[i].axis("off")
    
    plt.tight_layout()
    plt.show()
    
    vae.train()
    
generate_samples(vae, 10, device)
Variational Autoencoder — plot 2
In [8]:
def interpolate_two_images(vae, dataloader, device="cuda", n_steps=10):
    vae.eval()
    
    images = []
    for batch, _ in dataloader:
        images.append(batch)
        if len(images) == 2:
            break
    
    img1, img2 = images[0].to(device), images[1].to(device)
    
    plt.figure(figsize=(2*n_steps, 2))
    
    with torch.no_grad():
        x1 = img1.reshape(1, -1)
        x2 = img2.reshape(1, -1)
        
        mu1, sigma1, _ = vae(x1)
        mu2, sigma2, _ = vae(x2)
        z1 = vae.reparameterize(mu1, torch.exp(sigma1) * 0.5)
        z2 = vae.reparameterize(mu2, torch.exp(sigma2) * 0.5)
        
        for i in range(n_steps):
            alpha = i / (n_steps - 1)

            z = z1 + (z2 - z1) * alpha

            generated = vae.decoder(z)
            generated = generated.reshape(-1, 3, 32, 32)
            
            plt.subplot(1, n_steps, i + 1)
            plt.imshow(generated[0].cpu().permute(1, 2, 0).clamp(0, 1))
            plt.axis("off")
    
    plt.tight_layout()
    plt.show()
    
    vae.train()
    
interpolate_two_images(vae, dataloader)
Variational Autoencoder — plot 3