目录
简单介绍:
? ? ? ? 通过一个 encoder 将图片映射到标准分布(均值和方差),从映射的标准分布中随机采样一个样本,通过 decoder 重构图片;
? ? ? ? 计算源图片和重构图片之间的损失,具体损失函数的推导可以参考:变分自编码器(VAE)
简单的VAE模型搭建:
????????Encoder 返回映射标准分布的均值和方差,从标准分布中随机采样,利用Decoder重构图片;
class VAE(nn.Module):
def __init__(self, input_dim = 784, h_dim = 400, z_dim = 20):
super(VAE, self).__init__()
self.fc1 = nn.Linear(input_dim, h_dim) # 28*28 → 784
self.fc21 = nn.Linear(h_dim, z_dim) # 均值
self.fc22 = nn.Linear(h_dim, z_dim) # 标准差
self.fc3 = nn.Linear(z_dim, h_dim)
self.fc4 = nn.Linear(h_dim, input_dim) # 784 → 28*28
self.input_dim = input_dim
def encode(self, x):
h1 = F.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1) # 均值、标准差
def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
# z = mu + eps*std
return mu + eps*std
def decode(self, z):
h3 = F.relu(self.fc3(z))
# sigmoid: 0-1 之间,后边会用到 BCE loss 计算重构 loss(reconstruction loss)
return torch.sigmoid(self.fc4(h3))
def forward(self, x):
mu, logvar = self.encode(x.view(-1, self.input_dim))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
简单的损失计算:
x_reconst, mu, log_var = self.model(x)
# Compute reconstruction loss and kl divergence
reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
# Backprop and optimize
loss = reconst_loss + kl_div
完整可运行代码参考:VAE代码实例