pytorch 实现git地址
论文地址:Neural Discrete Representation Learning
encoder
将图片通过encoder得到图片点表征
如输入shape [32,3,32,32]
通过encoder后输出 [32,64,8,8] (其中64位输出维度)
量化码本
先随机构建一个码本,维度与encoder保持一致
这里定义512个离散特征,码本shap 为[512,64]
encoder 码本中向量最近查找
encoder输出shape [32,64,8,8], 经过维度变换 shape [3288,64]
在码本中找到最相近的向量,并替换为码本中相似向量
输出shape [3288,64],维度变换后,shape 为 [32,64,8,8]
decoder
将上述数据,喂给decoder,还原原始图片
loss
loss 包含两部分
a . encoder输出和码本向量接近
b. 重构loss,重构图片与原图片接近
encoder是常用的图片卷积神经网络
输入x shape [32,3,32,32]
输出 shape [32,128,8,8]
def __init__(self, in_dim, h_dim, n_res_layers, res_h_dim):
super(Encoder, self).__init__()
kernel = 4
stride = 2
self.conv_stack = nn.Sequential(
nn.Conv2d(in_dim, h_dim // 2, kernel_size=kernel,
stride=stride, padding=1),
nn.ReLU(),
nn.Conv2d(h_dim // 2, h_dim, kernel_size=kernel,
stride=stride, padding=1),
nn.ReLU(),
nn.Conv2d(h_dim, h_dim, kernel_size=kernel-1,
stride=stride-1, padding=1),
ResidualStack(
h_dim, h_dim, res_h_dim, n_res_layers)
)
def forward(self, x):
return self.conv_stack(x)
decoder层比较简单,与encoder层相反
输入x shape 【32,64,8,8】
输出shape [32,3,32,32]
class Decoder(nn.Module):
"""
This is the p_phi (x|z) network. Given a latent sample z p_phi
maps back to the original space z -> x.
Inputs:
- in_dim : the input dimension
- h_dim : the hidden layer dimension
- res_h_dim : the hidden dimension of the residual block
- n_res_layers : number of layers to stack
"""
def __init__(self, in_dim, h_dim, n_res_layers, res_h_dim):
super(Decoder, self).__init__()
kernel = 4
stride = 2
self.inverse_conv_stack = nn.Sequential(
nn.ConvTranspose2d(
in_dim, h_dim, kernel_size=kernel-1, stride=stride-1, padding=1),
ResidualStack(h_dim, h_dim, res_h_dim, n_res_layers),
nn.ConvTranspose2d(h_dim, h_dim // 2,
kernel_size=kernel, stride=stride, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(h_dim//2, 3, kernel_size=kernel,
stride=stride, padding=1)
)
def forward(self, x):
return self.inverse_conv_stack(x)
损失函数为重构损失和embedding损失之和
for i in range(args.n_updates):
(x, _) = next(iter(training_loader))
x = x.to(device)
optimizer.zero_grad()
embedding_loss, x_hat, perplexity = model(x)
recon_loss = torch.mean((x_hat - x)**2) / x_train_var
loss = recon_loss + embedding_loss
loss.backward()
optimizer.step()