近年来,基于Transformer和CNN的视觉基础模型取得巨大成功。有许多研究进一步地将Transformer结构与CNN架构结合,设计出了更为高效的hybrid CNN-Transformer Network,但它们的精度仍然不尽如意。本文介绍了一种新的基础模型SMT(Scale-Aware Modulation Transformer),它以更低的参数量(params)和计算量(flops)取得了大幅性能的提升。
SMT的总体框架如上图所示。整个网络包括四个阶段,每个阶段的下采样率为{4, 8, 16, 32}。我们并非和FocalNet一样构建一个无注意力机制的网络,而是首先在前两个阶段采用文章提出的尺度感知调制(SAM),然后在倒数第二个阶段中依次堆叠一个SAM Block和一个多头自注意力(MSA) Block,简称MIX Block,以建模从捕捉局部到全局依赖关系的转变。对于最后一个阶段,我们仅使用MSA块来有效地捕捉长距离依赖关系。
class Attention(nn.Module):
def __init__(self, dim, ca_num_heads=4, expand_ratio=2):
super().__init__()
self.dim = dim # 这个是输入的通道数
self.ca_num_heads = ca_num_heads # MHMC中分类头(head)的数量,也就是卷积核的数量
self.act = nn.GELU() # 后面SAA中卷积后使用的激活函数
self.split_groups=self.dim//ca_num_heads # 每一个分类头中有多少个通道,也就是SAA中形成的组的数量
self.v = nn.Linear(dim, dim, bias=qkv_bias) # SAM图中左边的Linear
self.s = nn.Linear(dim, dim, bias=qkv_bias) # SAM图中右边的Linear
# 为每一个分类头定义一个不同大小的卷积核,(3,5,7,9,……),分别命名为local_conv_1,local_conv_1,……
# 注意这里使用的是分组卷积,保证了每个通道之间的独立性
for i in range(self.ca_num_heads):
local_conv = nn.Conv2d(dim//self.ca_num_heads, dim//self.ca_num_heads, kernel_size=(3+i*2), padding=(1+i), stride=1, groups=dim//self.ca_num_heads)
setattr(self, f"local_conv_{i + 1}", local_conv)
self.proj0 = nn.Conv2d(dim, dim*expand_ratio, kernel_size=1, padding=0, stride=1, groups=self.split_groups) # 这个SAA中进行组中特征融合的第一个1x1卷积层
self.bn = nn.BatchNorm2d(dim*expand_ratio) # 正则化层
self.proj1 = nn.Conv2d(dim*expand_ratio, dim, kernel_size=1, padding=0, stride=1) # SAA中进行组间特征融合的第二个1x1卷积层
def forward(self, x, H, W):
B, N, C = x.shape
v = self.v(x) # 图a中左边的Linear
s = self.s(x).reshape(B, H, W, self.ca_num_heads, C//self.ca_num_heads).permute(3, 0, 4, 1, 2) # 图a中右边的Linear,并将形状变为[ca_num_heads,B, C//ca_num_heads, H, W]
# MHMC中一个head,一个head的进行卷积
for i in range(self.ca_num_heads):
local_conv = getattr(self, f"local_conv_{i + 1}")
s_i = s[i] # 取出一个head中的特征
s_i = local_conv(s_i) # 进行分组卷积,结果是[B, C//numheads, H, W],这里的C//numheads就是下面的self.split_groups
s_i = s_i.reshape(B, self.split_groups, -1, H, W) # 将卷积结果中一个通道编为一组,共self.split_groups组,那个-1实际上就是1,表示通道数为1
if i == 0:
s_out = s_i
else:
s_out = torch.cat([s_out,s_i],2) # 然后按通道维度将结果合并,实现了每个head中抽出一个通道组成一组
s_out = s_out.reshape(B, C, H, W) # 变换形状准备进行最后的特征聚合
# SAA中的特征融合,源代码中就一行,为了便于理解我将他拆开了
s_out = self.proj0(s_out) # 组内聚合
s_out = self.act(self.bn(s_out)) # 正则化,激活函数
s_out = self.proj1(s_out).reshape(B, C, N).permute(0, 2, 1) # 组间聚合,形成调制器modulator
x = s_out * v # 对v进行调制
return x
如下图所示,我们可视化出SAA前和SAA后的特征图,可以观察到SAA模块加强了语义相关的低频信号,并准确地聚焦于目标物体最重要的部分。与聚合之前的卷积映射相比,SAA模块展示了更好的能力来捕捉和表示视觉识别任务的关键特征。