Talk is cheap,show me the code!? 哈哈,先上几段常用的代码,以语义分割的DRIVE数据集加载为例:
DRIVE数据集的目录结构如下,下载链接DRIVE,如果官网下不了,到Kaggle官网可以下到:
1. 定义DriveDataset类,每行代码都加了注释,其中collate_fn()看不懂没关系:
import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset
class DriveDataset(Dataset): # 继承Dataset类
def __init__(self, root: str, train: bool, transforms=None):
super(DriveDataset, self).__init__()
self.flag = "training" if train else "test" # 根据train这个布尔类型确定需要处理的是训练集还是测试集
data_root = os.path.join(root, "DRIVE", self.flag) # 得到数据集根目录
assert os.path.exists(data_root), f"path '{data_root}' does not exists." # 判断路径是否存在
self.transforms = transforms # 初始化图像变换操作
img_names = [i for i in os.listdir(os.path.join(data_root, "images")) if i.endswith(".tif")] # 遍历图像文件夹获取每个图像的文件名
self.img_list = [os.path.join(data_root, "images", i) for i in img_names] # 获取图像路径
self.manual = [os.path.join(data_root, "1st_manual", i.split("_")[0] + "_manual1.gif") # 获取手动标签的路径
for i in img_names]
# 检查手动标签文件是否存在
for i in self.manual:
if os.path.exists(i) is False:
raise FileNotFoundError(f"file {i} does not exists.")
# 获取分割的ROI区域掩码
self.roi_mask = [os.path.join(data_root, "mask", i.split("_")[0] + f"_{self.flag}_mask.gif")
for i in img_names]
# check files
for i in self.roi_mask:
if os.path.exists(i) is False:
raise FileNotFoundError(f"file {i} does not exists.")
def __getitem__(self, idx):
img = Image.open(self.img_list[idx]).convert('RGB') # 加载图像,并转换为RGB模式
manual = Image.open(self.manual[idx]).convert('L') # 加载手动标注图像,并转换为灰度模式
manual = np.array(manual) / 255 # 进行归一化操作
roi_mask = Image.open(self.roi_mask[idx]).convert('L') # 加载ROI图像,并转换为灰度模式
roi_mask = 255 - np.array(roi_mask) # 对图像数组取反,使用这个方法将背景和前景颜色反转,白色是255,黑色是0,反转后ROI变成了内黑外白
mask = np.clip(manual + roi_mask, a_min=0, a_max=255) # 将手动标注图像和反转后的ROI图像相加,使用np.clip()将像素值控制在0-255范围,
# 这里转回PIL的原因是,transforms中是对PIL数据进行处理
mask = Image.fromarray(mask)
if self.transforms is not None:
img, mask = self.transforms(img, mask)
return img, mask
# 获取图像数据集长度
def __len__(self):
return len(self.img_list)
# 用于将批量的图像和标签数据合并为一个批张量。
@staticmethod
def collate_fn(batch):
images, targets = list(zip(*batch)) # 将批量数据拆分为图像和标签两个列表
batched_imgs = cat_list(images, fill_value=0) # 使用 cat_list() 函数将图像和标签列表合并成张量。用于将列表中的 PIL 图像数据堆叠成张量,
batched_targets = cat_list(targets, fill_value=255)
return batched_imgs, batched_targets
def cat_list(images, fill_value=0):
max_size = tuple(max(s) for s in zip(*[img.shape for img in images])) # 找到图像中最大的形状,以元组形式返回给max_size
batch_shape = (len(images),) + max_size # 计算出堆叠后的张量形状,包括批量大小和图像大小两个维度
batched_imgs = images[0].new(*batch_shape).fill_(fill_value) # 创建一个新的空白张量 batched_imgs,其形状与 batch_shape 相同,并将其填充为指定的填充值 fill_value
for img, pad_img in zip(images, batched_imgs): # 使用 zip() 函数将输入列表中的每个图像与其对应的空白张量进行拼接,以得到一个完整的张量。
pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img) # 将每个图像按照其实际大小插入到空白张量的左上角,以保持图像的相对位置不变。
return batched_imgs
2. 构建训练集和验证集对象。调用上述自定义的DriveDataset数据集类,通过传入不同的参数来区分训练集和验证集。arg.data_path表示数据集所在的路径,transforms参数则是表示对数据进行预处理的操作,包括图像增强和归一化等,mean和std是规范化处理时用到的均值和标准差。get_transform这个类下面会介绍。train=True表示构建训练集对象,train=False表示构建验证集对象。
train_dataset = DriveDataset(args.data_path,
train=True,
transforms=get_transform(train=True, mean=mean, std=std))
val_dataset = DriveDataset(args.data_path,
train=False,
transforms=get_transform(train=False, mean=mean, std=std))
3.这是定义图像预处理方式,包括训练集和测试集的图像和标签的预处理方式,每行代码的具体作用注释有介绍。
# 定义训练集图像的预处理方式
class SegmentationPresetTrain:
def __init__(self, base_size, crop_size, hflip_prob=0.5, vflip_prob=0.5,
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
min_size = int(0.5 * base_size)
max_size = int(1.2 * base_size)
trans = [T.RandomResize(min_size, max_size)] # 对图像的短边(长和宽中最短的)进行随机缩放以适应不同图像输入尺寸,缩放范围为【min_size, max_size】
if hflip_prob > 0:
trans.append(T.RandomHorizontalFlip(hflip_prob)) # 加入水平翻转
if vflip_prob > 0:
trans.append(T.RandomVerticalFlip(vflip_prob)) # 加入垂直翻转
trans.extend([
T.RandomCrop(crop_size), # 对图像进行随机裁剪
T.ToTensor(), # 将数组矩阵转换为tensor类型,规范化到【0,1】范围
T.Normalize(mean=mean, std=std), # 加入图像归一化,并定义均值和标准差,RGB三通道的
])
# trans是一个列表类型,包含各种了变换,将这些变换组成一个compose变换,注意transforms.Compose()函数需要接收一个列表类型
self.transforms = T.Compose(trans)
# 使用__call__()函数来调用transforms变换
def __call__(self, img, target):
return self.transforms(img, target) # target是指标签图像,img是指待分割图像
# 定义验证集的图像预处理组合类,比较简单,只有张量化和规范化两个操作,这里规范化使用的是ImageNet推荐的参数,注意这种做法是针对彩色图像
class SegmentationPresetEval:
def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
self.transforms = T.Compose([
T.ToTensor(),
T.Normalize(mean=mean, std=std),
])
def __call__(self, img, target):
return self.transforms(img, target)
# 定义一个函数根据数据集的类型来调用对应的数据集处理类
def get_transform(train, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
base_size = 565
crop_size = 480
# 检查train是否为True
if train:
return SegmentationPresetTrain(base_size, crop_size, mean=mean, std=std)
else:
return SegmentationPresetEval(mean=mean, std=std)
4. 定义训练时所使用的线程数目,这里如果时windows系统训练出错的话,建议把num_workers直接设置为0就可以解决。定义训练数据加载器train_loader,用于批量加载训练数据。在代码中,使用torch.utils.data.DataLoader类来创建数据加载器,构造函数的参数包括:
train_dataset
:训练数据集,应该是一个符合 PyTorch Dataset 接口的对象。batch_size
:每个批次的样本数量。num_workers
:用于数据加载的线程数。shuffle
:是否在每个时期(epoch)重新洗牌数据。一般只在训练集用pin_memory
:是否将数据加载到固定的内存区域,可以加速数据传输。collate_fn
:用于将样本列表转换为批次张量的函数。
num_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # 如果batch_size>1, 线程数num_workers取min(cpu核数,batch_size)
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=True,
pin_memory=True,
collate_fn=train_dataset.collate_fn)
val_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=1,
num_workers=num_workers,
pin_memory=True,
collate_fn=val_dataset.collate_fn)
至此数据集加载器创建完成!