目录
前言
Preparation
一、VOCSegmentation 类
1、__init__ 函数
2、__getitem__ 函数
3、collate_fn 函数
二、cat_list 函数
前言
文章性质:学习笔记 📖
视频教程:FCN源码解析(Pytorch)- 3 自定义读取数据集
主要内容:根据 视频教程 中提供的 FCN 源代码(PyTorch),对 my_dataset.py 文件进行具体讲解。
Preparation
FCN 源码:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/tree/master/pytorch_segmentation/fcn
【补充】除此之外,需要大家去 PASCAL VOC 官网下载数据集,下载后得到 VOCdevkit 文件夹:
?
一、VOCSegmentation 类
?VOCSegmentation 类继承自 data.Dataset 父类。
1、__init__ 函数
这个类使用 __init__ 初始化数据集对象,传入?VOCdevkit 根目录路径 voc_root、数据集年份 year、数据预处理操作 transforms、训练文件列表的文件名 txt_name 等参数,year 默认为 2012 ,txt_name 默认为 train.txt 。
关于 txt_name 取值的具体说明:
- 当 txt_name 取值为 train.txt 时,读取的是 训练集 数据,可根据 train.txt 中保存的图片名称去寻找对应的图片数据和标签数据。
- 当 txt_name 取值为 val.txt 时,读取的是 验证集 数据,可根据 val.txt 中保存的图片名称去寻找对应的图片数据和标签数据。
?
【代码解析】对 VOCSegmentation 类代码的具体解析(结合上图):
- ?我们这里的 VOC 数据集只支持 2007 和 2012 的,如果年份不是 2007 和 2012 则引发 AssertionError 异常。
- ?通过拼接 voc_root 、VOCdevkit 和 VOC2012 得到 root 路径,然后判断路径是否存在。
- ?通过拼接?root 路径和固定的目录名称得到 图片目录 image_dir 和 分割标签目录 mask_dir 。
- ?通过拼接 root 路径、固定的目录名称和 txt_name 得到 train.txt 文件的目录或者 val.txt 文件的目录,然后判断路径是和否存在。
- ?遍历 txt_path 路径指定的 txt 文件,读取所有的非空行,并通过 strip 方法将行首行尾的空格去掉。
- ?构建出 file_names ,内含训练集或者验证集的图片名称。
- ?构建出图片文件的路径 os.path.join(image_dir,?x + ".jpg")?和标签文件的路径?os.path.join(mask_dir,?x + ".png") 。
?
2、__getitem__ 函数
这个类使用 __getitem__ 方法实现根据索引获取数据的功能:
?
【代码解析】根据索引 index 打开对应的图片文件(转?RGB 格式)和标签文件,并进行一系列预处理操作,然后返回处理后的图片和标签。
3、collate_fn 函数
这个类使用 collate_fn 方法在加载数据时,将一个批次的数据进行整理和处理:
?
【代码解析】对 collate_fn 函数代码的具体解析(结合上图):
- ?将图片?images 和标签 targets 分别打包成两个列表。
- ?使用 cat_list 函数将图片列表 images 进行拼接,得到一个批次的图像数据,若图像尺寸不足最大尺寸,则用 0 进行填充。
- ?使用 cat_list 函数将标签列表 targets 进行拼接,得到一个批次的标签数据,若标签尺寸不足最大尺寸,则用 255 进行填充。
【补充】针对上面的第一步,我们可以通过断点调试的方式查看打包前后的区别:
?
?
二、cat_list 函数
可使用?cat_list 方法将一个批次的图像数据进行拼接,相关代码截图如下:
?
Step1 计算这个 batch 图像数据中的通道数 channel 、高度 h 、宽度 w 的最大值:
- ?通过遍历 images 图像列表获取图像数据。
- ?使用 zip(*[img.shape for img in images]) 将所有图像的对应维度取出并打包成元组。
- ?使用 max 函数求得各维度的最大值,得到 max_size 元组。
Step2 将 不同大小的 images 图片 打包成 一个 Tensor ,然后输入到网络当中进行运算
- ?构建批次图像数据的形状,即 batch_shape ,经断点调试可知为 [4, 3, 480, 480],这四个维度分别是图像数量、通道数、高度和宽度。
- ?创建与 batch_shape 形状相同的新张量 batched_imgs ,并使用 fill_ 方法将其元素填充为指定的 fill_value 。
- ?使用 pad_img[..., :img.shape[-2], :img.shape[-1].copy_(img)] 将原始图像 images 的内容复制到对应位置的批次图像 batched_imgs 上。
【补充】我们可以通过断点调试的方式查看 cat_list 函数的返回结果:
?