x
0
~
q
(
x
0
)
x_0 \sim q(x_0)
x0?~q(x0?)是真实数据分布,扩散模型学习一个分布
p
θ
(
x
0
)
p_\theta(x_0)
pθ?(x0?)去逼近真实数据分布。
p
θ
(
x
0
)
:
=
∫
p
θ
(
x
0
:
T
)
d
x
1
:
T
(1)
p_\theta(x_0) := \int p_\theta(x_{0:T})dx_{1:T} \tag{1}
pθ?(x0?):=∫pθ?(x0:T?)dx1:T?(1)
x
1
,
.
.
.
,
x
T
x_1,...,x_T
x1?,...,xT?是和数据
x
0
~
q
(
x
0
)
x_0 \sim q(x_0)
x0?~q(x0?)相同维度的隐变量。联合概率分布
p
θ
(
x
0
:
T
)
p_\theta(x_{0:T})
pθ?(x0:T?)称为reverse process,逆过程,去噪过程。被定义为从
p
(
x
T
)
=
N
(
x
T
;
0
,
I
)
p(x_T)=N(x_T;\bold0,\bold I)
p(xT?)=N(xT?;0,I)开始的马尔可夫链,转移矩阵为高斯分布。
p
θ
(
x
0
:
T
)
:
=
p
(
x
T
)
∏
t
=
1
T
p
θ
(
x
t
?
1
∣
x
t
)
(2)
p_\theta(x_{0:T}) :=p(x_T)\prod_{t=1}^T p_\theta(x_{t-1}|x_t) \tag{2}
pθ?(x0:T?):=p(xT?)t=1∏T?pθ?(xt?1?∣xt?)(2)
p
θ
(
x
t
?
1
∣
x
t
)
:
=
N
(
x
t
?
1
;
μ
θ
(
x
t
,
t
)
,
Σ
θ
(
x
t
,
t
)
)
(3)
p_\theta(x_{t-1}|x_t) :=N(x_{t-1}; \mu_\theta(x_t, t), \Sigma_\theta(x_t, t)) \tag{3}
pθ?(xt?1?∣xt?):=N(xt?1?;μθ?(xt?,t),Σθ?(xt?,t))(3)
均值和方差是
x
t
,
t
x_t, t
xt?,t的函数,标准高斯分布有了均值和方差,就可以从
x
t
x_t
xt?中采样出
x
t
?
1
x_{t-1}
xt?1?。
diffusion模型不同于其他隐变量模型的地方在于,近似后验分布
q
(
x
1
:
T
∣
x
0
)
q(x_{1:T}|x_0)
q(x1:T?∣x0?),一般也被称为前向过程或者diffusion过程,是一个马尔可夫链。可以根据方差调度值
β
1
,
.
.
.
,
β
T
\beta_1,..., \beta_T
β1?,...,βT?逐步对数据
x
0
x_0
x0?加噪声。
q
(
x
1
:
T
∣
x
0
)
:
=
∏
t
=
1
T
q
(
x
t
∣
x
t
?
1
)
(4)
q(x_{1:T}|x_0) := \prod_{t=1}^Tq(x_t|x_{t-1}) \tag{4}
q(x1:T?∣x0?):=t=1∏T?q(xt?∣xt?1?)(4)
q
(
x
t
∣
x
t
?
1
)
:
=
N
(
x
t
;
,
1
?
β
t
x
t
?
1
,
β
t
I
)
(5)
q(x_t|x_{t-1}) := N(x_t;, \sqrt{1-\beta_t}x_{t-1}, \beta_t\bold I) \tag{5}
q(xt?∣xt?1?):=N(xt?;,1?βt??xt?1?,βt?I)(5)
我们定义:
a
t
:
=
1
?
β
t
,
a
ˉ
t
:
=
∏
s
=
1
t
α
s
(6)
a_t := 1 - \beta_t, \quad \bar{a}_t := \prod_{s=1}^{t} \alpha_s \tag{6}
at?:=1?βt?,aˉt?:=s=1∏t?αs?(6)
x
t
=
α
t
x
t
?
1
+
1
?
α
t
?
t
,
?
t
~
N
(
0
,
I
)
(7)
x_t = \sqrt{\alpha_t}x_{t-1} + \sqrt{1-\alpha_t}\epsilon_t, \quad \epsilon_t \sim N(\bold0, \bold I) \tag{7}
xt?=αt??xt?1?+1?αt???t?,?t?~N(0,I)(7)
x
t
?
1
=
α
t
?
1
x
t
?
2
+
1
?
α
t
?
1
?
t
?
1
,
?
t
?
1
~
N
(
0
,
I
)
(8)
x_{t-1} = \sqrt{\alpha_{t-1}}x_{t-2} + \sqrt{1-\alpha_{t-1}}\epsilon_{t-1}, \quad \epsilon_{t-1} \sim N(\bold0, \bold I) \tag{8}
xt?1?=αt?1??xt?2?+1?αt?1???t?1?,?t?1?~N(0,I)(8)
x
t
=
α
t
(
α
t
?
1
x
t
?
2
+
1
?
α
t
?
1
?
t
?
1
)
+
1
?
α
t
?
t
=
α
t
α
t
?
1
x
t
?
2
+
α
t
?
α
t
α
t
?
1
?
t
?
1
+
1
?
α
t
?
t
=
N
(
x
t
;
α
t
α
t
?
1
x
t
?
2
,
1
?
α
t
α
t
?
1
I
)
=
α
t
α
t
?
1
x
t
?
2
+
1
?
α
t
α
t
?
1
?
~
t
.
.
.
.
.
.
=
N
(
x
t
;
α
ˉ
t
x
0
,
1
?
α
ˉ
t
I
)
(9)
\begin{aligned} x_t &= \sqrt{\alpha_t}(\sqrt{\alpha_{t-1}}x_{t-2} + \sqrt{1-\alpha_{t-1}}\epsilon_{t-1}) + \sqrt{1-\alpha_t}\epsilon_t \\ &= \sqrt{\alpha_t\alpha_{t-1}}x_{t-2} + \sqrt{\alpha_t - \alpha_t\alpha_{t-1}}\epsilon_{t-1} + \sqrt{1-\alpha_t}\epsilon_t \\ &=N(x_t; \sqrt{\alpha_t\alpha_{t-1}}x_{t-2}, \sqrt{1-\alpha_t\alpha_{t-1}}\bold I) \\ &=\sqrt{\alpha_t\alpha_{t-1}}x_{t-2} + \sqrt{1-\alpha_t\alpha_{t-1}}\tilde{\epsilon}_t \\ & ...... \\ &= N(x_t; \sqrt{\bar{\alpha}_t}x_0, \sqrt{1 - \bar{\alpha}_t} \bold I) \tag{9} \end{aligned}
xt??=αt??(αt?1??xt?2?+1?αt?1???t?1?)+1?αt???t?=αt?αt?1??xt?2?+αt??αt?αt?1???t?1?+1?αt???t?=N(xt?;αt?αt?1??xt?2?,1?αt?αt?1??I)=αt?αt?1??xt?2?+1?αt?αt?1???~t?......=N(xt?;αˉt??x0?,1?αˉt??I)?(9)
这个性质很重要,意味着可以不需要迭代过程,直接获得任意时间t的加噪数据。正常来说T都比较大,DDPM设为1000,
a
t
=
1
?
β
t
∈
[
0
,
1
]
a_t = 1 - \beta_t \in [0, 1]
at?=1?βt?∈[0,1], 根据极限可知,随着t越来越大,最终加噪后的数据分布趋近于各向同性的标准高斯分布。也为reverse process从一个标准高斯分布采样开始逐步去噪得到最终sample的过程,两相契合。
forward process是加噪过程,也是训练过程,从数据集中采样
x
0
~
q
(
x
0
)
x_0 \sim q(x_0)
x0?~q(x0?),随机选取timestep t, 根据式(9)得到
x
t
x_t
xt?,
x
t
x_t
xt?和
t
t
t做为网络输入,估算后验分布
q
(
x
t
?
1
∣
x
t
,
x
0
)
q(x_{t-1}|x_t, x_0)
q(xt?1?∣xt?,x0?),假设后验分布为高斯分布,则估算的就是高斯分布的均值和方差,式(11)和(12)就是网络学习时,均值和方差的gt。DDPM这篇工作假设方差是预定义好的,不需要网络学习。只需要学习均值即可。
q
(
x
t
?
1
∣
x
t
,
x
0
)
=
N
(
x
t
?
1
;
μ
~
t
(
x
t
,
x
0
)
,
β
t
~
I
)
(10)
q(x_{t-1}|x_t, x_0) = N(x_{t-1}; \tilde{\mu}_t(x_t, x_0), \tilde{\beta_t}\bold I) \tag{10}
q(xt?1?∣xt?,x0?)=N(xt?1?;μ~?t?(xt?,x0?),βt?~?I)(10)
where
μ
~
t
(
x
t
,
x
0
)
:
=
α
ˉ
t
?
1
β
t
1
?
α
ˉ
t
x
0
+
α
t
(
1
?
α
ˉ
t
?
1
)
1
?
α
ˉ
t
x
t
(11)
\tilde{\mu}_t(x_t, x_0) :=\frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}x_0 + \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})} {1-\bar{\alpha}_t} x_t \tag{11}
μ~?t?(xt?,x0?):=1?αˉt?αˉt?1??βt??x0?+1?αˉt?αt??(1?αˉt?1?)?xt?(11)
and
β
~
t
:
=
1
?
α
ˉ
t
?
1
1
?
α
ˉ
t
β
t
(12)
\tilde{\beta}_t := \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \beta_t \tag{12}
β~?t?:=1?αˉt?1?αˉt?1??βt?(12)
(这个地方有空再来推导吧)
网络收敛后,就可以从
x
T
~
N
(
0
,
I
)
x_T\sim N(\bold 0, \bold I)
xT?~N(0,I)采样开始。逐步去噪,得到最终的样本。
网络学习和输出的是t时刻的噪声。根据下式得到均值:
μ
θ
(
x
t
,
t
)
=
1
α
t
(
x
t
?
β
t
1
?
α
ˉ
t
?
θ
(
x
t
,
t
)
)
\mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha}_t}(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(x_t, t))
μθ?(xt?,t)=α?t?1?(xt??1?αˉt??βt???θ?(xt?,t))
采样
x
t
?
1
~
p
θ
(
x
t
?
1
∣
x
t
)
x_{t-1}\sim p_\theta(x_{t-1}|x_t)
xt?1?~pθ?(xt?1?∣xt?)可以通过
x
t
?
1
=
μ
θ
(
x
t
,
t
)
+
σ
z
x_{t-1}=\mu_\theta(x_t, t) + \sigma z
xt?1?=μθ?(xt?,t)+σz得到,
z
~
N
(
0
,
I
)
z\sim N(\bold 0, \bold I)
z~N(0,I)。
DDPM的优点就不说了,缺点主要有两个,推理过程步长太长,过于耗时。
β
\beta
β的设计导致加噪到T时刻,信噪比SNR不为0,加噪对原始数据分布破坏的不彻底,得到的不是真实的高斯分布噪声,原始数据分布中的一些低频信息泄露,导致文生图任务中,即便强prompt引导,生成的图片亮度也是围绕到0周围,无法产生过亮或者过暗的图片。
解决DDPM的步长问题。
进一步解决DDPM的步长问题。
解决常规 β \beta β调度策略无法产生zero SNR的问题。