相比于单步生成的模型(例如 GANs, VAEs, normalizing flows
),扩散模型的迭代式生成过程需要 10 到 2000 步计算来采样,导致推理速度低,实时性应用受限.
本文的目的是创造高效、单步的生成,同时不牺牲迭代采样的优势。在数据到噪声的 PF-ODE 轨迹上,学习轨迹上任意点到轨迹起点的映射,对这些映射的建模成为 consistency model.
两种训练 consistency model的方法
consistency model
.consistency model
.在一些数据集上测试.
使用 p d a t a ( x ) p_{data}(\mathrm{x}) pdata?(x)表示数据分布,扩散模型使用如下随机微分公式对服从原分布的数据进行扩散:
d x t = μ ( , x t , t ) + σ ( t ) d w t \large \mathrm{dx}_t = \mu(\mathrm,{x}_t, t) + \sigma(t)\mathrm{dw}_t dxt?=μ(,xt?,t)+σ(t)dwt?
其中 t t t为时间步,范围是 0 0 0 到 T T T, μ ( ? , ? ) \mu(·,·) μ(?,?) 和 σ ( ? ) \sigma(·) σ(?)分别是布朗运动中的漂移系数和扩散系数, x t \mathbf{x}_t xt?服从分布 p t ( x ) p_{t}(\mathrm{x}) pt?(x), x 0 \mathrm{x}_0 x0?服从分布 p d a t a ( x ) p_{data}(\mathrm{x}) pdata?(x). 该方程的一个重要属性是,其存在一个 PF-ODE 方程:
d x t = [ μ ( x t , t ) ? 1 2 σ ( t ) 2 ? log ? p t ( x t ) ] d t \large\mathrm{dx}_t = \left[ \mu(\mathrm{x}_t, t)-\frac{1}{2}\sigma(t)^2 \nabla\log{p_t(\mathrm{x}_t)} \right]\mathrm{d}t dxt?=[μ(xt?,t)?21?σ(t)2?logpt?(xt?)]dt
其中
?
log
?
p
t
(
x
)
\nabla\log{p_t(\mathrm{x})}
?logpt?(x)是
p
t
(
x
)
p_t(\mathrm{x})
pt?(x)的 score function.
在 SDE 中,令漂移系数
μ
(
x
,
t
)
=
0
\mu(\mathrm{x}, t) = 0
μ(x,t)=0, 扩散系数
σ
(
t
)
=
2
t
\sigma(t) = \sqrt{2t}
σ(t)=2t?. 使用得分匹配的方式训练模型
s
?
(
x
,
t
)
≈
?
log
?
p
t
(
x
)
s_{\phi}(\mathrm{x},t) \approx \nabla\log{p_t(\mathrm{x})}
s??(x,t)≈?logpt?(x),代入 PF-ODE 方程,得到 empirical PF-ODE:
d x t d t = ? t s ? ( x t , t ) \large \frac{\mathrm{dx}_t}{\mathrm{d}t}=-ts_{\phi}(\mathrm{x}_t,t) dtdxt??=?ts??(xt?,t)
采样时,使用 x ^ T ~ N ( 0 , T 2 I ) \hat{\mathrm{x}}_T\sim\mathcal{N}(0, T^2I) x^T?~N(0,T2I)初始化,再使用 numerical ODE solver(例如 Euler, Heun)按时间步倒推出 x ^ 0 \hat{x}_0 x^0?. 为了防止数值不稳定,会在 t = ? t=\epsilon t=?是提前终止, ? \epsilon ?为一个正小数,同时将 x ^ ? \hat{\mathrm{x}}_{\epsilon} x^??作为结果.
扩散模型的瓶颈在于采样速度慢, ODE solver 利用得分模型 s ? ( x , t ) s_{\phi}(\mathrm{x},t) s??(x,t)迭代求解,消耗算力多. 目前存在一些更快的 ODE solver,但是仍然需要大于 10 10 10 步的采样. 也存在一些蒸馏方法,但是大多数方法需要从扩散模型中采集巨大的数据集,同样消耗算力多.
根据 PF-ODE 得到一条解路径 { x t } t ∈ [ ? , T ] \{\mathrm{x}_t\}_{t\in[\epsilon, T]} {xt?}t∈[?,T]?,将 consistency function 定义为:
f : ( x t , t ) ? x ? \large f:(\mathrm{x}_t, t) \mapsto \mathrm{x}_{\epsilon} f:(xt?,t)?x??
对于该路径上的任意点
(
x
t
,
t
)
(\mathrm{x}_t, t)
(xt?,t),其输出是一致的. 对于任意的
t
,
t
′
∈
[
?
,
T
]
t, t' \in [\epsilon, T]
t,t′∈[?,T],有
f
(
x
t
,
t
)
=
f
(
x
t
′
,
t
′
)
f(\mathrm{x}_t, t) =f(\mathrm{x}_{t'}, t')
f(xt?,t)=f(xt′?,t′)恒成立.
令 F θ ( x , t ) F_{\theta}(\mathrm{x}, t) Fθ?(x,t)表示任意形式的神经网络,使用 sikp connection 可以将模型表示为:
f θ ( x , t ) = c s k i p ( t ) x + c o u t ( t ) F θ ( x , t ) \large f_{\theta}(\mathrm{x}, t)=c_{skip}(t)\mathrm{x}+c_{out}(t)F_{\theta}(\mathrm{x},t) fθ?(x,t)=cskip?(t)x+cout?(t)Fθ?(x,t)
其中边界条件为
c
s
k
i
p
(
?
)
=
1
c_{skip}(\epsilon)=1
cskip?(?)=1,
c
o
u
t
(
?
)
=
0
c_{out}(\epsilon)=0
cout?(?)=0.
具体为:
c s k i p ( t ) = σ d a t a 2 ( t ? ? ) 2 + σ d a t a 2 \large c_{skip}(t)=\frac{\sigma_{data}^2}{(t-\epsilon)^2+\sigma_{data}^2} cskip?(t)=(t??)2+σdata2?σdata2??
c o u t ( t ) = σ d a t a ( t ? ? ) σ d a t a 2 + t 2 \large c_{out}(t)=\frac{\sigma_{data}(t-\epsilon)}{\sqrt{\sigma_{data}^2+t^2}} cout?(t)=σdata2?+t2?σdata?(t??)?
σ d a t a \sigma_{data} σdata?取值 0.5 0.5 0.5.
有了一个训练好的 consistency model f θ ( ? , ? ) f_{\theta}(·, ·) fθ?(?,?)之后,从高斯噪声 N ( 0 , T 2 I ) \mathcal{N}(0, T^2I) N(0,T2I)采样 x ^ T \hat{\mathrm{x}}_T x^T?,再代入模型一步推出 x ^ ? = f θ ( x T ^ , T ) \hat{\mathrm{x}}_{\epsilon}=f_{\theta}(\hat{\mathrm{x}_T}, T) x^??=fθ?(xT?^?,T).为了提高质量,也可以进行多步采样,算法如下:
作者的第一个方法是在预训练的得分模型 s ? ( x , t ) s_{\phi}(\mathrm{x},t) s??(x,t)上蒸馏.
首先考虑将 ? \epsilon ?到 T T T的时间离散化成 N ? 1 N-1 N?1 个间隔,也即 t 1 = ? < t 2 < t 3 < . . . < t N = T t_1=\epsilon<t_2<t_3<...<t_N=T t1?=?<t2?<t3?<...<tN?=T. 在实践中,使用如下公式:
t i = ( ? 1 / ρ + i ? 1 N ? 1 ( T 1 / ρ ? ? 1 / ρ ) ) ρ \large t_i=\left(\epsilon^{1/\rho} + \frac{i-1}{N-1}\left(T^{1/\rho}-\epsilon^{1/\rho}\right) \right)^{\rho} ti?=(?1/ρ+N?1i?1?(T1/ρ??1/ρ))ρ
其中 ρ = 7 \rho=7 ρ=7. 当 N N N充分大时,可以获得 x t n \mathrm{x}_{t_n} xtn??到 x t n + 1 \mathrm{x}_{t_{n+1}} xtn+1??的准确估计,于是 x ^ t n ? \hat{\mathrm{x}}_{t_n}^{\phi} x^tn???可以定义为:
x ^ t n ? = x t n + 1 + ( t n ? t n + 1 ) Φ ( x t n + 1 , t n + 1 ; ? ) \large \hat{\mathrm{x}}_{t_n}^{\phi}=\mathrm{x}_{t_{n+1}} + (t_n-t_{n+1})\Phi(\mathrm{x}_{t_{n+1}}, t_{n+1};\phi) x^tn???=xtn+1??+(tn??tn+1?)Φ(xtn+1??,tn+1?;?)
Φ ( . . . ; ? ) \Phi(...;\phi) Φ(...;?)为 one-step ODE solver(比如Euler).
从数据集中采样 x \mathrm{x} x,通过 SDE 加噪 N ( x , t n + 1 2 I ) \mathcal{N}(\mathrm{x}, t_{n+1}^2I) N(x,tn+12?I)得到 x t n + 1 \mathrm{x}_{t_{n+1}} xtn+1??, 然后使用 ODE solver 求解出 x ^ t n ? \hat{\mathrm{x}}_{t_n}^{\phi} x^tn???,通过最小化在 x ^ t n ? \hat{\mathrm{x}}_{t_n}^{\phi} x^tn??? 和 x t n + 1 \mathrm{x}_{t_{n+1}} xtn+1??计算结果的差距训练模型.
Definition 1
consistency distillation loss (CD)
表示为:
L C D N ( θ , θ ? ; ? ) = E [ λ ( t n ) d ( f θ ( x t n + 1 , t n + 1 ) , f θ ? ( x ^ t n ? , t n ) ] \large \mathcal{L}_{CD}^{N}(\theta, \theta^-;\phi)=\mathbb{E}\left[\lambda(t_n)d(f_{\theta}(\mathrm{x}_{t_{n+1}},t_{n+1}),f_{\theta^-}(\hat{\mathrm{x}}_{t_n}^{\phi}, t_n) \right] LCDN?(θ,θ?;?)=E[λ(tn?)d(fθ?(xtn+1??,tn+1?),fθ??(x^tn???,tn?)]
其中, λ ( ? ) ∈ R + \lambda(·)\in\mathbb{R}^+ λ(?)∈R+是正权重函数, θ ? \theta^- θ?是 θ \theta θ在优化过程中历史值的均值. d ( ? , ? ) d(·,·) d(?,?)是一个度量函数,满足当且仅当两个输入相等时为 0 0 0,其余情况大于 0 0 0.
作者考虑 d ( ? , ? ) d(·,·) d(?,?) 使用 l 1 l_1 l1? 以及 l 2 l_2 l2?,在实验中 λ ( t n ) ≡ 1 \lambda(t_n) \equiv1 λ(tn?)≡1表现较好. θ ? \theta^- θ?使用 EMA 更新,计算公式如下:
θ ? ← s t o p g a r d ( μ θ ? + ( 1 ? μ ) θ ) \large \theta^- \leftarrow \mathrm{stopgard}(\mu\theta^-+(1-\mu)\theta) θ?←stopgard(μθ?+(1?μ)θ)
其中
0
≤
μ
<
1
0\le\mu<1
0≤μ<1. 使用 EMA 可以使训练更稳定,同时能提高模型的表现.
模型训练算法如下:
consistency model
可以不依赖预训练扩散模型训练,使用如下无偏估计替换
?
log
?
p
t
(
x
)
\nabla\log{p_t(\mathrm{x})}
?logpt?(x):
? log ? p t ( x ) = ? E [ x t ? x t 2 ∣ x t ] \large \nabla\log{p_t(\mathrm{x})}=-\mathbb{E}\left[\left.\frac{\mathrm{x}_t-\mathrm{x}}{t^2}\right|\mathrm{x}_t \right] ?logpt?(x)=?E[t2xt??x? ?xt?]
consistency training loss (CT)
表示为:
L C D N ( θ , θ ? ) = E [ λ ( t n ) d ( f θ ( x + t n + 1 z , t n + 1 ) , f θ ? ( x + t n z , t n ) ] \large \mathcal{L}_{CD}^{N}(\theta, \theta^-)=\mathbb{E}\left[\lambda(t_n)d(f_{\theta}(\mathrm{x}+t_{n+1}\mathrm{z},t_{n+1}),f_{\theta^-}(\mathrm{x}+t_{n}\mathrm{z},t_{n}) \right] LCDN?(θ,θ?)=E[λ(tn?)d(fθ?(x+tn+1?z,tn+1?),fθ??(x+tn?z,tn?)]
其中 z ~ N ( 0 , I ) \mathrm{z}\sim\mathcal{N}(0,I) z~N(0,I). 损失函数的计算依赖于 f θ f_{\theta} fθ?和 f θ ? f_{\theta^-} fθ??,且与扩散模型的无关.
为了提升模型效果,使用 schedule function
N
(
?
)
N(·)
N(?)控制
N
N
N 增长. 直觉上,当
N
N
N 小的时候,使用 consistency distillation loss
模型在一开始收敛更快,同时方差小、偏差大. 反之,在训练结束时,应当使
N
N
N 大,这样方差大、偏差小。同时,使用 schedule function
μ
(
?
)
\mu(·)
μ(?)替换
μ
\mu
μ,让它随着
N
N
N 增长而变化.
N
(
?
)
N(·)
N(?)和
μ
(
?
)
\mu(·)
μ(?)具体为
N ( k ) = ? k K ( ( s 1 + 1 ) 2 ? s 0 2 ) + s 0 2 ? 1 ? + 1 \large N(k)= \left\lceil\sqrt{\frac{k}{K}((s_1+1)^2-s_0^2)+s_0^2}-1 \right\rceil+1 N(k)= ?Kk?((s1?+1)2?s02?)+s02???1 ?+1
μ ( k ) = exp ? ( s 0 log ? μ 0 N ( k ) ) \large \mu(k)=\exp\left(\frac{s_0\log{\mu_0}}{N(k)}\right) μ(k)=exp(N(k)s0?logμ0??)
K K K表示整体训练步数, s 0 s_0 s0?表示开始的离散化步数.
训练算法如下:
关于 CD ,作者分别使用
l
1
l_1
l1?,
l
2
l_2
l2?,
L
P
I
P
S
\mathrm{LPIPS}
LPIPS作为度量函数,使用一阶Euler和二阶Heun座位 ODE solver,
N
N
N取
{
9
,
12
,
18
,
36
,
50
,
60
,
80
,
120
}
\{9,12,18,36,50,60,80,120\}
{9,12,18,36,50,60,80,120},使用相应的预训练扩散模型做初始化. 使用 CT 训练的模型则随机初始化.
(a) 对比不同的度量函数在 CD
上的表现,其中 LPIPS 的效果最好.
(b, c) 对不不同 ODE solver 和
N
N
N在 CD
上的表现,使用 Heun 且
N
N
N取
18
18
18时效果最好.在取相同的
N
N
N时,二阶Heun的表现优于一阶Euler,因为高阶的 ODE solver 的估计误差更小. 当
N
N
N充分大时,模型对
N
N
N变得不敏感.
(d) 根据之前的结论,关于 CT
的实验使用 LPIPS 作为度量函数. 更小的
N
N
N收敛更快,但是采样结构更差;使用自适应的
N
(
?
)
N(·)
N(?)和
μ
(
?
)
\mu(·)
μ(?)效果最好.
对比 CD
和 progressive disillation(PD)
在不同数据集上的效果,CD
的表现普遍比 PD
好.
对比 CT
和其它生成模型,仅使用一步或两步生成.
Zero-Shot Image Editing