I am new to pytorch and trying to implement a VAE for MNIST data. When I try to train my model, it appears that the model forces mu and logvar to zero (or something very close to zero) independent of the input. In a way it appears that it is failing to take into account the MSE part of the loss function, but I don't understand why.
Here's the complete code I am using:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import nn
import torch.nn.functional as F
batch_size = 32
loc_data = 'MNIST'
transformations = transforms.ToTensor()
mnist_train = datasets.MNIST(loc_data, train=True, download=True, transform = transformations)
mnist_test = datasets.MNIST(loc_data, train=False, download=True, transform = transformations)
train_loader = DataLoader(mnist_train,
batch_size=batch_size,
drop_last = True,
shuffle=True)
test_loader = DataLoader(mnist_test,
batch_size=batch_size,
drop_last = True,
shuffle=True)
class Encoder(nn.Module):
def init(self, latent_dim=10):
super(Encoder, self).init()
self.latent_dim = latent_dim
self._encoder = nn.Sequential(nn.Linear(in_features = 2828, out_features = 512),
nn.ReLU(),
nn.Linear(in_features = 512, out_features = 2latent_dim)
)
def forward(self, x):
x = torch.reshape(self._encoder.forward(x), (-1, 2, self.latent_dim))
mu, logvar = x[:,0,:], x[:,1,:]
return mu, logvar
class Decoder(nn.Module):
def init(self, latent_dim=10):
super(Decoder, self).init()
self.latent_dim = latent_dim
self._decoder = nn.Sequential(nn.Linear(in_features = latent_dim, out_features = 512),
nn.ReLU(),
nn.Linear(in_features = 512, out_features = 28*28),
nn.Sigmoid())
def forward(self,x):
return self._decoder.forward(x)
def sample(mu, logvar):
z = torch.randn_like(mu)
return mu + torch.mul(torch.exp(0.5*logvar), z)
def vae_loss(x, x_hat, mu, logvar):
mse = (x - x_hat).pow(2).sum()/(x.shape[0]1.0)
KL_loss = 0.5torch.sum(-1 + torch.pow(mu,2) - logvar + torch.exp(logvar))
return torch.add(mse, KL_loss)
def train(encoder, decoder, train_loader, optimizer, num_epochs = 10):
encoder.train()
decoder.train()
for ii in range(num_epochs):
print("Epoch {}".format(ii))
for jj, (x, y) in enumerate(train_loader):
x = torch.reshape(x, (-1,28*28))
x.to(device)
_mu, _logvar = encoder.forward(x)
_z = sample(_mu, _logvar)
x_hat = decoder.forward(_z) #.reshape((-1,28,28))
optimizer.zero_grad()
loss = vae_loss(x, x_hat, _mu, _logvar)
loss.backward()
optimizer.step()
if jj % 100 == 0:
print(loss)
return loss
latent_dim = 20
encoder = Encoder(latent_dim)
decoder = Decoder(latent_dim)
params = list(encoder.parameters())+list(decoder.parameters())
optimizer = optim.Adam(params, lr=1e-2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train(encoder, decoder, train_loader, optimizer, num_epochs = 1)
when I try to probe the mu or logvar for some test data, it seems that the result is almost identically zero.