AIGC笔记--VAE模型的搭建

发布时间:2024年01月16日

目录

1--VAE模型

2--代码实例


1--VAE模型

简单介绍:

? ? ? ? 通过一个 encoder 将图片映射到标准分布(均值和方差),从映射的标准分布中随机采样一个样本,通过 decoder 重构图片;

? ? ? ? 计算源图片和重构图片之间的损失,具体损失函数的推导可以参考:变分自编码器(VAE)

2--代码实例

简单的VAE模型搭建:

????????Encoder 返回映射标准分布的均值和方差,从标准分布中随机采样,利用Decoder重构图片;

class VAE(nn.Module):
    def __init__(self, input_dim = 784, h_dim = 400, z_dim = 20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(input_dim, h_dim) # 28*28 → 784
        self.fc21 = nn.Linear(h_dim, z_dim) # 均值
        self.fc22 = nn.Linear(h_dim, z_dim) # 标准差
        self.fc3 = nn.Linear(z_dim, h_dim)
        self.fc4 = nn.Linear(h_dim, input_dim) # 784 → 28*28
        self.input_dim = input_dim

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1) # 均值、标准差

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        # z = mu + eps*std
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        # sigmoid: 0-1 之间,后边会用到 BCE loss 计算重构 loss(reconstruction loss)
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, self.input_dim))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

简单的损失计算:

x_reconst, mu, log_var = self.model(x)
# Compute reconstruction loss and kl divergence
reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
# Backprop and optimize
loss = reconst_loss + kl_div

完整可运行代码参考:VAE代码实例

文章来源:https://blog.csdn.net/weixin_43863869/article/details/135613984
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。