框架解析:
三要素:查询(query),键(key),值(value)
通过query序列检索key,获取合适的value信息
假设有一个查询
q
∈
R
q
q \in \mathcal{R}^{q}
q∈Rq和
m
m
m个键值对
(
k
1
,
v
1
)
,
?
?
,
(
k
m
,
v
m
)
(k_{1},v_{1}),\cdots,(k_{m},v_{m})
(k1?,v1?),?,(km?,vm?),
k
∈
R
k
,
v
∈
R
v
k \in \mathcal{R}^{k},v\in \mathcal{R}^{ v}
k∈Rk,v∈Rv
注意力函数可表示为加权和的形式:
f
(
q
,
(
k
1
,
v
1
)
,
?
?
,
(
k
m
,
v
m
)
)
=
∑
i
=
1
m
α
(
q
,
k
i
)
v
i
∈
R
v
f(q,(k_{1},v_{1}),\cdots,(k_{m},v_{m}))=\sum_{i=1}^{m}\alpha(q,k_{i})v_{i}\in \mathcal{R}^{v}
f(q,(k1?,v1?),?,(km?,vm?))=i=1∑m?α(q,ki?)vi?∈Rv
其中
α
(
q
,
k
i
)
\alpha(q,k_{i})
α(q,ki?)是由注意力评分函数
a
a
a通过
s
o
f
t
m
a
x
softmax
softmax函数归一化得到
$
α
(
q
,
k
i
)
=
s
o
f
t
m
a
x
(
a
(
q
.
k
i
)
)
=
e
x
p
(
a
(
q
,
k
i
)
)
∑
j
m
e
x
p
(
a
(
q
,
k
j
)
)
\alpha(q,k_{i})=softmax(a(q.k_{i}))=\frac{exp(a(q,k_{i}))}{\sum_{j}^{m}exp(a(q,k_{j}))}
α(q,ki?)=softmax(a(q.ki?))=∑jm?exp(a(q,kj?))exp(a(q,ki?))?
a
a
a有以下几种形式:
除以 d k d_{k} dk?的原因
- 防止输入softmax的值过大,导致偏导数趋近于0,避免梯度消失
- 使得 q ? k q\cdot k q?k的值满足期望为0,方差为1的分布
当实际应用一个批量数据进行运算时,基于
n
n
n个查询和
m
m
m个键-值对计算注意力,其中查询,键长度为
d
d
d,值长度为
v
v
v,则
Q
∈
R
n
×
d
,
K
∈
R
m
×
d
,
V
∈
R
m
×
v
Q\in \mathcal{R}^{n\times d},K\in \mathcal{R}^{m\times d},V\in \mathcal{R}^{m\times v}
Q∈Rn×d,K∈Rm×d,V∈Rm×v的缩放点击注意力为:
s
o
f
t
m
a
x
(
Q
K
T
d
)
V
?
∈
?
R
n
×
v
softmax(\frac{QK^{T}}{\sqrt{d}})V\ \in\ \mathcal{R}^{n\times v}
softmax(d?QKT?)V?∈?Rn×v
区别:Dot Product Attention 和 Additive Attention两者在复杂度上是相似的。但是Additive Attention增加了三个可学习的矩阵,所以相比另外两个效果会更好,同时也增加了更多的模型参数,计算效率会较低。
查询、键、值均由同一个输入经过不同的“线性投影”变化得到,并采用缩放点积注意力得到最终输出
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d ) V ? ∈ ? R n × v Attention(Q,K,V)=softmax(\frac{QK^{T}}{\sqrt{d}})V\ \in\ \mathcal{R}^{n\times v} Attention(Q,K,V)=softmax(d?QKT?)V?∈?Rn×v
作用:防止Transformer在训练时泄露后面的它不应该看到的信息,确保仅看到当前及以前得信息
更多可见:MultiHead-Attention和Masked-Attention的机制和原理
原理:在给定相同的查询、键、值时,使用**h个独立的"线性投影"**来变换q,k,v,然后并行得使用h个注意力机制,学习到不同的行为,然后将h个自注意力的输出拼接在一起,通过另一个可学习的线性投影进行变换,产生最终的输出,来捕捉序列内各种范围内的依赖关系(例如短距离依赖和长距离依赖)
其中,每个自注意力被称为一个头
import torch
from torch import nn
##### 使多个头可以进行并行计算,p_q = p_k = p_v = p_o/h,p_o=num_hiddens,
# 直接用nn.Linear(query_size,num_hiddens),num_hiddens=p_v*h,即多个线性变换结合在一起
# 假设输出维度为num_hiddens,同时h*p_v = num_hiddens
def transpose_qkv(X, num_heads): # 将组合起来的输入,变换为num_heads个输入
# 输入X的shape为(batch_size,查询或者“键值对”的个数,num_hiddens)
# 输出X的shape为(batch_size,查询或者“键值对”的个数,num_heads,num_hiddens/num_heads)
X = X.reshape(X.shape[0],X.shape[1], num_heads,-1)
# 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,
# num_hiddens/num_heads)
X = X.permute(0,2,1,3)
# 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,
# num_hiddens/num_heads)
return X.reshape(-1,X.shape[2],X.shape[3])
def transpose_output(X,num_heads): # 将组合起来的输出,变换为num_heads个输出
"""逆转transpose_qkv函数的操作"""
X = X.reshape(-1,num_heads,X.shape[1].X.shape[2])
X = X.permute(0,2,1,3)
return X.reshape(X.shape[0], X.shape[1],-1)
class MultiHeadAttention(nn.Module):
def __init__(self,key_size,query_size,value_size,num_hiddens,num_heads,dropout,bias=False,**kwargs)
super(MultiHeadAttention,self).__init__(**kwargs)
self.num_heads = num_heads
self.attendtion = DotProductAttendtion(droupout)
self.W_q = nn.Linear(query_size,num_hiddens, bias)
self.W_k = nn.Linear(key_size,num_hiddens, bias)
self.W_v = nn.Linear(value_size,num_hiddens, bias)
self.W_o = nn.Linear(num_hiddens,num_hiddens, bias)
def forward(self,queries, keys, values, valid_lens):
# queries,keys,values的形状:
# (batch_size,查询或者“键-值”对的个数,num_hiddens)
# valid_lens 的形状:
# (batch_size,)或(batch_size,查询的个数)
# 经过变换后,输出的queries,keys,values 的形状:
# (batch_size*num_heads,查询或者“键-值”对的个数,
# num_hiddens/num_heads)
queries = transpose_qkv(self.W_q(queries),self.num_heads)
keys = transpose_qkv(self.W_k(keys),self.num_heads)
values = transpose_qkv(self.W_v(values),self.num_heads)
if valid_lens is not None:
# 按行重复num_heads遍
valid_lens = torch.repeat_interleave(
valid_lens, repeats=self.num_heads, dim=0)
output = self.attention(queries,keys, values,valid_lens)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)
作用:self-attention能够看到全局信息,忽略了顺序关系,为了使用序列的顺序信息,通过在输入表示中添加位置编码(positional encoding)来注入绝对的或相对的位置信息
假设输入 X ∈ R n × d X\in \mathcal{R}^{n\times d} X∈Rn×d表示一个序列中 n n n个词元的 d d d维嵌入表示。位置编码使用与输入 X X X相同形状的位置嵌入矩阵 P ∈ R n × d P\in \mathcal{R}^{n\times d} P∈Rn×d表示
固定位置编码:
P
i
,
2
j
=
s
i
n
(
i
1000
0
2
j
/
d
)
P_{i,2j}=sin(\frac{i}{10000^{2j/d}})
Pi,2j?=sin(100002j/di?)
P
i
,
2
j
+
1
=
c
o
s
(
i
1000
0
2
j
/
d
)
P_{i,2j+1}=cos(\frac{i}{10000^{2j/d}})
Pi,2j+1?=cos(100002j/di?)
即对于每个词元,奇数维度采用
c
o
s
cos
cos函数,偶数维度采用
s
i
n
sin
sin函数
包含以下两种信息
缺点:当词嵌入维度较大时,较大维度的位置编码值完全一致
因为神经网络的Block大部分都是矩阵运算,一个向量经过矩阵运算后值会越来越大,为了网络的稳定性,我们需要及时把值拉回正态分布。归一化的方式可以分为:
原因:神经网络的学习过程中,对于神经网络中间的每一层,其前面层的参数在学习中会不断改变,导致其输出也在不断改变,不利于这一层及后面层的学习,学习收敛速度会变慢,就会出现Internal Covariate Shift(内部协变量偏移). 随着网络的层数不断增大,这种误差就会不断积累,最终导致效果欠佳。
更多可见Batch normalization和Layer normalization
- 在模型能够收敛的情况下,网络越深,模型的准确率越低,同时,模型的准确率先达到饱和,此后迅速下降。称之为网络退化(Degradation),resnet能够有效训练出更深的网络模型(可以超过1000层),使得深网络的表现不差于浅网络,避免网络退化。
- 避免梯度消失/爆炸(主要通过归一化初始化和中间规归一化层来解决)
结构如下:
使数据可以跨层流动,残差模块的输出为:
H
(
x
)
=
F
(
x
)
+
x
H(x)=F(x)+x
H(x)=F(x)+x
其中,
F
(
x
)
F(x)
F(x)为残差函数,在网络深层的时候,在优化目标的约束下,模型通过学习使得逼近0(residule learning),让深层函数在学到东西的情况下,又不会发生网络退化的问题。