[DL] VAE

Author

김보람

Published

November 27, 2023

!pip install torcheval

import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from mpl_toolkits.axes_grid1 import ImageGrid
from torchvision.utils import save_image, make_grid

is_cuda = torch.cuda.is_available()
print(is_cuda)
device = torch.device('cuda' if is_cuda else 'cpu')
print('Current cuda device is', device)
Collecting torcheval
  Downloading torcheval-0.0.7-py3-none-any.whl (179 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 179.2/179.2 kB 3.7 MB/s eta 0:00:00 0:00:01
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torcheval) (4.5.0)
Installing collected packages: torcheval
Successfully installed torcheval-0.0.7
True
Current cuda device is cuda
!nvcc --version
print("Torch version:{}".format(torch.__version__))
print("cuda version: {}".format(torch.version.cuda))
print("cudnn version:{}".format(torch.backends.cudnn.version()))
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0
Torch version:2.1.0+cu118
cuda version: 11.8
cudnn version:8700
transform = transforms.Compose([transforms.ToTensor()])

# download the MNIST datasets
path = '~/datasets'
train_dataset = MNIST(path, transform=transform, download=True)
test_dataset  = MNIST(path, transform=transform, download=True)

# create train and test dataloaders
batch_size = 100
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /root/datasets/MNIST/raw/train-images-idx3-ubyte.gz
Extracting /root/datasets/MNIST/raw/train-images-idx3-ubyte.gz to /root/datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /root/datasets/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting /root/datasets/MNIST/raw/train-labels-idx1-ubyte.gz to /root/datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /root/datasets/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting /root/datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to /root/datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /root/datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting /root/datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to /root/datasets/MNIST/raw
100%|██████████| 9912422/9912422 [00:00<00:00, 136550644.38it/s]
100%|██████████| 28881/28881 [00:00<00:00, 119936330.52it/s]
100%|██████████| 1648877/1648877 [00:00<00:00, 204267696.39it/s]
100%|██████████| 4542/4542 [00:00<00:00, 2259849.20it/s]
image = next(iter(train_loader))

num_samples = 25
sample_images = [image[0][i+1,0] for i in range(num_samples)]

fig = plt.figure(figsize=(5, 5))
grid = ImageGrid(fig, 111, nrows_ncols=(5, 5), axes_pad=0.1)

for ax, im in zip(grid, sample_images):
    ax.imshow(im, cmap='gray')
    ax.axis('off')

plt.show()

class VAE(nn.Module):

    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=200, device=device):
        super(VAE, self).__init__()

        # encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, latent_dim),
            nn.LeakyReLU(0.2)
            )

        # latent mean and variance
        self.mean_layer = nn.Linear(latent_dim, 2)
        self.logvar_layer = nn.Linear(latent_dim, 2)

        # decoder
        self.decoder = nn.Sequential(
            nn.Linear(2, latent_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(latent_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
            )

    def encode(self, x):
        x = self.encoder(x)
        mean, logvar = self.mean_layer(x), self.logvar_layer(x)
        return mean, logvar

    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(device)
        z = mean + var*epsilon
        return z

    def decode(self, x):
        return self.decoder(x)

    def forward(self, x):
        mean, log_var = self.encode(x)
        z = self.reparameterization(mean, log_var)
        x_hat = self.decode(z)
        return x_hat, mean, log_var
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
def loss_function(x, x_hat, mean, log_var):
    reproduction_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')
    KLD = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())

    return reproduction_loss + KLD
x_dim = 784

def train(model, optimizer, epochs, device):
    model.train()
    for epoch in range(epochs):
        overall_loss = 0
        for batch_idx, (x, _) in enumerate(train_loader):
            x = x.view(batch_size, x_dim).to(device)

            optimizer.zero_grad()

            x_hat, mean, log_var = model(x)
            loss = loss_function(x, x_hat, mean, log_var)

            overall_loss += loss.item()

            loss.backward()
            optimizer.step()

        print("\tEpoch", epoch + 1, "\tAverage Loss: ", overall_loss/(batch_idx*batch_size))
    return overall_loss

train(model, optimizer, epochs=50, device=device)
    Epoch 1     Average Loss:  175.29515954324916
    Epoch 2     Average Loss:  157.1823385981845
    Epoch 3     Average Loss:  152.55317100766902
    Epoch 4     Average Loss:  149.62154792492697
    Epoch 5     Average Loss:  147.50248990831074
    Epoch 6     Average Loss:  145.86534660632304
    Epoch 7     Average Loss:  144.4231314560726
    Epoch 8     Average Loss:  143.31780508203778
    Epoch 9     Average Loss:  142.30168281771702
    Epoch 10    Average Loss:  141.49097752438962
    Epoch 11    Average Loss:  140.83598308378546
    Epoch 12    Average Loss:  140.17502820455968
    Epoch 13    Average Loss:  139.8000655063126
    Epoch 14    Average Loss:  139.30528065982367
    Epoch 15    Average Loss:  138.87654218619573
    Epoch 16    Average Loss:  138.48964087280885
    Epoch 17    Average Loss:  138.1048076383817
    Epoch 18    Average Loss:  137.74657612948664
    Epoch 19    Average Loss:  137.45947488979027
    Epoch 20    Average Loss:  137.23071929778797
    Epoch 21    Average Loss:  136.92694877204195
    Epoch 22    Average Loss:  136.66642131416944
    Epoch 23    Average Loss:  136.39156210872287
    Epoch 24    Average Loss:  136.2546355905676
    Epoch 25    Average Loss:  136.01997696355906
    Epoch 26    Average Loss:  135.8481364611592
    Epoch 27    Average Loss:  135.6641933985027
    Epoch 28    Average Loss:  135.42658317247495
    Epoch 29    Average Loss:  135.3036438660789
    Epoch 30    Average Loss:  135.2169284745409
    Epoch 31    Average Loss:  134.92463263968594
    Epoch 32    Average Loss:  134.92840655650042
    Epoch 33    Average Loss:  134.6508390122861
    Epoch 34    Average Loss:  134.68385046040277
    Epoch 35    Average Loss:  134.47515039714628
    Epoch 36    Average Loss:  134.2281339178579
    Epoch 37    Average Loss:  134.29903941464943
    Epoch 38    Average Loss:  134.0389246074186
    Epoch 39    Average Loss:  133.95940814443344
    Epoch 40    Average Loss:  133.74494802535474
    Epoch 41    Average Loss:  133.6946707076899
    Epoch 42    Average Loss:  133.62954862922578
    Epoch 43    Average Loss:  133.38309982783807
    Epoch 44    Average Loss:  133.46671858696786
    Epoch 45    Average Loss:  133.28396186026188
    Epoch 46    Average Loss:  133.149587479784
    Epoch 47    Average Loss:  133.28342907123852
    Epoch 48    Average Loss:  132.93941725792988
    Epoch 49    Average Loss:  132.92503653550187
    Epoch 50    Average Loss:  132.90003576925085
7960712.142578125
def generate_digit(mean, var):
    z_sample = torch.tensor([[mean, var]], dtype=torch.float).to(device)
    print(z_sample)
    x_decoded = model.decode(z_sample)
    digit = x_decoded.detach().cpu().reshape(28, 28)  # reshape vector to 2d array
    plt.imshow(digit, cmap='gray')
    plt.axis('off')
    plt.show()

generate_digit(0.7,-1.0)
tensor([[ 0.7000, -1.0000]], device='cuda:0')

def plot_latent_space(model, scale=1.0, n=25, digit_size=28, figsize=15):
    # display a n*n 2D manifold of digits
    figure = np.zeros((digit_size * n, digit_size * n))

    # construct a grid
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = torch.tensor([[xi, yi]], dtype=torch.float).to(device)
            x_decoded = model.decode(z_sample)
            digit = x_decoded[0].detach().cpu().reshape(digit_size, digit_size)
            figure[i * digit_size : (i + 1) * digit_size, j * digit_size : (j + 1) * digit_size,] = digit

    plt.figure(figsize=(figsize, figsize))
    plt.title('VAE Latent Space Visualization')
    start_range = digit_size // 2
    end_range = n * digit_size + start_range
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("mean, z [0]")
    plt.ylabel("var, z [1]")
    plt.imshow(figure, cmap="Greys_r")
    plt.show()

plot_latent_space(model)