变分自编码器VAE实现MNIST数据集生成by Pytorch

参考:
原文:Auto-Encoding Variational Bayes

Recap自编码器:

自编码器中,需要输入一个原始图片,原始图片经过编码之后得到一个隐向量,隐向量解码产生原图片对应的图片。在这种情况下,只能生成原图片对应的图片而无法任意生成新的图片,因为隐向量都是原始图片确定的。

变分自编码器VAE

引入变分自编码器(Variational autoencoder)可以在遵循某一分布下随机产生一些隐向量来生成与原始图片不相同的图片,而不需要预先给定原始图片。为了达到这个目的,需要在编码过程增加限制,使得生成的隐向量能够粗略地遵循标准正态分布。
实际情况下,需要在模型的准确率与隐向量服从标准正态分布之间做一个权衡。模型的准确率就是指解码器生成的图片与原图片的相似程度;隐向量分布采用KL散度来衡量与标准正态分布之间的误差。两部分误差之和作为总体的误差来优化。

这里VAE使用了重参数化这个技巧来KL散度的计算。编码器不再是生成一个隐向量,而是生成正态分布的均值和标准差(若是多维正态分布,会有多个均值和标准差),然后根据这两个统计量下的分布抽样生成隐含向量。因为我们想要使得隐含向量服从标准正态分布,即均值为0,标准差为1,通过优化KL散度来使得分布逼近标准正态分布。
在这里插入图片描述
同理,解码器阶段,根据给定的隐变量 z z z来生成多元正态分布的均值 μ x 1 , μ x 2 \mu_{x1},\mu_{x2} μx1,μx2标准差 σ z 1 , σ z 2 \sigma_{z1},\sigma_{z2} σz1,σz2,根据该分布抽样生成数值 x 1 , x 2 x_{1},x_{2} x1,x2
在这里插入图片描述
将编码器和解码器综合在一起:
在这里插入图片描述
设编码器的概率分布为 q ϕ ( z ∣ x ) q_{\phi }(z|x) qϕ(zx),解码器的概率分布为 p θ ( x ∣ z ) p_{\theta}(x|z) pθ(xz)

误差推导:
在这里插入图片描述在这里插入图片描述
这里需要用到正态分布之间的KL散度,直接给出公式,推导见参考文献:
单元正态分布:
K L ( μ 1 , μ 2 , σ 1 , σ 2 ) = log ⁡ σ 2 σ 1 + σ 1 2 + ( μ 1 − μ 2 ) 2 2 σ 2 − 1 2 KL(\mu_{1},\mu_{2},\sigma_{1},\sigma_{2})=\log{\frac{\sigma_{2}}{\sigma_{1}}}+\frac{\sigma_{1}^{2}+(\mu_{1}-\mu_{2})^{2}}{2\sigma^{2}}-\frac{1}{2} KL(μ1,μ2,σ1,σ2)=logσ1σ2+2σ2σ12+(μ1μ2)221
n n n元正态分布:
K L ( μ 1 , μ 2 , Σ 1 , Σ 2 ) = 1 2 [ log ⁡ det ⁡ Σ 2 det ⁡ Σ 1 − n + t r ( Σ 2 − 1 Σ 1 ) + ( μ 2 − μ 1 ) T Σ 2 − 1 ( μ 2 − μ 1 ) ] KL(\bm{\mu}_{1},\bm{\mu}_{2},\Sigma_{1},\Sigma_{2})=\frac{1}{2}[\log{\frac{\det{\Sigma_{2}}}{\det{\Sigma_{1}}}}-n+tr(\Sigma_{2}^{-1}\Sigma_{1})+(\bm{\mu}_{2}-\bm{\mu}_{1})^{T}\Sigma_{2}^{-1}(\bm{\mu}_{2}-\bm{\mu}_{1})] KL(μ1,μ2,Σ1,Σ2)=21[logdetΣ1detΣ2n+tr(Σ21Σ1)+(μ2μ1)TΣ21(μ2μ1)]
由此可得到:
− D K L ( q ( z ∣ x i ) ∣ ∣ p ( z ) ) = 1 2 ∑ j = 1 J ( 1 + log ⁡ ( ( σ z j ( i ) ) 2 ) − ( μ z j ( i ) ) 2 − ( σ z j ( i ) ) 2 ) -D_{KL}(q(z|x^{i})||p(z))=\frac{1}{2}\sum_{j=1}^{J}(1+\log{((\sigma_{zj}^{(i)}})^{2})-(\mu_{zj}^{(i)})^{2}-(\sigma_{zj}^{(i)})^{2}) DKL(q(zxi)p(z))=21j=1J(1+log((σzj(i))2)(μzj(i))2(σzj(i))2)
通过从分布 q ( z ∣ x ( i ) ) q(z|x^{(i)}) q(zx(i))抽样来近似 E q ( z ∣ x ( i ) ) \mathbb{E}_{q(z|x^{(i)})} Eq(zx(i))。抽样 L L L次,得到 z ( i , l ) , l = 1 , 2 , . . . , L z^{(i,l)},l=1,2,...,L z(i,l),l=1,2,...,L, L L L通常非常小,通常取1
E q ( z ∣ x ( i ) ) ( log ⁡ ( p ( x ( i ) ∣ z ) ) ) = 1 L ∑ l = 1 L log ⁡ p ( x ( i ) ∣ z ( i , l ) ) = 1 L ∑ l = 1 L ∑ j = 1 D 1 2 log ⁡ σ x j 2 + ( x j i − μ x j ) 2 σ x j 2 \mathbb{E}_{q(z|x^{(i)})}(\log{(p(x^{(i)}|z))})=\frac{1}{L}\sum_{l=1}^{L}\log{p(x^{(i)}|z^{(i,l)})}=\frac{1}{L}\sum_{l=1}^{L}\sum_{j=1}^{D}\frac{1}{2}\log{\sigma_{xj}^{2}}+\frac{(x^{i}_{j}-\mu_{xj})}{2\sigma_{xj}^{2}} Eq(zx(i))(log(p(x(i)z)))=L1l=1Llogp(x(i)z(i,l))=L1l=1Lj=1D21logσxj2+2σxj2(xjiμxj)
其中 D D D代表样本 x ( i ) x^{(i)} x(i)的维度,每个数 x j ( i ) x^{(i)}_{j} xj(i)都对应一个正态分布 N ( μ x j , σ x j 2 ) \mathcal{N}(\mu_{xj},\sigma^{2}_{xj}) N(μxj,σxj2)

Pytorch实现MNIST数据集生成

在本实例中,生成器最后的输出不是均值和方差,而是图片向量。所以重构误差看做为生成图片和原始图片的误差。在这里使用的是binary cross entropy,即BCE误差,因为图片中的值都是(0,1)。当然也可以使用平方误差。
B C E = − ∑ i = 1 n ∑ j = 1 d [ y j ( i ) log ⁡ x j ( i ) + ( 1 − y j ( i ) ) log ⁡ ( 1 − x j ( i ) ) ] BCE=-\sum_{i=1}^{n}\sum_{j=1}^{d}[y^{(i)}_{j}\log{x_{j}^{(i)}}+(1-y^{(i)}_{j})\log{(1-x_{j}^{(i)})}] BCE=i=1nj=1d[yj(i)logxj(i)+(1yj(i))log(1xj(i))]

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image


def loss_function(recon_x, x, mu, logvar):
    """
    :param recon_x: generated image
    :param x: original image
    :param mu: latent mean of z
    :param logvar: latent log variance of z
    """
    BCE_loss = nn.BCELoss(reduction='sum')
    reconstruction_loss = BCE_loss(recon_x, x)
    KL_divergence = -0.5 * torch.sum(1+logvar-torch.exp(logvar)-mu**2)
    #KLD_ele = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    #KLD = torch.sum(KLD_ele).mul_(-0.5)
    print(reconstruction_loss, KL_divergence)

    return reconstruction_loss + KL_divergence


class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(784, 400)
        self.fc2_mean = nn.Linear(400, 20)
        self.fc2_logvar = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc2_mean(h1), self.fc2_logvar(h1)

    def reparametrization(self, mu, logvar):
        # sigma = 0.5*exp(log(sigma^2))= 0.5*exp(log(var))
        std = 0.5 * torch.exp(logvar)
        # N(mu, std^2) = N(0, 1) * std + mu
        z = torch.randn(std.size()) * std + mu
        return z

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrization(mu, logvar)
        return self.decode(z), mu, logvar


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)

vae = VAE()
optimizer = torch.optim.Adam(vae.parameters(), lr=0.0003)

# Training
def train(epoch):
    vae.train()
    all_loss = 0.
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to('cpu'), targets.to('cpu')
        real_imgs = torch.flatten(inputs, start_dim=1)

        # Train Discriminator
        gen_imgs, mu, logvar = vae(real_imgs)
        loss = loss_function(gen_imgs, real_imgs, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        all_loss += loss.item()
        print('Epoch {}, loss: {:.6f}'.format(epoch, all_loss/(batch_idx+1)))
        # Save generated images for every epoch
    fake_images = gen_imgs.view(-1, 1, 28, 28)
    save_image(fake_images, 'MNIST_FAKE/fake_images-{}.png'.format(epoch + 1))



for epoch in range(20):
    train(epoch)

torch.save(vae.state_dict(), './vae.pth')

运行上述代码20轮所产生的的图片:
在这里插入图片描述
VAE和自编码器有一样的缺点,根据均平方误差计算的图片会比较模糊,之后出现的对抗生成网络则解决了这个问题。

相关推荐
©️2020 CSDN 皮肤主题: 编程工作室 设计师:CSDN官方博客 返回首页