利用GAN做image translation的开山之作:Image-to-Image Translation with Conditional Adversarial Networks
? 自然语言处理领域有text2text,所以自然而然图像领域也有image2image。作者提出pix2pix,即利用CGAN实现一个解决各种image2image任务(语义分割,边缘检测、风格迁移等等)的通用解决方案和框架(前人工作中只应用于特定领域)。
? 在此之前的单一领域的image2image中的都是逐像素分类或回归(per-pixel classification or regression),每个输出像素被认为有条件地独立于所有其他像素,是非结构化的。
? 这时就出现了一个问题,通用领域中的image2image任务怎么设计损失函数呢?作者在文中指出:CGAN的不同之处在于损失是可学习的,并且理论上可以惩罚输出和目标之间任何的可能不同的结构(CGAN的损失函数就很适合通用领域嘛!我个人理解是因为他的训练模式是相互博弈,所以损失就是基于全局如果不是像素的损失)。忘记了的可以去看一下这篇CGAN的讲解。
L
c
G
A
N
(
G
,
D
)
=
E
x
,
y
[
log
?
D
(
x
,
y
)
]
+
E
x
,
z
[
log
?
(
1
?
D
(
x
,
G
(
x
,
z
)
)
]
\begin{aligned} \mathcal{L}_{c G A N}(G, D)= & \mathbb{E}_{x, y}[\log D(x, y)]+ \mathbb{E}_{x, z}[\log (1-D(x, G(x, z))] \end{aligned}
LcGAN?(G,D)=?Ex,y?[logD(x,y)]+Ex,z?[log(1?D(x,G(x,z))]?
? 前人工作表明将GAN目标函数与更传统的损失函数(L1loss:MAE、L2loss:MSE)混合是相得益彰的。作者认为L1 loss 可以恢复图像的低频部分,而GAN loss可以恢复图像的高频部分(图像中的边缘等)。生成器的任务不仅是欺骗鉴别器,而且要在L2意义上接近真值输出。作者指出使用L1而不是L2,因为L1可以减少模糊。
L
L
1
(
G
)
=
E
x
,
y
,
z
[
∥
y
?
G
(
x
,
z
)
∥
1
]
\mathcal{L}_{L 1}(G)=\mathbb{E}_{x, y, z}\left[\|y-G(x, z)\|_1\right]
LL1?(G)=Ex,y,z?[∥y?G(x,z)∥1?]
G ? = arg ? min ? G max ? D L c G A N ( G , D ) + λ L L 1 ( G ) ? G^*=\arg \min _G \max _D \mathcal{L}_{c G A N}(G, D)+\lambda \mathcal{L}_{L 1}(G)\ G?=argGmin?Dmax?LcGAN?(G,D)+λLL1?(G)?
? 图4 不同的损失导致不同的结果质量
? PatchGAN学习图像特征的单位是patch而不是单个像素,也就是说把图像等分成patch,分别判断每个patch的真假,最后再取平均;生成器与判别器都是convolution-BatchNorm-ReLu 的结构。
? 对于许多image2image问题,在输入和输出之间存在大量共享的低级信息,并且希望通过网络直接传送这些信息。给生成器提供一种绕过上采样与下采样产生的信息瓶颈的方法:跳过连接(skip connections),遵循“U-Net”的一般形状。低级信息在图像生成中通常指图像的基本结构、边缘、纹理等底层特征,而跳过连接(skip connections)的作用是确保这些低级信息能够在生成过程中传递并被有效利用,以改善生成图像的质量。这种设计有助于解决传统上采样和下采样操作可能引入的信息瓶颈问题。
? 图5 Encoder-decoder模型、U-net模型结构图示
import torch.nn as nn
import torch.nn.functional as F
import torch
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
##############################
# U-NET
##############################
class UNetDown(nn.Module):
def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
super(UNetDown, self).__init__()
layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
if normalize:
layers.append(nn.InstanceNorm2d(out_size))
layers.append(nn.LeakyReLU(0.2))
if dropout:
layers.append(nn.Dropout(dropout))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class UNetUp(nn.Module):
def __init__(self, in_size, out_size, dropout=0.0):
super(UNetUp, self).__init__()
layers = [
nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
nn.InstanceNorm2d(out_size),
nn.ReLU(inplace=True),
]
if dropout:
layers.append(nn.Dropout(dropout))
self.model = nn.Sequential(*layers)
def forward(self, x, skip_input):
x = self.model(x)
x = torch.cat((x, skip_input), 1)
return x
class GeneratorUNet(nn.Module):
def __init__(self, in_channels=3, out_channels=3):
super(GeneratorUNet, self).__init__()
self.down1 = UNetDown(in_channels, 64, normalize=False)
self.down2 = UNetDown(64, 128)
self.down3 = UNetDown(128, 256)
self.down4 = UNetDown(256, 512, dropout=0.5)
self.down5 = UNetDown(512, 512, dropout=0.5)
self.down6 = UNetDown(512, 512, dropout=0.5)
self.down7 = UNetDown(512, 512, dropout=0.5)
self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)
self.up1 = UNetUp(512, 512, dropout=0.5)
self.up2 = UNetUp(1024, 512, dropout=0.5)
self.up3 = UNetUp(1024, 512, dropout=0.5)
self.up4 = UNetUp(1024, 512, dropout=0.5)
self.up5 = UNetUp(1024, 256)
self.up6 = UNetUp(512, 128)
self.up7 = UNetUp(256, 64)
self.final = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(128, out_channels, 4, padding=1),
nn.Tanh(),
)
def forward(self, x):
# U-Net generator with skip connections from encoder to decoder
d1 = self.down1(x)
d2 = self.down2(d1)
d3 = self.down3(d2)
d4 = self.down4(d3)
d5 = self.down5(d4)
d6 = self.down6(d5)
d7 = self.down7(d6)
d8 = self.down8(d7)
u1 = self.up1(d8, d7)
u2 = self.up2(u1, d6)
u3 = self.up3(u2, d5)
u4 = self.up4(u3, d4)
u5 = self.up5(u4, d3)
u6 = self.up6(u5, d2)
u7 = self.up7(u6, d1)
return self.final(u7)
##############################
# Discriminator
##############################
class Discriminator(nn.Module):
def __init__(self, in_channels=3):
super(Discriminator, self).__init__()
def discriminator_block(in_filters, out_filters, normalization=True):
"""返回每个鉴别器块的下采样层"""
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
if normalization:
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*discriminator_block(in_channels * 2, 64, normalization=False),
*discriminator_block(64, 128),
*discriminator_block(128, 256),
*discriminator_block(256, 512),
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(512, 1, 4, padding=1, bias=False)
)
def forward(self, img_A, img_B):
img_input = torch.cat((img_A, img_B), 1)
return self.model(img_input)