如何基于PyTorch框架自定义数据集类获取数据

发布时间:2023年12月30日

在PyTorch框架中,可以通过自定义数据集类来加载和处理数据

要自定义数据集类,需要继承 PyTorch提供的 torch.utils.data.Dataset类,并实现两个主要方法:__len__ __getitem__

下面是一个示例,展示如何基于PyTorch框架来自定义数据集类以获取数据:

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        item = self.data[index]
        # 在这里对数据进行预处理、转换等操作
        # 返回一个样本(通常是一个字典)
        return item

# 创建数据集实例
data = [...]  # 数据列表,包含训练样本
dataset = CustomDataset(data)

# 创建数据加载器
batch_size = 32
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 遍历数据加载器获取数据批次
for batch in dataloader:
    # 处理每个批次的数据
    inputs = batch['input']
    labels = batch['label']
    # 在这里进行模型训练、推理等操作

在此示例中,定义了一个名为 CustomDataset 的自定义数据集类,该类继承自torch.utils.data. Dataset

__init__方法是构造函数,传入数据列表 data 并将其保存为类的属性 self.data

__len__方法返回数据集的长度,即样本数量

__getitem__方法通过索引获取单个样本

然后,创建了一个数据集实例 dataset,并使用 torch.utils.data.DataLoader 创建了一个数据加载器 dataloader

通过遍历数据加载器可以获取每个批次 输入数据inputs 以及 标签数据labels,进行模型训练、推理等操作

注意:根据具体的应用需求,可以在__getitem__方法中对数据进行预处理、转换等操作,并将处理后的样本作为字典或其他形式返回, 这样,在训练过程中可以方便地获取输入数据和标签数据 ,并进行相应的操作

下面再来看一个例子,该例通过在 __getitem__方法中对数据进行预处理,并最终返回一个包含图片数据、对应的标签数据以及图像文件名的字典

class BipedDataset(Dataset):  # 定义了一个名为BipedDataset的类,它继承自PyTorch的Dataset类,用于自定义数据集
    '''
    用于构建一个自定义数据集,可以在训练神经网络时使用
    它提供了加载图像、预处理数据等功能,以便用于深度学习模型的训练
    '''
    def __init__(self,
                 data_root,  
                 img_height,
                 img_width,
                 mean_bgr,  # 图像的均值(以BGR通道顺序表示)
                 train_mode='train',  # 训练模式,可以是 'train' 或 'test' 之一,默认为 'train
                 crop_img=False,
                 arg=None
                 ):
        '''
        这是类的构造函数,用于初始化对象的属性
        它接受许多参数,包括数据根目录 data_root、图像高度 img_height、图像宽度 img_width、均值 mean_bgr、训练模式 train_mode 等
        '''
        self.data_root = data_root
        self.train_mode = train_mode
        self.img_height = img_height
        self.img_width = img_width
        self.mean_bgr = mean_bgr
        self.crop_img = crop_img
        self.arg = arg

        self.data_index = self._build_index()

    def _build_index(self):  # 用于构建数据索引
        data_root = os.path.abspath(self.data_root)
        sample_indices = []  # 用于存储图像和标签的文件路径对
        # 构建图像和标签的文件路径,其中 images_path 和 labels_path 分别指向数据集中图像和标签的存储路径
        # 使用两层循环遍历图像目录中的所有文件,构建图像和标签的文件路径,并将其添加到 sample_indices 列表中

        images_path = os.path.join(data_root,'edges\\imgs',self.train_mode)
      
        labels_path = os.path.join(data_root,'edges\\labels',self.train_mode)
            

        for file_name_ext in os.listdir(images_path):

            file_name = os.path.splitext(file_name_ext)[0]
            sample_indices.append(
                    ( os.path.join(images_path, file_name + '.tif'),
                      os.path.join(labels_path, file_name + '.tif'), )
                   )

        return sample_indices  # 返回构建好的图像和标签的文件路径对列表

    def __len__(self):  # 返回数据集的长度,即样本的数量
        return len(self.data_index)

    def __getitem__(self, idx):  # 用于获取指定索引处的数据样本,它接受一个索引 idx 作为参数
        # get data sample
        '''
        首先,根据索引获取图像路径和标签路径
        然后,使用OpenCV加载图像和标签
        接下来,调用self.transform方法进行数据变换
        最后,返回一个包含图像、对应标签以及图像文件名的字典
        '''

        image_path, label_path = self.data_index[idx]

        # load data
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
        image, label = self.transform(img=image, gt=label)  # transform方法:用于对图像和标签进行预处理
        img_name = os.path.basename(image_path)
        file_name = os.path.splitext(img_name)[0] + ".png"
        return dict(images=image, labels=label, file_names=file_name)

    def transform(self, img, gt):
        # 将标签转换为浮点型数组,并将其归一化到 [0, 1] 的范围内
        gt = np.array(gt, dtype=np.float32)
        if len(gt.shape) == 3:
            gt = gt[:, :, 0]
        gt /= 255.  
        
        # 将图像转换为浮点型数组,并减去均值 self.mean_bgr
        img = np.array(img, dtype=np.float32)
        img -= self.mean_bgr
        i_h, i_w, _ = img.shape  # 获取图像的高度、宽度和通道数

        # 根据设定的裁剪大小 crop_size 对图像进行裁剪或缩放
        crop_size = self.img_height if self.img_height == self.img_width else None  
        
        # 对于裁剪过程,它会在图像中随机选择一个位置来裁剪
        if i_w > crop_size and i_h > crop_size:
            i = random.randint(0, i_h - crop_size)
            j = random.randint(0, i_w - crop_size)
            img = img[i:i + crop_size, j:j + crop_size]
            gt = gt[i:i + crop_size, j:j + crop_size]
        else:  #  如果图像的尺寸小于 crop_size,则会使用双线性插值进行缩放
            # New addidings
            img = cv2.resize(img, dsize=(crop_size, crop_size))
            gt = cv2.resize(gt, dsize=(crop_size, crop_size))

        # 对标签gt进行一些额外的处理,然后将图像img和标签gt转换为PyTorch的张量形式
        gt[gt > 0.1] += 0.2  
        gt = np.clip(gt, 0., 1.)

        img = img.transpose((2, 0, 1))
        img = torch.from_numpy(img.copy()).float()
        gt = torch.from_numpy(np.array([gt])).float()
        return img, gt

在此处就定义完成了一个数据集类 BipedDataset

如何使用自定义的 BipedDataset 类来对数据进行加载呢?下面以加载验证集数据为例来进行说明

首先,对这个类进行实例化得到实例化后的数据集对象 dataset_val

dataset_val = BipedDataset(args.input_dir,
                                 img_width =args.img_width,
                                 img_height =args.img_height,
                                 mean_bgr =args.mean_pixel_values,
                                 train_mode ='test',
                                 arg =args
                                 )

其次,将该对象传入DataLoader中创建验证集数据加载器 dataloader_val

dataloader_val = DataLoader(dataset_val,
                                batch_size=1,
                                shuffle=False,
                                num_workers=args.workers)

然后,将数据集加载器 dataloader_val 作为参数传入进行验证过程的函数 validate_one_epoch 中

 val_precision,val_recall,val_IoU = validate_one_epoch(epoch,
                               dataloader_val,
                               model,
                               device,
                               img_test_dir,
                               arg=args)
def validate_one_epoch(epoch, dataloader, model, device, output_dir, arg=None):
    precision = 0.0
    recall = 0.0
    IoU = 0.0
    model.eval()  
    with torch.no_grad():
        for _, sample_batched in enumerate(dataloader):
   
            images = sample_batched['images'].to(device)
            labels = sample_batched['labels'].to(device)
            file_names = sample_batched['file_names']   
            preds = model(images)
            labels = normalize_image(labels)
            preds = normalize_image(preds)
            precision += calculate_precision(preds, labels)
            recall += calculate_recall(preds, labels)
            IoU += calculate_iou(preds, labels)
         
            save_image_batch_to_disk(preds, output_dir, file_names,arg=arg)

    precision = precision / len(dataloader)
    recall = recall / len(dataloader)
    IoU = IoU / len(dataloader)
    print(time.ctime(), '[Val_Epoch]: {0} Precision:{1}  Recall:{2}  IoU:{3} '.format(epoch, precision, recall, IoU))
    print(f"第{epoch}次迭代的验证精确度为{precision},验证召回率为{recall},验证交并比为{IoU}")
    return precision, recall, IoU

最后,我们可以看到将 dataloader_val验证集数据加载器 传入 函数validate_one_epoch 中,通过遍历 dataloader 中的数据,可以通过 自定义类BipedDataset 返回的包含三个元素的字典来获取图像数据、对应的标签数据以及图像文件名,如下图所示

 images = sample_batched['images'].to(device)
 labels = sample_batched['labels'].to(device)
 file_names = sample_batched['file_names']   
           

综上所述, 就是关于如何基于PyTorch深度学习框架自定义数据集来获取数据的详细步骤了,如果你觉得有用,麻烦点赞关注一下哈,谢谢!

文章来源:https://blog.csdn.net/Kelly_Ai_Bai/article/details/135305637
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。