下文如果有看不懂,最好直接看这个霹雳吧啦Wzd的B站讲解,可以结合着霹雳吧啦Wzd的博客,讲解和博客都很详细,我只是补充我自己难以理解的部分。
提供一下官方源码的链接
下面举例的window_size=[Wh, Ww]=[3, 3]
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # [nW, Mh*Mw]
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
在SwinTransformerBlock的forward中调用attn模块之前,还需要先将一副特征图进行滚动(滚动代码如下),滚动的方式是将下图1上方shift_size=Mw//2=3//2=1
行和左方shift_size=Mh//2=3//2=1
列移动到最下方和最右方,滚动的结果如图1右所示。
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
图
1
图1
图1然后使用下方的代码将相同的区域设置为同一个符号,注意由于window_size=[3, 3],只需要关注同一个window_size的尺寸的信息会不会乱窜。结果如图2所示,下方特征图的大小为
9
×
9
9\times9
9×9,在划分的时候只需要关注最下方和最右方的
3
×
3
3\times3
3×3的window窗口的信息不要乱窜。因此,可以看到来自上方和左方的平移的shift_size=1
个像素和下方和右方的(window_size-shift_size)=2
个像素被设置成不同的区域。从下图2左可以看出,以右下角的
3
×
3
3\times3
3×3的window窗口举例,1和(8, 9)和(16, 17)都不是相连的像素,因此,需要将其设置为不同的区域。
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
图
2
图2
图2window_partition(img_mask, self.window_size)
就是直接将一幅图划分成nW个
M
h
×
M
w
{Mh}\times{Mw}
Mh×Mw大小的窗口,如图3所示。
图
3
图3
图3在经过window_partition操作之后(这个之前的B站讲解得很详细,在这之后稍微有点难以理解,记录一下),图片内容如下图4,下图4右是mask_windows.view(-1, self.window_size*self.window_size)
的操作之后得到的mask_windows。
图
4
图4
图4以上图4最后一个行向量[4, 4, 5, 4, 4, 5, 7, 7, 8]
为例,在经过mask_windows.unsqueeze(1)
是下图5左,mask_windows.unsqueeze(2)
下图5右(B站讲解的window_size=[h, w]=[3, 3],因此,下面都是经过广播机制,都会是复制Mh*Mw=9次)
图
5
图5
图5最终得到的结果如下图6右所示
图
6
图6
图6然后使用attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)
,将上图6右中为0的地方值赋值为0,不为0的地方赋值为-100。
然后在forword过程中直接与atten相加
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
图 7 图7 图7怎样实现的:理论细节可以参考雳吧啦Wzd的博客
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
中相加的就是attn和图6右就是一一对应关系呢?首先
Q
K
T
QK^T
QKT是
Q
Q
Q矩阵相乘
K
K
K的转置,
Q
Q
Q的shape为[Wh*Ww, C]
,
K
T
K^T
KT的shape为[C, Wh*Ww]
,矩阵相乘之后的shape为[Wh*Ww, Wh*Ww]
。现假设Wh=Ww=3,也就是下图8所示(对应的图4中的行向量[4, 4, 5, 4, 4, 5, 7, 7, 8]
)
图
8
图8
图8假设C=1,
Q
Q
Q([9, 1])的和
K
T
K^T
KT(1, 9)的像素之间的关系就是图8右侧,然后矩阵相乘,得到下图9左侧,然后相减,得到图9右,为0的部分代表是同一区域的,加上0,值不变,不为0的空白的代表不是同一区域的,加上-100,经过softmax就置0了。
图
9
图9
图9可以看到这个图9与图6的位置相同。
这个要涉及到nn.Linear的操作,这个知乎文章介绍的很详细,假设input的shape为[Wh*Ww, C]=[9, 1],其中的区域关系如图8所示,可以从知乎文章中了解到nn.Linear就相当于矩阵相乘。
图
10
图10
图10从上图10可以知道,@符号前就是input(像素的位置信息),后面是nn.Linear的权重矩阵值,在经过矩阵相乘后,可以看到, q, k, v = qkv[0], qkv[1], qkv[2]
的像素的位置信息没有改变,所以才可以有图8中的像素位置对应关系。