MNIST数据集包含了6万张手写数字([1,28,28]尺寸),以特殊格式存储。本文首先将MNIST数据集另存为png格式,然后再读取png格式图片,开展后续训练
import torch
from torch.utils.data import Dataset
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision import models, transforms
from torchvision.utils import save_image
from PIL import Image
#将MNIST数据集转换为图片
tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1
datasetMNIST = MNIST("./data", train=True, download=True, transform=tf)
pbar = tqdm(datasetMNIST)
for index, (img,cl) in enumerate(pbar):
save_image(img, f"./data/MNIST_PNG/x/{index}.png")
# 以写入模式打开文件
with open(f"./data/MNIST_PNG/c/{index}.txt", "w", encoding="utf-8") as file:
# 将字符串写入文件
file.write(f"{cl}")
注意:MNIST源数据存放在./data
文件下,如果没有数据也没关系,代码会自动从网上下载。另存为png的数据放在了./data/MNIST_PNG/
文件下。子文件夹x
存放6万张图片,子文件夹c
存放6万个文本文件,每个文本文件内有一行字符串,说明该对应的手写数字是几(标签)。
class MyMNISTDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
x = self.data[idx][0] #图像
y = self.data[idx][1] #标签
return x, y
def load_data(dataNum=60000):
data = []
pbar = tqdm(range(dataNum))
for i in pbar:
# 指定图片路径
image_path = f'./data/MNIST_PNG/x/{i}.png'
cond_path=f'./data/MNIST_PNG/c/{i}.txt'
# 定义图像预处理
preprocess = transforms.Compose([
transforms.Grayscale(num_output_channels=1), # 将图像转换为灰度图像(单通道)
transforms.ToTensor()
])
# 使用预处理加载图像
image_tensor = preprocess(Image.open(image_path))
# 加载条件文档(tag)
with open(cond_path, 'r') as file:
line = file.readline()
number = int(line) # 将字符串转换为整数,图像的类别
data.append((image_tensor, number))
return data
data=load_data(60000)
# 创建数据集实例
dataset = MyMNISTDataset(data)
# 创建数据加载器
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
pbar = tqdm(dataloader)
for index, (img,cond) in enumerate(pbar):
#这里对每一批进行训练...
print(f"Batch {index}: img = {img.shape}, cond = {cond}")
load_data
函数用于读取数据文件,返回一个data张量。data张量又被用于构造MyMNISTDataset
类的对象dataset
,dataset
对象又被DataLoader函数转换为dataloader
。
dataloader
事实上按照batch将数据集进行了分割,4张图片一组进行训练。上述代码的输出如下:
......
Batch 7847: img = torch.Size([4, 1, 28, 28]), cond = tensor([0, 1, 5, 2])
Batch 7848: img = torch.Size([4, 1, 28, 28]), cond = tensor([2, 2, 6, 0])
Batch 7849: img = torch.Size([4, 1, 28, 28]), cond = tensor([4, 3, 0, 9])
Batch 7850: img = torch.Size([4, 1, 28, 28]), cond = tensor([6, 2, 9, 5])
Batch 7851: img = torch.Size([4, 1, 28, 28]), cond = tensor([7, 2, 4, 4])
Batch 7852: img = torch.Size([4, 1, 28, 28]), cond = tensor([1, 4, 2, 6])
Batch 7853: img = torch.Size([4, 1, 28, 28]), cond = tensor([2, 5, 3, 5])
Batch 7854: img = torch.Size([4, 1, 28, 28]), cond = tensor([7, 1, 0, 1])
Batch 7855: img = torch.Size([4, 1, 28, 28]), cond = tensor([9, 8, 9, 7])
Batch 7856: img = torch.Size([4, 1, 28, 28]), cond = tensor([4, 6, 6, 7])
Batch 7857: img = torch.Size([4, 1, 28, 28]), cond = tensor([7, 4, 1, 6])
Batch 7858: img = torch.Size([4, 1, 28, 28]), cond = tensor([5, 4, 6, 5])
Batch 7859: img = torch.Size([4, 1, 28, 28]), cond = tensor([6, 3, 1, 9])
Batch 7860: img = torch.Size([4, 1, 28, 28]), cond = tensor([5, 5, 8, 6])
Batch 7861: img = torch.Size([4, 1, 28, 28]), cond = tensor([0, 4, 8, 9])
Batch 7862: img = torch.Size([4, 1, 28, 28]), cond = tensor([2, 3, 5, 8])
Batch 7863: img = torch.Size([4, 1, 28, 28]), cond = tensor([8, 0, 0, 6])
......