图像分割实战-系列教程10:U2NET显著性检测实战2

发布时间:2024年01月05日

在这里插入图片描述

🍁🍁🍁图像分割实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传

U2NET显著性检测实战1
U2NET显著性检测实战2

5、残差Unet模块

在这里插入图片描述

class RSU7(nn.Module):#UNet07DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU7,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):
        hx = x
        hxin = self.rebnconvin(hx)
        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)

        hx5 = self.rebnconv5(hx)
        hx = self.pool5(hx5)

        hx6 = self.rebnconv6(hx)

        hx7 = self.rebnconv7(hx6)

        hx6d =  self.rebnconv6d(torch.cat((hx7,hx6),1))
        hx6dup = _upsample_like(hx6d,hx5)

        hx5d =  self.rebnconv5d(torch.cat((hx6dup,hx5),1))
        hx5dup = _upsample_like(hx5d,hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)
        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)
        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
        return hx1d + hxin

这里以RSU7举例,U2Net就是每一个backbone都是一个带残差连接的Unet,这些backbone再以Unet的形式进行连接,在连接的过程中有特征凭借和上采样下采样操作

  1. 首先这里的RSU7也是一个一个相同的小组件组成,每一个小组件都是一个REBNCONV
  2. 一个REBNCONV就是,一个卷积Conv、批归一化BatchNormalization、Relu激活函数的三连
  3. 在构造函数中,就是定义了1个REBNCONV+5个(REBNCONV+Maxpooling)+8个REBNCONV
  4. 其中第一个REBNCONV是为了进行残差连接,将输入的长、宽、通道数转化为和输出一直的维度,代表原始输入的x
  5. 在最后的输出会再加上这个x,表示残差连接操作
  6. 编码器就是5个(REBNCONV+Maxpooling)和最后一个没有Maxpooling的REBNCONV,即hx1到hx6
  7. hx7是中间结果
  8. 解码器就是5(对应位置进行拼接的REBNCONV+上采样)和最后一个没有上采样的REBNCONV,即hx6d到hx1d
  9. hx1d再加上前面提到的x就是最后的输出

6、上采样操作与REBNCONV

def _upsample_like(src,tar):

    src = F.upsample(src,size=tar.shape[2:],mode='bilinear')

    return src

使用双线性插值进行上采样操作

class REBNCONV(nn.Module):
    def __init__(self,in_ch=3,out_ch=3,dirate=1):
        super(REBNCONV,self).__init__()
        self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
        self.bn_s1 = nn.BatchNorm2d(out_ch)
        self.relu_s1 = nn.ReLU(inplace=True)
    def forward(self,x):
        hx = x
        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
        return xout

定义二维卷积、二维池化、Relu,然后进行对应的前向传播

7、各个残差Unet比较

在 U2-Net 中,RSU7, RSU6, RSU5, RSU4, 和 RSU4F 是用于构造网络不同层级的模块。它们共同构成了 U2-Net 的多层次特征提取体系

  1. RSU7 (Residual U-Block 7):

    • RSU7 是最深层的模块,具有最大的感受野,用于网络的最初阶段,用于从输入图像中提取基础和全局特征。在 U2NET 架构中,RSU7 作为第一个阶段使用。
  2. RSU6, RSU5, RSU4:

    • 这些模块是 U2-Net 架构中的中间层。RSU6, RSU5, RSU4 的主要区别在于它们的深度和感受野的大小。每个模块都比前一个模块浅一点,感受野也稍小。这些层用于提取越来越具体的特征,随着网络的深入,这些特征越来越侧重于局部细节。
  3. RSU4F (Residual U-Block 4-Full):

    • RSU4F 是一种特殊的 RSU 模块,它不使用最大池化层来减少特征图的尺寸,而是使用不同膨胀率的卷积来增大感受野(即空洞卷积),RSU4F 用于网络的深层,用于捕捉更细粒度的特征。

在 U2-Net 的结构中,这些 RSU 模块按照从 RSU7RSU4F 的顺序排列。

在编码器阶段,随着层级的增加,模块逐渐变得更浅,专注于更细节的特征提取。

在解码器阶段,这些模块的输出与对应编码器阶段的输出进行融合,通过上采样逐步恢复图像的空间维度,同时保持了特征的丰富性。

总结来说,RSU7RSU4F 的不同主要在于它们的深度(层数)和膨胀率,这影响了它们的感受野大小和特征提取的具体性。

U2NET显著性检测实战1
U2NET显著性检测实战2

文章来源:https://blog.csdn.net/weixin_50592077/article/details/135398229
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。