swin transformer的SW-MSA中的masked理解

发布时间:2024年01月03日

swin transformer的SW-MHSA中的masked理解

下文如果有看不懂,最好直接看这个霹雳吧啦Wzd的B站讲解,可以结合着霹雳吧啦Wzd的博客,讲解和博客都很详细,我只是补充我自己难以理解的部分。
提供一下官方源码的链接

B站视频中的是create_mask代码(跟源码一样,只不过作为一个函数使用),官方源码中的是这一部分(直接在SwinTransformerBlock)

下面举例的window_size=[Wh, Ww]=[3, 3]

if self.shift_size > 0:
	# calculate attention mask for SW-MSA
	H, W = self.input_resolution
    img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
	h_slices = (slice(0, -self.window_size),
	             slice(-self.window_size, -self.shift_size),
	             slice(-self.shift_size, None))
	w_slices = (slice(0, -self.window_size),
	             slice(-self.window_size, -self.shift_size),
	             slice(-self.shift_size, None))
	cnt = 0
	for h in h_slices:
	    for w in w_slices:
	        img_mask[:, h, w, :] = cnt
	        cnt += 1
	
	mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
	mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # [nW, Mh*Mw]  
	attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
	attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
    attn_mask = None

self.register_buffer("attn_mask", attn_mask)

在SwinTransformerBlock的forward中调用attn模块之前,还需要先将一副特征图进行滚动(滚动代码如下),滚动的方式是将下图1上方shift_size=Mw//2=3//2=1行和左方shift_size=Mh//2=3//2=1列移动到最下方和最右方,滚动的结果如图1右所示。

shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))

在这里插入图片描述 图 1 图1 1然后使用下方的代码将相同的区域设置为同一个符号,注意由于window_size=[3, 3],只需要关注同一个window_size的尺寸的信息会不会乱窜。结果如图2所示,下方特征图的大小为 9 × 9 9\times9 9×9,在划分的时候只需要关注最下方和最右方的 3 × 3 3\times3 3×3的window窗口的信息不要乱窜。因此,可以看到来自上方和左方的平移的shift_size=1个像素和下方和右方的(window_size-shift_size)=2个像素被设置成不同的区域。从下图2左可以看出,以右下角的 3 × 3 3\times3 3×3的window窗口举例,1和(8, 9)和(16, 17)都不是相连的像素,因此,需要将其设置为不同的区域。

H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
h_slices = (slice(0, -self.window_size),
            slice(-self.window_size, -self.shift_size),
            slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
            slice(-self.window_size, -self.shift_size),
            slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
  	for w in w_slices:
      	img_mask[:, h, w, :] = cnt
      	cnt += 1

在这里插入图片描述 图 2 图2 2window_partition(img_mask, self.window_size)就是直接将一幅图划分成nW个 M h × M w {Mh}\times{Mw} Mh×Mw大小的窗口,如图3所示。
在这里插入图片描述 图 3 图3 3在经过window_partition操作之后(这个之前的B站讲解得很详细,在这之后稍微有点难以理解,记录一下),图片内容如下图4,下图4右是mask_windows.view(-1, self.window_size*self.window_size)的操作之后得到的mask_windows。
图1 图 4 图4 4以上图4最后一个行向量[4, 4, 5, 4, 4, 5, 7, 7, 8]为例,在经过mask_windows.unsqueeze(1)是下图5左,mask_windows.unsqueeze(2)下图5右(B站讲解的window_size=[h, w]=[3, 3],因此,下面都是经过广播机制,都会是复制Mh*Mw=9次)
在这里插入图片描述 图 5 图5 5最终得到的结果如下图6右所示
在这里插入图片描述 图 6 图6 6然后使用attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0),将上图6右中为0的地方值赋值为0,不为0的地方赋值为-100。

然后在forword过程中直接与atten相加

if mask is not None:
    nW = mask.shape[0]
    attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
    attn = attn.view(-1, self.num_heads, N, N)
    attn = self.softmax(attn)
else:
    attn = self.softmax(attn)

重点解释雳吧啦Wzd的博客提到的防止shifted窗口合并后信息乱窜的问题

在这里插入图片描述 图 7 图7 7怎样实现的:理论细节可以参考雳吧啦Wzd的博客

怎样理解attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)中相加的就是attn和图6右就是一一对应关系呢?

首先 Q K T QK^T QKT Q Q Q矩阵相乘 K K K的转置, Q Q Q的shape为[Wh*Ww, C] K T K^T KT的shape为[C, Wh*Ww],矩阵相乘之后的shape为[Wh*Ww, Wh*Ww]。现假设Wh=Ww=3,也就是下图8所示(对应的图4中的行向量[4, 4, 5, 4, 4, 5, 7, 7, 8])
在这里插入图片描述 图 8 图8 8假设C=1, Q Q Q([9, 1])的和 K T K^T KT(1, 9)的像素之间的关系就是图8右侧,然后矩阵相乘,得到下图9左侧,然后相减,得到图9右,为0的部分代表是同一区域的,加上0,值不变,不为0的空白的代表不是同一区域的,加上-100,经过softmax就置0了。
在这里插入图片描述 图 9 图9 9可以看到这个图9与图6的位置相同。

注意,应该还有个疑问,就是在矩阵相乘以前,还有一个qkv=self.qkv(input)的操作,必须解释这个操作之后, Q Q Q K T K^T KT的关系如图8所示才行。

这个要涉及到nn.Linear的操作,这个知乎文章介绍的很详细,假设input的shape为[Wh*Ww, C]=[9, 1],其中的区域关系如图8所示,可以从知乎文章中了解到nn.Linear就相当于矩阵相乘。
在这里插入图片描述 图 10 图10 10从上图10可以知道,@符号前就是input(像素的位置信息),后面是nn.Linear的权重矩阵值,在经过矩阵相乘后,可以看到, q, k, v = qkv[0], qkv[1], qkv[2]的像素的位置信息没有改变,所以才可以有图8中的像素位置对应关系。

文章来源:https://blog.csdn.net/shilichangtin/article/details/135277745
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。