在训练时对图片数据进行归一化可以在梯度下降算法中更好的寻优,这是普遍认为的。那么PyTorch中的transforms.Normalize,究竟做了什么,这是应该知道的。
来看下面这个公式:x取自一组数据C, mean是这组数据的均值,而std则为标准差
x=(x-mean)/std
这也意味着,Normalize,简单来讲,就是按照此公式对输入数据进行更新,
来看这样一段代码:
import numpy as np
List1=np.array([1,2,3,4])
mean=np.mean(List1)
std=np.std(List1)
List2=(List1-mean)/std
>>> List1
array([1, 2, 3, 4])
>>> List2
array([-1.34164079, -0.4472136 , 0.4472136 , 1.34164079])
List1经过Normalize后变为List2
那么对于图片数据,Normalize具体是如何工作的呢?
假如我们有四张图片的数据,借用前面文章的数据导入方式,导入数据:
import os
from PIL import Image
import numpy as np
from torchvision import transforms
import torch
path="E:\\3-10\\dogandcats\\source"
IMG=[]
filenames=[name for name in os.listdir(path)]
for i,filename in enumerate(filenames):
img=Image.open(os.path.join(path,filename))
img=img.resize((28,28))#将图片像素改为28x28
img=np.array(img)#将图像数据转为numpy
img=torch.tensor(img)#将numpy转换为tensor张量
img=img.permute(2,0,1)#将H,W,C转换为C,H,W
IMG.append(img)#得到图片列表
IMGEND=torch.stack([ig for ig in IMG],dim=0)#堆叠tensor
???????
>>> IMGEND.size()
torch.Size([4, 3, 28, 28])
四张图片数据已经成功导入,并且已经转换为张量
获得r,g,b三个通道的均值??????
>>> mean=torch.mean(IMGEND,dim=(0,2,3),keepdim=True)
>>> mean
tensor([[[[160.8753]],
[[149.3600]],
[[126.5810]]]])
获得r,g,b三个通道的标准差:???????
>>> std=torch.std(IMGEND,dim=(0,2,3),keepdim=True)
>>> std
tensor([[[[61.7317]],
[[65.0915]],
[[84.2025]]]])
归一化:
process=transforms.Normalize([160.8753, 149.3600, 126.5810],[61.7317, 65.0915, 84.2025])
>>> dataend1=process(IMGEND)
>>> dataend1
tensor([[[[-1.3587, -0.9213, -0.7269, ..., -0.3382, -0.3868, -0.4516],
[-1.4397, -0.8727, -0.6135, ..., -0.1114, -0.1762, -0.2248],
[-1.8771, -1.3587, -0.9375, ..., 0.1640, 0.0830, -0.1438],
...,
[-1.9095, -1.8285, -1.8123, ..., -2.1687, -2.2497, -2.2173],
[-1.9419, -1.8609, -1.8123, ..., -2.3469, -2.4117, -2.2983],
[-1.9257, -1.8447, -1.8447, ..., -2.3307, -2.3307, -2.2821]],
[[-1.0502, -0.4357, -0.0055, ..., 0.4246, 0.3785, 0.3325],
[-1.1424, -0.4203, 0.0406, ..., 0.5783, 0.5475, 0.5168],
[-1.6340, -1.0656, -0.4664, ..., 0.7626, 0.7319, 0.5629],
...,
[-1.6340, -1.5572, -1.5418, ..., -1.6955, -1.7723, -1.7876],
[-1.6801, -1.5879, -1.5418, ..., -1.9413, -2.0027, -1.8491],
[-1.6186, -1.5726, -1.5726, ..., -1.9259, -1.8952, -1.8645]],
[[-0.4938, 0.0881, 0.5988, ..., 1.0026, 0.9788, 0.9313],
[-0.5888, 0.0762, 0.6107, ..., 1.0738, 1.0501, 1.0382],
[-1.0758, -0.5532, 0.1000, ..., 1.1807, 1.1570, 1.0501],
...,
[-0.9926, -0.9332, -0.9451, ..., -1.3845, -1.4083, -1.3845],
[-1.0401, -0.9689, -0.9570, ..., -1.4320, -1.4439, -1.3370],
[-0.9926, -0.9451, -0.9570, ..., -1.4320, -1.4558, -1.3964]]],
[[[-1.6827, -1.8609, -1.9095, ..., -0.4192, -0.4840, -0.5002],
[-1.6989, -1.8285, -1.8933, ..., -0.3868, -0.4678, -0.4516],
[-1.6989, -1.7961, -2.0877, ..., -0.3868, -0.4192, -0.4516],
...,
[ 0.7634, 0.8606, 0.8768, ..., 0.9254, 0.9092, 0.9092],
[ 0.8120, 0.8930, 0.8930, ..., 0.9416, 0.8930, 0.8930],
[ 0.8282, 0.9092, 0.9254, ..., 0.9254, 0.8930, 0.8930]],
[[-1.9413, -2.0334, -1.9720, ..., -1.6340, -1.6340, -1.6340],
[-1.9413, -2.0181, -1.9720, ..., -1.5879, -1.5572, -1.5572],
[-1.9413, -1.9874, -2.0488, ..., -1.5726, -1.5265, -1.5265],
...,
[ 0.5936, 0.7473, 0.7473, ..., 0.8702, 0.8394, 0.8241],
[ 0.6397, 0.7780, 0.7780, ..., 0.8702, 0.7780, 0.8087],
[ 0.7319, 0.8241, 0.8241, ..., 0.8394, 0.8087, 0.7933]],
[[-1.3608, -1.3845, -1.3370, ..., -1.2539, -1.2301, -1.2301],
[-1.3608, -1.3845, -1.3252, ..., -1.2183, -1.2064, -1.2064],
[-1.3608, -1.3727, -1.3964, ..., -1.2064, -1.1826, -1.1826],
...,
[ 0.5988, 0.7532, 0.7532, ..., 0.8719, 0.8363, 0.8363],
[ 0.6700, 0.7888, 0.8007, ..., 0.8719, 0.7650, 0.8126],
[ 0.7532, 0.8244, 0.8363, ..., 0.8482, 0.8126, 0.8007]]],
[[[ 0.6986, 0.8282, 0.7796, ..., 0.1640, 0.0830, 0.1316],
[ 0.3908, 0.5204, 0.5852, ..., 0.1964, 0.2774, 0.2126],
[ 0.4070, 0.4880, 0.6014, ..., 0.0182, 0.3746, 0.2612],
...,
[-0.3706, -0.6135, -0.4030, ..., -0.2248, -0.2572, -0.2086],
[-0.4516, -0.6783, -1.0185, ..., -0.3220, -0.3868, -0.4030],
[-0.5973, -0.5973, -1.0347, ..., -0.3868, -0.4678, -0.5649]],
[[ 0.6551, 0.7780, 0.6551, ..., -0.2360, -0.2513, 0.1020],
[ 0.2249, 0.3478, 0.3939, ..., -0.1899, -0.1438, 0.0252],
[ 0.2096, 0.2864, 0.3785, ..., -0.3282, -0.0363, -0.0055],
...,
[-0.1592, -0.5586, -0.6661, ..., -0.0055, -0.0363, -0.0055],
[-0.2360, -0.5740, -1.1424, ..., -0.0977, -0.1284, -0.1438],
[-0.3896, -0.4203, -0.9888, ..., -0.1899, -0.2206, -0.2974]],
[[-0.2088, -0.0782, -0.2919, ..., -0.5770, -0.5413, 0.0050],
[-0.7670, -0.6720, -0.6720, ..., -0.6363, -0.5532, -0.1257],
[-0.8145, -0.7432, -0.7195, ..., -0.6720, -0.5770, -0.2919],
...,
[-1.4202, -1.3845, -1.0282, ..., -1.0282, -1.0045, -0.9689],
[-1.4320, -1.4202, -1.2658, ..., -1.0758, -1.0401, -1.0282],
[-1.4320, -1.4202, -1.4202, ..., -1.0758, -1.0995, -1.0995]]],
[[[ 0.7958, 0.7958, 0.8120, ..., 0.7958, 0.7958, 0.7958],
[ 0.7958, 0.8120, 0.8120, ..., 0.8120, 0.8120, 0.7958],
[ 0.8120, 0.8120, 0.8120, ..., 0.8120, 0.8120, 0.7958],
...,
[ 0.7958, 0.7958, 0.7958, ..., 0.8120, 0.7958, 0.7796],
[ 0.8444, 0.8444, 0.8606, ..., 0.8930, 0.8930, 0.8768],
[ 0.8606, 0.8606, 0.8606, ..., 0.8930, 0.8930, 0.8930]],
[[ 0.9623, 0.9623, 0.9777, ..., 0.9623, 0.9623, 0.9623],
[ 0.9777, 0.9777, 0.9777, ..., 0.9777, 0.9777, 0.9623],
[ 0.9777, 0.9777, 0.9777, ..., 0.9777, 0.9777, 0.9623],
...,
[ 0.9623, 0.9623, 0.9623, ..., 0.9623, 0.9623, 0.9470],
[ 1.0084, 1.0084, 1.0238, ..., 1.0545, 1.0545, 1.0392],
[ 1.0238, 1.0238, 1.0238, ..., 1.0545, 1.0545, 1.0545]],
[[ 1.2638, 1.2638, 1.2757, ..., 1.2638, 1.2638, 1.2638],
[ 1.2638, 1.2757, 1.2876, ..., 1.2757, 1.2757, 1.2638],
[ 1.2995, 1.2995, 1.2995, ..., 1.2757, 1.2757, 1.2638],
...,
[ 1.2876, 1.2876, 1.2757, ..., 1.2638, 1.2638, 1.2638],
[ 1.3232, 1.3232, 1.3114, ..., 1.3351, 1.3351, 1.3232],
[ 1.3351, 1.3351, 1.3114, ..., 1.3351, 1.3351, 1.3351]]]])
现在按变换公式编程进行计算:???????
>>> enddata=(IMGEND-mean)/std
>>> enddata
tensor([[[[-1.3587, -0.9213, -0.7269, ..., -0.3382, -0.3868, -0.4516],
[-1.4397, -0.8727, -0.6135, ..., -0.1114, -0.1762, -0.2248],
[-1.8771, -1.3587, -0.9375, ..., 0.1640, 0.0830, -0.1438],
...,
[-1.9095, -1.8285, -1.8123, ..., -2.1687, -2.2497, -2.2173],
[-1.9419, -1.8609, -1.8123, ..., -2.3469, -2.4117, -2.2983],
[-1.9257, -1.8447, -1.8447, ..., -2.3307, -2.3307, -2.2821]],
[[-1.0502, -0.4357, -0.0055, ..., 0.4246, 0.3785, 0.3325],
[-1.1424, -0.4203, 0.0406, ..., 0.5783, 0.5475, 0.5168],
[-1.6340, -1.0656, -0.4664, ..., 0.7626, 0.7319, 0.5629],
...,
[-1.6340, -1.5572, -1.5418, ..., -1.6955, -1.7723, -1.7876],
[-1.6801, -1.5879, -1.5418, ..., -1.9413, -2.0027, -1.8491],
[-1.6186, -1.5726, -1.5726, ..., -1.9259, -1.8952, -1.8645]],
[[-0.4938, 0.0881, 0.5988, ..., 1.0026, 0.9788, 0.9313],
[-0.5888, 0.0762, 0.6107, ..., 1.0738, 1.0501, 1.0382],
[-1.0758, -0.5532, 0.1000, ..., 1.1807, 1.1570, 1.0501],
...,
[-0.9926, -0.9332, -0.9451, ..., -1.3845, -1.4083, -1.3845],
[-1.0401, -0.9689, -0.9570, ..., -1.4320, -1.4439, -1.3370],
[-0.9926, -0.9451, -0.9570, ..., -1.4320, -1.4558, -1.3964]]],
[[[-1.6827, -1.8609, -1.9095, ..., -0.4192, -0.4840, -0.5002],
[-1.6989, -1.8285, -1.8933, ..., -0.3868, -0.4678, -0.4516],
[-1.6989, -1.7961, -2.0877, ..., -0.3868, -0.4192, -0.4516],
...,
[ 0.7634, 0.8606, 0.8768, ..., 0.9254, 0.9092, 0.9092],
[ 0.8120, 0.8930, 0.8930, ..., 0.9416, 0.8930, 0.8930],
[ 0.8282, 0.9092, 0.9254, ..., 0.9254, 0.8930, 0.8930]],
[[-1.9413, -2.0334, -1.9720, ..., -1.6340, -1.6340, -1.6340],
[-1.9413, -2.0181, -1.9720, ..., -1.5879, -1.5572, -1.5572],
[-1.9413, -1.9874, -2.0488, ..., -1.5726, -1.5265, -1.5265],
...,
[ 0.5936, 0.7473, 0.7473, ..., 0.8702, 0.8394, 0.8241],
[ 0.6397, 0.7780, 0.7780, ..., 0.8702, 0.7780, 0.8087],
[ 0.7319, 0.8241, 0.8241, ..., 0.8394, 0.8087, 0.7933]],
[[-1.3608, -1.3845, -1.3370, ..., -1.2539, -1.2301, -1.2301],
[-1.3608, -1.3845, -1.3252, ..., -1.2183, -1.2064, -1.2064],
[-1.3608, -1.3727, -1.3964, ..., -1.2064, -1.1826, -1.1826],
...,
[ 0.5988, 0.7532, 0.7532, ..., 0.8719, 0.8363, 0.8363],
[ 0.6700, 0.7888, 0.8007, ..., 0.8719, 0.7650, 0.8126],
[ 0.7532, 0.8244, 0.8363, ..., 0.8482, 0.8126, 0.8007]]],
[[[ 0.6986, 0.8282, 0.7796, ..., 0.1640, 0.0830, 0.1316],
[ 0.3908, 0.5204, 0.5852, ..., 0.1964, 0.2774, 0.2126],
[ 0.4070, 0.4880, 0.6014, ..., 0.0182, 0.3746, 0.2612],
...,
[-0.3706, -0.6135, -0.4030, ..., -0.2248, -0.2572, -0.2086],
[-0.4516, -0.6783, -1.0185, ..., -0.3220, -0.3868, -0.4030],
[-0.5973, -0.5973, -1.0347, ..., -0.3868, -0.4678, -0.5650]],
[[ 0.6551, 0.7780, 0.6551, ..., -0.2360, -0.2513, 0.1020],
[ 0.2249, 0.3478, 0.3939, ..., -0.1899, -0.1438, 0.0252],
[ 0.2096, 0.2864, 0.3785, ..., -0.3282, -0.0363, -0.0055],
...,
[-0.1592, -0.5586, -0.6661, ..., -0.0055, -0.0363, -0.0055],
[-0.2360, -0.5740, -1.1424, ..., -0.0977, -0.1284, -0.1438],
[-0.3896, -0.4203, -0.9888, ..., -0.1899, -0.2206, -0.2974]],
[[-0.2088, -0.0782, -0.2919, ..., -0.5770, -0.5413, 0.0050],
[-0.7670, -0.6720, -0.6720, ..., -0.6363, -0.5532, -0.1257],
[-0.8145, -0.7432, -0.7195, ..., -0.6720, -0.5770, -0.2919],
...,
[-1.4202, -1.3845, -1.0282, ..., -1.0282, -1.0045, -0.9689],
[-1.4320, -1.4202, -1.2658, ..., -1.0758, -1.0401, -1.0282],
[-1.4320, -1.4202, -1.4202, ..., -1.0758, -1.0995, -1.0995]]],
[[[ 0.7958, 0.7958, 0.8120, ..., 0.7958, 0.7958, 0.7958],
[ 0.7958, 0.8120, 0.8120, ..., 0.8120, 0.8120, 0.7958],
[ 0.8120, 0.8120, 0.8120, ..., 0.8120, 0.8120, 0.7958],
...,
[ 0.7958, 0.7958, 0.7958, ..., 0.8120, 0.7958, 0.7796],
[ 0.8444, 0.8444, 0.8606, ..., 0.8930, 0.8930, 0.8768],
[ 0.8606, 0.8606, 0.8606, ..., 0.8930, 0.8930, 0.8930]],
[[ 0.9623, 0.9623, 0.9777, ..., 0.9623, 0.9623, 0.9623],
[ 0.9777, 0.9777, 0.9777, ..., 0.9777, 0.9777, 0.9623],
[ 0.9777, 0.9777, 0.9777, ..., 0.9777, 0.9777, 0.9623],
...,
[ 0.9623, 0.9623, 0.9623, ..., 0.9623, 0.9623, 0.9470],
[ 1.0084, 1.0084, 1.0238, ..., 1.0545, 1.0545, 1.0392],
[ 1.0238, 1.0238, 1.0238, ..., 1.0545, 1.0545, 1.0545]],
[[ 1.2638, 1.2638, 1.2757, ..., 1.2638, 1.2638, 1.2638],
[ 1.2638, 1.2757, 1.2876, ..., 1.2757, 1.2757, 1.2638],
[ 1.2995, 1.2995, 1.2995, ..., 1.2757, 1.2757, 1.2638],
...,
[ 1.2876, 1.2876, 1.2757, ..., 1.2638, 1.2638, 1.2638],
[ 1.3232, 1.3232, 1.3114, ..., 1.3351, 1.3351, 1.3232],
[ 1.3351, 1.3351, 1.3114, ..., 1.3351, 1.3351, 1.3351]]]])
显然,两次结果一样,这也说明transforms.Normalize的实质就是使用该公式对输入数据进行变换。
同时,当transforms.Normalize接受的均值和标准差为待变换数据的均值和标准差时,按照此公式变换,得到的新的数据服从的分布一定是均值为0,标准差为1的分布。
而当transforms.Normalize接受的均值和标准差不是待变换数据的均值和标准差时,所得的新数据均值未必为0,标准差也未必为1,仅仅是按照公式变换了数据而已。
就像这样:???????
>>> process=transforms.Normalize([0.5, 0.6, 0.4],[0.36, 0.45, 0.45])
>>> data=process(inputdata)
这里[0.5, 0.6, 0.4],[0.36, 0.45, 0.45]并不是inputdata的均值和标准差,是随意给的,仅仅是想对原数据进行变换,那么得到的新数据均值自然不一定为0,标准差也不一定为1。
当然,在我们对图片进行预处理时,往往会看到这两行代码一起出现:???????
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
)
这里的transforms.ToTensor()的作用就是
将输入的数据变为张量,同时shape由 W,H,C ——> C,W,H,?同时,将所有数除以255,将数据归一化到[0,1]。
根据公式:x=(x-mean)/std
得:
(0-0.5)/0.5=-1
(1-0.5)/0.5=1
可以发现:新的数据分布为[-1,1],但是新的数据均值未必为0,同时标准差也未必为0,这点需要明白。
之所以这样,是因为这里的[0.5,0.5,0.5],[0.5,0.5,0.5]并不一定就是原数据的均值和标准差。