上一篇文章中我们说到了GAN的数学解释
min
?
G
max
?
D
V
(
D
,
G
)
=
E
x
~
p
data
(
x
)
[
log
?
D
(
x
)
]
+
E
z
~
p
z
(
z
)
[
log
?
(
1
?
D
(
G
(
z
)
)
)
]
=
?
log
?
4
+
2
J
S
D
(
p
data
∥
p
g
)
≥
?
log
?
4
,
where?
[
p
d
a
t
a
=
p
g
]
\min_{G} \max_{D} V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_{z}(z)}[\log(1 - D(G(z)))]\\ = -\log 4 + 2JSD(p_{\text{data}} \parallel p_g)\\ \geq -\log 4, \quad \text{where } [p_{data} = p_g]
Gmin?Dmax?V(D,G)=Ex~pdata?(x)?[logD(x)]+Ez~pz?(z)?[log(1?D(G(z)))]=?log4+2JSD(pdata?∥pg?)≥?log4,where?[pdata?=pg?]
下面补充一下GAN的伪代码
Pseudo Code
上面这段伪码先训练鉴别器(Discriminator, D)k次,再训练生成器(Generator, G)。
对于每次训练迭代:
D
,进行 k 次更新:
D
(即最大化 log(D(x)) 和 log(1-D(G(z)))) 。G
:
G
(即最小化 log(1-D(G(z))))。Q1: 为什么是上升和下降(ascending & descending)?
A1: 这是因为在GAN的框架中,鉴别器D
的目标是最大化它正确分类真实和生成样本的能力,这可以通过最大化 log(D(x)) 和 log(1-D(G(z))) 来实现,这被称为上升(ascending)。相反,生成器G
的目标是最小化鉴别器D
正确识别其生成的样本的能力,这可以通过最小化 log(1-D(G(z))) 来实现,这被称为下降(descending)。
Q2: 生成器
G
的更新有什么问题吗?
A2: 当 D(G(z)) 趋向于0(即在开始时,鉴别器D
很容易识别出生成的样本是假的),log(1-D(G(z))) 的梯度会消失,这会使得生成器G
的训练变得非常缓慢。为了解决这个问题,可以使用 -log(D(G(z))) 来代替 log(1-D(G(z))),因为前者在 D(G(z)) 小的时候梯度较大,有利于生成器G
的训练。这实际上就是将假图的label置为了1。
WGAN是一种改进的GAN,旨在解决原始GAN训练中的一些问题,如梯度消失和模式崩溃。
还记得我们上一篇文章最后提到的那个还有点严重的问题吗?2017年的一篇论文彻底说破了这个问题的原因。
下面我们一步一步来说
如果鉴别器 ( D ) 达到最优 ( D* ),那么:
定义1(Transversality):③ 两个流形的切空间之和等于目标空间的切空间(即分布相交的情况)。
定义2(Perfectly Align):④ 两个流形的切空间之和不等于目标空间的切空间(即分布相切的情况)。
引理2:⑤ 如果两个分布的支撑集完美对齐,则它们的交集的概率为1。
引理3:⑥ 如果 M 和 Pr 的交集是 M 和 Pg 的交集的低维流形,那么这个交集在 M 和 Pg 中的测度为0。
定理2:⑦ 如果 Pr 和 Pg 的支撑集相交,且交集的测度为0,那么鉴别器也能达到最优。
定理3:⑧ 如果 Pr 和 Pg 这两个分布是高维空间中的低维流形,并且它们的支撑集在 M 和 Pg 中有交集,且这个交集的测度为0,那么Jensen-Shannon Divergence(JSD)就会等于 log2。
至此,我们发现在最优判别器D*的情况下,JSD的值应该会是log2。
min
?
G
max
?
D
V
(
D
,
G
)
=
E
x
~
p
data
(
x
)
[
log
?
D
(
x
)
]
+
E
z
~
p
z
(
z
)
[
log
?
(
1
?
D
(
G
(
z
)
)
)
]
=
?
log
?
4
+
2
J
S
D
(
p
data
∥
p
g
)
≥
?
log
?
4
,
where?
[
p
d
a
t
a
=
p
g
]
\begin{align} \min_{G} \max_{D} V(D, G) &= \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_{z}(z)}[\log(1 - D(G(z)))]\\ &= -\log 4 + 2JSD(p_{\text{data}} \parallel p_g)\\ &\geq -\log 4, \quad \text{where } [p_{data} = p_g] \end{align}
Gmin?Dmax?V(D,G)?=Ex~pdata?(x)?[logD(x)]+Ez~pz?(z)?[log(1?D(G(z)))]=?log4+2JSD(pdata?∥pg?)≥?log4,where?[pdata?=pg?]??
然而,还记得我们最开始的想法吗?**我们希望通过p_data = p_g,使得JSD = 0,从而取到minmax函数的最小值-log4。**现在的事实却是——minmax永远取不到最小值-log4而只能取到0,因为如果需要最优判别器完美分别两个分布,此时的JSD会等于log2而不是0。
重新梳理一下:
问题的核心在于,当鉴别器
D
达到其最优状态D*
时,是否可以使生成器的分布 P_g 完全匹配真实数据的分布P_data。However,在积分中带入0我们发现
对于生成器的目标函数 f(G) ,如果鉴别器是完美的(即 P_data = P_g ),那么 JS散度(JSD)应该是0,并且 f(G) 应该是 -log4 。然而,如果 P_data 和 P_g 完全不相交(即完全不重叠),JSD会是 log2 ,导致 f(G) = 0 。
这意味着,**如果鉴别器达到最优,我们不会有任何梯度(即“没有损失”)**来指导生成器的训练。换句话说,最优的鉴别器将无法提供有关如何改进生成器的有用信息,因为它会对所有的生成样本都给出相同的响应,导致生成器无法从鉴别器的反馈中学习(随着D越来越好,G会越来越差)。
这个矛盾从原理上解释了为什么我们之前用原始GAN生成的质量这么差,因为在一段时间后它根本就没有学任何东西。
这里可能有人会有疑问:D*是一个理想状态,为什么能说明整个学习的过程质量差呢?
其实这里作者还补充了以下两种情况的数学推导
当采用原始公式
当采用愚弄法(将假图的label置为了1)
总之,原始GAN的两种方法都存在问题。
俗话说不破不立,Arjovsky 不光将GAN的“旧世界”给“破”了,还在他的第二篇论文中“立”了一个“新世界”——WGAN(Wasserstein Distance)。
“Wasserstein”其实是德语里的一个复合词,由“Wasser”(水)和“Stein”(石头)组成。
从结果反推过程,原始GAN的损失函数由于JS距离而出现了问题。因此,Wasserstein GAN通过使用不同的损失函数(即Wasserstein距离)解决了JSD=log2的问题,从而即使在最优鉴别器D*的情况下也能够提供有效的梯度给生成器。
我们先来看看Wasserstein Distance的公式
W
(
P
r
,
P
g
)
=
inf
?
γ
∈
Π
(
P
r
,
P
g
)
E
(
x
,
y
)
~
γ
[
∥
x
?
y
∥
]
W(P_r, P_g) = \inf_{\gamma \in \Pi(P_r, P_g)} \mathbb{E}_{(x,y)\sim \gamma} [\|x - y\|]
W(Pr?,Pg?)=γ∈Π(Pr?,Pg?)inf?E(x,y)~γ?[∥x?y∥]
inf
表示下确界,寻找使得期望值最小的联合分布γ
。
E_{(x,y)~γ}
表示对于联合分布γ
下的随机变量(x, y)
的期望值。
[|x - y|]
是x
和y
之间的范数,通常是欧几里得距离。
我们还是拆成几部分来看
① 其中,
E
(
x
,
y
)
~
γ
[
∥
x
?
y
∥
]
=
∫
y
∫
x
γ
(
x
,
y
)
∥
x
?
y
∥
?
d
x
?
d
y
=
∑
x
,
y
∥
x
?
y
∥
γ
(
x
,
y
)
\mathbb{E}_{(x,y)\sim \gamma} [\|x - y\|] = \int_{y}\int_{x} \gamma(x,y) \|x - y\| \,dx\,dy = \sum_{x,y} \|x - y\| \gamma(x,y)
E(x,y)~γ?[∥x?y∥]=∫y?∫x?γ(x,y)∥x?y∥dxdy=x,y∑?∥x?y∥γ(x,y)
我们应该时刻牢记Wasserstein Distance的目标——算出Pr和Pg之间的距离,所以很容易理解公式中的(x, y)
可以分别对应为左下角图片的Pr和Pg(每根柱子就是x1,x2,x3,…,y1,y2,y3…)
这样一来||x - y||
其实就是一个距离矩阵(每根柱子与另一分布的每个柱子作差)
接下来看到第3张图(也是最重要的一张),它揭示了“我需要做出什么样的努力才能使Pg变成Pr,或者Pr变成Pg”,即什么是γ(x,y)
。
明确几个概念:
在这张热力图中,亮度较高的点表示在联合分布Π
中具有较高的概率密度,这意味着在最优运输问题中,从真实分布 Pr 到生成分布 Pg 的特定值转移概率质量的可能性较大。
比如 Pr 中第1行的亮点位于 Pg 的第2和第4列,表明这部分概率质量主要被转移到了这两个位置。原因可能是:
相反,如果某些点很暗,甚至是黑色,这意味着在这些 (x, y)
对上几乎没有概率质量被转移。原因反之同上。
在热力图中,每个亮点代表了一个 (x, y)
对,而 |x - y|
就是这对之间的距离或成本。γ(x, y)
是这对的联合概率,所以||x - y||·γ(x,y)
实际上代表了从 Pr 到 Pg 的概率质量转移的“强度”或“成本”。
因此,整个公式就是计算在所有可能的 (x, y)
对上,将质量从 x
移动到 y
的期望成本。
总结一下:|x - y|
是分布移动的“距离”,γ(x,y)
是分布移动的“量”,两者相乘就是我们移动的“工作量”。我们的目的是让这个“工作量”尽可能小(如果你看懂了上面的概念就会知道,Wasserstein的转移方法可能会有无穷多种,具体如何找最小是一个找全局最优解的过程,下面会进行数学推导)。
②③在上面也已经顺带解释过了——Π
包含了所有可能的 (x, y)
对,我们的目标是找到一个全局最优解
展示的是Wasserstein距离 W(Pr, Pg)
的几种等价定义,这些定义来自于数学中的对偶性原理。这里涉及的是Kantorovich-Rubinstein对偶性,它允许我们从最优运输问题的原问题(primal problem)转化为其对偶问题(dual problem)。
原问题:最初的问题是寻找最小化运输成本的运输计划,即找到使得期望运输成本最小的联合分布 γ
,这表达为
i
n
f
γ
∈
Π
(
P
r
,
P
g
)
E
(
x
,
y
)
~
γ
[
∥
x
?
y
∥
]
inf_{\gamma \in \Pi(P_r, P_g)} E_{(x,y)\sim\gamma} [\|x - y\|]
infγ∈Π(Pr?,Pg?)?E(x,y)~γ?[∥x?y∥]
对偶问题:通过对偶性原理,我们可以将这个最小化问题转换为一个最大化问题。对偶问题寻找的是满足1-Lipschitz条件的函数集合上的一个最大值,这个条件意味着这些函数的梯度(或者在离散情况下的差分)被限制在1以内。这些函数 f
被称为1-Lipschitz函数,因为它们的斜率(在任意两点之间的斜率)被限制在±1的范围内。
Lipschitz条件:Lipschitz条件是数学中对函数斜率的一种约束,具体来说,如果一个函数是K-Lipschitz的,那么对于所有的 x1
和 x2
,有 |f(x1) - f(x2)| <= K * |x1 - x2|
。在Wasserstein距离的背景下,我们通常关注1-Lipschitz函数,即 K=1
。
这种从最小化到最大化的转变反映了我们可以从寻找实际的运输计划(计算成本)转变为寻找一个函数,该函数能够“衡量”两个分布之间的差异。这个函数在所有可能的情况下给出的期望值差异是最大的,而且这个函数满足Lipschitz条件。因此,在数学上,寻找最优运输计划的最小值问题转化为了寻找衡量分布之间差异的函数的最大值问题。这两个问题在数学上是等价的。
W
(
P
r
,
P
g
)
=
inf
?
γ
∈
Π
(
P
r
,
P
g
)
E
(
x
,
y
)
~
γ
[
∥
x
?
y
∥
]
=
sup
?
∥
f
∥
L
≤
1
E
x
~
P
r
[
f
(
x
)
]
?
E
x
~
P
g
[
f
(
x
)
]
=
max
?
w
∈
W
E
x
~
P
r
[
f
w
(
x
)
]
?
E
z
~
P
z
[
f
w
(
g
θ
(
z
)
)
]
\begin{align} W(P_r, P_g) &= \inf_{\gamma \in \Pi(P_r, P_g)} \mathbb{E}_{(x,y)\sim\gamma} [\|x - y\|] \\ &= \sup_{\|f\|_L\leq 1} \mathbb{E}_{x\sim P_r}[f(x)] - \mathbb{E}_{x\sim P_g}[f(x)] \\ &= \max_{w\in W} \mathbb{E}_{x\sim P_r} [f_w(x)] - \mathbb{E}_{z\sim P_z} [f_w(g_\theta(z))] \end{align}
W(Pr?,Pg?)?=γ∈Π(Pr?,Pg?)inf?E(x,y)~γ?[∥x?y∥]=∥f∥L?≤1sup?Ex~Pr??[f(x)]?Ex~Pg??[f(x)]=w∈Wmax?Ex~Pr??[fw?(x)]?Ez~Pz??[fw?(gθ?(z))]??
最后一个等式是在生成对抗网络(GAN)中的对应。这里 fw
通是判别器网络,gθ
是生成器网络,z
是来自先验分布 Pz
的噪声变量。在GAN的背景下,实际上就是在最大化这个期望值差异相当于训练判别器以最大程度地区分真实数据分布 Pr
和生成数据分布 Pg
。
至此,来看一下WGAN的伪代码
两个要点:
将真图和假图分别送入D,然后会得到一个值,相减做一个L1-Loss即可,然后我们回传这个loss去训练D。
为了满足Lipschitz条件(才能满足Kantorovich-Rubinstein对偶性),这里直接对w
进行了硬截断。
# dataset: mnist
import argparse
import os
import numpy as np
import math
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch
from generator import Generator
from discriminator import Discriminator
os.makedirs("images_wgan", exist_ok=True)
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.00005, help="learning rate")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--n_critic", type=int, default=5, help="number of training steps for discriminator per iter")
parser.add_argument("--clip_value", type=float, default=0.01, help="lower and upper clip value for disc. weights")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)
img_shape = (opt.channels, opt.img_size, opt.img_size)
cuda = True if torch.cuda.is_available() else False
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
if cuda:
generator.cuda()
discriminator.cuda()
# Configure data loader
os.makedirs("./data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
datasets.MNIST(
"./data/mnist",
train=True,
download=True,
transform=transforms.Compose(
[transforms.Resize(opt.img_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])] # [] means channel, 0.5,0.5 means mean & std
# => img = (img - mean) / 0.5 per channel
),
),
batch_size=opt.batch_size,
shuffle=True,
)
# Optimizers
optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=opt.lr)
optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=opt.lr)
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
# ----------
# Training
# ----------
batches_done = 0
for epoch in range(opt.n_epochs):
for i, (imgs, _) in enumerate(dataloader):
# Configure input
real_imgs = imgs.type(Tensor)
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Sample noise as generator input
z = Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))
# Generate a batch of images
fake_imgs = generator(z).detach()
# Adversarial loss: 原始GAN使用的是交叉熵损失函数,而WGAN使用的是Wasserstein损失
loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs)) # 真实图片的判别器输出的负均值+生成图片的判别器输出的均值
loss_D.backward()
optimizer_D.step()
# Clip weights of discriminator: 为了满足Lipschitz约束(函数的梯度必须小于等于1),需要对判别器的权重进行剪辑,以确保它们位于一个固定的很小的区间内
for p in discriminator.parameters():
p.data.clamp_(-opt.clip_value, opt.clip_value)
# Train the generator every n_critic iterations
if i % opt.n_critic == 0:
# -----------------
# Train Generator
# -----------------
optimizer_G.zero_grad()
# Generate a batch of images
gen_imgs = generator(z)
# Adversarial loss
loss_G = -torch.mean(discriminator(gen_imgs))
loss_G.backward()
optimizer_G.step()
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (
epoch, opt.n_epochs, batches_done % len(dataloader), len(dataloader), loss_D.item(), loss_G.item())
)
if batches_done % opt.sample_interval == 0:
save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
batches_done += 1
训练策略(Training Strategy):
opt.n_critic
控制,即每训练 n_critic
次判别器之后,才训练一次生成器。最后生成结果
可以看到每个数字都不一样了,即便是生成错误的也是不一样的错误。对比一下原始GAN的生成结果
而且WGAN还顺便解决了生成图像存在噪声(noise)这一问题。
以上就是WGAN作者Arjovsky 的两篇论文,对于原始GAN的先“破”后“立”。但是你以为这就完了吗?NoNo,Arjovsky 认为他在第二篇文章中“立”的新世界并不完美,于是便又发表了第三篇文章——WGAN III。欲知后事如何,请听下回分解~