?数据预处理
下列代码为train.py中常见的一些数据处理方法
train_transform = transforms.Compose([
? ? transforms.Resize((224, 224)),
? ? transforms.RandomVerticalFlip(),
? ? # 随机旋转,-45度到45度之间随机选
? ? transforms.RandomRotation(45),
? ? # 从中心开始裁剪
? ? transforms.CenterCrop(224),
? ? # 随机水平翻转 选择概率值为 p=0.5
? ? transforms.RandomHorizontalFlip(p=0.5),
? ? # 随机垂直翻转
? ? transforms.RandomVerticalFlip(p=0.5),
? ? # 参数:亮度、对比度、饱和度、色相
? ? transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),
? ? # 转为3通道灰度图 R=G=B 概率设定0.025
? ? transforms.RandomGrayscale(p=0.025),
? ? transforms.ToTensor(),
? ? transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
1.transforms.Resize
在下述示例中,我们首先使用PIL库的Image.open方法读取了一张图片。然后,我们使用transforms.Resize(500)定义了一个resize操作,将图片的短边缩放至500像素,同时保持长宽比不变。最后,我们将resize操作应用到图片上,得到了resize后的图片resized_image。最后一步是使用show方法显示resize后的图片。
from PIL import Image
from torchvision import transforms
# 读取图片
image = Image.open('image.jpg')
# 定义transforms.Resize操作
resize = transforms.Resize(500)
# 对图片进行resize操作
resized_image = resize(image)
# 显示resize后的图片
resized_image.show()
将resize = transforms.Resize(500) 改为?resize = transforms.Resize((500,500))
?
2.?transforms.RandomVerticalFlip
transforms.RandomVerticalFlip
是一个PyTorch中的数据预处理方法,用于垂直翻转图像。它可以根据给定的概率p,以p的概率对图像进行垂直翻转,以1-p的概率保持原始图像不变。
以下是一个使用transforms.RandomVerticalFlip
的示例代码:
import torch
from torchvision import transforms
from PIL import Image
# 加载图像
image = Image.open('image.jpg')
# 定义数据预处理方法
transform = transforms.Compose([
transforms.RandomVerticalFlip(p=0.5)
])
# 对图像进行数据预处理
transformed_image = transform(image)
# 显示预处理后的图像
transformed_image.show()
在上面的示例中,我们首先加载了一张图像,然后定义了一个transforms.Compose
对象,其中包含了transforms.RandomVerticalFlip
方法。接下来,我们将图像传递给transform
对象,它会根据给定的概率p对图像进行垂直翻转。最后,我们显示了预处理后的图像。?
3.?transforms.RandomHorizontalFlip?
transforms.RandomHorizontalFlip是torchvision.transforms中的一个类,用于对图像进行随机水平翻转的操作。它可以将图像水平翻转,即左右翻转。这个操作可以增加数据的多样性,提高模型的泛化能力。
下面是一个使用transforms.RandomHorizontalFlip的示例代码:
import torchvision.transforms as transforms
from PIL import Image
# 创建一个RandomHorizontalFlip对象
transform = transforms.RandomHorizontalFlip(p=0.5)
# 加载图像
image = Image.open('image.jpg')
# 对图像进行水平翻转
flipped_image = transform(image)
# 显示原始图像和翻转后的图像
image.show()
flipped_image.show()
在上面的代码中,我们首先导入了transforms模块和Image模块。然后,我们创建了一个RandomHorizontalFlip对象,并设置了翻转的概率为0.5。接下来,我们加载了一张图像,并使用transform对图像进行水平翻转操作。最后,我们分别显示了原始图像和翻转后的图像。?
4.?transforms.RandomRotation?随机旋转
transforms.RandomRotation是PyTorch中的一个图像变换操作,用于对图像进行随机旋转。它可以将图像按照一定的角度范围进行随机旋转,增加数据的多样性和鲁棒性。
以下是transforms.RandomRotation的使用示例:
import torchvision.transforms as transforms
from PIL import Image
# 创建一个RandomRotation对象,设置旋转角度范围为±30度
random_rotation = transforms.RandomRotation(30)
# 加载图像
image = Image.open('image.jpg')
# 对图像进行随机旋转
rotated_image = random_rotation(image)
# 显示旋转后的图像
rotated_image.show()
在上述示例中,我们首先导入了transforms模块和Image类。然后,我们创建了一个RandomRotation对象,并设置旋转角度范围为±30度。接下来,我们加载了一张图像,并使用random_rotation对图像进行随机旋转。最后,我们显示了旋转后的图像。
5.?transforms.CenterCrop 中心裁剪
transforms.CenterCrop
是PyTorch中的一个图像变换函数,用于对图像进行中心裁剪。它可以根据给定的尺寸对图像进行裁剪,并将图像的中心部分保留下来。
以下是一个使用transforms.CenterCrop
的示例代码:
import torchvision.transforms as transforms
from PIL import Image
# 加载图像
image = Image.open('image.jpg')
# 定义裁剪尺寸
crop_size = 224
# 创建CenterCrop变换对象
center_crop = transforms.CenterCrop(crop_size)
# 对图像进行中心裁剪
cropped_image = center_crop(image)
# 显示裁剪后的图像
cropped_image.show()
在上面的示例中,我们首先导入了transforms
模块和Image
类。然后,我们加载了一张图像,并定义了裁剪尺寸为224。接下来,我们创建了一个CenterCrop
变换对象,并将裁剪尺寸作为参数传递给它。最后,我们使用center_crop
对象对图像进行中心裁剪,并显示裁剪后的图像。
6.? ?transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1)
transforms.ColorJitter
是PyTorch中的一个图像变换类,它可以改变图像的亮度、对比度、饱和度和色调等属性。在你提供的例子中,brightness=0.2
表示将图像的亮度随机变化为原图亮度的80%(1-0.2)到120%(1+0.2)之间。同样地,contrast=0.1
、saturation=0.1
和hue=0.1
分别表示对比度、饱和度和色调的变化范围。
以下是一个示例代码,展示了如何使用transforms.ColorJitter
来改变图像的亮度、对比度、饱和度和色调属性:
import torch
from torchvision import transforms
# 创建一个ColorJitter对象,设置亮度、对比度、饱和度和色调的变化范围
jitter = transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1)
# 加载图像
image = Image.open('image.jpg')
# 对图像进行变换
transformed_image = jitter(image)
# 显示变换后的图像
transformed_image.show()
这段代码会加载一张名为image.jpg
的图像,并使用transforms.ColorJitter
对图像进行亮度、对比度、饱和度和色调的变换。最后,显示变换后的图像。
可以明显看到这些人的脸更红了
7.transforms.RandomGrayscale
transforms.RandomGrayscale是一个用于随机将图像转换为灰度图像的操作。它可以在图像上随机选择一些像素,并将它们转换为灰度值,而其他像素保持不变。
下面是一个使用transforms.RandomGrayscale的示例代码:
import torchvision.transforms as transforms
from PIL import Image
# 加载图像
image = Image.open("image.jpg")
# 定义transforms
transform = transforms.Compose([
transforms.RandomGrayscale(p=0.5),
])
# 对图像进行转换
transformed_image = transform(image)
# 显示转换后的图像
transformed_image.show()
在上面的示例中,我们首先加载了一张图像,然后定义了一个transforms.Compose对象,其中包含了transforms.RandomGrayscale操作和transforms.ToTensor操作。然后,我们将图像应用于这个transforms对象,得到了转换后的图像。最后,我们使用show()方法显示了转换后的图像。
8.transforms.ToTensor()
transforms.ToTensor()是PyTorch中的一个图像转换函数,它将PIL图像或NumPy数组转换为张量(Tensor)。这个函数的作用是将图像数据从范围[0, 255]转换为范围[0.0, 1.0]的浮点数张量,并且将通道顺序从H×W×C转换为C×H×W。
下面是一个使用transforms.ToTensor()的示例:
import torch
from torchvision import transforms
# 假设有一张PIL图像img
img = Image.open('image.jpg')
# 创建一个transforms对象,将图像转换为张量
transform = transforms.ToTensor()
# 使用transforms对象对图像进行转换
tensor_img = transform(img)
print('tensor_img',tensor_img.shape)
print(tensor_img)
输出结果将是一个形状为[C, H, W]的张量,其中C是通道数,H是图像的高度,W是图像的宽度。
tensor_img torch.Size([3, 375, 500])
tensor([[[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
...,
[0.5333, 0.4431, 0.4667, ..., 0.4392, 0.4431, 0.4627],
[0.4510, 0.6275, 0.4549, ..., 0.4510, 0.4196, 0.4235],
[0.5804, 0.4196, 0.3961, ..., 0.4588, 0.4275, 0.4157]],
[[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
...,
[0.4941, 0.4078, 0.4392, ..., 0.4392, 0.4314, 0.4510],
[0.4000, 0.5882, 0.4275, ..., 0.4392, 0.4039, 0.4078],
[0.5176, 0.3804, 0.3843, ..., 0.4471, 0.4118, 0.3882]],
[[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
...,
[0.4471, 0.3725, 0.4078, ..., 0.4078, 0.3961, 0.4157],
[0.3373, 0.5490, 0.3961, ..., 0.4039, 0.3686, 0.3725],
[0.4549, 0.3412, 0.3569, ..., 0.4196, 0.3765, 0.3569]]])
9.transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])是PyTorch中的一个数据预处理操作,用于将图像数据进行归一化处理。具体来说,它将输入的图像数据减去均值[0.5, 0.5, 0.5],然后再除以标准差[0.5, 0.5, 0.5],从而使得处理后的图像数据的均值为0,标准差为1。
这个操作的目的是为了使得模型更容易收敛,因为经过归一化处理后的数据符合标准正态分布。在归一化之前,图像数据的像素值通常是在[0, 1]的范围内,归一化之后,像素值会变成在[-1, 1]的范围内。
需要注意的是,如果使用的是transforms.Normalize(channel_mean, channel_std),其中channel_mean和channel_std是根据数据计算得到的均值和标准差,那么归一化之后的数据的均值会变成0,标准差会变成1。
以下是一个示例代码,演示了如何使用transforms.Normalize进行图像数据的归一化处理:
import torch
import torchvision.transforms as transforms
# 假设img是一个图像数据
img = ...
# 定义归一化操作
normalize = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
# 对图像数据进行归一化处理
normalized_img = normalize(img)
# 打印归一化后的图像数据
print(normalized_img)
输出
normalized_img torch.Size([3, 375, 500])
tensor([[[ 1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
[ 1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
[ 1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
...,
[ 0.0667, -0.1137, -0.0667, ..., -0.1216, -0.1137, -0.0745],
[-0.0980, 0.2549, -0.0902, ..., -0.0980, -0.1608, -0.1529],
[ 0.1608, -0.1608, -0.2078, ..., -0.0824, -0.1451, -0.1686]],
[[ 1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
[ 1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
[ 1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
...,
[-0.0118, -0.1843, -0.1216, ..., -0.1216, -0.1373, -0.0980],
[-0.2000, 0.1765, -0.1451, ..., -0.1216, -0.1922, -0.1843],
[ 0.0353, -0.2392, -0.2314, ..., -0.1059, -0.1765, -0.2235]],
[[ 1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
[ 1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
[ 1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
...,
[-0.1059, -0.2549, -0.1843, ..., -0.1843, -0.2078, -0.1686],
[-0.3255, 0.0980, -0.2078, ..., -0.1922, -0.2627, -0.2549],
[-0.0902, -0.3176, -0.2863, ..., -0.1608, -0.2471, -0.2863]]])