在深度学习中,需要反复对数据集进行卷积操作,甚至可以说,数据集是深度学习的唯一操作对象,为此,Pytorch为数据集提供一个专门的抽象类,显然是合情合理的。
静态直观地看,数据集就是数据的集合;而动态抽向地看,数据集则是一种索引方式。torch.utils.data.Dataset就是这样一个类,用以构建索引到样本的映射。
在Python中,索引方式有两种,一种是迭代式,另一种则是字典式,这两种方式在Pytorch中均有实现,但实际使用中,后者更加灵活,也更加常见,下面具体实现一个。
众所周知,字典索引的关键是魔法函数__getitem__,故而字典式数据集务必重载这个魔法函数。
下面就实现一个字典式数据集
from torch.utils.data import Dataset
from PIL import Image
import os
import numpy as np
class ImgData(Dataset):
def __init__(self, path):
self.path = path
def __getitem__(self, index):
path = os.path.join(self.path, f"{index}.png")
img = Image.open(path)
return np.array(img)
def __len__(self):
return len(os.listdir(self.path))
下面测试一下
path = 'image' # 此为一个图像的目录
imgData = ImgData(path)
x = imgData[0]
print(x.shape)
# (512, 512)
可见,索引到一个 512 × 512 512\times512 512×512的文件。
这个案例虽然短小,但却表明了Dataset的必要性,当数据集非常大的时候,没办法将所有的数据写在内存中,故而需要重新实现__getitem__,而不是将所有数据写入字典并索引。
在Pytorch中,Dataset做好的数据集,将提交给DataLoader,并迭代产生训练数据,最终提交给具体的模型。
一般来说,DataLoader并不需要继承,其参数如下,其中dataset即刚刚讲过的Dataset类型。
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
其他参数说明如下
下面以以刚刚建立的imgData为例,来演示DataLoader的用法
from torch.utils.data import DataLoader
imgLoader = DataLoader(imgData, 4)
imgL = iter(imgLoader)
im = next(imgL)
print(im.shape)
# torch.Size([4, 512, 512])
从上面的示例可知,DataLoader内置了迭代器魔法函数,且每次迭代,返回的数据维度是 4 × 512 × 512 4\times512\times512 4×512×512,后面的 512 × 512 512\times512 512×512正是图像的尺寸,而前面的 4 4 4则是batch_size。
在实际使用中,DataLoader往往被用于for循环中,例如下面这样
for im, label in data_loader:
pass