本文给大家带来的改进机制是利用Swin Transformer替换YOLOv5中的骨干网络其是一个开创性的视觉变换器模型,它通过使用位移窗口来构建分层的特征图,有效地适应了计算机视觉任务。与传统的变换器模型不同,Swin Transformer的自注意力计算仅限于局部窗口内,使得计算复杂度与图像大小成线性关系,而非二次方。这种设计不仅提高了模型的效率,还保持了强大的特征提取能力。Swin Transformer的创新在于其能够在不同层次上捕捉图像的细节和全局信息,使其成为各种视觉任务的强大通用骨干网络。亲测在小目标检测和大尺度目标检测的数据集上都有涨点效果。
推荐指数:???
涨点效果:???
训练结果对比图->??
目录
??四、手把手教你添加Swin Transformer网络结构
论文地址:官方论文地址
代码地址:官方代码地址
Swin Transformer是一个新的视觉变换器,能够作为通用的计算机视觉骨干网络。这个模型解决了将Transformer从语言处理领域适应到视觉任务中的挑战,主要是因为这两个领域之间存在差异,例如视觉实体的尺度变化大,以及图像中像素的高分辨率与文本中的单词相比。下图对比展示了Swin Transformer与Vision Transformer (ViT)的不同之处,清楚地展示了Swin Transformer在构建特征映射和处理计算复杂度方面的创新优势。
(a) Swin Transformer:提出的Swin Transformer通过在更深层次合并图像小块(灰色部分所示)来构建层次化的特征映射。在每个局部窗口(红色部分所示)内只计算自注意力,因此它对输入图像大小有线性的计算复杂度。它可以作为通用的骨干网络,用于图像分类和密集识别任务,如分割和检测。
(b) Vision Transformer (ViT):以前的视觉Transformer模型(如ViT)产生单一低分辨率的特征映射,并且由于全局自注意力的计算,其计算复杂度与输入图像大小呈二次方关系。
我们可以将Swin Transformer的基本原理分为以下几点:
1. 层次化特征映射:Swin Transformer通过合并图像的相邻小块(patches),在更深的Transformer层次中逐步构建层次化的特征映射。这样的层次化特征映射可以方便地利用密集预测的高级技术,如特征金字塔网络(Feature Pyramid Networks, FPN)或U-Net。
2. 局部自注意力计算:为了实现线性计算复杂性,Swin Transformer在非重叠的局部窗口内计算自注意力,这些窗口是通过划分图像来创建的。每个窗口内的小块数量是固定的,因此计算复杂性与图像大小成线性关系。
3. 移动窗口自注意力(Shifted Window based Self-Attention):标准的Transformer架构在全局范围内计算自注意力,即计算一个标记与所有其他标记之间的关系。这种全局计算导致与标记数量成二次方的计算复杂性,不适用于许多需要处理大规模高维数据的视觉问题。Swin Transformer通过一个基于移动窗口的多头自注意力(MSA)模块取代了传统的MSA模块。每个Swin Transformer块由一个基于移动窗口的MSA模块组成,然后是两层带有GELU非线性的MLP,之前是LayerNorm(LN)层,之后是残差连接。
4. 移动窗口分区:为了在连续的Swin Transformer块中引入跨窗口连接的同时保持非重叠窗口的有效计算,提出了一种移动窗口分区方法。这种方法在连续的块之间交替使用两种分区配置。第一个模块使用常规的窗口分区策略,然后下一个模块采用的窗口配置与前一层相比,通过移动窗口偏移了一定距离,从而实现窗口的交替。
下图详细展示了Swin Transformer的架构和两个连续Swin Transformer块的设计。图中的W-MSA和SW-MSA分别代表带有常规和移动窗口配置的多头自注意力模块。这两种类型的注意力模块交替使用,允许模型在保持局部计算的同时,也能够捕捉更广泛的上下文信息。
(a) 架构(Architecture):图展示了Swin Transformer的四个阶段。每个阶段都包含若干Swin Transformer块。输入图像首先通过“Patch Partition”被划分成小块,并通过“Linear Embedding”转换成向量序列。各个阶段通过“Patch Merging”操作降低特征图的分辨率,同时增加特征维数(例如,第一阶段输出的特征维数为C,第二阶段为2C,依此类推)。
(b) 两个连续的Swin Transformer块(Two Successive Swin Transformer Blocks):每个Swin Transformer块由多头自注意力模块(W-MSA和SW-MSA)和多层感知机(MLP)组成,其中W-MSA使用常规窗口配置,而SW-MSA使用移动窗口配置。每个块内部,先是LayerNorm(LN)层,然后是自注意力模块,再是另一个LayerNorm层,最后是MLP。块之间通过残差连接进行连接,这样的设计可以避免深层网络中的梯度消失问题,并允许信息在网络中更有效地流动。
层次化特征映射可以使Swin Transformer有效地处理不同分辨率的特征,并适用于各种视觉任务,如图像分类、对象检测和语义分割。这种层次化设计使Swin Transformer与以往基于Transformer的架构(这些架构产生单一分辨率的特征图并具有二次方复杂度)形成对比,后者不适合需要在像素级进行密集预测的视觉任务。Swin Transformer的层次化特征映射主要通过以下步骤实现:
1. 分块和线性嵌入:首先,输入图像被分割成小块(通常是4x4像素大小),每个小块被视为一个“标记”,其特征是原始像素RGB值的串联。然后,一个线性嵌入层被应用于这些原始值特征,将其投影到任意维度(表示为C)。这些步骤构成了所谓的“第1阶段”。
2. 分块合并:随着网络深入,通过合并层减少标记的数量,从而降低特征图的分辨率。例如,第一个合并层将每组2x2相邻小块的特征合并,并应用一个线性层到这些4C维度的串联特征上,这样做将标记的数量减少了4倍(分辨率降低了2倍),并将输出维度设为2C。这个过程在后续的“第2阶段”、“第3阶段”和“第4阶段”中重复,分别产生更低分辨率的输出。
3. 层次化特征图:通过在更深的Transformer层合并相邻小块,Swin Transformer构建了层次化的特征映射。这些层次化特征映射允许模型方便地使用密集预测的高级技术,例如特征金字塔网络(FPN)或U-Net。
4. 计算效率:Swin Transformer在非重叠的局部窗口内局部计算自注意力,从而实现了线性的计算复杂度。每个窗口中的小块数量是固定的,因此复杂度与图像大小成线性关系。
Swin Transformer的局部自注意力计算通过在小窗口内计算自注意力以及通过移动窗口在连续层之间引入跨窗口的信息流通,使得计算更加高效,同时保留了模型捕捉长距离依赖的能力。Swin Transformer中的局部自注意力计算我们可以通过以下方式实现:
1. 替代标准多头自注意力模块:Swin Transformer使用基于移动窗口的多头自注意力(MSA)模块替代了传统Transformer块中的标准多头自注意力模块,其他层保持不变。每个Swin Transformer块由一个基于移动窗口的MSA模块组成,后跟一个两层的MLP,中间包含GELU非线性激活函数。在每个MSA模块和MLP之前都会应用一个LayerNorm(LN)层,每个模块之后都会应用残差连接。
2. 在各个窗口内计算自注意力:在每一层中,采用常规的窗口分区方案,每个窗口内部独立计算自注意力。在下一层中,窗口分区会发生移动,形成新的窗口。新窗口中的自注意力计算会跨越之前层中窗口的边界,建立它们之间的连接。
3. 非重叠窗口中的自注意力:为了有效的建模,Swin Transformer在非重叠的局部窗口内计算自注意力。这些窗口被安排以均匀非重叠的方式分割图像。假设每个窗口包含M×M个小块,全局MSA模块和基于窗口的MSA模块的计算复杂度分别为二次方和线性,当M固定时(默认设为7)。
4. 循环位移和掩码机制:提出了一种通过循环位移来提高批量计算的效率的方法。通过这种位移,一个批次的窗口可能由几个在特征图中不相邻的子窗口组成,因此采用掩码机制限制在每个子窗口内计算自注意力。这种循环位移保持了批次窗口的数量与常规窗口分区相同。
5. 窗口间的位移:为了在连续层之间实现更高效的硬件实现,Swin Transformer提出在连续层之间位移窗口,这样的位移允许跨窗口的连接,同时维持计算的高效性。
6. 相对位置偏置:在计算自注意力时,Swin Transformer包括了相对位置偏置B,以增强模型对不同位置之间关系的学习能力。
移动窗口自注意力是Swin Transformer设计的核心元素,它通过在局部窗口内计算自注意力并在连续层之间引入窗口位移,以实现高效的计算和强大的建模能力。在Swin Transformer论文中,移动窗口自注意力(shifted window self-attention)的主要特点包括:
1. 替代多头自注意力模块:在Swin Transformer块中,标准的多头自注意力(MSA)模块被基于移动窗口的MSA模块替换。这种基于移动窗口的MSA模块后跟一个两层的MLP,中间有GELU非线性激活函数。每个MSA模块和MLP之前都会应用一个LayerNorm(LN)层,每个模块之后都会应用残差连接。
2. 移动窗口分区:在连续的Swin Transformer块中,窗口分区策略在每一层之间交替。在某一层中,采用常规窗口分区,而在下一层中,窗口分区会发生移动,从而形成新的窗口。这种移动窗口分区方法能够跨越前一层中窗口的边界,提供窗口间的连接。
3. 交替分区配置:移动窗口分区方法在连续的Swin Transformer块中交替使用两种分区配置。例如,第一个模块从左上角像素开始使用常规窗口分区策略,接着下一个模块采用的窗口配置将与前一层相比移动一定距离。
4. 移动窗口自注意力的计算:移动窗口自注意力计算的有效性不仅在图像分类、目标检测和语义分割任务中得到了验证,而且它的实现也被证明在所有MLP架构中有益。
5. 效率:相比于滑动窗口方法,移动窗口方法具有更低的延迟,但在建模能力上却相似。此外,移动窗口方法也有助于提高批量计算的效率。
6. 连续块的计算:在移动窗口分区方法中,连续的Swin Transformer块的计算方式如下:,然后是MLP层,之后是。这里,和 分别代表块l的(S)W-MSA模块和MLP模块的输出特征。
下面我给大家展示了所提出的Swin Transformer架构中用于计算自注意力的移动窗口方法。
在第l层(左侧),采用了常规窗口划分方案,并且在每个窗口内计算自注意力。在接下来的第l+1层(右侧),窗口划分被移动,结果在新的窗口中进行了自注意力计算。这些新窗口中的自注意力计算跨越了l层中之前窗口的边界,提供了它们之间的连接。这种移动窗口方法提高了效率,因为它限制了自注意力计算在非重叠的局部窗口内,同时允许窗口间的交叉连接。?
移动窗口分区是Swin Transformer中一项关键的创新,它通过在连续层之间交替窗口的分区方式,有效地促进了信息在窗口之间的流动,同时保持了处理高分辨率图像时的计算效率。下面我将通过图片解释如何使用循环位移来计算在移动窗口中的自注意力,以及如何高效地实施这一计算。
(1)窗口分区(Window partition):首先,图像被分成多个窗口。
(2)循环位移(Cyclic shift):接着,为了计算自注意力,窗口内的像素或特征会进行循环位移。这样可以将本来不相邻的像素或特征暂时性地排列到同一个窗口内,使得可以在局部窗口中计算原本跨窗口的自注意力。
(3)掩码多头自注意力(Masked MSA):在经过循环位移后,可以在这些临时形成的窗口上执行掩码多头自注意力操作,以此计算注意力得分和更新特征。
(4)逆循环位移(Reverse cyclic shift):完成自注意力计算后,特征会进行逆循环位移,恢复到它们原来在图像中的位置。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
class Mlp(nn.Module):
""" Multilayer perceptron."""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
""" Forward function.
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock(nn.Module):
""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.H = None
self.W = None
def forward(self, x, mask_matrix):
""" Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
mask_matrix: Attention mask for cyclic shift.
"""
B, L, C = x.shape
H, W = self.H, self.W
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# pad feature maps to multiples of window size
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
attn_mask = mask_matrix.type(x.dtype)
else:
shifted_x = x
attn_mask = None
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchMerging(nn.Module):
""" Patch Merging Layer
Args:
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x, H, W):
""" Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
# padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of feature channels
depth (int): Depths of this stage.
num_heads (int): Number of attention head.
window_size (int): Local window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self,
dim,
depth,
num_heads,
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False):
super().__init__()
self.window_size = window_size
self.shift_size = window_size // 2
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, H, W):
""" Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
# calculate attention mask for SW-MSA
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
for blk in self.blocks:
blk.H, blk.W = H, W
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, attn_mask)
else:
x = blk(x, attn_mask)
if self.downsample is not None:
x_down = self.downsample(x, H, W)
Wh, Ww = (H + 1) // 2, (W + 1) // 2
return x, H, W, x_down, Wh, Ww
else:
return x, H, W, x, H, W
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
Args:
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
# padding
_, _, H, W = x.size()
if W % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
if H % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
x = self.proj(x) # B C Wh Ww
if self.norm is not None:
Wh, Ww = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
return x
class SwinTransformer(nn.Module):
""" Swin Transformer backbone.
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
pretrain_img_size (int): Input image size for training the pretrained model,
used in absolute postion embedding. Default 224.
patch_size (int | tuple(int)): Patch size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
depths (tuple[int]): Depths of each Swin Transformer stage.
num_heads (tuple[int]): Number of attention head of each stage.
window_size (int): Window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
drop_rate (float): Dropout rate.
attn_drop_rate (float): Attention dropout rate. Default: 0.
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
out_indices (Sequence[int]): Output from which stages.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self,
pretrain_img_size=224,
patch_size=4,
in_chans=3,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.2,
norm_layer=nn.LayerNorm,
ape=False,
patch_norm=True,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
use_checkpoint=False):
super().__init__()
self.pretrain_img_size = pretrain_img_size
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.out_indices = out_indices
self.frozen_stages = frozen_stages
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
# absolute position embedding
if self.ape:
pretrain_img_size = to_2tuple(pretrain_img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint)
self.layers.append(layer)
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
self.num_features = num_features
# add a norm layer for each output
for i_layer in out_indices:
layer = norm_layer(num_features[i_layer])
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
self.width_list = [i.size(1) for i in self.forward(torch.randn(1, 3, 640, 640))]
def forward(self, x):
"""Forward function."""
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
if self.ape:
# interpolate the position embedding to the corresponding size
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
else:
x = x.flatten(2).transpose(1, 2)
x = self.pos_drop(x)
outs = []
for i in range(self.num_layers):
layer = self.layers[i]
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
x_out = norm_layer(x_out)
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
outs.append(out)
return outs
这个主干的网络结构添加起来算是所有的改进机制里最麻烦的了,因为有一些网略结构可以用yaml文件搭建出来,有一些网络结构其中的一些细节根本没有办法用yaml文件去搭建,用yaml文件去搭建会损失一些细节部分(而且一个网络结构设计很多细节的结构修改方式都不一样,一个一个去修改大家难免会出错),所以这里让网络直接返回整个网络,然后修改部分 yolo代码以后就都以这种形式添加了,以后我提出的网络模型基本上都会通过这种方式修改,我也会进行一些模型细节改进。创新出新的网络结构大家直接拿来用就可以的。下面开始添加教程->
(同时每一个后面都有代码,大家拿来复制粘贴替换即可,但是要看好了不要复制粘贴替换多了)
我们复制网络结构代码到“yolov5-master/models”目录下创建一个目录,我这里的名字是modules(如果将文件放在models下面随着改进机制越来越多不太好区分,所以创建一个文件目录将改进机制全部放在里面)?,然后创建一个py文件将代码复制粘贴到里面我这里起的名字是Swin Transformer。
?
然后我们在我们创建的目录里面创建一个初始化文件'__init__.py',然后在里面导入我们同级目录的所有改进机制
??
我们找到如下文件'models/yolo.py'在开头里面导入我们的模块,这里需要注意要将代码放在common导入的文件上面,否则有一些模块会使用我们modules里面的,可能同名导致报错,这里如果你使用多个我的改进机制填写一个即可,不用重复添加。
添加如下两行代码,根据行数找相似的代码进行添加
找到七百多行大概把具体看图片,按照图片来修改就行,添加红框内的部分,注意没有()只是函数名,我这里只添加了部分的版本,大家有兴趣这个ShuffleNetV1还有更多的版本可以添加,看我给的代码函数头即可。
elif m in {自行添加对应的模型即可,下面都是一样的}:
m = m()
c2 = m.width_list # 返回通道列表
backbone = True
下面的两个红框内都是需要改动的。?
if isinstance(c2, list):
m_ = m
m_.backbone = True
else:
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
t = str(m)[8:-2].replace('__main__.', '') # module type
np = sum(x.numel() for x in m_.parameters()) # number params
m_.i, m_.f, m_.type, m_.np = i + 4 if backbone else i, f, t, np # attach index, 'from' index, type
如下的也需要修改,全部按照我的来。
代码如下把原先的代码替换了即可。?
save.extend(x % (i + 4 if backbone else i) for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
layers.append(m_)
if i == 0:
ch = []
if isinstance(c2, list):
ch.extend(c2)
if len(c2) != 5:
ch.insert(0, 0)
else:
ch.append(c2)
修改七和前面的都不太一样,需要修改前向传播中的一个部分,?已经离开了parse_model方法了。
可以在图片中开代码行数,没有离开task.py文件都是同一个文件。 同时这个部分有好几个前向传播都很相似,大家不要看错了,是70多行左右的!!!,同时我后面提供了代码,大家直接复制粘贴即可,有时间我针对这里会出一个视频。
找到如下的代码,这里不太好找,我给大家上传一个原始的样子。
然后我们用后面的代码进行替换,替换完之后的样子如下->?
??
代码如下->
def _forward_once(self, x, profile=False, visualize=False):
y, dt = [], [] # outputs
for m in self.model:
if m.f != -1: # if not from previous layer
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
if profile:
self._profile_one_layer(m, x, dt)
if hasattr(m, 'backbone'):
x = m(x)
if len(x) != 5: # 0 - 5
x.insert(0, None)
for index, i in enumerate(x):
if index in self.save:
y.append(i)
else:
y.append(None)
x = x[-1] # 最后一个输出传给下一层
else:
x = m(x) # run
y.append(x if m.i in self.save else None) # save output
if visualize:
feature_visualization(x, m.type, m.i, save_dir=visualize)
return x
到这里就完成了修改部分,但是这里面细节很多,大家千万要注意不要替换多余的代码,导致报错,也不要拉下任何一部,都会导致运行失败,而且报错很难排查!!!很难排查!!!?
复制如下yaml文件进行运行!!!?
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.25 # layer channel multiple
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
backbone:
# [from, number, module, args]
[[-1, 1, SwinTransformer, []], # 0-4-P1/
[-1, 1, SPPF, [1024, 5]], # 5
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 3], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 9
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 2], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 13 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 16 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 5], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 19 (P5/32-large)
[[13, 16, 19], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]
下面是成功运行的截图,已经完成了有1个epochs的训练,图片太大截不全第2个epochs了。?
?
??
?
到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv5改进有效涨点专栏,本专栏目前为新开的平均质量分97分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,目前本专栏免费阅读(暂时,大家尽早关注不迷路~),如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~