Pytorch中Dataset和dadaloader的理解

发布时间:2024年01月24日

不同的数据集在形式上千差万别,为了能够统一用于模型的训练,Pytorch框架下定义了一个dataset类和一个dataloader类。

dataset用于获取数据集中的样本,dataloader 用于抽取部分样本用于训练。比如说一个用于分割任务的图像数据集的结构如图1所示,一个样本由原图像和对应的mask组成。

图1 典型数据集的结构

为了获取数据集,典型的代码如下

from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import os
from torchvision import transforms

# 定义数据集
train_data_dir = 'dataset/train'
train_GT_dir = 'dataset/train_GT'

class MyData(Dataset):
    def __init__(self, imgdir, maskdir,transform):
        self.imgdir = imgdir
        self.maskdir = maskdir
        self.transform = transform
        self.img_list = os.listdir(self.imgdir)
        self.mask_list= os.listdir(self.maskdir)
        self.img_list.sort()
        self.mask_list.sort()

    def __getitem__(self, idx):
        img_name = self.img_list[idx]
        mask_name =self.mask_list[idx]
        img_item_path = os.path.join(self.imgdir, img_name)
        mask_item_path =os.path.join(self.maskdir,mask_name)

        img =Image.open(img_item_path)
        mask =Image.open(mask_item_path)

        img = self.transform(img)
        mask = self.transform(mask)

        return img, mask

    def __len__(self):
        assert len(self.img_list) == len(self.mask_list)
        return len(self.img_list)

if __name__ == '__main__':
    transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
    train_data_dir = 'dataset/train'
    train_GT_dir = 'dataset/train_GT'
    dataset = MyData(train_data_dir, train_GT_dir ,transform)
    dataloader = DataLoader(dataset, batch_size=4, num_workers=0)
    for step, (img,mask) in enumerate(dataloader):
        print(step)
        print(img.shape)
        print(mask.shape)
        if step>0:
            break

程序运行的结果如下:

返回了一个batch的img 和mask 的尺寸,说明数据集抽取成功了.

在建立数据集的过程中需用重写__getitem()__和__len()__方法即可。

文章来源:https://blog.csdn.net/qq_42982824/article/details/135827012
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。