Pytorch数据集对象:Dataset和DataLoader

发布时间:2023年12月27日

简介

在深度学习中,需要反复对数据集进行卷积操作,甚至可以说,数据集是深度学习的唯一操作对象,为此,Pytorch为数据集提供一个专门的抽象类,显然是合情合理的。

静态直观地看,数据集就是数据的集合;而动态抽向地看,数据集则是一种索引方式。torch.utils.data.Dataset就是这样一个类,用以构建索引到样本的映射。

在Python中,索引方式有两种,一种是迭代式,另一种则是字典式,这两种方式在Pytorch中均有实现,但实际使用中,后者更加灵活,也更加常见,下面具体实现一个。

字典式数据集

众所周知,字典索引的关键是魔法函数__getitem__,故而字典式数据集务必重载这个魔法函数。

下面就实现一个字典式数据集

from torch.utils.data import Dataset
from PIL import Image
import os
import numpy as np
class ImgData(Dataset):
    def __init__(self, path):
        self.path = path

    def __getitem__(self, index):
        path = os.path.join(self.path, f"{index}.png")
        img = Image.open(path)
        return np.array(img)

    def __len__(self):
        return len(os.listdir(self.path))

下面测试一下

path = 'image'  # 此为一个图像的目录
imgData = ImgData(path)
x = imgData[0]
print(x.shape)
# (512, 512)

可见,索引到一个 512 × 512 512\times512 512×512的文件。

这个案例虽然短小,但却表明了Dataset的必要性,当数据集非常大的时候,没办法将所有的数据写在内存中,故而需要重新实现__getitem__,而不是将所有数据写入字典并索引。

DataLoader

在Pytorch中,Dataset做好的数据集,将提交给DataLoader,并迭代产生训练数据,最终提交给具体的模型。

一般来说,DataLoader并不需要继承,其参数如下,其中dataset即刚刚讲过的Dataset类型。

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

其他参数说明如下

  • batch_size 单个batch所包含的样本数。
  • shuffle 布尔类型,每一个epoch的batch样本是相同还是随机。
  • sampler和batch_sampler 均为Sampler类型,由于本文主要介绍Pytorch数据集的工作原理,故而暂且不表。
  • num_workers 进程数。
  • collate_fn 回调函数,合并样本列表以形成小批量。
  • pin_memory 布尔类型,如果为True,数据加载器在返回前将张量复制到CUDA固定内存中。
  • drop_last 布尔类型,当数据集大小不能被batch_size整除时发挥作用。为True,则删除最后一个不完整的批处理;为False,则减小最后一个batch的尺寸。
  • timeout 为正数,表明等待从worker进程中收集一个batch等待的时间,若超时,就不收集这个内容了。
  • worker_init_fn 回调函数,表示每个worker初始化函数

下面以以刚刚建立的imgData为例,来演示DataLoader的用法

from torch.utils.data import DataLoader
imgLoader = DataLoader(imgData, 4)
imgL = iter(imgLoader)
im = next(imgL)
print(im.shape)
# torch.Size([4, 512, 512])

从上面的示例可知,DataLoader内置了迭代器魔法函数,且每次迭代,返回的数据维度是 4 × 512 × 512 4\times512\times512 4×512×512,后面的 512 × 512 512\times512 512×512正是图像的尺寸,而前面的 4 4 4则是batch_size。

在实际使用中,DataLoader往往被用于for循环中,例如下面这样

for im, label in data_loader:
    pass
文章来源:https://blog.csdn.net/m0_37816922/article/details/135119591
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。