Pytorch基础:nn模块与网络创建——卷积模块的使用

发布时间:2024年01月12日

1. 卷积模块

  • 卷积是卷积神经网络中最常用的操作之一。在计算机视觉任务中经常会使用普通卷积、分组卷积、深度可分离卷积、空洞卷积、转置卷积(反卷积)和3D卷积。

2. nn.Conv2d

  • PyTorch 的卷积模块是nn.Conv2d,构造函数如下:
nn.Conv2d(
	in_channels,  //输入图像/特征图的通道数。
	out_channels, //输出图像/特征图的通道数。
	kernel_size,  //卷积核的尺寸。
	stride=1,     //卷积的步长。
	padding=0, 	  //0的数目。
	dilation=1,   //空洞率。代表卷积核的核元素之间的空间距离,默认值为1时为普通卷积,大于1时代表空洞卷积。
	groups=1,     //分组数。
	bias=True,    //bool类型。默认为True代表使用卷积计算的偏差项,False代表不使用卷积计算的偏差项。
	padding_mode='zeros',//0模式。常见的补0模式包括zeros/reflect/replicate/circular.
	device=None,  //
	dtype=None    //
	)
  • nn.Conv2d 卷积输入数据的格式是(batch_size,C,H,W)
    • batch_size:在一次迭代中输入网络的数据量;
    • C:输入数据的通道数;
    • H,W:代表输入数据的长(Height)和宽(Width)。

3. 特征图的输入和输出尺寸的关系表

H o u t = H i n + 2 × P H ? D H × ( K H ? 1 ) ? 1 S H + 1 H_{out}=\frac{H_{in}+2\times PH-DH\times (KH-1)-1}{SH}+1 Hout?=SHHin?+2×PH?DH×(KH?1)?1?+1

W o u t = W i n + 2 × P W ? D W × ( K W ? 1 ) ? 1 S W + 1 W_{out}=\frac{W_{in}+2\times PW-DW\times (KW-1)-1}{SW}+1 Wout?=SWWin?+2×PW?DW×(KW?1)?1?+1

H i n / H o u t 代表输入输出特征图长度, W i n / W o u t 代表输入输出特征图宽度。 K H 代表卷积核的长度、 K W 代表卷积核的宽度。 P H 代表长度方向补 0 的数目, P W 代表宽度方向补 0 的数目。 S H 代表长度方向卷积步长, S W 代表宽度方向卷积步长。 D H 代表长度方向空洞率, D W 代表宽度方向空洞率。 C i n / C o u t 代表输入 / 输出特征图的通道数 \begin{aligned} &H_{in}/H_{out} 代表输入输出特征图长度,W_{in}/W_{out} 代表输入输出特征图宽度。\\ &KH 代表卷积核的长度、KW代表卷积核的宽度。\\ &PH 代表长度方向补0的数目,PW 代表宽度方向补0的数目。\\ &SH 代表长度方向卷积步长,SW 代表宽度方向卷积步长。\\ &DH 代表长度方向空洞率,DW 代表宽度方向空洞率。\\ &C_{in}/C_{out} 代表输入/输出特征图的通道数\\ \end{aligned} ?Hin?/Hout?代表输入输出特征图长度,Win?/Wout?代表输入输出特征图宽度。KH代表卷积核的长度、KW代表卷积核的宽度。PH代表长度方向补0的数目,PW代表宽度方向补0的数目。SH代表长度方向卷积步长,SW代表宽度方向卷积步长。DH代表长度方向空洞率,DW代表宽度方向空洞率。Cin?/Cout?代表输入/输出特征图的通道数?

4. 卷积后输出特征图维度和输入相同

import torch
import torch.nn as nn 
data_in=torch.randn(size=(1,3,400,400))
k_size=[1,3,5,7]
stride=[1,1,1,1]
pad=[0,1,2,3]
# 输入和输出通道均设为3
ch_in=3
ch_out=3

for i in range(len(k_size)):
    conv2d=nn.Conv2d(ch_in,ch_out,k_size[i],stride[i],pad[i])
    print(conv2d(data_in).size())
# 输出结果
torch.Size([1, 3, 400, 400])
torch.Size([1, 3, 400, 400])
torch.Size([1, 3, 400, 400])
torch.Size([1, 3, 400, 400]) 
文章来源:https://blog.csdn.net/weixin_38566632/article/details/135537472
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。