图像生成之pix2pix

发布时间:2024年01月19日

简要介绍

利用GAN做image translation的开山之作:Image-to-Image Translation with Conditional Adversarial Networks

? 自然语言处理领域有text2text,所以自然而然图像领域也有image2image。作者提出pix2pix,即利用CGAN实现一个解决各种image2image任务(语义分割,边缘检测、风格迁移等等)的通用解决方案和框架(前人工作中只应用于特定领域)。

CGAN的结构化损失

? 在此之前的单一领域的image2image中的都是逐像素分类或回归(per-pixel classification or regression),每个输出像素被认为有条件地独立于所有其他像素,是非结构化的。

? 这时就出现了一个问题,通用领域中的image2image任务怎么设计损失函数呢?作者在文中指出:CGAN的不同之处在于损失是可学习的,并且理论上可以惩罚输出和目标之间任何的可能不同的结构(CGAN的损失函数就很适合通用领域嘛!我个人理解是因为他的训练模式是相互博弈,所以损失就是基于全局如果不是像素的损失)。忘记了的可以去看一下这篇CGAN的讲解
L c G A N ( G , D ) = E x , y [ log ? D ( x , y ) ] + E x , z [ log ? ( 1 ? D ( x , G ( x , z ) ) ] \begin{aligned} \mathcal{L}_{c G A N}(G, D)= & \mathbb{E}_{x, y}[\log D(x, y)]+ \mathbb{E}_{x, z}[\log (1-D(x, G(x, z))] \end{aligned} LcGAN?(G,D)=?Ex,y?[logD(x,y)]+Ex,z?[log(1?D(x,G(x,z))]?
? 前人工作表明将GAN目标函数与更传统的损失函数(L1loss:MAE、L2loss:MSE)混合是相得益彰的。作者认为L1 loss 可以恢复图像的低频部分,而GAN loss可以恢复图像的高频部分(图像中的边缘等)。生成器的任务不仅是欺骗鉴别器,而且要在L2意义上接近真值输出。作者指出使用L1而不是L2,因为L1可以减少模糊。
L L 1 ( G ) = E x , y , z [ ∥ y ? G ( x , z ) ∥ 1 ] \mathcal{L}_{L 1}(G)=\mathbb{E}_{x, y, z}\left[\|y-G(x, z)\|_1\right] LL1?(G)=Ex,y,z?[y?G(x,z)1?]

G ? = arg ? min ? G max ? D L c G A N ( G , D ) + λ L L 1 ( G ) ? G^*=\arg \min _G \max _D \mathcal{L}_{c G A N}(G, D)+\lambda \mathcal{L}_{L 1}(G)\ G?=argGmin?Dmax?LcGAN?(G,D)+λLL1?(G)?

不同的损失导致不同的结果质量

? 图4 不同的损失导致不同的结果质量

生成器U-net与鉴别器conv PatchGAN

? PatchGAN学习图像特征的单位是patch而不是单个像素,也就是说把图像等分成patch,分别判断每个patch的真假,最后再取平均;生成器与判别器都是convolution-BatchNorm-ReLu 的结构。

? 对于许多image2image问题,在输入和输出之间存在大量共享的低级信息,并且希望通过网络直接传送这些信息。给生成器提供一种绕过上采样与下采样产生的信息瓶颈的方法:跳过连接(skip connections),遵循“U-Net”的一般形状。低级信息在图像生成中通常指图像的基本结构、边缘、纹理等底层特征,而跳过连接(skip connections)的作用是确保这些低级信息能够在生成过程中传递并被有效利用,以改善生成图像的质量。这种设计有助于解决传统上采样和下采样操作可能引入的信息瓶颈问题。

 Encoder-decoder模型、U-net模型结构

? 图5 Encoder-decoder模型、U-net模型结构图示

缺点

  1. 一对一映射结构: 作者指出,pix2pix 模型实际上学到的是输入和输出之间的一对一映射关系。具体而言,模型的任务是将输入(比如轮廓图)映射到对应的输出(真实图)。这种结构的局限性在于,它主要适用于对给定输入的重建,而不太能够处理与训练集数据有较大差异的情况。
  2. 对ground truth的重建: 文中指出,pix2pix 主要是对于 ground truth 的重建,即尽量使生成的图像与真实图像一致。这使得模型在处理与训练集中差异较大的输入时表现不佳,因为它缺乏对于更广泛数据变化的适应能力。
  3. 应用范围有限: 由于模型学到的是一对一映射,其应用范围受到限制。当输入数据与训练集中的数据相差较大时,生成的结果可能失去意义。因此,为了获得有意义的生成结果,需要确保训练集涵盖各种类型的数据,以便模型更好地泛化到不同情况。
  4. 颜色生成不满意: 通过示例,作者展示了在输入与训练集中不存在的轮廓图时,生成图像的颜色可能不令人满意。这说明 pix2pix 在处理未见过的数据时可能无法有效地保持颜色一致性,表现出一定的泛化不足。

改进

  • CycleGAN

模型代码

import torch.nn as nn
import torch.nn.functional as F
import torch

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

##############################
#           U-NET
##############################

class UNetDown(nn.Module):
    def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
        super(UNetDown, self).__init__()
        layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_size))
        layers.append(nn.LeakyReLU(0.2))
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

class UNetUp(nn.Module):
    def __init__(self, in_size, out_size, dropout=0.0):
        super(UNetUp, self).__init__()
        layers = [
            nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(out_size),
            nn.ReLU(inplace=True),
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1)
        return x

class GeneratorUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(GeneratorUNet, self).__init__()

        self.down1 = UNetDown(in_channels, 64, normalize=False)
        self.down2 = UNetDown(64, 128)
        self.down3 = UNetDown(128, 256)
        self.down4 = UNetDown(256, 512, dropout=0.5)
        self.down5 = UNetDown(512, 512, dropout=0.5)
        self.down6 = UNetDown(512, 512, dropout=0.5)
        self.down7 = UNetDown(512, 512, dropout=0.5)
        self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)

        self.up1 = UNetUp(512, 512, dropout=0.5)
        self.up2 = UNetUp(1024, 512, dropout=0.5)
        self.up3 = UNetUp(1024, 512, dropout=0.5)
        self.up4 = UNetUp(1024, 512, dropout=0.5)
        self.up5 = UNetUp(1024, 256)
        self.up6 = UNetUp(512, 128)
        self.up7 = UNetUp(256, 64)

        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(128, out_channels, 4, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        # U-Net generator with skip connections from encoder to decoder
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)
        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)
        return self.final(u7)

##############################
#        Discriminator
##############################

class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalization=True):
            """返回每个鉴别器块的下采样层"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels * 2, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1, bias=False)
        )

    def forward(self, img_A, img_B):
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)

文章来源:https://blog.csdn.net/Freeandeasy_roni/article/details/135671906
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。