目录
文章性质:学习笔记 📖
视频教程:U2-Net 源码讲解(PyTorch)- 4 自定义数据读取
主要内容:根据 视频教程 中提供的 U2-Net?源代码(PyTorch),对 my_dataset.py?文件进行具体讲解。
├── src: 搭建网络相关代码
├── train_utils: 训练以及验证相关代码
├── my_dataset.py: 自定义数据集读取相关代码
├── predict.py: 简易的预测代码
├── train.py: 单GPU或CPU训练代码
├── train_multi_GPU.py: 多GPU并行训练代码
├── validation.py: 单独验证模型相关代码
├── transforms.py: 数据预处理相关代码
└── requirements.txt: 项目依赖
?
?
原论文提供的 U2-Net 网络结构图如下所示:?
???
【说明】在 Encoder 阶段,每通过一个 block 后都经 Maxpool 下采样 2 倍,在 Decoder 阶段,每通过一个 block 后都经 Bilinear 上采样 2 倍。U2-Net 网络的核心 block 是 ReSidual U-block,分为具备上下采样的 block 和不具备上下采样的 block:
DUTS 数据集解压后的目录结构:
├── DUTS-TR
│ ? ? ?├── DUTS-TR-Image: 该文件夹存放所有训练集的图片
│ ? ? ?└── DUTS-TR-Mask: 该文件夹存放对应训练图片的 GT 标签(Mask 蒙板形式)
│
└── DUTS-TE
? ? ? ?├── DUTS-TE-Image: 该文件夹存放所有测试(验证)集的图片
? ? ? ?└── DUTS-TE-Mask: 该文件夹存放对应测试(验证)图片的 GT 标签(Mask 蒙板形式)
这个 DUTSDataset 类继承自 data.Dataset 父类,在其初始化 __init__ 方法中,传入参数包括:
【代码解析1】对 DUTSDataset 类代码的具体解析(结合上图):?
【代码解析2】对 DUTSDataset 类代码的具体解析(结合上图):?
【代码解析3】对 DUTSDataset 类代码的具体解析(结合上图):?
import os
import cv2
import torch.utils.data as data
class DUTSDataset(data.Dataset):
def __init__(self, root: str, train: bool = True, transforms=None):
assert os.path.exists(root), f"path '{root}' does not exist."
if train:
self.image_root = os.path.join(root, "DUTS-TR", "DUTS-TR-Image")
self.mask_root = os.path.join(root, "DUTS-TR", "DUTS-TR-Mask")
else:
self.image_root = os.path.join(root, "DUTS-TE", "DUTS-TE-Image")
self.mask_root = os.path.join(root, "DUTS-TE", "DUTS-TE-Mask")
assert os.path.exists(self.image_root), f"path '{self.image_root}' does not exist."
assert os.path.exists(self.mask_root), f"path '{self.mask_root}' does not exist."
image_names = [p for p in os.listdir(self.image_root) if p.endswith(".jpg")]
mask_names = [p for p in os.listdir(self.mask_root) if p.endswith(".png")]
assert len(image_names) > 0, f"not find any images in {self.image_root}."
# check images and mask
re_mask_names = []
for p in image_names:
mask_name = p.replace(".jpg", ".png")
assert mask_name in mask_names, f"{p} has no corresponding mask."
re_mask_names.append(mask_name)
mask_names = re_mask_names
self.images_path = [os.path.join(self.image_root, n) for n in image_names]
self.masks_path = [os.path.join(self.mask_root, n) for n in mask_names]
self.transforms = transforms
def __getitem__(self, idx):
image_path = self.images_path[idx]
mask_path = self.masks_path[idx]
image = cv2.imread(image_path, flags=cv2.IMREAD_COLOR)
assert image is not None, f"failed to read image: {image_path}"
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # BGR -> RGB
h, w, _ = image.shape
target = cv2.imread(mask_path, flags=cv2.IMREAD_GRAYSCALE)
assert target is not None, f"failed to read mask: {mask_path}"
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
def __len__(self):
return len(self.images_path)
@staticmethod
def collate_fn(batch):
images, targets = list(zip(*batch))
batched_imgs = cat_list(images, fill_value=0)
batched_targets = cat_list(targets, fill_value=0)
return batched_imgs, batched_targets
def cat_list(images, fill_value=0):
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
batch_shape = (len(images),) + max_size
batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
for img, pad_img in zip(images, batched_imgs):
pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)
return batched_imgs
if __name__ == '__main__':
train_dataset = DUTSDataset("./", train=True)
print(len(train_dataset))
val_dataset = DUTSDataset("./", train=False)
print(len(val_dataset))
i, t = train_dataset[0]