[2205.14135] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
中文:FlashAttention:一种具有 IO 感知,且兼具快速、内存高效的新型注意力算法
科研团队:斯坦福大学计算机系+纽约州立大学布法罗分校
发表时间:20220527
结合背景123,可以发现更大的模型在更长的上下文背景上还存在着一定的挑战。
背景4:计算读写开销:论文GPU内不同存储系统的速度举例如下:
在此背景之下,有人提出一些近似自注意力的方法,旨在减少注意力计算和内存需求。
缺点:尽管这些方法可以将计算降低到线性或接近线性,但它们过于关注降低每秒所执行的浮点运算次数(FLops),换句话说更倾向于单纯降低计算复杂度。忽略来自内存访问(IO)的开销。不能实现更高且更有实用价值的计算加速范式。
(更详细的推导过程和描述可以参考前文)
Attention机制
其核心为计算输入向量的相关程度
,例如在翻译过程中,不同的英文对中文的依赖程度不同,Attention机制
通常可以进行如下描述
?Attention? ( Q , K , V ) ∈ R m × d v = softmax ? ( Q K T d k ) V \text { Attention }(Q, K, V) \in\mathbb R^{m \times d_{v}}=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}\right) V ?Attention?(Q,K,V)∈Rm×dv?=softmax(dk??QKT?)V
为什么可以计算得到不同输入向量之间的得分
矩阵乘法
假设共有十个输入向量,每个向量的长度为512,也即为 m = 10 m=10 m=10, d k = 512 d_k=512 dk?=512
Q = ( q 1 [ 0 ] ? q 1 [ d k ] ? ? ? q 10 [ 0 ] ? q 10 [ 511 ] ) = ( q 1 ? ? q 10 ? ) Q=\left(\begin{array}{ccc} q_{1}[0] & \cdots & q_{1}[d_k] \\ \vdots & \cdots & \vdots \\ q_{10}[0] & \cdots & q_{10}[511] \end{array}\right) = \left(\begin{array}{c}\vec{q_{1}}\\\vdots\\ \vec{q_{10}} \end{array}\right) Q= ?q1?[0]?q10?[0]?????q1?[dk?]?q10?[511]? ?= ?q1???q10??? ?
K = ( k 1 [ 0 ] ? k 1 [ 511 ] ? ? ? k 10 [ 0 ] ? k 10 [ 511 ] ) = ( k 1 ? ? k 10 ? ) K=\left(\begin{array}{ccc}k_{1}[0] & \cdots & k_{1}[511] \\\vdots & \cdots & \vdots \\k_{10}[0] & \cdots & k_{10}[511]\end{array}\right) = \left(\begin{array}{c}\vec{k_{1}}\\\vdots\\ \vec{k_{10}} \end{array}\right) K= ?k1?[0]?k10?[0]?????k1?[511]?k10?[511]? ?= ?k1???k10??? ?
相乘结果如下
Q
?
K
T
∈
R
m
×
m
=
(
q
1
?
?
q
10
?
)
?
(
k
1
?
T
?
k
10
?
T
)
(
q
1
?
?
k
1
?
T
?
q
1
?
?
k
10
?
T
?
?
?
q
10
?
?
k
1
?
T
?
q
10
?
?
k
10
?
T
)
=
(
s
1
?
1
?
s
1
?
10
?
?
?
s
10
?
1
?
s
10
?
10
)
Q \cdot K^T \in \mathbf{R}^{m\times m}= \left(\begin{array}{c}\vec{q_{1}}\\\vdots\\ \vec{q_{10}} \end{array}\right) \cdot \left(\vec{k_{1}}^T\cdots \vec{k_{10}}^T\right) \left(\begin{array}{ccc} \vec{q_{1}}\cdot\vec{k_{1}}^T & \cdots & \vec{q_{1}}\cdot\vec{k_{10}}^T \\\vdots & \cdots & \vdots \\\vec{q_{10}}\cdot\vec{k_{1}}^T& \cdots & \vec{q_{10}}\cdot\vec{k_{10}}^T\end{array}\right) =\left(\begin{array}{ccc}s_{1-1} & \cdots & s_{1-10} \\\vdots & \cdots & \vdots \\s_{10-1} & \cdots & s_{10-10}\end{array}\right)
Q?KT∈Rm×m=
?q1???q10???
??(k1??T?k10??T)
?q1???k1??T?q10???k1??T?????q1???k10??T?q10???k10??T?
?=
?s1?1??s10?1??????s1?10??s10?10??
?
矩阵 S S S中的每一个元素通过分别来自于 Q \mathbf{Q} Q 和 K \mathbf{K} K的两个向量的点乘得到的,通过最原始的矩阵定义,可以得知两个向量的点乘意味着一个向量在另一个向量的投影,也可以李继伟表示向量 q i ? \vec{q_{i}} qi??, k j ? \vec{k_j} kj??的相似程度
除以一个数字 d k \sqrt{d_{k}} dk??的意义是:
softmax函数
被推动到非常平缓的区域。将得到scaling后的相似度进行Softmax
操作,假定Scaling之后相似度矩阵为
(
s
1
?
1
′
?
s
1
?
m
′
?
?
?
s
m
?
1
′
?
s
m
?
m
′
)
=
(
s
1
?
1
/
d
k
?
s
1
?
m
/
d
k
?
?
?
s
m
?
1
/
d
k
?
s
m
?
m
/
d
k
)
\left(\begin{array}{ccc}s'_{1-1} & \cdots & s'_{1-m} \\\vdots & \cdots & \vdots \\ s'_{m-1} & \cdots & s'_{m-m}\end{array}\right) = \left(\begin{array}{ccc}s_{1-1}/\sqrt{d_{k}} & \cdots & s_{1-m}/\sqrt{d_{k}} \\\vdots & \cdots & \vdots \\s_{m-1}/\sqrt{d_{k}} & \cdots & s_{m-m}/\sqrt{d_{k}}\end{array}\right)
?s1?1′??sm?1′??????s1?m′??sm?m′??
?=
?s1?1?/dk???sm?1?/dk???????s1?m?/dk???sm?m?/dk???
?
进行归一化
(
s
1
?
1
′
′
?
s
1
?
m
′
′
?
?
?
s
m
?
1
′
′
?
s
m
?
m
′
)
=
(
e
s
1
?
1
′
∑
i
=
1
m
e
s
1
?
i
′
?
e
s
1
?
m
′
∑
i
=
1
m
e
s
1
?
i
′
?
?
?
e
s
m
?
1
′
∑
i
=
1
m
e
s
m
?
i
′
?
e
s
m
?
m
′
∑
i
=
1
m
e
s
m
?
i
′
)
\left(\begin{array}{ccc}s''_{1-1} & \cdots & s''_{1-m} \\\vdots & \cdots & \vdots \\ s''_{m-1} & \cdots & s'_{m-m}\end{array}\right) = \left(\begin{array}{ccc}\frac{e^{s'_{1-1}}} {\sum_{i=1}^{m} e^{s'_{1-i}} } & \cdots & \frac{e^{s'_{1-m}}} {\sum_{i=1}^{m} e^{s'_{1-i}} } \\\vdots & \cdots & \vdots \\ \frac{e^{s'_{m-1}}} {\sum_{i=1}^{m} e^{s'_{m-i}} } & \cdots & \frac{e^{s'_{m-m}}} {\sum_{i=1}^{m} e^{s'_{m-i}} } \end{array}\right)
?s1?1′′??sm?1′′??????s1?m′′??sm?m′??
?=
?∑i=1m?es1?i′?es1?1′???∑i=1m?esm?i′?esm?1′???????∑i=1m?es1?i′?es1?m′???∑i=1m?esm?i′?esm?m′???
?
如此实现一横行的加权和为1,不同的 v i v_i vi? 向量获得的加权综合为1
针对计算出来的权重
α
i
\alpha_{i}
αi?,通过权重对
V
V
V中所有的values
进行加权求和计算,得到Attention向量
(
s
1
?
1
′
?
s
1
?
m
′
?
?
?
s
m
?
1
′
?
s
m
?
m
′
)
(
v
1
?
?
v
m
?
)
\left(\begin{array}{ccc}s'_{1-1} & \cdots & s'_{1-m} \\\vdots & \cdots & \vdots \\ s'_{m-1} & \cdots & s'_{m-m}\end{array}\right)\left(\begin{array}{c}\vec{v_{1}}\\\vdots\\ \vec{v_{m}} \end{array}\right)
?s1?1′??sm?1′??????s1?m′??sm?m′??
?
?v1???vm???
?
#########Standard Attention Implementation
Require: Matrices Q, K, V ∈ R^{N×d} in HBM.
1: Load Q, K by blocks from HBM, compute S = QK^{T}, write S to HBM.
2: Read S from HBM, compute P = softmax(S), write P to HBM.
3: Load P and V by blocks from HBM, compute O = PV, write O to HBM.
4: Return O.
attention抽象
为对value
的每个表示(token
)进行加权,而加权的weight
就是 attention weight
,而 attention weight
就是根据 query
和 key
计算得到,其意义为:为了用 value
求出 query
的结果, 根据 query
和 key
来决定注意力应该放在value
的哪部分。在标准注意力实现中,注意力的性能主要受限于内存带宽,是内存受限的。频繁地从HBM中读写 R N × N \mathbb{R}^{N \times N} RN×N的矩阵是影响性能的主要瓶颈。稀疏近似和低秩近似等近似注意力方法虽然减少了计算量FLOPs,但对于内存受限的操作,运行时间的瓶颈是从HBM中读写数据的耗时,减少计算量并不能有效地减少运行时间(wall-clock time)。针对内存受限的标准注意力,Flash Attention是IO感知的,目标是避免频繁地从HBM中读写数据。
从GPU显存分级来看,SRAM的读写速度比HBM高一个数量级,但内存大小要小很多。通过kernel融合的方式,将多个操作融合为一个操作,利用高速的SRAM进行计算,可以减少读写HBM的次数,从而有效减少内存受限操作的运行时间。但SRAM的内存大小有限,不可能一次性计算完整的注意力,因此必须进行分块计算,使得分块计算需要的内存不超过SRAM的大小。
问题一:为什么要进行分块计算呢?
内存受限 --> 减少HBM读写次数 --> kernel融合 --> 满足SRAM的内存大小 --> 分块计算
因此分块大小block_size不能太大,否则会导致存储内容踢出。
问题二:分块计算的难点是什么呢?
注意力机制的计算过程是“矩阵乘法 --> scale --> mask --> softmax --> dropout --> 矩阵乘法”,矩阵乘法和逐点操作(scale,mask,dropout)的分块计算是容易实现的,难点在于softmax的分块计算。由于计算softmax的归一化因子(分母)时,需要获取到完整的输入数据,进行分块计算的难度比较大。论文中也是重点对softmax的分块计算进行了阐述。
tiling的主要思想是分块计算注意力。分块计算的难点在于softmax的分块计算,softmax与矩阵 K K K 的列是耦合的,通过引入了两个额外的统计量 m ( x ) m(x) m(x), l ( x ) l(x) l(x)来进行解耦,实现了分块计算。需要注意的是,可以利用GPU多线程同时并行计算多个block的softmax。为了充分利用硬件性能,多个block的计算不是串行(sequential)的, 而是并行的。
对向量 A = [ 1 , 2 , 3 , 4 ] A = [1,2,3,4] A=[1,2,3,4] 计算Softmax,分成两块 A 1 = [ 1 , 2 ] A_1 = [1,2] A1?=[1,2] 和 A 2 = [ 3 , 4 ] A_2 = [3,4] A2?=[3,4] 进行计算。 计算block1和block2:
block1
m
1
=
m
a
x
(
[
1
,
2
]
)
=
2
f
1
=
[
e
1
?
m
1
,
e
2
?
m
1
]
=
[
e
?
1
,
e
0
]
l
1
=
∑
f
1
=
e
?
1
+
e
0
o
1
=
f
1
l
1
=
[
e
?
1
,
e
0
]
e
?
1
+
e
0
=
[
e
?
1
e
?
1
+
e
0
,
e
0
e
?
1
+
e
0
]
m_1 = max([1,2]) = 2\\ f_1 = [e^{1-m_1},e^{2-m_1}] = [e^{-1},e^0]\\ l_1 = \sum f_1 = e^{-1} + e^0\\ o_1 = \frac{f_1}{l_1} = \frac{[e^{-1},e^0]}{e^{-1} + e^0} = \left[ \frac{e^{-1}}{e^{-1} + e^0}, \frac{e^0}{e^{-1} + e^0}\right]
m1?=max([1,2])=2f1?=[e1?m1?,e2?m1?]=[e?1,e0]l1?=∑f1?=e?1+e0o1?=l1?f1??=e?1+e0[e?1,e0]?=[e?1+e0e?1?,e?1+e0e0?]
block2
m
2
=
m
a
x
(
[
3
,
4
]
)
=
4
f
2
=
[
e
3
?
m
2
,
e
4
?
m
2
]
=
[
e
?
1
,
e
0
]
l
2
=
∑
f
2
=
e
?
1
+
e
0
o
2
=
f
2
l
2
=
[
e
?
1
,
e
0
]
e
?
1
+
e
0
=
[
e
?
1
e
?
1
+
e
0
,
e
0
e
?
1
+
e
0
]
m_2 = max([3,4]) = 4\\ f_2 = [e^{3-m_2},e^{4-m_2}] = [e^{-1},e^0]\\ l_2 = \sum f_2 = e^{-1} + e^0\\ o_2 = \frac{f_2}{l_2} = \frac{[e^{-1},e^0]}{e^{-1} + e^0} = \left[ \frac{e^{-1}}{e^{-1} + e^0}, \frac{e^0}{e^{-1} + e^0}\right]
m2?=max([3,4])=4f2?=[e3?m2?,e4?m2?]=[e?1,e0]l2?=∑f2?=e?1+e0o2?=l2?f2??=e?1+e0[e?1,e0]?=[e?1+e0e?1?,e?1+e0e0?]
合并得到完整的softmax结果:
m
=
m
a
x
(
m
a
x
1
,
m
a
x
2
)
=
4
f
=
[
e
m
1
?
m
f
1
,
e
m
2
?
m
?
f
2
]
=
[
e
?
3
,
e
?
2
,
e
?
1
,
e
0
]
l
=
e
m
1
?
m
l
1
,
e
m
2
?
m
?
l
2
=
e
?
3
+
e
?
2
+
e
?
1
+
e
0
o
=
f
l
=
[
e
?
1
,
e
0
]
e
?
1
+
e
0
=
[
e
?
1
e
?
1
+
e
0
,
e
0
e
?
1
+
e
0
]
m = max(max_1,max_2) = 4\\ f = \left[e^{m_1-m}f_1,e^{m_2-m}*f_2\right] = \left[e^{-3},e^{-2},e^{-1},e^0\right]\\ l = e^{m_1-m}l_1,e^{m_2-m}*l_2 = e^{-3}+e^{-2}+e^{-1}+e^0\\ o = \frac{f}{l} = \frac{[e^{-1},e^0]}{e^{-1} + e^0} = \left[ \frac{e^{-1}}{e^{-1} + e^0}, \frac{e^0}{e^{-1} + e^0}\right]
m=max(max1?,max2?)=4f=[em1??mf1?,em2??m?f2?]=[e?3,e?2,e?1,e0]l=em1??ml1?,em2??m?l2?=e?3+e?2+e?1+e0o=lf?=e?1+e0[e?1,e0]?=[e?1+e0e?1?,e?1+e0e0?]
备注:这是在在忽略mask和dropout的情况下,简化分析Flash Attention算法的前向计算过程
在Flash Attention的前向计算算法中可以看出,FlashAttention实现在不访问整个输入的情况下计算softmax,实现IO的较大缩减,标准Attention算法由于要计算softmax,而softmax都是按行来计算的,即在和 V \mathbf{V} V做矩阵乘之前,需要让 Q \mathbf{Q} Q、 K \mathbf{K} K 的各个分块完成整一行分块的计算得到Softmax的结果后,再和矩阵 V \mathbf{V} V分块做矩阵乘。而在Flash Attention中,将输入分割成块,并在输入块上进行多次传递,从而以增量方式执行softmax缩减。
将前文的前向计算抽象成如下模型,便于后文的引用
S
=
τ
Q
K
?
∈
R
N
×
N
S
masked?
=
M
A
S
K
(
S
)
∈
R
N
×
N
P
=
softmax
?
(
S
masked?
)
∈
R
N
×
N
P
dropped?
=
dropout
?
(
P
,
p
drop?
)
∈
R
N
×
N
O
=
P
dropped?
V
∈
R
N
×
d
\begin{gathered} S=\tau Q K^{\top} \in \mathbb{R}^{N \times N} \\ S^{\text {masked }}=M A S K(S) \in \mathbb{R}^{N \times N} \\ P=\operatorname{softmax}\left(S^{\text {masked }}\right) \in \mathbb{R}^{N \times N} \\ P^{\text {dropped }}=\operatorname{dropout}\left(P, p_{\text {drop }}\right) \in \mathbb{R}^{N \times N} \\ O=P^{\text {dropped }} V \in \mathbb{R}^{N \times d} \end{gathered}
S=τQK?∈RN×NSmasked?=MASK(S)∈RN×NP=softmax(Smasked?)∈RN×NPdropped?=dropout(P,pdrop??)∈RN×NO=Pdropped?V∈RN×d?
在标准注意力实现中,后向传递计算
Q
\mathbf{Q}
Q,
K
\mathbf{K}
K,
V
\mathbf{V}
V的梯度时,需要用到中间矩阵
S
∈
R
N
×
N
\mathbf{S}\in\mathbb{R}^{N\times N}
S∈RN×N,
P
∈
R
N
×
N
\mathbf{P}\in\mathbb{R}^{N\times N}
P∈RN×N。Flash Attention没有保存这两个矩阵,而是保存了两个统计量
m
(
x
)
m(x)
m(x),
l
(
x
)
l(x)
l(x),在后向传递时进行重计算。
在反向传递过程中, 需要计算损失函数 ? \phi ? 对 O \mathbf{O} O, Q \mathbf{Q} Q, K \mathbf{K} K, V \mathbf{V} V 的梯度。在给定 d O ∈ R N × d d \mathbf{O} \in \mathbb{R}^{N \times d} dO∈RN×d 的情况下, 计算梯度 d Q ∈ R N × d d\mathbf{Q}\in \mathbb{R}^{N \times d} dQ∈RN×d, d K ∈ R N × d d\mathbf{K}\in \mathbb{R}^{N \times d} dK∈RN×d, d V ∈ R N × d d\mathbf{V} \in \mathbb{R}^{N \times d} dV∈RN×d 。其中, d O d\mathbf{O} dO, d Q d\mathbf{Q} dQ, d K d\mathbf{K} dK, d V d\mathbf{V} dV 分别表示为 ? ? ? O \frac{\partial \phi}{\partial \mathbf{O}} ?O???, ? ? ? Q \frac{\partial \phi}{\partial \mathbf{Q}} ?Q???, ? ? ? K \frac{\partial \phi}{\partial \mathbf{K}} ?K???, ? ? ? V \frac{\partial \phi}{\partial \mathbf{V}} ?V???
梯度
d
V
d\mathbf{V}
dV 是容易计算的。由
O
=
P
V
\mathbf{O}=\mathbf{P} \mathbf{V}
O=PV,基于矩阵求导算法和链式法则, 得到矩阵形式的梯度
d
V
=
P
?
d
O
d\mathbf{V}=\mathbf{P}^{\top} d \mathbf{O}
dV=P?dO 。在元素形式上,有:
d
v
j
=
∑
i
P
i
j
d
o
i
=
∑
i
e
(
q
i
?
k
j
)
L
i
d
o
i
d \mathbf{v}_j=\sum_i \mathbf{P}_{i j} d \mathbf{o}_i=\sum_i \frac{e^{(\mathbf{q}_i^{\top} k_j)}}{L_i} d \mathbf{o}_i
dvj?=i∑?Pij?doi?=i∑?Li?e(qi??kj?)?doi?
之前已经计算好
L
i
L_i
Li?,就可以通过反复累加的方式计算得到
d
v
j
d \mathbf{v}_j
dvj? 。
梯度
d
Q
d\mathbf{Q}
dQ,
K
\mathbf{K}
K 的计算是略微复杂的。首先要计算
d
P
d\mathbf{P}
dP,
d
S
d\mathbf{S}
dS 。由
O
=
P
V
\mathbf{O}=\mathbf{P} \mathbf{V}
O=PV,得到矩阵形式的梯度
d
P
=
d
O
V
?
d\mathbf{P}=d\mathbf{O} \mathbf{V}^{\top}
dP=dOV? 。在元素形式上,有:
d
P
i
j
=
d
o
i
?
v
j
d \mathbf{P}_{i j}=d \mathbf{o}_i^{\top} \mathbf{v}_j
dPij?=doi??vj?
有
P
i
:
=
softmax
?
(
S
i
:
)
\mathbf{P}_{i:}=\operatorname{softmax}\left(\mathbf{S}_{i:}\right)
Pi:?=softmax(Si:?) (表示
i
i
i的一整行)。基于
y
=
softmax
?
(
x
)
y=\operatorname{softmax}(x)
y=softmax(x) 的雅各比矩阵为
diag
?
(
y
)
?
y
y
?
\operatorname{diag}(y)-y y^{\top}
diag(y)?yy? 。可以得到:
d
S
i
:
=
(
diag
?
(
P
i
:
)
?
P
i
:
P
i
:
?
)
d
P
i
:
=
P
i
:
°
d
P
i
:
?
(
P
i
:
?
d
P
i
:
)
P
i
:
d \mathbf{S}_{i:}=\left(\operatorname{diag}\left(\mathbf{P}_{i:}\right)-\mathbf{P}_{i:} P_{i:}^{\top}\right) d \mathbf{P}_{i:}=\mathbf{P}_{i:} \circ d \mathbf{P}_{i:}-\left(P_{i:}^{\top} d \mathbf{P}_{i:}\right) \mathbf{P}_{i:}
dSi:?=(diag(Pi:?)?Pi:?Pi:??)dPi:?=Pi:?°dPi:??(Pi:??dPi:?)Pi:?
其中 ° \circ ° 表示逐点相乘。
可以定义:
D
i
=
P
i
:
?
d
P
i
:
=
∑
j
e
q
i
?
k
j
L
i
d
o
i
?
v
j
=
d
o
i
?
∑
j
e
q
i
?
k
j
L
i
v
j
=
d
o
i
?
o
i
D_i=P_{i:}^{\top} d P_{i:}=\sum_j \frac{e^{q_i^{\top} k_j}}{L_i} d o_i^{\top} v_j=d o_i^{\top} \sum_j \frac{e^{q_i^{\top} k_j}}{L_i} v_j=d o_i^{\top} o_i
Di?=Pi:??dPi:?=j∑?Li?eqi??kj??doi??vj?=doi??j∑?Li?eqi??kj??vj?=doi??oi?
将该定义代回到上式中, 可以得到:
d
S
i
:
=
P
i
:
°
d
P
i
:
?
D
i
P
i
:
d S_{i:}=P_{i:} \circ d P_{i:}-D_i P_{i:}
dSi:?=Pi:?°dPi:??Di?Pi:?
因此,梯度
d
S
d\mathbf{S}
dS 可以表示为以下形式:
d
S
i
j
=
P
i
j
d
P
i
j
?
D
i
P
i
j
=
P
i
j
(
d
P
i
j
?
D
i
)
d \mathbf{S}_{i j}=\mathbf{P}_{i j} d \mathbf{P}_{i j}-\mathbf{D}_i \mathbf{P}_{i j}=\mathbf{P}_{i j}\left(d \mathbf{P}_{i j}-\mathbf{D}_i\right)
dSij?=Pij?dPij??Di?Pij?=Pij?(dPij??Di?)
在计算得到
d
P
i
j
d \mathbf{P}_{i j}
dPij?,
d
S
i
j
d \mathbf{S}_{i j}
dSij? 后, 可以计算
d
Q
d\mathbf{Q}
dQ,
d
K
d\mathbf{K}
dK 。有前向计算公式
S
i
j
=
q
i
?
k
j
\mathbf{S}_{i j}=\mathbf{q}_i^{\top} \mathbf{k}_j
Sij?=qi??kj?, 可以得到:
d
q
i
=
∑
j
d
S
i
j
k
j
=
∑
j
P
i
j
(
d
P
i
j
?
D
i
)
k
j
=
∑
j
e
(
q
i
?
k
j
)
L
i
(
d
o
i
?
v
j
?
D
i
)
k
j
d
k
j
=
∑
i
d
S
i
j
q
i
=
∑
i
P
i
j
(
d
P
i
j
?
D
i
)
q
i
=
∑
i
e
(
q
i
?
k
j
)
L
i
(
d
o
i
?
v
j
?
D
i
)
q
i
\begin{gathered} d \mathbf{q}_i=\sum_j d \mathbf{S}_{i j} \mathbf{k}_j=\sum_j \mathbf{P}_{i j}\left(d \mathbf{P}_{i j}-\mathbf{D}_i\right) \mathbf{k}_j=\sum_j \frac{e^{(\mathbf{q}_i^{\top} \mathbf{k}_j)}}{\mathbf{L}_i}\left(d \mathbf{o}_i^{\top} \mathbf{v}_j-\mathbf{D}_i\right) \mathbf{k}_j \\ d \mathbf{k}_j=\sum_i d \mathbf{S}_{i j} \mathbf{q}_i=\sum_i \mathbf{P}_{i j}\left(d \mathbf{P}_{i j}-\mathbf{D}_i\right) \mathbf{q}_i=\sum_i \frac{e^{(\mathbf{q}_i^{\top} \mathbf{k}_j)}}{\mathbf{L}_i}\left(d \mathbf{o}_i^{\top} \mathbf{v}_j-\mathbf{D}_i\right) \mathbf{q}_i \end{gathered}
dqi?=j∑?dSij?kj?=j∑?Pij?(dPij??Di?)kj?=j∑?Li?e(qi??kj?)?(doi??vj??Di?)kj?dkj?=i∑?dSij?qi?=i∑?Pij?(dPij??Di?)qi?=i∑?Li?e(qi??kj?)?(doi??vj??Di?)qi??
与前向计算类似,在计算得到 L i \mathbf{L}_i Li? 后, 就可以通过反复累加的方式计算得到 d q i d \mathbf{q}_i dqi?, d k j d \mathbf{k}_j dkj?, d v j d \mathbf{v}_j dvj? 。避免了实例化矩阵 P \mathbf{P} P, S \mathbf{S} S,节省了显存,后向传递的显存复杂度为 O ( N ) O(N) O(N) 。
对比标准Attention算法的实现过程中,其需要将计算中的 S \mathbf{S} S、 P \mathbf{P} P写入到HBM中,而这些中间矩阵的大小与输入的序列长度有关且为二次型;
Flash Attention算法中,其并没有将 S \mathbf{S} S、 P \mathbf{P} P写入HBM中去,而是通过分块写入到HBM中去,存储前向传递的 softmax 归一化因子,在后向传播中快速重新计算片上注意力,这比从HBM中读取中间注意力矩阵的标准方法更快。即使由于重新计算导致 FLOPS 增加,但其运行速度更快并且使用更少的内存(序列长度线性),主要是因为大大减少了 HBM 访问量。
Flash Attention实现了不使用中间注意力矩阵,通过存储归一化因子来减少HBM内存的消耗。
[2205.14135] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness