import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch import nn, optim
from torch.utils.data import DataLoader
from IPython.display import clear_output
from torchvision import datasets, transforms
cifar = datasets.CIFAR10(
root="./data",
train=True,
download=True,
transform=transforms.ToTensor()
)
cifar = torch.utils.data.Subset(cifar, range(20))
dataloader = DataLoader(
cifar,
batch_size=1,
shuffle=True,
)
Files already downloaded and verified
class VariationalAutoEncoder(nn.Module):
def __init__(self, x_dim, hidden_dim, latent_dim=20):
super().__init__()
self.x_dim = x_dim
self.hidden_dim = hidden_dim
self.latent_dim = latent_dim
self.encoder = nn.Sequential(
nn.Linear(x_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(0.2),
)
self.mu = nn.Linear(hidden_dim, latent_dim)
self.sigma = nn.Linear(hidden_dim, latent_dim)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, x_dim),
nn.Sigmoid(),
)
def reparameterize(self, mu, sigma):
epsilon = torch.randn_like(sigma)
return mu + sigma * epsilon
def forward(self, x):
x = self.encoder(x)
mu = self.mu(x)
sigma = self.sigma(x)
z = self.reparameterize(mu, torch.exp(sigma) * 0.5) # sigma has to be positive values
reconstructed = self.decoder(z)
return mu, sigma, reconstructed
class LossFunction(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, mu, sigma, reconstructed):
"""L(θ,φ;x) = -D_KL(q_φ(z|x)||p(z)) + E_q_φ(z|x)[log p_θ(x|z)]"""
kl_divergence_loss = -0.5 * torch.sum(1 + sigma - mu ** 2 - sigma.exp())
reconstruction_loss = nn.functional.binary_cross_entropy(reconstructed, x, reduction="sum")
return kl_divergence_loss + reconstruction_loss
epochs = 500
device = "cuda" if torch.cuda.is_available() else "cpu"
vae = VariationalAutoEncoder(x_dim=32*32*3, hidden_dim=32*32, latent_dim=320).to(device)
optimizer = optim.Adam(vae.parameters(), lr=1e-4)
loss_function = LossFunction()
for epoch in range(1, epochs + 1):
clear_output()
print(f"{epoch}/{epochs}")
for batch, labels in tqdm(dataloader):
batch = batch.to(device)
x = batch.reshape(batch.shape[0], -1)
optimizer.zero_grad()
mu, sigma, reconstructed = vae(x)
loss = loss_function(x, mu, sigma, reconstructed)
loss.backward()
optimizer.step()
500/500
100%|██████████| 20/20 [00:00<00:00, 111.03it/s]
def visualize(vae, dataloader, device, num_images=10):
vae.eval()
_, axes = plt.subplots(2, 10, figsize=(2 * 10, 4))
for i, (batch, _) in enumerate(dataloader):
if i >= num_images:
break
original = batch.cpu()
with torch.no_grad():
x = batch.reshape(batch.shape[0], -1).to(device)
_, _, reconstructed = vae(x)
reconstructed = reconstructed.reshape(-1, 3, 32, 32).cpu()
axes[0, i].imshow(original[0].permute(1, 2, 0).clamp(0, 1))
axes[0, i].axis("off")
axes[1, i].imshow(reconstructed[0].permute(1, 2, 0).clamp(0, 1))
axes[1, i].axis("off")
plt.tight_layout()
plt.show()
vae.train()
visualize(vae, dataloader, device)
def generate_samples(vae, n_samples, device):
vae.eval()
with torch.no_grad():
z = torch.randn(n_samples, vae.latent_dim).to(device)
generated = vae.decoder(z)
generated = generated.reshape(-1, 3, 32, 32)
_, axes = plt.subplots(1, n_samples, figsize=(2 * n_samples, 2))
for i in range(n_samples):
axes[i].imshow(generated[i].cpu().permute(1, 2, 0).clamp(0, 1))
axes[i].axis("off")
plt.tight_layout()
plt.show()
vae.train()
generate_samples(vae, 10, device)
def interpolate_two_images(vae, dataloader, device="cuda", n_steps=10):
vae.eval()
images = []
for batch, _ in dataloader:
images.append(batch)
if len(images) == 2:
break
img1, img2 = images[0].to(device), images[1].to(device)
plt.figure(figsize=(2*n_steps, 2))
with torch.no_grad():
x1 = img1.reshape(1, -1)
x2 = img2.reshape(1, -1)
mu1, sigma1, _ = vae(x1)
mu2, sigma2, _ = vae(x2)
z1 = vae.reparameterize(mu1, torch.exp(sigma1) * 0.5)
z2 = vae.reparameterize(mu2, torch.exp(sigma2) * 0.5)
for i in range(n_steps):
alpha = i / (n_steps - 1)
z = z1 + (z2 - z1) * alpha
generated = vae.decoder(z)
generated = generated.reshape(-1, 3, 32, 32)
plt.subplot(1, n_steps, i + 1)
plt.imshow(generated[0].cpu().permute(1, 2, 0).clamp(0, 1))
plt.axis("off")
plt.tight_layout()
plt.show()
vae.train()
interpolate_two_images(vae, dataloader)