PyTorch的Dataset是一个抽象类,用于表示数据集。它提供了一些通用的方法,如__len__()
和__getitem__(),
分别用于获取数据集的大小和获取指定索引的数据样本。用户可以通过继承Dataset类并实现这些方法来自定义自己的数据集。
Dataset类:
TensorDataset类:
?
此外,在torchvision库中,针对视觉处理,提供了继承自Dataset的VisionDataset类作为机器视觉数据集的基础类,目前实现了VisionDataset类的子类有74个数据集(比如CIFAR*, MNIST)。
?
?
数据集(Dataset)使用场景如下:
PyTorch的DataLoader是一个用于加载数据的工具,它可以将数据集分批次地加载到内存中,并支持多线程并行处理。使用DataLoader可以方便地实现小批量训练、分布式训练和数据增强等操作。
参数说明:
dataset
:要加载的数据集对象,必须是torch.utils.data.Dataset
的子类。batch_size
:每个批次的大小,默认为1。shuffle
:是否在每个epoch开始时打乱数据顺序,默认为False。sampler
:用于指定从数据集中抽取样本的策略,可以是torch.utils.data.Sampler
或其子类的对象。batch_sampler
:与sampler
类似,但是用于指定从数据集中抽取批次的策略,可以是torch.utils.data.BatchSampler
或其子类的对象。num_workers
:用于数据加载的子进程数,默认为0,表示不使用多进程加载数据。collate_fn
:用于将多个样本组合成一个批次的函数,默认为torch.utils.data.dataloader.default_collate
。pin_memory
:是否将数据存储在固定内存中,默认为False。drop_last
:如果为True,则丢弃最后一个不完整的批次,默认为False。timeout
:从工作进程中获取数据的超时时间,默认为0,表示无限等待。worker_init_fn
:用于初始化工作进程的函数,默认为None。multiprocessing_context
:用于指定多进程上下文的类型,默认为None。?
DataLoader使用场景:
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
# 定义神经网络模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(1, 64)
self.fc2 = nn.Linear(64, 64)
self.fc3 = nn.Linear(64, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
model = Net()
print('model:', model)
# 创建数据集
x = torch.randn(1000, 1) # 生成1000个样本,每个样本有1个特征
print('x.size:', x.size())
noise = torch.from_numpy(np.random.normal(0, 0.1, (1000, 1))).float() # 生成1000 x 1个数值在0~0.1之间噪音
y = 3 * x + 2 + noise # 生成1000个标签 + 其中包含了噪音
print('y.size:', y.size())
dataset = TensorDataset(x, y) # 将数据和标签封装成TensorDataset对象
print('dataset:', dataset)
# 创建数据加载器
batch_size = 100 # 每个批次的大小
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # 创建DataLoader对象,用于批量加载数据
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01) # 随机梯度下降,学习率为0.01
# 训练模型
num_epochs = 1000
for epoch in range(0,num_epochs):
# 遍历数据集
for i, (inputs, labels) in enumerate(dataloader):
# 前向传播
outputs = model(inputs)
# 计算损失
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch+1) % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.6f}'.format(epoch+1, num_epochs, i+1, len(dataloader), loss.item()))